diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 24fdce9d59..8ebfd82222 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -14,7 +14,10 @@ from __future__ import annotations +import asyncio from typing import Any +from typing import Awaitable +from typing import Callable from typing import Optional import uuid @@ -213,6 +216,19 @@ class InvocationContext(BaseModel): of this invocation. """ + _tool_call_cache_lock: asyncio.Lock = PrivateAttr( + default_factory=asyncio.Lock + ) + _tool_call_cache: dict[tuple[Any, ...], asyncio.Task] = PrivateAttr( + default_factory=dict + ) + """Caches tool call results within a single invocation. + + This is used to prevent redundant tool execution when the model repeats the + same function call (same tool name + same args) multiple times during a single + invocation. + """ + @property def is_resumable(self) -> bool: """Returns whether the current invocation is resumable.""" @@ -221,6 +237,76 @@ def is_resumable(self) -> bool: and self.resumability_config.is_resumable ) + @staticmethod + def _canonicalize_tool_args(value: Any) -> Any: + """Converts a JSON-like structure into a stable, hashable representation.""" + if isinstance(value, dict): + return tuple( + (k, InvocationContext._canonicalize_tool_args(v)) + for k, v in sorted(value.items()) + ) + if isinstance(value, list): + return tuple(InvocationContext._canonicalize_tool_args(v) for v in value) + if isinstance(value, (str, int, float, bool)) or value is None: + return value + # Fallback: keep it hashable and stable. + return repr(value) + + def _tool_call_cache_key( + self, *, tool_name: str, tool_args: dict[str, Any] + ) -> tuple[Any, ...]: + """Builds a cache key for a tool call within this invocation.""" + return ( + self.branch, + tool_name, + InvocationContext._canonicalize_tool_args(tool_args), + ) + + async def get_or_execute_deduped_tool_call( + self, + *, + tool_name: str, + tool_args: dict[str, Any], + execute: Callable[[], Awaitable[Any]], + dedupe: bool = False, + ) -> tuple[Any, bool]: + """Returns cached tool result for identical calls, otherwise executes once. + + Args: + tool_name: Tool name. + tool_args: Tool arguments from the model. + execute: A coroutine factory that executes the tool and returns its + response. + + Returns: + A tuple of (tool_result, cache_hit). + """ + if not dedupe: + return await execute(), False + + key = self._tool_call_cache_key(tool_name=tool_name, tool_args=tool_args) + + async with self._tool_call_cache_lock: + task = self._tool_call_cache.get(key) + if task is None: + task = asyncio.create_task(execute()) + self._tool_call_cache[key] = task + cache_hit = False + else: + cache_hit = True + + try: + result = await task + except Exception: + # If the execution failed, remove from cache so subsequent calls can + # retry instead of returning a cached exception forever. + async with self._tool_call_cache_lock: + if self._tool_call_cache.get(key) is task: + self._tool_call_cache.pop(key, None) + raise + + return result, cache_hit + def set_agent_state( self, agent_name: str, diff --git a/src/google/adk/agents/run_config.py b/src/google/adk/agents/run_config.py index ae210ef471..9a6f740577 100644 --- a/src/google/adk/agents/run_config.py +++ b/src/google/adk/agents/run_config.py @@ -251,6 +251,19 @@ class RunConfig(BaseModel): - Less than or equal to 0: This allows for unbounded number of llm calls. """ + dedupe_tool_calls: bool = False + """ + Whether to deduplicate identical tool calls (same tool name + same arguments) + within a single invocation. + + This helps prevent redundant tool execution when the model repeats the same + function call multiple times (for example, when a tool is slow or the model + does not follow the instruction to call a tool only once). + + Note: Only the tool result is reused; tool side effects (state/artifact + deltas) are only applied once from the first execution. + """ + custom_metadata: Optional[dict[str, Any]] = None """Custom metadata for the current invocation.""" diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..233f454d75 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -340,75 +340,92 @@ async def _run_on_tool_error_callbacks( async def _run_with_trace(): nonlocal function_args - # Step 1: Check if plugin before_tool_callback overrides the function - # response. - function_response = ( - await invocation_context.plugin_manager.run_before_tool_callback( - tool=tool, tool_args=function_args, tool_context=tool_context - ) - ) + async def _execute_tool_pipeline() -> Any: + """Executes tool call pipeline once; result can be cached by invocation.""" + # Step 1: Check if plugin before_tool_callback overrides the function + # response. + function_response = ( + await invocation_context.plugin_manager.run_before_tool_callback( + tool=tool, tool_args=function_args, tool_context=tool_context + ) + ) - # Step 2: If no overrides are provided from the plugins, further run the - # canonical callback. - if function_response is None: - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break + # Step 2: If no overrides are provided from the plugins, further run the + # canonical callback. + if function_response is None: + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break + + # Step 3: Otherwise, proceed calling the tool normally. + if function_response is None: + try: + function_response = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + except Exception as tool_error: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_error - # Step 3: Otherwise, proceed calling the tool normally. - if function_response is None: - try: - function_response = await __call_tool_async( - tool, args=function_args, tool_context=tool_context - ) - except Exception as tool_error: - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - function_response = error_response - else: - raise tool_error + # Step 4: Check if plugin after_tool_callback overrides the function + # response. + altered_function_response = ( + await invocation_context.plugin_manager.run_after_tool_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + result=function_response, + ) + ) - # Step 4: Check if plugin after_tool_callback overrides the function - # response. - altered_function_response = ( - await invocation_context.plugin_manager.run_after_tool_callback( - tool=tool, + # Step 5: If no overrides are provided from the plugins, further run the + # canonical after_tool_callbacks. + if altered_function_response is None: + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + # Step 6: If alternative response exists from after_tool_callback, use it + # instead of the original function response. + if altered_function_response is not None: + function_response = altered_function_response + + return function_response + + should_dedupe = bool( + invocation_context.run_config + and invocation_context.run_config.dedupe_tool_calls + ) or tool.is_long_running + function_response, cache_hit = ( + await invocation_context.get_or_execute_deduped_tool_call( + tool_name=tool.name, tool_args=function_args, - tool_context=tool_context, - result=function_response, + execute=_execute_tool_pipeline, + dedupe=should_dedupe, ) ) - # Step 5: If no overrides are provided from the plugins, further run the - # canonical after_tool_callbacks. - if altered_function_response is None: - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break - - # Step 6: If alternative response exists from after_tool_callback, use it - # instead of the original function response. - if altered_function_response is not None: - function_response = altered_function_response - if tool.is_long_running: # Allow long running function to return None to not provide function # response. @@ -423,6 +440,11 @@ async def _run_with_trace(): function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) + if cache_hit: + function_response_event.custom_metadata = ( + function_response_event.custom_metadata or {} + ) + function_response_event.custom_metadata['adk_tool_call_cache_hit'] = True return function_response_event with tracer.start_as_current_span(f'execute_tool {tool.name}'): @@ -517,48 +539,69 @@ async def _execute_single_function_call_live( async def _run_with_trace(): nonlocal function_args - # Do not use "args" as the variable name, because it is a reserved keyword - # in python debugger. - # Make a deep copy to avoid being modified. - function_response = None + async def _execute_tool_pipeline() -> Any: + # Do not use "args" as the variable name, because it is a reserved keyword + # in python debugger. + # Make a deep copy to avoid being modified. + function_response = None - # Handle before_tool_callbacks - iterate through the canonical callback - # list - for callback in agent.canonical_before_tool_callbacks: - function_response = callback( - tool=tool, args=function_args, tool_context=tool_context - ) - if inspect.isawaitable(function_response): - function_response = await function_response - if function_response: - break - - if function_response is None: - function_response = await _process_function_live_helper( - tool, - tool_context, - function_call, - function_args, - invocation_context, - streaming_lock, - ) + # Handle before_tool_callbacks - iterate through the canonical callback + # list. + for callback in agent.canonical_before_tool_callbacks: + function_response = callback( + tool=tool, args=function_args, tool_context=tool_context + ) + if inspect.isawaitable(function_response): + function_response = await function_response + if function_response: + break - # Calls after_tool_callback if it exists. - altered_function_response = None - for callback in agent.canonical_after_tool_callbacks: - altered_function_response = callback( - tool=tool, - args=function_args, - tool_context=tool_context, - tool_response=function_response, - ) - if inspect.isawaitable(altered_function_response): - altered_function_response = await altered_function_response - if altered_function_response: - break + if function_response is None: + function_response = await _process_function_live_helper( + tool, + tool_context, + function_call, + function_args, + invocation_context, + streaming_lock, + ) - if altered_function_response is not None: - function_response = altered_function_response + # Calls after_tool_callback if it exists. + altered_function_response = None + for callback in agent.canonical_after_tool_callbacks: + altered_function_response = callback( + tool=tool, + args=function_args, + tool_context=tool_context, + tool_response=function_response, + ) + if inspect.isawaitable(altered_function_response): + altered_function_response = await altered_function_response + if altered_function_response: + break + + if altered_function_response is not None: + function_response = altered_function_response + + return function_response + + # Never cache stop_streaming calls (control operation). + if function_call.name == 'stop_streaming': + function_response = await _execute_tool_pipeline() + cache_hit = False + else: + should_dedupe = bool( + invocation_context.run_config + and invocation_context.run_config.dedupe_tool_calls + ) or tool.is_long_running + function_response, cache_hit = ( + await invocation_context.get_or_execute_deduped_tool_call( + tool_name=tool.name, + tool_args=function_args, + execute=_execute_tool_pipeline, + dedupe=should_dedupe, + ) + ) if tool.is_long_running: # Allow async function to return None to not provide function response. @@ -573,6 +616,11 @@ async def _run_with_trace(): function_response_event = __build_response_event( tool, function_response, tool_context, invocation_context ) + if cache_hit: + function_response_event.custom_metadata = ( + function_response_event.custom_metadata or {} + ) + function_response_event.custom_metadata['adk_tool_call_cache_hit'] = True return function_response_event with tracer.start_as_current_span(f'execute_tool {tool.name}'): diff --git a/tests/unittests/flows/llm_flows/test_tool_call_deduplication.py b/tests/unittests/flows/llm_flows/test_tool_call_deduplication.py new file mode 100644 index 0000000000..fb51581ebb --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_tool_call_deduplication.py @@ -0,0 +1,123 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tool call de-duplication (Issue #3940).""" + +from google.adk.agents.llm_agent import Agent +from google.adk.agents.run_config import RunConfig +from google.genai import types +import pytest + +from ... import testing_utils + + +def _function_call(name: str, args: dict) -> types.Part: + return types.Part.from_function_call(name=name, args=args) + + +@pytest.mark.asyncio +async def test_dedupe_identical_tool_calls_across_steps(): + """Identical tool calls should execute once and reuse the cached result.""" + responses = [ + _function_call("test_tool", {"x": 1}), + _function_call("test_tool", {"x": 1}), + "done", + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> dict: + nonlocal call_count + call_count += 1 + return {"result": call_count} + + agent = Agent(name="root_agent", model=mock_model, tools=[test_tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + run_config = RunConfig(dedupe_tool_calls=True) + events = [] + async for event in runner.runner.run_async( + user_id=runner.session.user_id, + session_id=runner.session.id, + new_message=testing_utils.get_user_content("run"), + run_config=run_config, + ): + events.append(event) + simplified = testing_utils.simplify_events(events) + + # Tool should execute exactly once even though the model calls it twice. + assert call_count == 1 + + # Both tool responses should contain the same cached payload. + tool_responses = [ + content + for _, content in simplified + if isinstance(content, types.Part) and content.function_response + ] + assert len(tool_responses) == 2 + assert tool_responses[0].function_response.response == {"result": 1} + assert tool_responses[1].function_response.response == {"result": 1} + + +def test_dedupe_identical_tool_calls_within_one_step(): + """Identical tool calls within the same step should execute once.""" + responses = [ + [ + _function_call("test_tool", {"x": 1}), + _function_call("test_tool", {"x": 1}), + ], + "done", + ] + mock_model = testing_utils.MockModel.create(responses=responses) + + call_count = 0 + + def test_tool(x: int) -> dict: + nonlocal call_count + call_count += 1 + return {"result": call_count} + + agent = Agent(name="root_agent", model=mock_model, tools=[test_tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + run_config = RunConfig(dedupe_tool_calls=True) + events = list( + runner.runner.run( + user_id=runner.session.user_id, + session_id=runner.session.id, + new_message=testing_utils.get_user_content("run"), + run_config=run_config, + ) + ) + simplified = testing_utils.simplify_events(events) + + assert call_count == 1 + + # The merged tool response event contains 2 function_response parts. + merged_parts = [ + content + for _, content in simplified + if isinstance(content, list) + and all(isinstance(p, types.Part) for p in content) + and any(p.function_response for p in content) + ] + assert len(merged_parts) == 1 + function_responses = [ + p.function_response.response + for p in merged_parts[0] + if p.function_response is not None + ] + assert function_responses == [{"result": 1}, {"result": 1}] +