diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py index 12207e8070..85115d81cd 100644 --- a/src/google/adk/cli/utils/local_storage.py +++ b/src/google/adk/cli/utils/local_storage.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utilities for local .adk folder persistence.""" + from __future__ import annotations import asyncio import logging from pathlib import Path +from typing import Any from typing import Mapping from typing import Optional @@ -27,6 +29,7 @@ from ...events.event import Event from ...sessions.base_session_service import BaseSessionService from ...sessions.base_session_service import GetSessionConfig +from ...sessions.base_session_service import ListSessionsConfig from ...sessions.base_session_service import ListSessionsResponse from ...sessions.session import Session from .dot_adk_folder import dot_adk_folder_for_agent @@ -155,8 +158,10 @@ async def create_session( *, app_name: str, user_id: str, - state: Optional[dict[str, object]] = None, + state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: service = await self._get_service(app_name) return await service.create_session( @@ -164,6 +169,8 @@ async def create_session( user_id=user_id, state=state, session_id=session_id, + display_name=display_name, + labels=labels, ) @override @@ -189,9 +196,12 @@ async def list_sessions( *, app_name: str, user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: service = await self._get_service(app_name) - return await service.list_sessions(app_name=app_name, user_id=user_id) + return await service.list_sessions( + app_name=app_name, user_id=user_id, config=config + ) @override async def delete_session( diff --git a/src/google/adk/sessions/__init__.py b/src/google/adk/sessions/__init__.py index cb0df86bd2..b8f58629ae 100644 --- a/src/google/adk/sessions/__init__.py +++ b/src/google/adk/sessions/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from .base_session_service import BaseSessionService +from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsConfig +from .base_session_service import ListSessionsResponse from .in_memory_session_service import InMemorySessionService from .session import Session from .state import State @@ -20,7 +23,10 @@ __all__ = [ 'BaseSessionService', 'DatabaseSessionService', + 'GetSessionConfig', 'InMemorySessionService', + 'ListSessionsConfig', + 'ListSessionsResponse', 'Session', 'State', 'VertexAiSessionService', diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py index 0b2f99eef2..0e817581fa 100644 --- a/src/google/adk/sessions/_session_util.py +++ b/src/google/adk/sessions/_session_util.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for session service.""" + from __future__ import annotations from typing import Any diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index f2f6f9f22d..24ffa679b6 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -33,6 +33,14 @@ class GetSessionConfig(BaseModel): after_timestamp: Optional[float] = None +class ListSessionsConfig(BaseModel): + """The configuration of listing sessions.""" + + labels: Optional[dict[str, str]] = None + """Filter sessions by labels. Only sessions that have all the specified + labels will be returned.""" + + class ListSessionsResponse(BaseModel): """The response of listing sessions. @@ -56,6 +64,8 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: """Creates a new session. @@ -65,6 +75,9 @@ async def create_session( state: the initial state of the session. session_id: the client-provided id of the session. If not provided, a generated ID will be used. + display_name: optional display name for the session. + labels: optional labels with user-defined metadata to organize sessions. + Label keys and values can be no longer than 64 characters. Returns: session: The newly created session instance. @@ -83,7 +96,11 @@ async def get_session( @abc.abstractmethod async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: """Lists all the sessions for a user. @@ -91,6 +108,7 @@ async def list_sessions( app_name: The name of the app. user_id: The ID of the user. If not provided, lists all sessions for all users. + config: Optional configuration for filtering sessions. Returns: A ListSessionsResponse containing the sessions. diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 863bbfa861..a427ee9284 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -41,6 +41,7 @@ from ..events.event import Event from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsConfig from .base_session_service import ListSessionsResponse from .migration import _schema_check_utils from .schemas.v0 import Base as BaseV0 @@ -229,6 +230,8 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: # 1. Populate states. # 2. Build storage session object @@ -280,6 +283,8 @@ async def create_session( user_id=user_id, id=session_id, state=session_state, + display_name=display_name, + labels=labels or {}, ) sql_session.add(storage_session) await sql_session.commit() @@ -355,7 +360,11 @@ async def get_session( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: await self._prepare_tables() schema = self._get_schema_classes() @@ -366,6 +375,21 @@ async def list_sessions( if user_id is not None: stmt = stmt.filter(schema.StorageSession.user_id == user_id) + labels_filter = config.labels if config else None + + # Apply label filter at database level for backends with native JSON + # support (PostgreSQL JSONB). For other backends, filter in Python. + apply_python_filter = False + if labels_filter: + if self.db_engine.dialect.name == "postgresql": + # PostgreSQL JSONB supports efficient containment queries via @> + stmt = stmt.filter( + schema.StorageSession.labels.contains(labels_filter) + ) + else: + # For other backends (SQLite, MySQL), filter in Python after fetching + apply_python_filter = True + result = await sql_session.execute(stmt) results = result.scalars().all() @@ -394,6 +418,14 @@ async def list_sessions( sessions = [] for storage_session in results: + # Apply Python-level label filter for non-PostgreSQL backends + if apply_python_filter and labels_filter: + session_labels = storage_session.labels or {} + if not all( + session_labels.get(k) == v for k, v in labels_filter.items() + ): + continue + session_state = storage_session.state user_state = user_states_map.get(storage_session.user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) @@ -436,8 +468,8 @@ async def append_event(self, session: Session, event: Event) -> Event: if storage_session.update_timestamp_tz > session.last_update_time: raise ValueError( "The last_update_time provided in the session object" - f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is" - " earlier than the update_time in the storage_session" + f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'}" + " is earlier than the update_time in the storage_session" f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}." " Please check if it is a stale session." ) diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 6ba7f0bb01..24356bb17a 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -27,6 +27,7 @@ from ..events.event import Event from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsConfig from .base_session_service import ListSessionsResponse from .session import Session from .state import State @@ -58,12 +59,16 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: return self._create_session_impl( app_name=app_name, user_id=user_id, state=state, session_id=session_id, + display_name=display_name, + labels=labels, ) def create_session_sync( @@ -73,6 +78,8 @@ def create_session_sync( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: logger.warning('Deprecated. Please migrate to the async method.') return self._create_session_impl( @@ -80,6 +87,8 @@ def create_session_sync( user_id=user_id, state=state, session_id=session_id, + display_name=display_name, + labels=labels, ) def _create_session_impl( @@ -89,6 +98,8 @@ def _create_session_impl( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: if session_id and self._get_session_impl( app_name=app_name, user_id=user_id, session_id=session_id @@ -116,6 +127,8 @@ def _create_session_impl( id=session_id, state=session_state or {}, last_update_time=time.time(), + display_name=display_name, + labels=labels or {}, ) if app_name not in self.sessions: @@ -218,20 +231,44 @@ def _merge_state( ][key] return copied_session + def _matches_labels( + self, session: Session, labels: Optional[dict[str, str]] + ) -> bool: + """Checks if a session has all the specified labels.""" + if not labels: + return True + return all(session.labels.get(k) == v for k, v in labels.items()) + @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: - return self._list_sessions_impl(app_name=app_name, user_id=user_id) + return self._list_sessions_impl( + app_name=app_name, user_id=user_id, config=config + ) def list_sessions_sync( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: logger.warning('Deprecated. Please migrate to the async method.') - return self._list_sessions_impl(app_name=app_name, user_id=user_id) + return self._list_sessions_impl( + app_name=app_name, user_id=user_id, config=config + ) def _list_sessions_impl( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: empty_response = ListSessionsResponse() if app_name not in self.sessions: @@ -240,17 +277,22 @@ def _list_sessions_impl( return empty_response sessions_without_events = [] + labels_filter = config.labels if config else None if user_id is None: for user_id in self.sessions[app_name]: for session_id in self.sessions[app_name][user_id]: session = self.sessions[app_name][user_id][session_id] + if not self._matches_labels(session, labels_filter): + continue copied_session = copy.deepcopy(session) copied_session.events = [] copied_session = self._merge_state(app_name, user_id, copied_session) sessions_without_events.append(copied_session) else: for session in self.sessions[app_name][user_id].values(): + if not self._matches_labels(session, labels_filter): + continue copied_session = copy.deepcopy(session) copied_session.events = [] copied_session = self._merge_state(app_name, user_id, copied_session) diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py index 0a3a45f676..56e0545741 100644 --- a/src/google/adk/sessions/migration/migration_runner.py +++ b/src/google/adk/sessions/migration/migration_runner.py @@ -13,6 +13,7 @@ # limitations under the License. """Migration runner to upgrade schemas to the latest version.""" + from __future__ import annotations import logging diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index a69c29243a..5381984c33 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -123,6 +123,13 @@ class StorageSession(Base): PreciseTimestamp, default=func.now(), onupdate=func.now() ) + display_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + labels: Mapped[MutableDict[str, str]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + storage_events: Mapped[list[StorageEvent]] = relationship( "StorageEvent", back_populates="storage_session", @@ -164,6 +171,8 @@ def to_session( state=state, events=events, last_update_time=self.update_timestamp_tz, + display_name=self.display_name, + labels=self.labels or {}, ) diff --git a/src/google/adk/sessions/schemas/v1.py b/src/google/adk/sessions/schemas/v1.py index df309287fa..cde050629d 100644 --- a/src/google/adk/sessions/schemas/v1.py +++ b/src/google/adk/sessions/schemas/v1.py @@ -96,6 +96,13 @@ class StorageSession(Base): PreciseTimestamp, default=func.now(), onupdate=func.now() ) + display_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + labels: Mapped[MutableDict[str, str]] = mapped_column( + MutableDict.as_mutable(DynamicJSON), default={} + ) + storage_events: Mapped[list[StorageEvent]] = relationship( "StorageEvent", back_populates="storage_session", @@ -139,6 +146,8 @@ def to_session( state=state, events=events, last_update_time=self.update_timestamp_tz, + display_name=self.display_name, + labels=self.labels or {}, ) diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index e674dd3778..eb740962d7 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -15,6 +15,7 @@ from __future__ import annotations from typing import Any +from typing import Optional from pydantic import alias_generators from pydantic import BaseModel @@ -48,3 +49,8 @@ class Session(BaseModel): call/response, etc.""" last_update_time: float = 0.0 """The last update time of the session.""" + display_name: Optional[str] = None + """Optional display name for the session.""" + labels: dict[str, str] = Field(default_factory=dict) + """Labels with user-defined metadata to organize sessions. + Label keys and values can be no longer than 64 characters.""" diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 1d9516ec73..0cdb07695f 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -34,6 +34,7 @@ from ..events.event import Event from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsConfig from .base_session_service import ListSessionsResponse from .session import Session from .state import State @@ -68,6 +69,8 @@ state TEXT NOT NULL, create_time REAL NOT NULL, update_time REAL NOT NULL, + display_name TEXT, + labels TEXT NOT NULL DEFAULT '{}', PRIMARY KEY (app_name, user_id, id) ); """ @@ -161,12 +164,15 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, ) -> Session: if session_id: session_id = session_id.strip() if not session_id: session_id = str(uuid.uuid4()) now = time.time() + labels = labels or {} async with self._get_db_connection() as db: # Check if session_id already exists @@ -200,8 +206,8 @@ async def create_session( # Store the session await db.execute( """ - INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time, display_name, labels) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( app_name, @@ -210,6 +216,8 @@ async def create_session( json.dumps(session_state), now, now, + display_name, + json.dumps(labels), ), ) await db.commit() @@ -225,6 +233,8 @@ async def create_session( state=merged_state, events=[], last_update_time=now, + display_name=display_name, + labels=labels, ) @override @@ -238,8 +248,8 @@ async def get_session( ) -> Optional[Session]: async with self._get_db_connection() as db: async with db.execute( - "SELECT state, update_time FROM sessions WHERE app_name=? AND" - " user_id=? AND id=?", + "SELECT state, update_time, display_name, labels FROM sessions WHERE" + " app_name=? AND user_id=? AND id=?", (app_name, user_id, session_id), ) as cursor: session_row = await cursor.fetchone() @@ -247,6 +257,10 @@ async def get_session( return None session_state = json.loads(session_row["state"]) last_update_time = session_row["update_time"] + display_name = session_row["display_name"] + labels = ( + json.loads(session_row["labels"]) if session_row["labels"] else {} + ) # Build events query query_parts = [ @@ -288,27 +302,42 @@ async def get_session( state=merged_state, events=events, last_update_time=last_update_time, + display_name=display_name, + labels=labels, ) @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: sessions_list = [] async with self._get_db_connection() as db: - # Fetch sessions + # Build query with filters applied at database level + query_parts = [ + "SELECT id, user_id, state, update_time, display_name, labels FROM" + " sessions", + "WHERE app_name=?", + ] + params: list[Any] = [app_name] + if user_id: - session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=? AND user_id=?", - (app_name, user_id), - ) - else: - session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=?", - (app_name,), - ) + query_parts.append("AND user_id=?") + params.append(user_id) + + # Apply label filter at database level using json_extract + labels_filter = config.labels if config else None + if labels_filter: + for key, value in labels_filter.items(): + query_parts.append("AND json_extract(labels, ?)=?") + params.extend([f"$.{key}", value]) + + session_rows = await db.execute_fetchall( + " ".join(query_parts), tuple(params) + ) # Fetch app state app_state = await self._get_app_state(db, app_name) @@ -333,6 +362,8 @@ async def list_sessions( session_state = json.loads(row["state"]) user_state = user_states_map.get(session_user_id, {}) merged_state = _merge_state(app_state, user_state, session_state) + labels = json.loads(row["labels"]) if row["labels"] else {} + sessions_list.append( Session( app_name=app_name, @@ -341,6 +372,8 @@ async def list_sessions( state=merged_state, events=[], last_update_time=row["update_time"], + display_name=row["display_name"], + labels=labels, ) ) return ListSessionsResponse(sessions=sessions_list) @@ -461,8 +494,22 @@ async def _get_db_connection(self): db.row_factory = aiosqlite.Row await db.execute(PRAGMA_FOREIGN_KEYS) await db.executescript(CREATE_SCHEMA_SQL) + # Ensure new columns exist for existing databases + await self._ensure_new_columns(db) yield db + async def _ensure_new_columns(self, db: aiosqlite.Connection) -> None: + """Ensures display_name and labels columns exist in the sessions table.""" + async with db.execute("PRAGMA table_info(sessions)") as cursor: + columns = [row[1] async for row in cursor] + + if "display_name" not in columns: + await db.execute("ALTER TABLE sessions ADD COLUMN display_name TEXT") + if "labels" not in columns: + await db.execute( + "ALTER TABLE sessions ADD COLUMN labels TEXT NOT NULL DEFAULT '{}'" + ) + async def _get_state( self, db: aiosqlite.Connection, query: str, params: tuple ) -> dict[str, Any]: diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 3f9e514e03..345d33aea2 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -36,6 +36,7 @@ from ..utils.vertex_ai_utils import get_express_mode_api_key from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig +from .base_session_service import ListSessionsConfig from .base_session_service import ListSessionsResponse from .session import Session @@ -84,6 +85,8 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, + labels: Optional[dict[str, str]] = None, **kwargs: Any, ) -> Session: """Creates a new session. @@ -93,6 +96,8 @@ async def create_session( user_id: The ID of the user. state: The initial state of the session. session_id: The ID of the session. + display_name: Optional display name for the session. + labels: Optional labels with user-defined metadata to organize sessions. **kwargs: Additional arguments to pass to the session creation. E.g. set expire_time='2025-10-01T00:00:00Z' to set the session expiration time. See https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.reasoningEngines.sessions @@ -110,6 +115,10 @@ async def create_session( reasoning_engine_id = self._get_reasoning_engine_id(app_name) config = {'session_state': state} if state else {} + if display_name: + config['display_name'] = display_name + if labels: + config['labels'] = labels config.update(kwargs) async with self._get_api_client() as api_client: api_response = await api_client.agent_engines.sessions.create( @@ -127,6 +136,8 @@ async def create_session( id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=get_session_response.update_time.timestamp(), + display_name=getattr(get_session_response, 'display_name', None), + labels=getattr(get_session_response, 'labels', None) or {}, ) return session @@ -184,6 +195,8 @@ async def get_session( id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=update_timestamp, + display_name=getattr(get_session_response, 'display_name', None), + labels=getattr(get_session_response, 'labels', None) or {}, ) # Preserve the entire event stream that Vertex returns rather than trying # to discard events written milliseconds after the session resource was @@ -201,18 +214,30 @@ async def get_session( @override async def list_sessions( - self, *, app_name: str, user_id: Optional[str] = None + self, + *, + app_name: str, + user_id: Optional[str] = None, + config: Optional[ListSessionsConfig] = None, ) -> ListSessionsResponse: reasoning_engine_id = self._get_reasoning_engine_id(app_name) async with self._get_api_client() as api_client: sessions = [] - config = {} + api_config = {} + filter_parts = [] if user_id is not None: - config['filter'] = f'user_id="{user_id}"' + filter_parts.append(f'user_id="{user_id}"') + # Add labels filter if specified + if config and config.labels: + for key, value in config.labels.items(): + filter_parts.append(f'labels.{key}="{value}"') + if filter_parts: + api_config['filter'] = ' AND '.join(filter_parts) + sessions_iterator = await api_client.agent_engines.sessions.list( name=f'reasoningEngines/{reasoning_engine_id}', - config=config, + config=api_config, ) for api_session in sessions_iterator: @@ -223,6 +248,8 @@ async def list_sessions( id=api_session.name.split('/')[-1], state=getattr(api_session, 'session_state', None) or {}, last_update_time=api_session.update_time.timestamp(), + display_name=getattr(api_session, 'display_name', None), + labels=getattr(api_session, 'labels', None) or {}, ) ) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 96d2f38726..0c1aee6ca4 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -603,3 +603,123 @@ async def test_partial_events_are_not_persisted(session_service): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +# === Tests for display_name and labels === + + +@pytest.mark.asyncio +async def test_create_session_with_display_name_and_labels(session_service): + app_name = 'my_app' + user_id = 'test_user' + display_name = 'Test Session' + labels = {'env': 'test', 'team': 'ai'} + + session = await session_service.create_session( + app_name=app_name, + user_id=user_id, + display_name=display_name, + labels=labels, + ) + + assert session.display_name == display_name + assert session.labels == labels + + # Verify persisted session + got_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got_session.display_name == display_name + assert got_session.labels == labels + + +@pytest.mark.asyncio +async def test_list_sessions_filter_by_labels(session_service): + from google.adk.sessions.base_session_service import ListSessionsConfig + + app_name = 'my_app' + user_id = 'test_user' + + # Create sessions with different labels + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id='dev_session', + display_name='Dev Session', + labels={'env': 'dev', 'version': '1.0'}, + ) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id='prod_session', + display_name='Prod Session', + labels={'env': 'prod', 'version': '2.0'}, + ) + await session_service.create_session( + app_name=app_name, + user_id=user_id, + session_id='test_session', + display_name='Test Session', + labels={'env': 'dev', 'version': '2.0'}, + ) + + # List all sessions (no filter) + all_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + assert len(all_response.sessions) == 3 + + # Filter by env=dev + config = ListSessionsConfig(labels={'env': 'dev'}) + dev_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, config=config + ) + assert len(dev_response.sessions) == 2 + assert {s.id for s in dev_response.sessions} == { + 'dev_session', + 'test_session', + } + + # Filter by env=prod + config = ListSessionsConfig(labels={'env': 'prod'}) + prod_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, config=config + ) + assert len(prod_response.sessions) == 1 + assert prod_response.sessions[0].id == 'prod_session' + + # Filter by multiple labels + config = ListSessionsConfig(labels={'env': 'dev', 'version': '2.0'}) + filtered_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, config=config + ) + assert len(filtered_response.sessions) == 1 + assert filtered_response.sessions[0].id == 'test_session' + + # Filter by non-existent label value + config = ListSessionsConfig(labels={'env': 'staging'}) + empty_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id, config=config + ) + assert len(empty_response.sessions) == 0 + + +@pytest.mark.asyncio +async def test_session_default_display_name_and_labels(session_service): + app_name = 'my_app' + user_id = 'test_user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + + # Default values + assert session.display_name is None + assert session.labels == {} + + # Verify persisted session + got_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got_session.display_name is None + assert got_session.labels == {}