diff --git a/contributing/samples/memory_chroma/README.md b/contributing/samples/memory_chroma/README.md new file mode 100644 index 0000000000..81b5a0a565 --- /dev/null +++ b/contributing/samples/memory_chroma/README.md @@ -0,0 +1,65 @@ +# ChromaDB Memory Service Example + +This example demonstrates using `ChromaMemoryService` for semantic memory search +with embeddings generated by Ollama. + +## Prerequisites + +1. **Ollama Server Running** + ```bash + ollama serve + ``` + +2. **Embedding Model Pulled** + ```bash + ollama pull nomic-embed-text + ``` + +3. **Dependencies Installed** + ```bash + pip install chromadb + # Or with uv: + uv pip install chromadb + ``` + +## Running the Example + +```bash +cd contributing/samples/memory_chroma +python main.py +``` + +## What This Demo Does + +1. **Session 1**: Creates memories by having a conversation with the agent + - User introduces themselves as "Jack" + - User mentions they like badminton + - User mentions what they ate recently + +2. **Memory Storage**: The session is saved to ChromaDB with semantic embeddings + - Data persists to `./chroma_db` directory + - Embeddings are generated using Ollama's `nomic-embed-text` model + +3. **Session 2**: Queries the memories using semantic search + - User asks about their hobbies (agent should recall "badminton") + - User asks about what they ate (agent should recall "burger") + +## Key Differences from InMemoryMemoryService + +| Feature | InMemory | ChromaDB | +|---------|----------|----------| +| Search Type | Keyword matching | **Semantic similarity** | +| Persistence | No (lost on restart) | **Yes (disk)** | +| Synonyms | No | **Yes** | +| Performance | Fast | Fast (with HNSW index) | + +## Customization + +You can change the embedding model by modifying the `OllamaEmbeddingProvider`: + +```python +embedding_provider = OllamaEmbeddingProvider( + model="mxbai-embed-large", # Higher quality but slower + host="http://remote-server:11434", # Remote Ollama server +) +``` diff --git a/contributing/samples/memory_chroma/__init__.py b/contributing/samples/memory_chroma/__init__.py new file mode 100644 index 0000000000..f28970b9ad --- /dev/null +++ b/contributing/samples/memory_chroma/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample package for ChromaMemoryService demonstration.""" diff --git a/contributing/samples/memory_chroma/agent.py b/contributing/samples/memory_chroma/agent.py new file mode 100644 index 0000000000..acbb9bca99 --- /dev/null +++ b/contributing/samples/memory_chroma/agent.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent definition for ChromaMemoryService demo.""" + +from datetime import datetime + +from google.adk import Agent +from google.adk.agents.callback_context import CallbackContext +from google.adk.tools.load_memory_tool import load_memory_tool +from google.adk.tools.preload_memory_tool import preload_memory_tool + + +def update_current_time(callback_context: CallbackContext): + callback_context.state["_time"] = datetime.now().isoformat() + + +root_agent = Agent( + model="gemini-2.0-flash-001", + name="chroma_memory_agent", + description="Agent with ChromaDB-backed semantic memory.", + before_agent_callback=update_current_time, + instruction="""\ +You are an agent that helps users answer questions. +You have access to a semantic memory system that stores past conversations. +Use the memory tools to recall information from previous sessions. + +Current time: {_time} +""", + tools=[ + load_memory_tool, + preload_memory_tool, + ], +) diff --git a/contributing/samples/memory_chroma/main.py b/contributing/samples/memory_chroma/main.py new file mode 100644 index 0000000000..9b299134a3 --- /dev/null +++ b/contributing/samples/memory_chroma/main.py @@ -0,0 +1,139 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demo script for ChromaMemoryService with OllamaEmbeddingProvider. + +This example demonstrates using ChromaDB for semantic memory search +with embeddings generated by Ollama. + +Prerequisites: + 1. Ollama server running: `ollama serve` + 2. Embedding model pulled: `ollama pull nomic-embed-text` + 3. Dependencies installed: `pip install chromadb` + +Usage: + python main.py +""" + +import asyncio +from datetime import datetime +from datetime import timedelta +from typing import cast + +import agent +from dotenv import load_dotenv +from google.adk.cli.utils import logs +from google.adk.memory import ChromaMemoryService +from google.adk.memory import OllamaEmbeddingProvider +from google.adk.runners import InMemoryRunner +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + app_name = "my_app" + user_id_1 = "user1" + + # Initialize the ChromaMemoryService with Ollama embeddings + embedding_provider = OllamaEmbeddingProvider( + model="nomic-embed-text", # Or another embedding model you have + ) + memory_service = ChromaMemoryService( + embedding_provider=embedding_provider, + collection_name="demo_memory", + persist_directory="./chroma_db", # Persist to disk + ) + + runner = InMemoryRunner( + app_name=app_name, + agent=agent.root_agent, + memory_service=memory_service, + ) + + async def run_prompt(session: Session, new_message: str) -> Session: + content = types.Content( + role="user", parts=[types.Part.from_text(text=new_message)] + ) + print("** User says:", content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if not event.content or not event.content.parts: + continue + if event.content.parts[0].text: + print(f"** {event.author}: {event.content.parts[0].text}") + elif event.content.parts[0].function_call: + print( + f"** {event.author}: fc /" + f" {event.content.parts[0].function_call.name} /" + f" {event.content.parts[0].function_call.args}\n" + ) + elif event.content.parts[0].function_response: + print( + f"** {event.author}: fr /" + f" {event.content.parts[0].function_response.name} /" + f" {event.content.parts[0].function_response.response}\n" + ) + + return cast( + Session, + await runner.session_service.get_session( + app_name=app_name, user_id=user_id_1, session_id=session.id + ), + ) + + # Session 1: Create memories + session_1 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + print(f"----Session to create memory: {session_1.id} ----------------------") + session_1 = await run_prompt(session_1, "Hi") + session_1 = await run_prompt(session_1, "My name is Jack") + session_1 = await run_prompt(session_1, "I like badminton.") + session_1 = await run_prompt( + session_1, + f"I ate a burger on {(datetime.now() - timedelta(days=1)).date()}.", + ) + session_1 = await run_prompt( + session_1, + f"I ate a banana on {(datetime.now() - timedelta(days=2)).date()}.", + ) + + print("Saving session to ChromaDB memory service...") + await memory_service.add_session_to_memory(session_1) + print("Session saved! Data persisted to ./chroma_db") + print("-------------------------------------------------------------------") + + # Session 2: Query memories using semantic search + session_2 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + print(f"----Session to use memory: {session_2.id} ----------------------") + session_2 = await run_prompt(session_2, "Hi") + session_2 = await run_prompt(session_2, "What do I like to do?") + # Expected: The agent should recall "badminton" from semantic search + session_2 = await run_prompt(session_2, "When did I say that?") + session_2 = await run_prompt(session_2, "What did I eat yesterday?") + # Expected: The agent should recall "burger" from semantic search + print("-------------------------------------------------------------------") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 5e3e9ec3d6..ce5a479c84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,6 +165,10 @@ otel-gcp = ["opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0"] toolbox = ["toolbox-adk>=0.5.7, <0.6.0"] +chroma = [ + "chromadb>=0.4.0, <1.0.0", # For ChromaMemoryService +] + [tool.pyink] # Format py files following Google style-guide line-length = 80 diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..a0a1eeccc8 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -35,3 +35,20 @@ ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) + +try: + from .chroma_memory_service import ChromaMemoryService + from .embeddings import BaseEmbeddingProvider + from .embeddings import OllamaEmbeddingProvider + + __all__.extend([ + 'ChromaMemoryService', + 'BaseEmbeddingProvider', + 'OllamaEmbeddingProvider', + ]) +except ImportError: + logger.debug( + 'chromadb is not installed. If you want to use the ChromaMemoryService' + ' please install it with: pip install \'google-adk[chroma]\'. If not, you can' + ' ignore this warning.' + ) diff --git a/src/google/adk/memory/chroma_memory_service.py b/src/google/adk/memory/chroma_memory_service.py new file mode 100644 index 0000000000..30dcee6672 --- /dev/null +++ b/src/google/adk/memory/chroma_memory_service.py @@ -0,0 +1,208 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ChromaDB-based memory service with semantic search capabilities.""" + +from __future__ import annotations + +import hashlib +import logging +from typing import Optional +from typing import TYPE_CHECKING + +from google.genai import types + +from typing_extensions import override + +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .embeddings.base_embedding_provider import BaseEmbeddingProvider +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..events.event import Event + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + + +def _user_key(app_name: str, user_id: str) -> str: + """Generate a unique key for a user within an app.""" + return f"{app_name}/{user_id}" + + +def _event_id(session_id: str, event_id: str) -> str: + """Generate a unique document ID for an event.""" + return hashlib.sha256(f"{session_id}/{event_id}".encode()).hexdigest()[:32] + + +class ChromaMemoryService(BaseMemoryService): + """A memory service that uses ChromaDB for semantic search. + + This service stores session events as documents in a ChromaDB collection + and uses vector embeddings for semantic similarity search. + + Example: + >>> from google.adk.memory.embeddings import OllamaEmbeddingProvider + >>> embedding_provider = OllamaEmbeddingProvider(model="nomic-embed-text") + >>> memory = ChromaMemoryService( + ... embedding_provider=embedding_provider, + ... persist_directory="./memory_db" + ... ) + """ + + def __init__( + self, + embedding_provider: BaseEmbeddingProvider, + collection_name: str = "adk_memory", + persist_directory: Optional[str] = None, + ): + """Initialize the ChromaMemoryService. + + Args: + embedding_provider: The embedding provider to use for generating + vector representations of text. + collection_name: The name of the ChromaDB collection to use. + persist_directory: Optional directory path for persisting the + ChromaDB data. If None, data is stored in memory only. + """ + try: + import chromadb + except ImportError as exc: + raise ImportError( + "chromadb is required for ChromaMemoryService. " + "Install it with: pip install chromadb" + ) from exc + + self._embedding_provider = embedding_provider + self._collection_name = collection_name + + if persist_directory: + self._client = chromadb.PersistentClient(path=persist_directory) + else: + self._client = chromadb.Client() + + self._collection = self._client.get_or_create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"}, + ) + + @override + async def add_session_to_memory(self, session: "Session"): + """Add a session's events to the ChromaDB collection. + + Each event with text content is stored as a separate document with + its embedding, along with metadata for filtering. + + Args: + session: The session to add to memory. + """ + user_key = _user_key(session.app_name, session.user_id) + + documents: list[str] = [] + metadatas: list[dict] = [] + ids: list[str] = [] + + for event in session.events: + if not event.content or not event.content.parts: + continue + + text_parts = [part.text for part in event.content.parts if part.text] + if not text_parts: + continue + + document_text = " ".join(text_parts) + documents.append(document_text) + metadatas.append({ + "user_key": user_key, + "app_name": session.app_name, + "user_id": session.user_id, + "session_id": session.id, + "event_id": event.id, + "author": event.author or "", + "timestamp": event.timestamp or 0, + }) + ids.append(_event_id(session.id, event.id)) + + if not documents: + return + + # Generate embeddings + embeddings = await self._embedding_provider.embed(documents) + + # Upsert to ChromaDB (update if exists, insert otherwise) + self._collection.upsert( + ids=ids, + embeddings=embeddings, + documents=documents, + metadatas=metadatas, + ) + + logger.debug( + "Added %d events from session %s to ChromaDB", + len(documents), + session.id, + ) + + @override + async def search_memory( + self, + *, + app_name: str, + user_id: str, + query: str, + ) -> SearchMemoryResponse: + """Search for memories semantically similar to the query. + + Args: + app_name: The name of the application. + user_id: The id of the user. + query: The query to search for. + + Returns: + A SearchMemoryResponse containing the matching memories. + """ + + user_key = _user_key(app_name, user_id) + + # Generate embedding for query + query_embeddings = await self._embedding_provider.embed([query]) + if not query_embeddings: + return SearchMemoryResponse() + + # Search ChromaDB with user filtering + results = self._collection.query( + query_embeddings=query_embeddings, + n_results=10, + where={"user_key": user_key}, + include=["documents", "metadatas"], + ) + + memories: list[MemoryEntry] = [] + + if results["documents"] and results["metadatas"]: + for doc, metadata in zip( + results["documents"][0], results["metadatas"][0] + ): + content = types.Content(parts=[types.Part(text=doc)]) + memories.append( + MemoryEntry( + content=content, + author=metadata.get("author", ""), + timestamp=_utils.format_timestamp(metadata.get("timestamp", 0)), + ) + ) + + return SearchMemoryResponse(memories=memories) diff --git a/src/google/adk/memory/embeddings/__init__.py b/src/google/adk/memory/embeddings/__init__.py new file mode 100644 index 0000000000..e8d7ecf31c --- /dev/null +++ b/src/google/adk/memory/embeddings/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedding providers for memory services.""" + +from .base_embedding_provider import BaseEmbeddingProvider +from .ollama_embedding_provider import OllamaEmbeddingProvider + +__all__ = [ + "BaseEmbeddingProvider", + "OllamaEmbeddingProvider", +] diff --git a/src/google/adk/memory/embeddings/base_embedding_provider.py b/src/google/adk/memory/embeddings/base_embedding_provider.py new file mode 100644 index 0000000000..72ab55a184 --- /dev/null +++ b/src/google/adk/memory/embeddings/base_embedding_provider.py @@ -0,0 +1,44 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for embedding providers.""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + + +class BaseEmbeddingProvider(ABC): + """Abstract base class for embedding providers. + + Embedding providers are responsible for converting text into vector + representations for use in semantic search. + """ + + @abstractmethod + async def embed(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings for a list of texts. + + Args: + texts: A list of strings to embed. + + Returns: + A list of embeddings, where each embedding is a list of floats. + """ + + @property + @abstractmethod + def dimension(self) -> int: + """Return the dimension of the embedding vectors.""" diff --git a/src/google/adk/memory/embeddings/ollama_embedding_provider.py b/src/google/adk/memory/embeddings/ollama_embedding_provider.py new file mode 100644 index 0000000000..89ad6121d7 --- /dev/null +++ b/src/google/adk/memory/embeddings/ollama_embedding_provider.py @@ -0,0 +1,134 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ollama embedding provider for ChromaMemoryService.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Optional +import requests + +from .base_embedding_provider import BaseEmbeddingProvider + +logger = logging.getLogger("google_adk." + __name__) + +_EMBED_ENDPOINT = "/api/embed" + + +class OllamaEmbeddingProvider(BaseEmbeddingProvider): + """Embedding provider using Ollama's embedding API. + + This provider uses Ollama's `/api/embed` endpoint to generate embeddings. + It requires an Ollama server running with an embedding model available. + + Example: + >>> provider = OllamaEmbeddingProvider(model="nomic-embed-text") + >>> embeddings = await provider.embed(["Hello, world!"]) + """ + + def __init__( + self, + model: str = "nomic-embed-text", + host: Optional[str] = None, + request_timeout: float = 60.0, + ): + """Initialize the Ollama embedding provider. + + Args: + model: The name of the Ollama embedding model to use. + Popular options: "nomic-embed-text", "mxbai-embed-large", + "all-minilm". + host: The base URL of the Ollama server. Defaults to + http://localhost:11434 or OLLAMA_API_BASE env var. + request_timeout: Timeout in seconds for embedding requests. + """ + import os + + self._model = model + self._host = host or os.environ.get( + "OLLAMA_API_BASE", "http://localhost:11434" + ) + self._request_timeout = request_timeout + self._dimension: Optional[int] = None + + @property + def dimension(self) -> int: + """Return the dimension of the embedding vectors. + + The dimension is determined by the first embedding request. + """ + if self._dimension is None: + raise ValueError( + "Dimension is not available until the first embedding is generated." + ) + return self._dimension + + async def embed(self, texts: list[str]) -> list[list[float]]: + """Generate embeddings for a list of texts using Ollama. + + Args: + texts: A list of strings to embed. + + Returns: + A list of embeddings, where each embedding is a list of floats. + + Raises: + RuntimeError: If the Ollama API call fails. + """ + if not texts: + return [] + + try: + response_json = await asyncio.to_thread(self._post_embed, texts) + except RuntimeError as exc: + logger.error("Failed to generate embeddings from Ollama: %s", exc) + raise + + embeddings = response_json.get("embeddings", []) + + # Set dimension from first embedding if not already set + if embeddings and self._dimension is None: + self._dimension = len(embeddings[0]) + + return embeddings + + def _post_embed(self, texts: list[str]) -> dict: + """Perform a blocking POST /api/embed call to Ollama. + + Args: + texts: A list of strings to embed. + + Returns: + The JSON response from Ollama. + + Raises: + RuntimeError: If the request fails. + """ + url = self._host.rstrip("/") + _EMBED_ENDPOINT + payload = { + "model": self._model, + "input": texts, + } + try: + response = requests.post( + url, + json=payload, + timeout=self._request_timeout, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as exc: + raise RuntimeError(f"Failed to connect to Ollama: {exc}") from exc diff --git a/tests/unittests/memory/embeddings/test_ollama_embedding_provider.py b/tests/unittests/memory/embeddings/test_ollama_embedding_provider.py new file mode 100644 index 0000000000..9c398df21f --- /dev/null +++ b/tests/unittests/memory/embeddings/test_ollama_embedding_provider.py @@ -0,0 +1,93 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for OllamaEmbeddingProvider.""" + +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.memory.embeddings.ollama_embedding_provider import OllamaEmbeddingProvider +import pytest +import requests + + +@pytest.fixture +def provider(): + return OllamaEmbeddingProvider(model="test-model", host="http://test-host") + + +@patch("requests.post") +def test_embed_success(mock_post, provider): + """Test successful embedding generation.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + } + mock_post.return_value = mock_response + + # Run the async method synchronously for testing logic + # Since we mocked the synchronous _post_embed call via requests.post, + # we can verify the result. + # However, OllamaEmbeddingProvider.embed is async and uses asyncio.to_thread. + # We need to run it in an async loop or trust pytest-asyncio. + import asyncio + + embeddings = asyncio.run(provider.embed(["text1", "text2"])) + + assert len(embeddings) == 2 + assert embeddings[0] == [0.1, 0.2, 0.3] + assert embeddings[1] == [0.4, 0.5, 0.6] + assert provider.dimension == 3 + + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + assert args[0] == "http://test-host/api/embed" + assert kwargs["json"] == { + "model": "test-model", + "input": ["text1", "text2"], + } + + +@patch("requests.post") +def test_embed_http_error(mock_post, provider): + """Test handling of HTTP errors.""" + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( + "404 Client Error" + ) + mock_post.return_value = mock_response + + import asyncio + + with pytest.raises(RuntimeError, match="Failed to connect to Ollama"): + asyncio.run(provider.embed(["text"])) + + +@patch("requests.post") +def test_embed_connection_error(mock_post, provider): + """Test handling of connection errors.""" + mock_post.side_effect = requests.exceptions.ConnectionError( + "Connection refused" + ) + + import asyncio + + with pytest.raises(RuntimeError, match="Failed to connect to Ollama"): + asyncio.run(provider.embed(["text"])) + + +def test_dimension_property(provider): + """Test dimension property raises error if not set.""" + with pytest.raises(ValueError, match="Dimension is not available"): + _ = provider.dimension diff --git a/tests/unittests/memory/test_chroma_memory_service.py b/tests/unittests/memory/test_chroma_memory_service.py new file mode 100644 index 0000000000..9f0b45e241 --- /dev/null +++ b/tests/unittests/memory/test_chroma_memory_service.py @@ -0,0 +1,234 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ChromaMemoryService.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.events.event import Event +from google.adk.memory.chroma_memory_service import ChromaMemoryService +from google.adk.memory.embeddings.base_embedding_provider import BaseEmbeddingProvider +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = "test-app" +MOCK_USER_ID = "test-user" +MOCK_OTHER_USER_ID = "another-user" + + +class MockEmbeddingProvider(BaseEmbeddingProvider): + """A mock embedding provider for testing.""" + + def __init__(self, dimension: int = 384): + self._dimension = dimension + + @property + def dimension(self) -> int: + return self._dimension + + async def embed(self, texts: list[str]) -> list[list[float]]: + """Return deterministic mock embeddings based on text content.""" + embeddings = [] + for text in texts: + # Create a simple deterministic embedding based on hash + base_value = hash(text) % 1000 / 1000.0 + embedding = [base_value + i * 0.001 for i in range(self._dimension)] + embeddings.append(embedding) + return embeddings + + +MOCK_SESSION_1 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-1", + last_update_time=1000, + events=[ + Event( + id="event-1a", + invocation_id="inv-1", + author="user", + timestamp=12345, + content=types.Content( + parts=[types.Part(text="The ADK is a great toolkit.")] + ), + ), + # Event with no content, should be ignored by the service + Event( + id="event-1b", + invocation_id="inv-2", + author="user", + timestamp=12346, + ), + Event( + id="event-1c", + invocation_id="inv-3", + author="model", + timestamp=12347, + content=types.Content( + parts=[ + types.Part( + text="I agree. The Agent Development Kit (ADK) rocks!" + ) + ] + ), + ), + ], +) + +MOCK_SESSION_2 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-2", + last_update_time=2000, + events=[ + Event( + id="event-2a", + invocation_id="inv-4", + author="user", + timestamp=54321, + content=types.Content( + parts=[types.Part(text="I like to code in Python.")] + ), + ), + ], +) + +MOCK_SESSION_DIFFERENT_USER = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_OTHER_USER_ID, + id="session-3", + last_update_time=3000, + events=[ + Event( + id="event-3a", + invocation_id="inv-5", + author="user", + timestamp=60000, + content=types.Content(parts=[types.Part(text="This is a secret.")]), + ), + ], +) + +MOCK_SESSION_WITH_NO_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id="session-4", + last_update_time=4000, +) + + +@pytest.fixture +def embedding_provider(): + """Create a mock embedding provider.""" + return MockEmbeddingProvider(dimension=384) + + +@pytest.fixture +def memory_service(embedding_provider, request): + """Create a ChromaMemoryService with in-memory storage and unique collection.""" + # Use test function name to create unique collection per test + collection_name = f"test_{request.node.name}" + return ChromaMemoryService( + embedding_provider=embedding_provider, + collection_name=collection_name, + ) + + +@pytest.mark.asyncio +async def test_add_session_to_memory(memory_service): + """Tests that a session with events is correctly added to memory.""" + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + # Verify documents were added to the collection + count = memory_service._collection.count() + # Should have 2 events (one has no content and is filtered) + assert count == 2 + + +@pytest.mark.asyncio +async def test_add_session_with_no_events_to_memory(memory_service): + """Tests that adding a session with no events does not cause an error.""" + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_NO_EVENTS) + + # Verify no documents were added + count = memory_service._collection.count() + assert count == 0 + + +@pytest.mark.asyncio +async def test_search_memory_returns_results(memory_service): + """Tests that search returns relevant results.""" + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_2) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="ADK toolkit" + ) + + # Should return results (exact matching depends on embedding similarity) + assert len(result.memories) > 0 + + +@pytest.mark.asyncio +async def test_search_memory_no_match(memory_service): + """Tests search with no matching user returns empty results.""" + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id="nonexistent-user", query="ADK" + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_is_scoped_by_user(memory_service): + """Tests that search results are correctly scoped to the user_id.""" + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_DIFFERENT_USER) + + # Verify that searching as MOCK_OTHER_USER_ID returns the secret + result_other_user = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_OTHER_USER_ID, query="secret" + ) + assert len(result_other_user.memories) == 1 + assert ( + result_other_user.memories[0].content.parts[0].text == "This is a secret." + ) + + # Verify that searching as MOCK_USER_ID does NOT return the secret + # (it should return MOCK_USER_ID's data, not MOCK_OTHER_USER_ID's) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query="secret" + ) + # Results should only contain MOCK_USER_ID's content, not the secret + for memory in result.memories: + assert "secret" not in memory.content.parts[0].text.lower() + + +@pytest.mark.asyncio +async def test_upsert_updates_existing_documents(memory_service): + """Tests that adding the same session twice updates existing documents.""" + await memory_service.add_session_to_memory(MOCK_SESSION_1) + initial_count = memory_service._collection.count() + + # Add the same session again + await memory_service.add_session_to_memory(MOCK_SESSION_1) + final_count = memory_service._collection.count() + + # Count should remain the same (upsert, not duplicate) + assert initial_count == final_count