diff --git a/.gitignore b/.gitignore index 884098a30..e37e092cc 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,8 @@ datajoint.json # Test outputs *_test_summary.txt + +# Swap files +*.swp +*.swo +*~ diff --git a/docker-compose.yaml b/docker-compose.yaml index 2c48ffd10..23fd773c1 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -24,6 +24,19 @@ services: timeout: 30s retries: 5 interval: 15s + postgres: + image: postgres:${POSTGRES_VER:-15} + environment: + - POSTGRES_PASSWORD=${PG_PASS:-password} + - POSTGRES_USER=${PG_USER:-postgres} + - POSTGRES_DB=${PG_DB:-test} + ports: + - "5432:5432" + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U postgres" ] + timeout: 30s + retries: 5 + interval: 15s minio: image: minio/minio:${MINIO_VER:-RELEASE.2025-02-28T09-55-16Z} environment: @@ -52,6 +65,8 @@ services: depends_on: db: condition: service_healthy + postgres: + condition: service_healthy minio: condition: service_healthy environment: @@ -61,6 +76,10 @@ services: - DJ_TEST_HOST=db - DJ_TEST_USER=datajoint - DJ_TEST_PASSWORD=datajoint + - DJ_PG_HOST=postgres + - DJ_PG_USER=postgres + - DJ_PG_PASS=password + - DJ_PG_PORT=5432 - S3_ENDPOINT=minio:9000 - S3_ACCESS_KEY=datajoint - S3_SECRET_KEY=datajoint diff --git a/pyproject.toml b/pyproject.toml index ab7603535..20832342b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,9 @@ [build-system] -requires = ["setuptools>=60"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "datajoint" -# dynamically set in tools.setuptools.dynamic dynamic = ["version"] dependencies = [ "numpy", @@ -74,7 +73,7 @@ Repository = "https://github.com/datajoint/datajoint-python" "Bug Tracker" = "https://github.com/datajoint/datajoint-python/issues" "Release Notes" = "https://github.com/datajoint/datajoint-python/releases" -[project.entry-points."console_scripts"] +[project.scripts] dj = "datajoint.cli:cli" datajoint = "datajoint.cli:cli" @@ -87,7 +86,7 @@ test = [ "matplotlib", "ipython", "graphviz", - "testcontainers[mysql,minio]>=4.0", + "testcontainers[mysql,minio,postgres]>=4.0", "polars>=0.20.0", "pyarrow>=14.0.0", ] @@ -96,6 +95,7 @@ test = [ s3 = ["s3fs>=2023.1.0"] gcs = ["gcsfs>=2023.1.0"] azure = ["adlfs>=2023.1.0"] +postgres = ["psycopg2-binary>=2.9.0"] polars = ["polars>=0.20.0"] arrow = ["pyarrow>=14.0.0"] viz = ["matplotlib", "ipython"] @@ -107,7 +107,8 @@ test = [ "matplotlib", "ipython", "s3fs>=2023.1.0", - "testcontainers[mysql,minio]>=4.0", + "testcontainers[mysql,minio,postgres]>=4.0", + "psycopg2-binary>=2.9.0", "polars>=0.20.0", "pyarrow>=14.0.0", ] @@ -211,12 +212,11 @@ module = [ ] ignore_errors = true -[tool.setuptools] -packages = ["datajoint"] -package-dir = {"" = "src"} +[tool.hatch.version] +path = "src/datajoint/version.py" -[tool.setuptools.dynamic] -version = { attr = "datajoint.version.__version__"} +[tool.hatch.build.targets.wheel] +packages = ["src/datajoint"] [tool.codespell] skip = ".git,*.pdf,*.svg,*.csv,*.ipynb,*.drawio" @@ -229,6 +229,9 @@ ignore-words-list = "rever,numer,astroid" markers = [ "requires_mysql: marks tests as requiring MySQL database (deselect with '-m \"not requires_mysql\"')", "requires_minio: marks tests as requiring MinIO object storage (deselect with '-m \"not requires_minio\"')", + "mysql: marks tests that run on MySQL backend (select with '-m mysql')", + "postgresql: marks tests that run on PostgreSQL backend (select with '-m postgresql')", + "backend_agnostic: marks tests that should pass on all backends (auto-marked for parameterized tests)", ] diff --git a/src/datajoint/adapters/__init__.py b/src/datajoint/adapters/__init__.py new file mode 100644 index 000000000..5115a982a --- /dev/null +++ b/src/datajoint/adapters/__init__.py @@ -0,0 +1,54 @@ +""" +Database adapter registry for DataJoint. + +This module provides the adapter factory function and exports all adapters. +""" + +from __future__ import annotations + +from .base import DatabaseAdapter +from .mysql import MySQLAdapter +from .postgres import PostgreSQLAdapter + +__all__ = ["DatabaseAdapter", "MySQLAdapter", "PostgreSQLAdapter", "get_adapter"] + +# Adapter registry mapping backend names to adapter classes +ADAPTERS: dict[str, type[DatabaseAdapter]] = { + "mysql": MySQLAdapter, + "postgresql": PostgreSQLAdapter, + "postgres": PostgreSQLAdapter, # Alias for postgresql +} + + +def get_adapter(backend: str) -> DatabaseAdapter: + """ + Get adapter instance for the specified database backend. + + Parameters + ---------- + backend : str + Backend name: 'mysql', 'postgresql', or 'postgres'. + + Returns + ------- + DatabaseAdapter + Adapter instance for the specified backend. + + Raises + ------ + ValueError + If the backend is not supported. + + Examples + -------- + >>> from datajoint.adapters import get_adapter + >>> mysql_adapter = get_adapter('mysql') + >>> postgres_adapter = get_adapter('postgresql') + """ + backend_lower = backend.lower() + + if backend_lower not in ADAPTERS: + supported = sorted(set(ADAPTERS.keys())) + raise ValueError(f"Unknown database backend: {backend}. " f"Supported backends: {', '.join(supported)}") + + return ADAPTERS[backend_lower]() diff --git a/src/datajoint/adapters/base.py b/src/datajoint/adapters/base.py new file mode 100644 index 000000000..35b32ed5f --- /dev/null +++ b/src/datajoint/adapters/base.py @@ -0,0 +1,1169 @@ +""" +Abstract base class for database backend adapters. + +This module defines the interface that all database adapters must implement +to support multiple database backends (MySQL, PostgreSQL, etc.) in DataJoint. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class DatabaseAdapter(ABC): + """ + Abstract base class for database backend adapters. + + Adapters provide database-specific implementations for SQL generation, + type mapping, error translation, and connection management. + """ + + # ========================================================================= + # Connection Management + # ========================================================================= + + @abstractmethod + def connect( + self, + host: str, + port: int, + user: str, + password: str, + **kwargs: Any, + ) -> Any: + """ + Establish database connection. + + Parameters + ---------- + host : str + Database server hostname. + port : int + Database server port. + user : str + Username for authentication. + password : str + Password for authentication. + **kwargs : Any + Additional backend-specific connection parameters. + + Returns + ------- + Any + Database connection object (backend-specific). + """ + ... + + @abstractmethod + def close(self, connection: Any) -> None: + """ + Close the database connection. + + Parameters + ---------- + connection : Any + Database connection object to close. + """ + ... + + @abstractmethod + def ping(self, connection: Any) -> bool: + """ + Check if connection is alive. + + Parameters + ---------- + connection : Any + Database connection object to check. + + Returns + ------- + bool + True if connection is alive, False otherwise. + """ + ... + + @abstractmethod + def get_connection_id(self, connection: Any) -> int: + """ + Get the current connection/backend process ID. + + Parameters + ---------- + connection : Any + Database connection object. + + Returns + ------- + int + Connection or process ID. + """ + ... + + @property + @abstractmethod + def default_port(self) -> int: + """ + Default port for this database backend. + + Returns + ------- + int + Default port number (3306 for MySQL, 5432 for PostgreSQL). + """ + ... + + @property + @abstractmethod + def backend(self) -> str: + """ + Backend identifier string. + + Returns + ------- + str + Backend name: 'mysql' or 'postgresql'. + """ + ... + + @abstractmethod + def get_cursor(self, connection: Any, as_dict: bool = False) -> Any: + """ + Get a cursor from the database connection. + + Parameters + ---------- + connection : Any + Database connection object. + as_dict : bool, optional + If True, return cursor that yields rows as dictionaries. + If False, return cursor that yields rows as tuples. + Default False. + + Returns + ------- + Any + Database cursor object (backend-specific). + """ + ... + + # ========================================================================= + # SQL Syntax + # ========================================================================= + + @abstractmethod + def quote_identifier(self, name: str) -> str: + """ + Quote an identifier (table/column name) for this backend. + + Parameters + ---------- + name : str + Identifier to quote. + + Returns + ------- + str + Quoted identifier (e.g., `name` for MySQL, "name" for PostgreSQL). + """ + ... + + @abstractmethod + def quote_string(self, value: str) -> str: + """ + Quote a string literal for this backend. + + Parameters + ---------- + value : str + String value to quote. + + Returns + ------- + str + Quoted string literal with proper escaping. + """ + ... + + @abstractmethod + def get_master_table_name(self, part_table: str) -> str | None: + """ + Extract master table name from a part table name. + + Parameters + ---------- + part_table : str + Full table name (e.g., `schema`.`master__part` for MySQL, + "schema"."master__part" for PostgreSQL). + + Returns + ------- + str or None + Master table name if part_table is a part table, None otherwise. + """ + ... + + @property + @abstractmethod + def parameter_placeholder(self) -> str: + """ + Parameter placeholder style for this backend. + + Returns + ------- + str + Placeholder string (e.g., '%s' for MySQL/psycopg2, '?' for SQLite). + """ + ... + + # ========================================================================= + # Type Mapping + # ========================================================================= + + @abstractmethod + def core_type_to_sql(self, core_type: str) -> str: + """ + Convert a DataJoint core type to backend SQL type. + + Parameters + ---------- + core_type : str + DataJoint core type (e.g., 'int64', 'float32', 'uuid'). + + Returns + ------- + str + Backend SQL type (e.g., 'bigint', 'float', 'binary(16)'). + + Raises + ------ + ValueError + If core_type is not a valid DataJoint core type. + """ + ... + + @abstractmethod + def sql_type_to_core(self, sql_type: str) -> str | None: + """ + Convert a backend SQL type to DataJoint core type (if mappable). + + Parameters + ---------- + sql_type : str + Backend SQL type. + + Returns + ------- + str or None + DataJoint core type if mappable, None otherwise. + """ + ... + + # ========================================================================= + # DDL Generation + # ========================================================================= + + @abstractmethod + def create_schema_sql(self, schema_name: str) -> str: + """ + Generate CREATE SCHEMA/DATABASE statement. + + Parameters + ---------- + schema_name : str + Name of schema/database to create. + + Returns + ------- + str + CREATE SCHEMA/DATABASE SQL statement. + """ + ... + + @abstractmethod + def drop_schema_sql(self, schema_name: str, if_exists: bool = True) -> str: + """ + Generate DROP SCHEMA/DATABASE statement. + + Parameters + ---------- + schema_name : str + Name of schema/database to drop. + if_exists : bool, optional + Include IF EXISTS clause. Default True. + + Returns + ------- + str + DROP SCHEMA/DATABASE SQL statement. + """ + ... + + @abstractmethod + def create_table_sql( + self, + table_name: str, + columns: list[dict[str, Any]], + primary_key: list[str], + foreign_keys: list[dict[str, Any]], + indexes: list[dict[str, Any]], + comment: str | None = None, + ) -> str: + """ + Generate CREATE TABLE statement. + + Parameters + ---------- + table_name : str + Name of table to create. + columns : list[dict] + Column definitions with keys: name, type, nullable, default, comment. + primary_key : list[str] + List of primary key column names. + foreign_keys : list[dict] + Foreign key definitions with keys: columns, ref_table, ref_columns. + indexes : list[dict] + Index definitions with keys: columns, unique. + comment : str, optional + Table comment. + + Returns + ------- + str + CREATE TABLE SQL statement. + """ + ... + + @abstractmethod + def drop_table_sql(self, table_name: str, if_exists: bool = True) -> str: + """ + Generate DROP TABLE statement. + + Parameters + ---------- + table_name : str + Name of table to drop. + if_exists : bool, optional + Include IF EXISTS clause. Default True. + + Returns + ------- + str + DROP TABLE SQL statement. + """ + ... + + @abstractmethod + def alter_table_sql( + self, + table_name: str, + add_columns: list[dict[str, Any]] | None = None, + drop_columns: list[str] | None = None, + modify_columns: list[dict[str, Any]] | None = None, + ) -> str: + """ + Generate ALTER TABLE statement. + + Parameters + ---------- + table_name : str + Name of table to alter. + add_columns : list[dict], optional + Columns to add with keys: name, type, nullable, default, comment. + drop_columns : list[str], optional + Column names to drop. + modify_columns : list[dict], optional + Columns to modify with keys: name, type, nullable, default, comment. + + Returns + ------- + str + ALTER TABLE SQL statement. + """ + ... + + @abstractmethod + def add_comment_sql( + self, + object_type: str, + object_name: str, + comment: str, + ) -> str | None: + """ + Generate comment statement (may be None if embedded in CREATE). + + Parameters + ---------- + object_type : str + Type of object ('table', 'column'). + object_name : str + Fully qualified object name. + comment : str + Comment text. + + Returns + ------- + str or None + COMMENT statement, or None if comments are inline in CREATE. + """ + ... + + # ========================================================================= + # DML Generation + # ========================================================================= + + @abstractmethod + def insert_sql( + self, + table_name: str, + columns: list[str], + on_duplicate: str | None = None, + ) -> str: + """ + Generate INSERT statement. + + Parameters + ---------- + table_name : str + Name of table to insert into. + columns : list[str] + Column names to insert. + on_duplicate : str, optional + Duplicate handling: 'ignore', 'replace', 'update', or None. + + Returns + ------- + str + INSERT SQL statement with parameter placeholders. + """ + ... + + @abstractmethod + def update_sql( + self, + table_name: str, + set_columns: list[str], + where_columns: list[str], + ) -> str: + """ + Generate UPDATE statement. + + Parameters + ---------- + table_name : str + Name of table to update. + set_columns : list[str] + Column names to set. + where_columns : list[str] + Column names for WHERE clause. + + Returns + ------- + str + UPDATE SQL statement with parameter placeholders. + """ + ... + + @abstractmethod + def delete_sql(self, table_name: str) -> str: + """ + Generate DELETE statement (WHERE clause added separately). + + Parameters + ---------- + table_name : str + Name of table to delete from. + + Returns + ------- + str + DELETE SQL statement without WHERE clause. + """ + ... + + @abstractmethod + def upsert_on_duplicate_sql( + self, + table_name: str, + columns: list[str], + primary_key: list[str], + num_rows: int, + ) -> str: + """ + Generate INSERT ... ON DUPLICATE KEY UPDATE (MySQL) or + INSERT ... ON CONFLICT ... DO UPDATE (PostgreSQL) statement. + + Parameters + ---------- + table_name : str + Fully qualified table name (with quotes). + columns : list[str] + Column names to insert (unquoted). + primary_key : list[str] + Primary key column names (unquoted) for conflict detection. + num_rows : int + Number of rows to insert (for generating placeholders). + + Returns + ------- + str + Upsert SQL statement with placeholders. + + Examples + -------- + MySQL: + INSERT INTO `table` (a, b, c) VALUES (%s, %s, %s), (%s, %s, %s) + ON DUPLICATE KEY UPDATE a = VALUES(a), b = VALUES(b), c = VALUES(c) + + PostgreSQL: + INSERT INTO "table" (a, b, c) VALUES (%s, %s, %s), (%s, %s, %s) + ON CONFLICT (a) DO UPDATE SET b = EXCLUDED.b, c = EXCLUDED.c + """ + ... + + @abstractmethod + def skip_duplicates_clause( + self, + full_table_name: str, + primary_key: list[str], + ) -> str: + """ + Generate clause to skip duplicate key insertions. + + For MySQL: ON DUPLICATE KEY UPDATE pk=table.pk (no-op update) + For PostgreSQL: ON CONFLICT (pk_cols) DO NOTHING + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). + primary_key : list[str] + Primary key column names (unquoted). + + Returns + ------- + str + SQL clause to append to INSERT statement. + """ + ... + + @property + def supports_inline_indexes(self) -> bool: + """ + Whether this backend supports inline INDEX in CREATE TABLE. + + MySQL supports inline index definitions in CREATE TABLE. + PostgreSQL requires separate CREATE INDEX statements. + + Returns + ------- + bool + True for MySQL, False for PostgreSQL. + """ + return True # Default for MySQL, override in PostgreSQL + + def create_index_ddl( + self, + full_table_name: str, + columns: list[str], + unique: bool = False, + index_name: str | None = None, + ) -> str: + """ + Generate CREATE INDEX statement. + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). + columns : list[str] + Column names to index (unquoted). + unique : bool, optional + If True, create a unique index. + index_name : str, optional + Custom index name. If None, auto-generate from table/columns. + + Returns + ------- + str + CREATE INDEX SQL statement. + """ + quoted_cols = ", ".join(self.quote_identifier(col) for col in columns) + # Generate index name from table and columns if not provided + if index_name is None: + # Extract table name from full_table_name for index naming + table_part = full_table_name.split(".")[-1].strip('`"') + col_part = "_".join(columns)[:30] # Truncate for long column lists + index_name = f"idx_{table_part}_{col_part}" + unique_clause = "UNIQUE " if unique else "" + return f"CREATE {unique_clause}INDEX {self.quote_identifier(index_name)} ON {full_table_name} ({quoted_cols})" + + # ========================================================================= + # Introspection + # ========================================================================= + + @abstractmethod + def list_schemas_sql(self) -> str: + """ + Generate query to list all schemas/databases. + + Returns + ------- + str + SQL query to list schemas. + """ + ... + + @abstractmethod + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: + """ + Generate query to list tables in a schema. + + Parameters + ---------- + schema_name : str + Name of schema to list tables from. + pattern : str, optional + LIKE pattern to filter table names. Use %% for % in SQL. + + Returns + ------- + str + SQL query to list tables. + """ + ... + + @abstractmethod + def get_table_info_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get table metadata (comment, engine, etc.). + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get table info. + """ + ... + + @abstractmethod + def get_columns_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get column definitions. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get column definitions. + """ + ... + + @abstractmethod + def get_primary_key_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get primary key columns. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get primary key columns. + """ + ... + + @abstractmethod + def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get foreign key constraints. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get foreign key constraints. + """ + ... + + @abstractmethod + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: + """ + Generate query to get foreign key constraint details from information_schema. + + Used during cascade delete to determine FK columns when error message + doesn't provide full details. + + Parameters + ---------- + constraint_name : str + Name of the foreign key constraint. + schema_name : str + Schema/database name of the child table. + table_name : str + Name of the child table. + + Returns + ------- + str + SQL query that returns rows with columns: + - fk_attrs: foreign key column name in child table + - parent: parent table name (quoted, with schema) + - pk_attrs: referenced column name in parent table + """ + ... + + @abstractmethod + def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str] | None] | None: + """ + Parse a foreign key violation error message to extract constraint details. + + Used during cascade delete to identify which child table is preventing + deletion and what columns are involved. + + Parameters + ---------- + error_message : str + The error message from a foreign key constraint violation. + + Returns + ------- + dict or None + Dictionary with keys if successfully parsed: + - child: child table name (quoted with schema if available) + - name: constraint name (quoted) + - fk_attrs: list of foreign key column names (may be None if not in message) + - parent: parent table name (quoted, may be None if not in message) + - pk_attrs: list of parent key column names (may be None if not in message) + + Returns None if error message doesn't match FK violation pattern. + + Examples + -------- + MySQL error: + "Cannot delete or update a parent row: a foreign key constraint fails + (`schema`.`child`, CONSTRAINT `fk_name` FOREIGN KEY (`child_col`) + REFERENCES `parent` (`parent_col`))" + + PostgreSQL error: + "update or delete on table \"parent\" violates foreign key constraint + \"child_parent_id_fkey\" on table \"child\" + DETAIL: Key (parent_id)=(1) is still referenced from table \"child\"." + """ + ... + + @abstractmethod + def get_indexes_sql(self, schema_name: str, table_name: str) -> str: + """ + Generate query to get index definitions. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query to get index definitions. + """ + ... + + @abstractmethod + def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]: + """ + Parse a column info row into standardized format. + + Parameters + ---------- + row : dict + Raw column info row from database introspection query. + + Returns + ------- + dict + Standardized column info with keys: name, type, nullable, + default, comment, etc. + """ + ... + + # ========================================================================= + # Transactions + # ========================================================================= + + @abstractmethod + def start_transaction_sql(self, isolation_level: str | None = None) -> str: + """ + Generate START TRANSACTION statement. + + Parameters + ---------- + isolation_level : str, optional + Transaction isolation level. + + Returns + ------- + str + START TRANSACTION SQL statement. + """ + ... + + @abstractmethod + def commit_sql(self) -> str: + """ + Generate COMMIT statement. + + Returns + ------- + str + COMMIT SQL statement. + """ + ... + + @abstractmethod + def rollback_sql(self) -> str: + """ + Generate ROLLBACK statement. + + Returns + ------- + str + ROLLBACK SQL statement. + """ + ... + + # ========================================================================= + # Functions and Expressions + # ========================================================================= + + @abstractmethod + def current_timestamp_expr(self, precision: int | None = None) -> str: + """ + Expression for current timestamp. + + Parameters + ---------- + precision : int, optional + Fractional seconds precision (0-6). + + Returns + ------- + str + SQL expression for current timestamp. + """ + ... + + @abstractmethod + def interval_expr(self, value: int, unit: str) -> str: + """ + Expression for time interval. + + Parameters + ---------- + value : int + Interval value. + unit : str + Time unit ('second', 'minute', 'hour', 'day', etc.). + + Returns + ------- + str + SQL expression for interval (e.g., 'INTERVAL 5 SECOND' for MySQL, + "INTERVAL '5 seconds'" for PostgreSQL). + """ + ... + + @abstractmethod + def current_user_expr(self) -> str: + """ + SQL expression to get the current user. + + Returns + ------- + str + SQL expression for current user (e.g., 'user()' for MySQL, + 'current_user' for PostgreSQL). + """ + ... + + @abstractmethod + def json_path_expr(self, column: str, path: str, return_type: str | None = None) -> str: + """ + Generate JSON path extraction expression. + + Parameters + ---------- + column : str + Column name containing JSON data. + path : str + JSON path (e.g., 'field' or 'nested.field'). + return_type : str, optional + Return type specification (MySQL-specific). + + Returns + ------- + str + Database-specific JSON extraction SQL expression. + + Examples + -------- + MySQL: json_value(`column`, _utf8mb4'$.path' returning type) + PostgreSQL: jsonb_extract_path_text("column", 'path_part1', 'path_part2') + """ + ... + + def translate_expression(self, expr: str) -> str: + """ + Translate SQL expression for backend compatibility. + + Converts database-specific function calls to the equivalent syntax + for the current backend. This enables portable DataJoint code that + uses common aggregate functions. + + Translations performed: + - GROUP_CONCAT(col) ↔ STRING_AGG(col, ',') + + Parameters + ---------- + expr : str + SQL expression that may contain function calls. + + Returns + ------- + str + Translated expression for the current backend. + + Notes + ----- + The base implementation returns the expression unchanged. + Subclasses override to provide backend-specific translations. + """ + return expr + + # ========================================================================= + # DDL Generation + # ========================================================================= + + @abstractmethod + def format_column_definition( + self, + name: str, + sql_type: str, + nullable: bool = False, + default: str | None = None, + comment: str | None = None, + ) -> str: + """ + Format a column definition for DDL. + + Parameters + ---------- + name : str + Column name. + sql_type : str + SQL type (already backend-specific, e.g., 'bigint', 'varchar(255)'). + nullable : bool, optional + Whether column is nullable. Default False. + default : str | None, optional + Default value expression (e.g., 'NULL', '"value"', 'CURRENT_TIMESTAMP'). + comment : str | None, optional + Column comment. + + Returns + ------- + str + Formatted column definition (without trailing comma). + + Examples + -------- + MySQL: `name` bigint NOT NULL COMMENT "user ID" + PostgreSQL: "name" bigint NOT NULL + """ + ... + + @abstractmethod + def table_options_clause(self, comment: str | None = None) -> str: + """ + Generate table options clause (ENGINE, etc.) for CREATE TABLE. + + Parameters + ---------- + comment : str | None, optional + Table-level comment. + + Returns + ------- + str + Table options clause (e.g., 'ENGINE=InnoDB, COMMENT "..."' for MySQL). + + Examples + -------- + MySQL: ENGINE=InnoDB, COMMENT "experiment sessions" + PostgreSQL: (empty string, comments handled separately) + """ + ... + + @abstractmethod + def table_comment_ddl(self, full_table_name: str, comment: str) -> str | None: + """ + Generate DDL for table-level comment (if separate from CREATE TABLE). + + Parameters + ---------- + full_table_name : str + Fully qualified table name (quoted). + comment : str + Table comment. + + Returns + ------- + str or None + DDL statement for table comment, or None if handled inline. + + Examples + -------- + MySQL: None (inline) + PostgreSQL: COMMENT ON TABLE "schema"."table" IS 'comment text' + """ + ... + + @abstractmethod + def column_comment_ddl(self, full_table_name: str, column_name: str, comment: str) -> str | None: + """ + Generate DDL for column-level comment (if separate from CREATE TABLE). + + Parameters + ---------- + full_table_name : str + Fully qualified table name (quoted). + column_name : str + Column name (unquoted). + comment : str + Column comment. + + Returns + ------- + str or None + DDL statement for column comment, or None if handled inline. + + Examples + -------- + MySQL: None (inline) + PostgreSQL: COMMENT ON COLUMN "schema"."table"."column" IS 'comment text' + """ + ... + + @abstractmethod + def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None: + """ + Generate DDL for enum type definition (if needed before CREATE TABLE). + + Parameters + ---------- + type_name : str + Enum type name. + values : list[str] + Enum values. + + Returns + ------- + str or None + DDL statement for enum type, or None if handled inline. + + Examples + -------- + MySQL: None (inline enum('val1', 'val2')) + PostgreSQL: CREATE TYPE "type_name" AS ENUM ('val1', 'val2') + """ + ... + + @abstractmethod + def job_metadata_columns(self) -> list[str]: + """ + Return job metadata column definitions for Computed/Imported tables. + + Returns + ------- + list[str] + List of column definition strings (fully formatted with quotes). + + Examples + -------- + MySQL: + ["`_job_start_time` datetime(3) DEFAULT NULL", + "`_job_duration` float DEFAULT NULL", + "`_job_version` varchar(64) DEFAULT ''"] + PostgreSQL: + ['"_job_start_time" timestamp DEFAULT NULL', + '"_job_duration" real DEFAULT NULL', + '"_job_version" varchar(64) DEFAULT \'\''] + """ + ... + + # ========================================================================= + # Error Translation + # ========================================================================= + + @abstractmethod + def translate_error(self, error: Exception, query: str = "") -> Exception: + """ + Translate backend-specific error to DataJoint error. + + Parameters + ---------- + error : Exception + Backend-specific exception. + + Returns + ------- + Exception + DataJoint exception or original error if no mapping exists. + """ + ... + + # ========================================================================= + # Native Type Validation + # ========================================================================= + + @abstractmethod + def validate_native_type(self, type_str: str) -> bool: + """ + Check if a native type string is valid for this backend. + + Parameters + ---------- + type_str : str + Native type string to validate. + + Returns + ------- + bool + True if valid for this backend, False otherwise. + """ + ... diff --git a/src/datajoint/adapters/mysql.py b/src/datajoint/adapters/mysql.py new file mode 100644 index 000000000..88339335f --- /dev/null +++ b/src/datajoint/adapters/mysql.py @@ -0,0 +1,1094 @@ +""" +MySQL database adapter for DataJoint. + +This module provides MySQL-specific implementations for SQL generation, +type mapping, error translation, and connection management. +""" + +from __future__ import annotations + +from typing import Any + +import pymysql as client + +from .. import errors +from .base import DatabaseAdapter + +# Core type mapping: DataJoint core types → MySQL types +CORE_TYPE_MAP = { + "int64": "bigint", + "int32": "int", + "int16": "smallint", + "int8": "tinyint", + "float32": "float", + "float64": "double", + "bool": "tinyint", + "uuid": "binary(16)", + "bytes": "longblob", + "json": "json", + "date": "date", + # datetime, char, varchar, decimal, enum require parameters - handled in method +} + +# Reverse mapping: MySQL types → DataJoint core types (for introspection) +SQL_TO_CORE_MAP = { + "bigint": "int64", + "int": "int32", + "smallint": "int16", + "tinyint": "int8", # Could be bool, need context + "float": "float32", + "double": "float64", + "binary(16)": "uuid", + "longblob": "bytes", + "json": "json", + "date": "date", +} + + +class MySQLAdapter(DatabaseAdapter): + """MySQL database adapter implementation.""" + + # ========================================================================= + # Connection Management + # ========================================================================= + + def connect( + self, + host: str, + port: int, + user: str, + password: str, + **kwargs: Any, + ) -> Any: + """ + Establish MySQL connection. + + Parameters + ---------- + host : str + MySQL server hostname. + port : int + MySQL server port. + user : str + Username for authentication. + password : str + Password for authentication. + **kwargs : Any + Additional MySQL-specific parameters: + - init_command: SQL initialization command + - ssl: TLS/SSL configuration dict (deprecated, use use_tls) + - use_tls: bool or dict - DataJoint's SSL parameter (preferred) + - charset: Character set (default from kwargs) + + Returns + ------- + pymysql.Connection + MySQL connection object. + """ + init_command = kwargs.get("init_command") + # Handle both ssl (old) and use_tls (new) parameter names + ssl_config = kwargs.get("use_tls", kwargs.get("ssl")) + # Convert boolean True to dict for PyMySQL (PyMySQL expects dict or SSLContext) + if ssl_config is True: + ssl_config = {} # Enable SSL with default settings + charset = kwargs.get("charset", "") + + # Prepare connection parameters + conn_params = { + "host": host, + "port": port, + "user": user, + "passwd": password, + "init_command": init_command, + "sql_mode": "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," + "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", + "charset": charset, + "autocommit": True, # DataJoint manages transactions explicitly + } + + # Handle SSL configuration + if ssl_config is False: + # Explicitly disable SSL + conn_params["ssl_disabled"] = True + elif ssl_config is not None: + # Enable SSL with config dict (can be empty for defaults) + conn_params["ssl"] = ssl_config + # Explicitly enable SSL by setting ssl_disabled=False + conn_params["ssl_disabled"] = False + + return client.connect(**conn_params) + + def close(self, connection: Any) -> None: + """Close the MySQL connection.""" + connection.close() + + def ping(self, connection: Any) -> bool: + """ + Check if MySQL connection is alive. + + Returns + ------- + bool + True if connection is alive. + """ + try: + connection.ping(reconnect=False) + return True + except Exception: + return False + + def get_connection_id(self, connection: Any) -> int: + """ + Get MySQL connection ID. + + Returns + ------- + int + MySQL connection_id(). + """ + cursor = connection.cursor() + cursor.execute("SELECT connection_id()") + return cursor.fetchone()[0] + + @property + def default_port(self) -> int: + """MySQL default port 3306.""" + return 3306 + + @property + def backend(self) -> str: + """Backend identifier: 'mysql'.""" + return "mysql" + + def get_cursor(self, connection: Any, as_dict: bool = False) -> Any: + """ + Get a cursor from MySQL connection. + + Parameters + ---------- + connection : Any + pymysql connection object. + as_dict : bool, optional + If True, return DictCursor that yields rows as dictionaries. + If False, return standard Cursor that yields rows as tuples. + Default False. + + Returns + ------- + Any + pymysql cursor object. + """ + import pymysql + + cursor_class = pymysql.cursors.DictCursor if as_dict else pymysql.cursors.Cursor + return connection.cursor(cursor=cursor_class) + + # ========================================================================= + # SQL Syntax + # ========================================================================= + + def quote_identifier(self, name: str) -> str: + """ + Quote identifier with backticks for MySQL. + + Parameters + ---------- + name : str + Identifier to quote. + + Returns + ------- + str + Backtick-quoted identifier: `name` + """ + return f"`{name}`" + + def quote_string(self, value: str) -> str: + """ + Quote string literal for MySQL with escaping. + + Parameters + ---------- + value : str + String value to quote. + + Returns + ------- + str + Quoted and escaped string literal. + """ + # Use pymysql's escape_string for proper escaping + escaped = client.converters.escape_string(value) + return f"'{escaped}'" + + def get_master_table_name(self, part_table: str) -> str | None: + """Extract master table name from part table (MySQL backtick format).""" + import re + + # MySQL format: `schema`.`master__part` + match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) + return match["master"] + "`" if match else None + + @property + def parameter_placeholder(self) -> str: + """MySQL/pymysql uses %s placeholders.""" + return "%s" + + # ========================================================================= + # Type Mapping + # ========================================================================= + + def core_type_to_sql(self, core_type: str) -> str: + """ + Convert DataJoint core type to MySQL type. + + Parameters + ---------- + core_type : str + DataJoint core type, possibly with parameters: + - int64, float32, bool, uuid, bytes, json, date + - datetime or datetime(n) + - char(n), varchar(n) + - decimal(p,s) + - enum('a','b','c') + + Returns + ------- + str + MySQL SQL type. + + Raises + ------ + ValueError + If core_type is not recognized. + """ + # Handle simple types without parameters + if core_type in CORE_TYPE_MAP: + return CORE_TYPE_MAP[core_type] + + # Handle parametrized types + if core_type.startswith("datetime"): + # datetime or datetime(precision) + return core_type # MySQL supports datetime(n) directly + + if core_type.startswith("char("): + # char(n) + return core_type + + if core_type.startswith("varchar("): + # varchar(n) + return core_type + + if core_type.startswith("decimal("): + # decimal(precision, scale) + return core_type + + if core_type.startswith("enum("): + # enum('value1', 'value2', ...) + return core_type + + raise ValueError(f"Unknown core type: {core_type}") + + def sql_type_to_core(self, sql_type: str) -> str | None: + """ + Convert MySQL type to DataJoint core type (if mappable). + + Parameters + ---------- + sql_type : str + MySQL SQL type. + + Returns + ------- + str or None + DataJoint core type if mappable, None otherwise. + """ + # Normalize type string (lowercase, strip spaces) + sql_type_lower = sql_type.lower().strip() + + # Direct mapping + if sql_type_lower in SQL_TO_CORE_MAP: + return SQL_TO_CORE_MAP[sql_type_lower] + + # Handle parametrized types + if sql_type_lower.startswith("datetime"): + return sql_type # Keep precision + + if sql_type_lower.startswith("char("): + return sql_type # Keep size + + if sql_type_lower.startswith("varchar("): + return sql_type # Keep size + + if sql_type_lower.startswith("decimal("): + return sql_type # Keep precision/scale + + if sql_type_lower.startswith("enum("): + return sql_type # Keep values + + # Not a mappable core type + return None + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def create_schema_sql(self, schema_name: str) -> str: + """ + Generate CREATE DATABASE statement for MySQL. + + Parameters + ---------- + schema_name : str + Database name. + + Returns + ------- + str + CREATE DATABASE SQL. + """ + return f"CREATE DATABASE {self.quote_identifier(schema_name)}" + + def drop_schema_sql(self, schema_name: str, if_exists: bool = True) -> str: + """ + Generate DROP DATABASE statement for MySQL. + + Parameters + ---------- + schema_name : str + Database name. + if_exists : bool + Include IF EXISTS clause. + + Returns + ------- + str + DROP DATABASE SQL. + """ + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP DATABASE {if_exists_clause}{self.quote_identifier(schema_name)}" + + def create_table_sql( + self, + table_name: str, + columns: list[dict[str, Any]], + primary_key: list[str], + foreign_keys: list[dict[str, Any]], + indexes: list[dict[str, Any]], + comment: str | None = None, + ) -> str: + """ + Generate CREATE TABLE statement for MySQL. + + Parameters + ---------- + table_name : str + Fully qualified table name (schema.table). + columns : list[dict] + Column defs: [{name, type, nullable, default, comment}, ...] + primary_key : list[str] + Primary key column names. + foreign_keys : list[dict] + FK defs: [{columns, ref_table, ref_columns}, ...] + indexes : list[dict] + Index defs: [{columns, unique}, ...] + comment : str, optional + Table comment. + + Returns + ------- + str + CREATE TABLE SQL statement. + """ + lines = [] + + # Column definitions + for col in columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + default = f" DEFAULT {col['default']}" if "default" in col else "" + col_comment = f" COMMENT {self.quote_string(col['comment'])}" if "comment" in col else "" + lines.append(f"{col_name} {col_type} {nullable}{default}{col_comment}") + + # Primary key + if primary_key: + pk_cols = ", ".join(self.quote_identifier(col) for col in primary_key) + lines.append(f"PRIMARY KEY ({pk_cols})") + + # Foreign keys + for fk in foreign_keys: + fk_cols = ", ".join(self.quote_identifier(col) for col in fk["columns"]) + ref_cols = ", ".join(self.quote_identifier(col) for col in fk["ref_columns"]) + lines.append( + f"FOREIGN KEY ({fk_cols}) REFERENCES {fk['ref_table']} ({ref_cols}) " f"ON UPDATE CASCADE ON DELETE RESTRICT" + ) + + # Indexes + for idx in indexes: + unique = "UNIQUE " if idx.get("unique", False) else "" + idx_cols = ", ".join(self.quote_identifier(col) for col in idx["columns"]) + lines.append(f"{unique}INDEX ({idx_cols})") + + # Assemble CREATE TABLE + table_def = ",\n ".join(lines) + comment_clause = f" COMMENT={self.quote_string(comment)}" if comment else "" + return f"CREATE TABLE IF NOT EXISTS {table_name} (\n {table_def}\n) ENGINE=InnoDB{comment_clause}" + + def drop_table_sql(self, table_name: str, if_exists: bool = True) -> str: + """Generate DROP TABLE statement for MySQL.""" + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP TABLE {if_exists_clause}{table_name}" + + def alter_table_sql( + self, + table_name: str, + add_columns: list[dict[str, Any]] | None = None, + drop_columns: list[str] | None = None, + modify_columns: list[dict[str, Any]] | None = None, + ) -> str: + """ + Generate ALTER TABLE statement for MySQL. + + Parameters + ---------- + table_name : str + Table name. + add_columns : list[dict], optional + Columns to add. + drop_columns : list[str], optional + Column names to drop. + modify_columns : list[dict], optional + Columns to modify. + + Returns + ------- + str + ALTER TABLE SQL statement. + """ + clauses = [] + + if add_columns: + for col in add_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + clauses.append(f"ADD {col_name} {col_type} {nullable}") + + if drop_columns: + for col_name in drop_columns: + clauses.append(f"DROP {self.quote_identifier(col_name)}") + + if modify_columns: + for col in modify_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + clauses.append(f"MODIFY {col_name} {col_type} {nullable}") + + return f"ALTER TABLE {table_name} {', '.join(clauses)}" + + def add_comment_sql( + self, + object_type: str, + object_name: str, + comment: str, + ) -> str | None: + """ + MySQL embeds comments in CREATE/ALTER, not separate statements. + + Returns None since comments are inline. + """ + return None + + # ========================================================================= + # DML Generation + # ========================================================================= + + def insert_sql( + self, + table_name: str, + columns: list[str], + on_duplicate: str | None = None, + ) -> str: + """ + Generate INSERT statement for MySQL. + + Parameters + ---------- + table_name : str + Table name. + columns : list[str] + Column names. + on_duplicate : str, optional + 'ignore', 'replace', or 'update'. + + Returns + ------- + str + INSERT SQL with placeholders. + """ + cols = ", ".join(self.quote_identifier(col) for col in columns) + placeholders = ", ".join([self.parameter_placeholder] * len(columns)) + + if on_duplicate == "ignore": + return f"INSERT IGNORE INTO {table_name} ({cols}) VALUES ({placeholders})" + elif on_duplicate == "replace": + return f"REPLACE INTO {table_name} ({cols}) VALUES ({placeholders})" + elif on_duplicate == "update": + # ON DUPLICATE KEY UPDATE col=VALUES(col) + updates = ", ".join(f"{self.quote_identifier(col)}=VALUES({self.quote_identifier(col)})" for col in columns) + return f"INSERT INTO {table_name} ({cols}) VALUES ({placeholders}) ON DUPLICATE KEY UPDATE {updates}" + else: + return f"INSERT INTO {table_name} ({cols}) VALUES ({placeholders})" + + def update_sql( + self, + table_name: str, + set_columns: list[str], + where_columns: list[str], + ) -> str: + """Generate UPDATE statement for MySQL.""" + set_clause = ", ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in set_columns) + where_clause = " AND ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in where_columns) + return f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" + + def delete_sql(self, table_name: str) -> str: + """Generate DELETE statement for MySQL (WHERE added separately).""" + return f"DELETE FROM {table_name}" + + def upsert_on_duplicate_sql( + self, + table_name: str, + columns: list[str], + primary_key: list[str], + num_rows: int, + ) -> str: + """Generate INSERT ... ON DUPLICATE KEY UPDATE statement for MySQL.""" + # Build column list + col_list = ", ".join(columns) + + # Build placeholders for VALUES + placeholders = ", ".join(["(%s)" % ", ".join(["%s"] * len(columns))] * num_rows) + + # Build UPDATE clause (all columns) + update_clauses = ", ".join(f"{col} = VALUES({col})" for col in columns) + + return f""" + INSERT INTO {table_name} ({col_list}) + VALUES {placeholders} + ON DUPLICATE KEY UPDATE {update_clauses} + """ + + def skip_duplicates_clause( + self, + full_table_name: str, + primary_key: list[str], + ) -> str: + """ + Generate clause to skip duplicate key insertions for MySQL. + + Uses ON DUPLICATE KEY UPDATE with a no-op update (pk=pk) to effectively + skip duplicates without raising an error. + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). + primary_key : list[str] + Primary key column names (unquoted). + + Returns + ------- + str + MySQL ON DUPLICATE KEY UPDATE clause. + """ + quoted_pk = self.quote_identifier(primary_key[0]) + return f" ON DUPLICATE KEY UPDATE {quoted_pk}={full_table_name}.{quoted_pk}" + + # ========================================================================= + # Introspection + # ========================================================================= + + def list_schemas_sql(self) -> str: + """Query to list all databases in MySQL.""" + return "SELECT schema_name FROM information_schema.schemata" + + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: + """Query to list tables in a database.""" + sql = f"SHOW TABLES IN {self.quote_identifier(schema_name)}" + if pattern: + sql += f" LIKE '{pattern}'" + return sql + + def get_table_info_sql(self, schema_name: str, table_name: str) -> str: + """Query to get table metadata (comment, engine, etc.).""" + return ( + f"SELECT * FROM information_schema.tables " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)}" + ) + + def get_columns_sql(self, schema_name: str, table_name: str) -> str: + """Query to get column definitions.""" + return f"SHOW FULL COLUMNS FROM {self.quote_identifier(table_name)} IN {self.quote_identifier(schema_name)}" + + def get_primary_key_sql(self, schema_name: str, table_name: str) -> str: + """Query to get primary key columns.""" + return ( + f"SELECT COLUMN_NAME as column_name FROM information_schema.key_column_usage " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND constraint_name = 'PRIMARY' " + f"ORDER BY ordinal_position" + ) + + def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: + """Query to get foreign key constraints.""" + return ( + f"SELECT CONSTRAINT_NAME as constraint_name, COLUMN_NAME as column_name, " + f"REFERENCED_TABLE_NAME as referenced_table_name, REFERENCED_COLUMN_NAME as referenced_column_name " + f"FROM information_schema.key_column_usage " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND referenced_table_name IS NOT NULL " + f"ORDER BY constraint_name, ordinal_position" + ) + + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: + """Query to get FK constraint details from information_schema.""" + return ( + "SELECT " + " COLUMN_NAME as fk_attrs, " + " CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent, " + " REFERENCED_COLUMN_NAME as pk_attrs " + "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE " + "WHERE CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s" + ) + + def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str] | None] | None: + """Parse MySQL foreign key violation error message.""" + import re + + # MySQL FK error pattern with backticks + pattern = re.compile( + r"[\w\s:]*\((?P`[^`]+`.`[^`]+`), " + r"CONSTRAINT (?P`[^`]+`) " + r"(FOREIGN KEY \((?P[^)]+)\) " + r"REFERENCES (?P`[^`]+`(\.`[^`]+`)?) \((?P[^)]+)\)[\s\w]+\))?" + ) + + match = pattern.match(error_message) + if not match: + return None + + result = match.groupdict() + + # Parse comma-separated FK attrs if present + if result.get("fk_attrs"): + result["fk_attrs"] = [col.strip("`") for col in result["fk_attrs"].split(",")] + # Parse comma-separated PK attrs if present + if result.get("pk_attrs"): + result["pk_attrs"] = [col.strip("`") for col in result["pk_attrs"].split(",")] + + return result + + def get_indexes_sql(self, schema_name: str, table_name: str) -> str: + """Query to get index definitions. + + Note: For MySQL 8.0+, EXPRESSION column contains the expression for + functional indexes. COLUMN_NAME is NULL for such indexes. + """ + return ( + f"SELECT INDEX_NAME as index_name, " + f"COALESCE(COLUMN_NAME, CONCAT('(', EXPRESSION, ')')) as column_name, " + f"NON_UNIQUE as non_unique, SEQ_IN_INDEX as seq_in_index " + f"FROM information_schema.statistics " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND index_name != 'PRIMARY' " + f"ORDER BY index_name, seq_in_index" + ) + + def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]: + """ + Parse MySQL SHOW FULL COLUMNS output into standardized format. + + Parameters + ---------- + row : dict + Row from SHOW FULL COLUMNS query. + + Returns + ------- + dict + Standardized column info with keys: + name, type, nullable, default, comment, key, extra + """ + return { + "name": row["Field"], + "type": row["Type"], + "nullable": row["Null"] == "YES", + "default": row["Default"], + "comment": row["Comment"], + "key": row["Key"], # PRI, UNI, MUL + "extra": row["Extra"], # auto_increment, etc. + } + + # ========================================================================= + # Transactions + # ========================================================================= + + def start_transaction_sql(self, isolation_level: str | None = None) -> str: + """Generate START TRANSACTION statement.""" + if isolation_level: + return f"START TRANSACTION WITH CONSISTENT SNAPSHOT, {isolation_level}" + return "START TRANSACTION WITH CONSISTENT SNAPSHOT" + + def commit_sql(self) -> str: + """Generate COMMIT statement.""" + return "COMMIT" + + def rollback_sql(self) -> str: + """Generate ROLLBACK statement.""" + return "ROLLBACK" + + # ========================================================================= + # Functions and Expressions + # ========================================================================= + + def current_timestamp_expr(self, precision: int | None = None) -> str: + """ + CURRENT_TIMESTAMP expression for MySQL. + + Parameters + ---------- + precision : int, optional + Fractional seconds precision (0-6). + + Returns + ------- + str + CURRENT_TIMESTAMP or CURRENT_TIMESTAMP(n). + """ + if precision is not None: + return f"CURRENT_TIMESTAMP({precision})" + return "CURRENT_TIMESTAMP" + + def interval_expr(self, value: int, unit: str) -> str: + """ + INTERVAL expression for MySQL. + + Parameters + ---------- + value : int + Interval value. + unit : str + Time unit (singular: 'second', 'minute', 'hour', 'day'). + + Returns + ------- + str + INTERVAL n UNIT (e.g., 'INTERVAL 5 SECOND'). + """ + # MySQL uses singular unit names + return f"INTERVAL {value} {unit.upper()}" + + def current_user_expr(self) -> str: + """MySQL current user expression.""" + return "user()" + + def json_path_expr(self, column: str, path: str, return_type: str | None = None) -> str: + """ + Generate MySQL json_value() expression. + + Parameters + ---------- + column : str + Column name containing JSON data. + path : str + JSON path (e.g., 'field' or 'nested.field'). + return_type : str, optional + Return type specification (e.g., 'decimal(10,2)'). + + Returns + ------- + str + MySQL json_value() expression. + + Examples + -------- + >>> adapter.json_path_expr('data', 'field') + "json_value(`data`, _utf8mb4'$.field')" + >>> adapter.json_path_expr('data', 'value', 'decimal(10,2)') + "json_value(`data`, _utf8mb4'$.value' returning decimal(10,2))" + """ + quoted_col = self.quote_identifier(column) + return_clause = f" returning {return_type}" if return_type else "" + return f"json_value({quoted_col}, _utf8mb4'$.{path}'{return_clause})" + + def translate_expression(self, expr: str) -> str: + """ + Translate SQL expression for MySQL compatibility. + + Converts PostgreSQL-specific functions to MySQL equivalents: + - STRING_AGG(col, 'sep') → GROUP_CONCAT(col SEPARATOR 'sep') + - STRING_AGG(col, ',') → GROUP_CONCAT(col) + + Parameters + ---------- + expr : str + SQL expression that may contain function calls. + + Returns + ------- + str + Translated expression for MySQL. + """ + import re + + # STRING_AGG(col, 'sep') → GROUP_CONCAT(col SEPARATOR 'sep') + def replace_string_agg(match): + inner = match.group(1).strip() + # Parse arguments: col, 'separator' + # Handle both single and double quoted separators + arg_match = re.match(r"(.+?)\s*,\s*(['\"])(.+?)\2", inner) + if arg_match: + col = arg_match.group(1).strip() + sep = arg_match.group(3) + # Remove ::text cast if present (PostgreSQL-specific) + col = re.sub(r"::text$", "", col) + if sep == ",": + return f"GROUP_CONCAT({col})" + else: + return f"GROUP_CONCAT({col} SEPARATOR '{sep}')" + else: + # No separator found, just use the expression + col = re.sub(r"::text$", "", inner) + return f"GROUP_CONCAT({col})" + + expr = re.sub(r"STRING_AGG\s*\((.+?)\)", replace_string_agg, expr, flags=re.IGNORECASE) + + return expr + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def format_column_definition( + self, + name: str, + sql_type: str, + nullable: bool = False, + default: str | None = None, + comment: str | None = None, + ) -> str: + """ + Format a column definition for MySQL DDL. + + Examples + -------- + >>> adapter.format_column_definition('user_id', 'bigint', nullable=False, comment='user ID') + "`user_id` bigint NOT NULL COMMENT \\"user ID\\"" + """ + parts = [self.quote_identifier(name), sql_type] + if default: + parts.append(default) # e.g., "DEFAULT NULL" or "NOT NULL DEFAULT 5" + elif not nullable: + parts.append("NOT NULL") + if comment: + parts.append(f'COMMENT "{comment}"') + return " ".join(parts) + + def table_options_clause(self, comment: str | None = None) -> str: + """ + Generate MySQL table options clause. + + Examples + -------- + >>> adapter.table_options_clause('test table') + 'ENGINE=InnoDB, COMMENT "test table"' + >>> adapter.table_options_clause() + 'ENGINE=InnoDB' + """ + clause = "ENGINE=InnoDB" + if comment: + clause += f', COMMENT "{comment}"' + return clause + + def table_comment_ddl(self, full_table_name: str, comment: str) -> str | None: + """ + MySQL uses inline COMMENT in CREATE TABLE, so no separate DDL needed. + + Examples + -------- + >>> adapter.table_comment_ddl('`schema`.`table`', 'test comment') + None + """ + return None # MySQL uses inline COMMENT + + def column_comment_ddl(self, full_table_name: str, column_name: str, comment: str) -> str | None: + """ + MySQL uses inline COMMENT in column definitions, so no separate DDL needed. + + Examples + -------- + >>> adapter.column_comment_ddl('`schema`.`table`', 'column', 'test comment') + None + """ + return None # MySQL uses inline COMMENT + + def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None: + """ + MySQL uses inline enum type in column definition, so no separate DDL needed. + + Examples + -------- + >>> adapter.enum_type_ddl('status_type', ['active', 'inactive']) + None + """ + return None # MySQL uses inline enum + + def job_metadata_columns(self) -> list[str]: + """ + Return MySQL-specific job metadata column definitions. + + Examples + -------- + >>> adapter.job_metadata_columns() + ["`_job_start_time` datetime(3) DEFAULT NULL", + "`_job_duration` float DEFAULT NULL", + "`_job_version` varchar(64) DEFAULT ''"] + """ + return [ + "`_job_start_time` datetime(3) DEFAULT NULL", + "`_job_duration` float DEFAULT NULL", + "`_job_version` varchar(64) DEFAULT ''", + ] + + # ========================================================================= + # Error Translation + # ========================================================================= + + def translate_error(self, error: Exception, query: str = "") -> Exception: + """ + Translate MySQL error to DataJoint exception. + + Parameters + ---------- + error : Exception + MySQL exception (typically pymysql error). + + Returns + ------- + Exception + DataJoint exception or original error. + """ + if not hasattr(error, "args") or len(error.args) == 0: + return error + + err, *args = error.args + + match err: + # Loss of connection errors + case 0 | "(0, '')": + return errors.LostConnectionError("Server connection lost due to an interface error.", *args) + case 2006: + return errors.LostConnectionError("Connection timed out", *args) + case 2013: + return errors.LostConnectionError("Server connection lost", *args) + + # Access errors + case 1044 | 1142: + query = args[0] if args else "" + return errors.AccessError("Insufficient privileges.", args[0] if args else "", query) + + # Integrity errors + case 1062: + return errors.DuplicateError(*args) + case 1217 | 1451 | 1452 | 3730: + return errors.IntegrityError(*args) + + # Syntax errors + case 1064: + query = args[0] if args else "" + return errors.QuerySyntaxError(args[0] if args else "", query) + + # Existence errors + case 1146: + query = args[0] if args else "" + return errors.MissingTableError(args[0] if args else "", query) + case 1364: + return errors.MissingAttributeError(*args) + case 1054: + return errors.UnknownAttributeError(*args) + + # All other errors pass through unchanged + case _: + return error + + # ========================================================================= + # Native Type Validation + # ========================================================================= + + def validate_native_type(self, type_str: str) -> bool: + """ + Check if a native MySQL type string is valid. + + Parameters + ---------- + type_str : str + Type string to validate. + + Returns + ------- + bool + True if valid MySQL type. + """ + type_lower = type_str.lower().strip() + + # MySQL native types (simplified validation) + valid_types = { + # Integer types + "tinyint", + "smallint", + "mediumint", + "int", + "integer", + "bigint", + # Floating point + "float", + "double", + "real", + "decimal", + "numeric", + # String types + "char", + "varchar", + "binary", + "varbinary", + "tinyblob", + "blob", + "mediumblob", + "longblob", + "tinytext", + "text", + "mediumtext", + "longtext", + # Temporal types + "date", + "time", + "datetime", + "timestamp", + "year", + # Other + "enum", + "set", + "json", + "geometry", + } + + # Extract base type (before parentheses) + base_type = type_lower.split("(")[0].strip() + + return base_type in valid_types diff --git a/src/datajoint/adapters/postgres.py b/src/datajoint/adapters/postgres.py new file mode 100644 index 000000000..12fecae6a --- /dev/null +++ b/src/datajoint/adapters/postgres.py @@ -0,0 +1,1510 @@ +""" +PostgreSQL database adapter for DataJoint. + +This module provides PostgreSQL-specific implementations for SQL generation, +type mapping, error translation, and connection management. +""" + +from __future__ import annotations + +import re +from typing import Any + +try: + import psycopg2 as client + from psycopg2 import sql +except ImportError: + client = None # type: ignore + sql = None # type: ignore + +from .. import errors +from .base import DatabaseAdapter + +# Core type mapping: DataJoint core types → PostgreSQL types +CORE_TYPE_MAP = { + "int64": "bigint", + "int32": "integer", + "int16": "smallint", + "int8": "smallint", # PostgreSQL lacks tinyint; semantically equivalent + "float32": "real", + "float64": "double precision", + "bool": "boolean", + "uuid": "uuid", # Native UUID support + "bytes": "bytea", + "json": "jsonb", # Using jsonb for better performance + "date": "date", + # datetime, char, varchar, decimal, enum require parameters - handled in method +} + +# Reverse mapping: PostgreSQL types → DataJoint core types (for introspection) +SQL_TO_CORE_MAP = { + "bigint": "int64", + "integer": "int32", + "smallint": "int16", + "real": "float32", + "double precision": "float64", + "boolean": "bool", + "uuid": "uuid", + "bytea": "bytes", + "jsonb": "json", + "json": "json", + "date": "date", +} + + +class PostgreSQLAdapter(DatabaseAdapter): + """PostgreSQL database adapter implementation.""" + + def __init__(self) -> None: + """Initialize PostgreSQL adapter.""" + if client is None: + raise ImportError( + "psycopg2 is required for PostgreSQL support. " "Install it with: pip install 'datajoint[postgres]'" + ) + + # ========================================================================= + # Connection Management + # ========================================================================= + + def connect( + self, + host: str, + port: int, + user: str, + password: str, + **kwargs: Any, + ) -> Any: + """ + Establish PostgreSQL connection. + + Parameters + ---------- + host : str + PostgreSQL server hostname. + port : int + PostgreSQL server port. + user : str + Username for authentication. + password : str + Password for authentication. + **kwargs : Any + Additional PostgreSQL-specific parameters: + - dbname: Database name + - sslmode: SSL mode ('disable', 'allow', 'prefer', 'require') + - use_tls: bool or dict - DataJoint's SSL parameter (converted to sslmode) + - connect_timeout: Connection timeout in seconds + + Returns + ------- + psycopg2.connection + PostgreSQL connection object. + """ + dbname = kwargs.get("dbname", "postgres") # Default to postgres database + connect_timeout = kwargs.get("connect_timeout", 10) + + # Handle use_tls parameter (from DataJoint Connection) + # Convert to PostgreSQL's sslmode + use_tls = kwargs.get("use_tls") + if "sslmode" in kwargs: + # Explicit sslmode takes precedence + sslmode = kwargs["sslmode"] + elif use_tls is False: + # use_tls=False → disable SSL + sslmode = "disable" + elif use_tls is True or isinstance(use_tls, dict): + # use_tls=True or dict → require SSL + sslmode = "require" + else: + # use_tls=None (default) → prefer SSL but allow fallback + sslmode = "prefer" + + conn = client.connect( + host=host, + port=port, + user=user, + password=password, + dbname=dbname, + sslmode=sslmode, + connect_timeout=connect_timeout, + ) + # DataJoint manages transactions explicitly via start_transaction() + # Set autocommit=True to avoid implicit transactions + conn.autocommit = True + + # Register numpy type adapters so numpy types can be used directly in queries + self._register_numpy_adapters() + + return conn + + def _register_numpy_adapters(self) -> None: + """ + Register psycopg2 adapters for numpy types. + + This allows numpy scalar types (bool_, int64, float64, etc.) to be used + directly in queries without explicit conversion to Python native types. + """ + try: + import numpy as np + from psycopg2.extensions import register_adapter, AsIs + + # Numpy bool type + register_adapter(np.bool_, lambda x: AsIs(str(bool(x)).upper())) + + # Numpy integer types + for np_type in (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64): + register_adapter(np_type, lambda x: AsIs(int(x))) + + # Numpy float types + for np_ftype in (np.float16, np.float32, np.float64): + register_adapter(np_ftype, lambda x: AsIs(repr(float(x)))) + + except ImportError: + pass # numpy not available + + def close(self, connection: Any) -> None: + """Close the PostgreSQL connection.""" + connection.close() + + def ping(self, connection: Any) -> bool: + """ + Check if PostgreSQL connection is alive. + + Returns + ------- + bool + True if connection is alive. + """ + try: + cursor = connection.cursor() + cursor.execute("SELECT 1") + cursor.close() + return True + except Exception: + return False + + def get_connection_id(self, connection: Any) -> int: + """ + Get PostgreSQL backend process ID. + + Returns + ------- + int + PostgreSQL pg_backend_pid(). + """ + cursor = connection.cursor() + cursor.execute("SELECT pg_backend_pid()") + return cursor.fetchone()[0] + + @property + def default_port(self) -> int: + """PostgreSQL default port 5432.""" + return 5432 + + @property + def backend(self) -> str: + """Backend identifier: 'postgresql'.""" + return "postgresql" + + def get_cursor(self, connection: Any, as_dict: bool = False) -> Any: + """ + Get a cursor from PostgreSQL connection. + + Parameters + ---------- + connection : Any + psycopg2 connection object. + as_dict : bool, optional + If True, return Real DictCursor that yields rows as dictionaries. + If False, return standard cursor that yields rows as tuples. + Default False. + + Returns + ------- + Any + psycopg2 cursor object. + """ + import psycopg2.extras + + if as_dict: + return connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + return connection.cursor() + + # ========================================================================= + # SQL Syntax + # ========================================================================= + + def quote_identifier(self, name: str) -> str: + """ + Quote identifier with double quotes for PostgreSQL. + + Parameters + ---------- + name : str + Identifier to quote. + + Returns + ------- + str + Double-quoted identifier: "name" + """ + return f'"{name}"' + + def quote_string(self, value: str) -> str: + """ + Quote string literal for PostgreSQL with escaping. + + Parameters + ---------- + value : str + String value to quote. + + Returns + ------- + str + Quoted and escaped string literal. + """ + # Escape single quotes by doubling them (PostgreSQL standard) + escaped = value.replace("'", "''") + return f"'{escaped}'" + + def get_master_table_name(self, part_table: str) -> str | None: + """Extract master table name from part table (PostgreSQL double-quote format).""" + import re + + # PostgreSQL format: "schema"."master__part" + match = re.match(r'(?P"\w+"."#?\w+)__\w+"', part_table) + return match["master"] + '"' if match else None + + @property + def parameter_placeholder(self) -> str: + """PostgreSQL/psycopg2 uses %s placeholders.""" + return "%s" + + # ========================================================================= + # Type Mapping + # ========================================================================= + + def core_type_to_sql(self, core_type: str) -> str: + """ + Convert DataJoint core type to PostgreSQL type. + + Parameters + ---------- + core_type : str + DataJoint core type, possibly with parameters: + - int64, float32, bool, uuid, bytes, json, date + - datetime or datetime(n) → timestamp(n) + - char(n), varchar(n) + - decimal(p,s) → numeric(p,s) + - enum('a','b','c') → requires CREATE TYPE + + Returns + ------- + str + PostgreSQL SQL type. + + Raises + ------ + ValueError + If core_type is not recognized. + """ + # Handle simple types without parameters + if core_type in CORE_TYPE_MAP: + return CORE_TYPE_MAP[core_type] + + # Handle parametrized types + if core_type.startswith("datetime"): + # datetime or datetime(precision) → timestamp or timestamp(precision) + if "(" in core_type: + # Extract precision: datetime(3) → timestamp(3) + precision = core_type[core_type.index("(") : core_type.index(")") + 1] + return f"timestamp{precision}" + return "timestamp" + + if core_type.startswith("char("): + # char(n) + return core_type + + if core_type.startswith("varchar("): + # varchar(n) + return core_type + + if core_type.startswith("decimal("): + # decimal(precision, scale) → numeric(precision, scale) + params = core_type[7:] # Remove "decimal" + return f"numeric{params}" + + if core_type.startswith("enum("): + # PostgreSQL requires CREATE TYPE for enums + # Extract enum values and generate a deterministic type name + enum_match = re.match(r"enum\s*\((.+)\)", core_type, re.I) + if enum_match: + # Parse enum values: enum('M','F') -> ['M', 'F'] + values_str = enum_match.group(1) + # Split by comma, handling quoted values + values = [v.strip().strip("'\"") for v in values_str.split(",")] + # Generate a deterministic type name based on values + # Use a hash to keep name reasonable length + import hashlib + + value_hash = hashlib.md5("_".join(sorted(values)).encode()).hexdigest()[:8] + type_name = f"enum_{value_hash}" + # Track this enum type for CREATE TYPE DDL + if not hasattr(self, "_pending_enum_types"): + self._pending_enum_types = {} + self._pending_enum_types[type_name] = values + # Return schema-qualified type reference using placeholder + # {database} will be replaced with actual schema name in table.py + return '"{database}".' + self.quote_identifier(type_name) + return "text" # Fallback if parsing fails + + raise ValueError(f"Unknown core type: {core_type}") + + def sql_type_to_core(self, sql_type: str) -> str | None: + """ + Convert PostgreSQL type to DataJoint core type (if mappable). + + Parameters + ---------- + sql_type : str + PostgreSQL SQL type. + + Returns + ------- + str or None + DataJoint core type if mappable, None otherwise. + """ + # Normalize type string (lowercase, strip spaces) + sql_type_lower = sql_type.lower().strip() + + # Direct mapping + if sql_type_lower in SQL_TO_CORE_MAP: + return SQL_TO_CORE_MAP[sql_type_lower] + + # Handle parametrized types + if sql_type_lower.startswith("timestamp"): + # timestamp(n) → datetime(n) + if "(" in sql_type_lower: + precision = sql_type_lower[sql_type_lower.index("(") : sql_type_lower.index(")") + 1] + return f"datetime{precision}" + return "datetime" + + if sql_type_lower.startswith("char("): + return sql_type # Keep size + + if sql_type_lower.startswith("varchar("): + return sql_type # Keep size + + if sql_type_lower.startswith("numeric("): + # numeric(p,s) → decimal(p,s) + params = sql_type_lower[7:] # Remove "numeric" + return f"decimal{params}" + + # Not a mappable core type + return None + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def create_schema_sql(self, schema_name: str) -> str: + """ + Generate CREATE SCHEMA statement for PostgreSQL. + + Parameters + ---------- + schema_name : str + Schema name. + + Returns + ------- + str + CREATE SCHEMA SQL. + """ + return f"CREATE SCHEMA {self.quote_identifier(schema_name)}" + + def drop_schema_sql(self, schema_name: str, if_exists: bool = True) -> str: + """ + Generate DROP SCHEMA statement for PostgreSQL. + + Parameters + ---------- + schema_name : str + Schema name. + if_exists : bool + Include IF EXISTS clause. + + Returns + ------- + str + DROP SCHEMA SQL. + """ + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP SCHEMA {if_exists_clause}{self.quote_identifier(schema_name)} CASCADE" + + def create_table_sql( + self, + table_name: str, + columns: list[dict[str, Any]], + primary_key: list[str], + foreign_keys: list[dict[str, Any]], + indexes: list[dict[str, Any]], + comment: str | None = None, + ) -> str: + """ + Generate CREATE TABLE statement for PostgreSQL. + + Parameters + ---------- + table_name : str + Fully qualified table name (schema.table). + columns : list[dict] + Column defs: [{name, type, nullable, default, comment}, ...] + primary_key : list[str] + Primary key column names. + foreign_keys : list[dict] + FK defs: [{columns, ref_table, ref_columns}, ...] + indexes : list[dict] + Index defs: [{columns, unique}, ...] + comment : str, optional + Table comment (added via separate COMMENT ON statement). + + Returns + ------- + str + CREATE TABLE SQL statement (comments via separate COMMENT ON). + """ + lines = [] + + # Column definitions + for col in columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + default = f" DEFAULT {col['default']}" if "default" in col else "" + # PostgreSQL comments are via COMMENT ON, not inline + lines.append(f"{col_name} {col_type} {nullable}{default}") + + # Primary key + if primary_key: + pk_cols = ", ".join(self.quote_identifier(col) for col in primary_key) + lines.append(f"PRIMARY KEY ({pk_cols})") + + # Foreign keys + for fk in foreign_keys: + fk_cols = ", ".join(self.quote_identifier(col) for col in fk["columns"]) + ref_cols = ", ".join(self.quote_identifier(col) for col in fk["ref_columns"]) + lines.append( + f"FOREIGN KEY ({fk_cols}) REFERENCES {fk['ref_table']} ({ref_cols}) " f"ON UPDATE CASCADE ON DELETE RESTRICT" + ) + + # Indexes - PostgreSQL creates indexes separately via CREATE INDEX + # (handled by caller after table creation) + + # Assemble CREATE TABLE (no ENGINE in PostgreSQL) + table_def = ",\n ".join(lines) + return f"CREATE TABLE IF NOT EXISTS {table_name} (\n {table_def}\n)" + + def drop_table_sql(self, table_name: str, if_exists: bool = True) -> str: + """Generate DROP TABLE statement for PostgreSQL.""" + if_exists_clause = "IF EXISTS " if if_exists else "" + return f"DROP TABLE {if_exists_clause}{table_name} CASCADE" + + def alter_table_sql( + self, + table_name: str, + add_columns: list[dict[str, Any]] | None = None, + drop_columns: list[str] | None = None, + modify_columns: list[dict[str, Any]] | None = None, + ) -> str: + """ + Generate ALTER TABLE statement for PostgreSQL. + + Parameters + ---------- + table_name : str + Table name. + add_columns : list[dict], optional + Columns to add. + drop_columns : list[str], optional + Column names to drop. + modify_columns : list[dict], optional + Columns to modify. + + Returns + ------- + str + ALTER TABLE SQL statement. + """ + clauses = [] + + if add_columns: + for col in add_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = "NULL" if col.get("nullable", False) else "NOT NULL" + clauses.append(f"ADD COLUMN {col_name} {col_type} {nullable}") + + if drop_columns: + for col_name in drop_columns: + clauses.append(f"DROP COLUMN {self.quote_identifier(col_name)}") + + if modify_columns: + # PostgreSQL requires ALTER COLUMN ... TYPE ... for type changes + for col in modify_columns: + col_name = self.quote_identifier(col["name"]) + col_type = col["type"] + nullable = col.get("nullable", False) + clauses.append(f"ALTER COLUMN {col_name} TYPE {col_type}") + if nullable: + clauses.append(f"ALTER COLUMN {col_name} DROP NOT NULL") + else: + clauses.append(f"ALTER COLUMN {col_name} SET NOT NULL") + + return f"ALTER TABLE {table_name} {', '.join(clauses)}" + + def add_comment_sql( + self, + object_type: str, + object_name: str, + comment: str, + ) -> str | None: + """ + Generate COMMENT ON statement for PostgreSQL. + + Parameters + ---------- + object_type : str + 'table' or 'column'. + object_name : str + Fully qualified object name. + comment : str + Comment text. + + Returns + ------- + str + COMMENT ON statement. + """ + comment_type = object_type.upper() + return f"COMMENT ON {comment_type} {object_name} IS {self.quote_string(comment)}" + + # ========================================================================= + # DML Generation + # ========================================================================= + + def insert_sql( + self, + table_name: str, + columns: list[str], + on_duplicate: str | None = None, + ) -> str: + """ + Generate INSERT statement for PostgreSQL. + + Parameters + ---------- + table_name : str + Table name. + columns : list[str] + Column names. + on_duplicate : str, optional + 'ignore' or 'update' (PostgreSQL uses ON CONFLICT). + + Returns + ------- + str + INSERT SQL with placeholders. + """ + cols = ", ".join(self.quote_identifier(col) for col in columns) + placeholders = ", ".join([self.parameter_placeholder] * len(columns)) + + base_insert = f"INSERT INTO {table_name} ({cols}) VALUES ({placeholders})" + + if on_duplicate == "ignore": + return f"{base_insert} ON CONFLICT DO NOTHING" + elif on_duplicate == "update": + # ON CONFLICT (pk_cols) DO UPDATE SET col=EXCLUDED.col + # Caller must provide constraint name or columns + updates = ", ".join(f"{self.quote_identifier(col)}=EXCLUDED.{self.quote_identifier(col)}" for col in columns) + return f"{base_insert} ON CONFLICT DO UPDATE SET {updates}" + else: + return base_insert + + def update_sql( + self, + table_name: str, + set_columns: list[str], + where_columns: list[str], + ) -> str: + """Generate UPDATE statement for PostgreSQL.""" + set_clause = ", ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in set_columns) + where_clause = " AND ".join(f"{self.quote_identifier(col)} = {self.parameter_placeholder}" for col in where_columns) + return f"UPDATE {table_name} SET {set_clause} WHERE {where_clause}" + + def delete_sql(self, table_name: str) -> str: + """Generate DELETE statement for PostgreSQL (WHERE added separately).""" + return f"DELETE FROM {table_name}" + + def upsert_on_duplicate_sql( + self, + table_name: str, + columns: list[str], + primary_key: list[str], + num_rows: int, + ) -> str: + """Generate INSERT ... ON CONFLICT ... DO UPDATE statement for PostgreSQL.""" + # Build column list + col_list = ", ".join(columns) + + # Build placeholders for VALUES + placeholders = ", ".join(["(%s)" % ", ".join(["%s"] * len(columns))] * num_rows) + + # Build conflict target (primary key columns) + conflict_cols = ", ".join(primary_key) + + # Build UPDATE clause (non-PK columns only) + non_pk_columns = [col for col in columns if col not in primary_key] + update_clauses = ", ".join(f"{col} = EXCLUDED.{col}" for col in non_pk_columns) + + return f""" + INSERT INTO {table_name} ({col_list}) + VALUES {placeholders} + ON CONFLICT ({conflict_cols}) DO UPDATE SET {update_clauses} + """ + + def skip_duplicates_clause( + self, + full_table_name: str, + primary_key: list[str], + ) -> str: + """ + Generate clause to skip duplicate key insertions for PostgreSQL. + + Uses ON CONFLICT (pk_cols) DO NOTHING to skip duplicates without + raising an error. + + Parameters + ---------- + full_table_name : str + Fully qualified table name (with quotes). Unused but kept for + API compatibility with MySQL adapter. + primary_key : list[str] + Primary key column names (unquoted). + + Returns + ------- + str + PostgreSQL ON CONFLICT DO NOTHING clause. + """ + pk_cols = ", ".join(self.quote_identifier(pk) for pk in primary_key) + return f" ON CONFLICT ({pk_cols}) DO NOTHING" + + @property + def supports_inline_indexes(self) -> bool: + """ + PostgreSQL does not support inline INDEX in CREATE TABLE. + + Returns False to indicate indexes must be created separately + with CREATE INDEX statements. + """ + return False + + # ========================================================================= + # Introspection + # ========================================================================= + + def list_schemas_sql(self) -> str: + """Query to list all schemas in PostgreSQL.""" + return ( + "SELECT schema_name FROM information_schema.schemata " + "WHERE schema_name NOT IN ('pg_catalog', 'information_schema')" + ) + + def list_tables_sql(self, schema_name: str, pattern: str | None = None) -> str: + """Query to list tables in a schema.""" + sql = ( + f"SELECT table_name FROM information_schema.tables " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_type = 'BASE TABLE'" + ) + if pattern: + sql += f" AND table_name LIKE '{pattern}'" + return sql + + def get_table_info_sql(self, schema_name: str, table_name: str) -> str: + """Query to get table metadata including table comment.""" + schema_str = self.quote_string(schema_name) + table_str = self.quote_string(table_name) + regclass_expr = f"({schema_str} || '.' || {table_str})::regclass" + return ( + f"SELECT t.*, obj_description({regclass_expr}, 'pg_class') as table_comment " + f"FROM information_schema.tables t " + f"WHERE t.table_schema = {schema_str} " + f"AND t.table_name = {table_str}" + ) + + def get_columns_sql(self, schema_name: str, table_name: str) -> str: + """Query to get column definitions including comments.""" + # Use col_description() to retrieve column comments stored via COMMENT ON COLUMN + # The regclass cast allows using schema.table notation to get the OID + schema_str = self.quote_string(schema_name) + table_str = self.quote_string(table_name) + regclass_expr = f"({schema_str} || '.' || {table_str})::regclass" + return ( + f"SELECT c.column_name, c.data_type, c.udt_name, c.is_nullable, c.column_default, " + f"c.character_maximum_length, c.numeric_precision, c.numeric_scale, " + f"col_description({regclass_expr}, c.ordinal_position) as column_comment " + f"FROM information_schema.columns c " + f"WHERE c.table_schema = {schema_str} " + f"AND c.table_name = {table_str} " + f"ORDER BY c.ordinal_position" + ) + + def get_primary_key_sql(self, schema_name: str, table_name: str) -> str: + """Query to get primary key columns.""" + return ( + f"SELECT column_name FROM information_schema.key_column_usage " + f"WHERE table_schema = {self.quote_string(schema_name)} " + f"AND table_name = {self.quote_string(table_name)} " + f"AND constraint_name IN (" + f" SELECT constraint_name FROM information_schema.table_constraints " + f" WHERE table_schema = {self.quote_string(schema_name)} " + f" AND table_name = {self.quote_string(table_name)} " + f" AND constraint_type = 'PRIMARY KEY'" + f") " + f"ORDER BY ordinal_position" + ) + + def get_foreign_keys_sql(self, schema_name: str, table_name: str) -> str: + """Query to get foreign key constraints.""" + return ( + f"SELECT kcu.constraint_name, kcu.column_name, " + f"ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name " + f"FROM information_schema.key_column_usage AS kcu " + f"JOIN information_schema.constraint_column_usage AS ccu " + f" ON kcu.constraint_name = ccu.constraint_name " + f"WHERE kcu.table_schema = {self.quote_string(schema_name)} " + f"AND kcu.table_name = {self.quote_string(table_name)} " + f"AND kcu.constraint_name IN (" + f" SELECT constraint_name FROM information_schema.table_constraints " + f" WHERE table_schema = {self.quote_string(schema_name)} " + f" AND table_name = {self.quote_string(table_name)} " + f" AND constraint_type = 'FOREIGN KEY'" + f") " + f"ORDER BY kcu.constraint_name, kcu.ordinal_position" + ) + + def get_constraint_info_sql(self, constraint_name: str, schema_name: str, table_name: str) -> str: + """ + Query to get FK constraint details from information_schema. + + Returns matched pairs of (fk_column, parent_table, pk_column) for each + column in the foreign key constraint, ordered by position. + """ + return ( + "SELECT " + " kcu.column_name as fk_attrs, " + " '\"' || ccu.table_schema || '\".\"' || ccu.table_name || '\"' as parent, " + " ccu.column_name as pk_attrs " + "FROM information_schema.key_column_usage AS kcu " + "JOIN information_schema.referential_constraints AS rc " + " ON kcu.constraint_name = rc.constraint_name " + " AND kcu.constraint_schema = rc.constraint_schema " + "JOIN information_schema.key_column_usage AS ccu " + " ON rc.unique_constraint_name = ccu.constraint_name " + " AND rc.unique_constraint_schema = ccu.constraint_schema " + " AND kcu.ordinal_position = ccu.ordinal_position " + "WHERE kcu.constraint_name = %s " + " AND kcu.table_schema = %s " + " AND kcu.table_name = %s " + "ORDER BY kcu.ordinal_position" + ) + + def parse_foreign_key_error(self, error_message: str) -> dict[str, str | list[str] | None] | None: + """ + Parse PostgreSQL foreign key violation error message. + + PostgreSQL FK error format: + 'update or delete on table "X" violates foreign key constraint "Y" on table "Z"' + Where: + - "X" is the referenced table (being deleted/updated) + - "Z" is the referencing table (has the FK, needs cascade delete) + """ + import re + + pattern = re.compile( + r'.*table "(?P[^"]+)" violates foreign key constraint ' + r'"(?P[^"]+)" on table "(?P[^"]+)"' + ) + + match = pattern.match(error_message) + if not match: + return None + + result = match.groupdict() + + # The child is the referencing table (the one with the FK that needs cascade delete) + # The parent is the referenced table (the one being deleted) + # The error doesn't include schema, so we return unqualified names + child = f'"{result["referencing_table"]}"' + parent = f'"{result["referenced_table"]}"' + + return { + "child": child, + "name": f'"{result["name"]}"', + "fk_attrs": None, # Not in error message, will need constraint query + "parent": parent, + "pk_attrs": None, # Not in error message, will need constraint query + } + + def get_indexes_sql(self, schema_name: str, table_name: str) -> str: + """Query to get index definitions.""" + return ( + f"SELECT indexname, indexdef FROM pg_indexes " + f"WHERE schemaname = {self.quote_string(schema_name)} " + f"AND tablename = {self.quote_string(table_name)}" + ) + + def parse_column_info(self, row: dict[str, Any]) -> dict[str, Any]: + """ + Parse PostgreSQL column info into standardized format. + + Parameters + ---------- + row : dict + Row from information_schema.columns query with col_description() join. + + Returns + ------- + dict + Standardized column info with keys: + name, type, nullable, default, comment, key, extra + """ + # For user-defined types (enums), use udt_name instead of data_type + # PostgreSQL reports enums as "USER-DEFINED" in data_type + data_type = row["data_type"] + if data_type == "USER-DEFINED": + data_type = row["udt_name"] + + # Reconstruct parametrized types that PostgreSQL splits into separate fields + char_max_len = row.get("character_maximum_length") + num_precision = row.get("numeric_precision") + num_scale = row.get("numeric_scale") + + if data_type == "character" and char_max_len is not None: + # char(n) - PostgreSQL reports as "character" with length in separate field + data_type = f"char({char_max_len})" + elif data_type == "character varying" and char_max_len is not None: + # varchar(n) + data_type = f"varchar({char_max_len})" + elif data_type == "numeric" and num_precision is not None: + # numeric(p,s) - reconstruct decimal type + if num_scale is not None and num_scale > 0: + data_type = f"decimal({num_precision},{num_scale})" + else: + data_type = f"decimal({num_precision})" + + return { + "name": row["column_name"], + "type": data_type, + "nullable": row["is_nullable"] == "YES", + "default": row["column_default"], + "comment": row.get("column_comment"), # Retrieved via col_description() + "key": "", # PostgreSQL key info retrieved separately + "extra": "", # PostgreSQL doesn't have auto_increment in same way + } + + # ========================================================================= + # Transactions + # ========================================================================= + + def start_transaction_sql(self, isolation_level: str | None = None) -> str: + """Generate BEGIN statement for PostgreSQL.""" + if isolation_level: + return f"BEGIN ISOLATION LEVEL {isolation_level}" + return "BEGIN" + + def commit_sql(self) -> str: + """Generate COMMIT statement.""" + return "COMMIT" + + def rollback_sql(self) -> str: + """Generate ROLLBACK statement.""" + return "ROLLBACK" + + # ========================================================================= + # Functions and Expressions + # ========================================================================= + + def current_timestamp_expr(self, precision: int | None = None) -> str: + """ + CURRENT_TIMESTAMP expression for PostgreSQL. + + Parameters + ---------- + precision : int, optional + Fractional seconds precision (0-6). + + Returns + ------- + str + CURRENT_TIMESTAMP or CURRENT_TIMESTAMP(n). + """ + if precision is not None: + return f"CURRENT_TIMESTAMP({precision})" + return "CURRENT_TIMESTAMP" + + def interval_expr(self, value: int, unit: str) -> str: + """ + INTERVAL expression for PostgreSQL. + + Parameters + ---------- + value : int + Interval value. + unit : str + Time unit (singular: 'second', 'minute', 'hour', 'day'). + + Returns + ------- + str + INTERVAL 'n units' (e.g., "INTERVAL '5 seconds'"). + """ + # PostgreSQL uses plural unit names and quotes + unit_plural = unit.lower() + "s" if not unit.endswith("s") else unit.lower() + return f"INTERVAL '{value} {unit_plural}'" + + def current_user_expr(self) -> str: + """PostgreSQL current user expression.""" + return "current_user" + + def json_path_expr(self, column: str, path: str, return_type: str | None = None) -> str: + """ + Generate PostgreSQL jsonb_extract_path_text() expression. + + Parameters + ---------- + column : str + Column name containing JSON data. + path : str + JSON path (e.g., 'field' or 'nested.field'). + return_type : str, optional + Return type specification for casting (e.g., 'float', 'decimal(10,2)'). + + Returns + ------- + str + PostgreSQL jsonb_extract_path_text() expression, with optional cast. + + Examples + -------- + >>> adapter.json_path_expr('data', 'field') + 'jsonb_extract_path_text("data", \\'field\\')' + >>> adapter.json_path_expr('data', 'nested.field') + 'jsonb_extract_path_text("data", \\'nested\\', \\'field\\')' + >>> adapter.json_path_expr('data', 'value', 'float') + 'jsonb_extract_path_text("data", \\'value\\')::float' + """ + quoted_col = self.quote_identifier(column) + # Split path by '.' for nested access, handling array notation + path_parts = [] + for part in path.split("."): + # Handle array access like field[0] + if "[" in part: + base, rest = part.split("[", 1) + path_parts.append(base) + # Extract array indices + indices = rest.rstrip("]").split("][") + path_parts.extend(indices) + else: + path_parts.append(part) + path_args = ", ".join(f"'{part}'" for part in path_parts) + expr = f"jsonb_extract_path_text({quoted_col}, {path_args})" + # Add cast if return type specified + if return_type: + # Map DataJoint types to PostgreSQL types + pg_type = return_type.lower() + if pg_type in ("unsigned", "signed"): + pg_type = "integer" + elif pg_type == "double": + pg_type = "double precision" + expr = f"({expr})::{pg_type}" + return expr + + def translate_expression(self, expr: str) -> str: + """ + Translate SQL expression for PostgreSQL compatibility. + + Converts MySQL-specific functions to PostgreSQL equivalents: + - GROUP_CONCAT(col) → STRING_AGG(col::text, ',') + - GROUP_CONCAT(col SEPARATOR 'sep') → STRING_AGG(col::text, 'sep') + + Parameters + ---------- + expr : str + SQL expression that may contain function calls. + + Returns + ------- + str + Translated expression for PostgreSQL. + """ + import re + + # GROUP_CONCAT(col) → STRING_AGG(col::text, ',') + # GROUP_CONCAT(col SEPARATOR 'sep') → STRING_AGG(col::text, 'sep') + def replace_group_concat(match): + inner = match.group(1).strip() + # Check for SEPARATOR clause + sep_match = re.match(r"(.+?)\s+SEPARATOR\s+(['\"])(.+?)\2", inner, re.IGNORECASE) + if sep_match: + col = sep_match.group(1).strip() + sep = sep_match.group(3) + return f"STRING_AGG({col}::text, '{sep}')" + else: + return f"STRING_AGG({inner}::text, ',')" + + expr = re.sub(r"GROUP_CONCAT\s*\((.+?)\)", replace_group_concat, expr, flags=re.IGNORECASE) + + # Replace simple functions FIRST before complex patterns + # CURDATE() → CURRENT_DATE + expr = re.sub(r"CURDATE\s*\(\s*\)", "CURRENT_DATE", expr, flags=re.IGNORECASE) + + # NOW() → CURRENT_TIMESTAMP + expr = re.sub(r"\bNOW\s*\(\s*\)", "CURRENT_TIMESTAMP", expr, flags=re.IGNORECASE) + + # YEAR(date) → EXTRACT(YEAR FROM date)::int + expr = re.sub(r"\bYEAR\s*\(\s*([^)]+)\s*\)", r"EXTRACT(YEAR FROM \1)::int", expr, flags=re.IGNORECASE) + + # MONTH(date) → EXTRACT(MONTH FROM date)::int + expr = re.sub(r"\bMONTH\s*\(\s*([^)]+)\s*\)", r"EXTRACT(MONTH FROM \1)::int", expr, flags=re.IGNORECASE) + + # DAY(date) → EXTRACT(DAY FROM date)::int + expr = re.sub(r"\bDAY\s*\(\s*([^)]+)\s*\)", r"EXTRACT(DAY FROM \1)::int", expr, flags=re.IGNORECASE) + + # TIMESTAMPDIFF(YEAR, d1, d2) → EXTRACT(YEAR FROM AGE(d2, d1))::int + # Use a more robust regex that handles the comma-separated arguments + def replace_timestampdiff(match): + unit = match.group(1).upper() + date1 = match.group(2).strip() + date2 = match.group(3).strip() + if unit == "YEAR": + return f"EXTRACT(YEAR FROM AGE({date2}, {date1}))::int" + elif unit == "MONTH": + return f"(EXTRACT(YEAR FROM AGE({date2}, {date1})) * 12 + EXTRACT(MONTH FROM AGE({date2}, {date1})))::int" + elif unit == "DAY": + return f"({date2}::date - {date1}::date)" + else: + return f"EXTRACT({unit} FROM AGE({date2}, {date1}))::int" + + # Match TIMESTAMPDIFF with proper argument parsing + # The arguments are: unit, date1, date2 - we need to handle identifiers and CURRENT_DATE + expr = re.sub( + r"TIMESTAMPDIFF\s*\(\s*(\w+)\s*,\s*([^,]+)\s*,\s*([^)]+)\s*\)", + replace_timestampdiff, + expr, + flags=re.IGNORECASE, + ) + + # SUM(expr='value') → SUM((expr='value')::int) for PostgreSQL boolean handling + # This handles patterns like SUM(sex='F') which produce boolean in PostgreSQL + def replace_sum_comparison(match): + inner = match.group(1).strip() + # Check if inner contains a comparison operator + if re.search(r"[=<>!]", inner) and not inner.startswith("("): + return f"SUM(({inner})::int)" + return match.group(0) # Return unchanged if no comparison + + expr = re.sub(r"\bSUM\s*\(\s*([^)]+)\s*\)", replace_sum_comparison, expr, flags=re.IGNORECASE) + + return expr + + # ========================================================================= + # DDL Generation + # ========================================================================= + + def format_column_definition( + self, + name: str, + sql_type: str, + nullable: bool = False, + default: str | None = None, + comment: str | None = None, + ) -> str: + """ + Format a column definition for PostgreSQL DDL. + + Examples + -------- + >>> adapter.format_column_definition('user_id', 'bigint', nullable=False, comment='user ID') + '"user_id" bigint NOT NULL' + """ + parts = [self.quote_identifier(name), sql_type] + if default: + parts.append(default) + elif not nullable: + parts.append("NOT NULL") + # Note: PostgreSQL comments handled separately via COMMENT ON + return " ".join(parts) + + def table_options_clause(self, comment: str | None = None) -> str: + """ + Generate PostgreSQL table options clause (empty - no ENGINE in PostgreSQL). + + Examples + -------- + >>> adapter.table_options_clause('test table') + '' + >>> adapter.table_options_clause() + '' + """ + return "" # PostgreSQL uses COMMENT ON TABLE separately + + def table_comment_ddl(self, full_table_name: str, comment: str) -> str | None: + """ + Generate COMMENT ON TABLE statement for PostgreSQL. + + Examples + -------- + >>> adapter.table_comment_ddl('"schema"."table"', 'test comment') + 'COMMENT ON TABLE "schema"."table" IS \\'test comment\\'' + """ + # Escape single quotes by doubling them + escaped_comment = comment.replace("'", "''") + return f"COMMENT ON TABLE {full_table_name} IS '{escaped_comment}'" + + def column_comment_ddl(self, full_table_name: str, column_name: str, comment: str) -> str | None: + """ + Generate COMMENT ON COLUMN statement for PostgreSQL. + + Examples + -------- + >>> adapter.column_comment_ddl('"schema"."table"', 'column', 'test comment') + 'COMMENT ON COLUMN "schema"."table"."column" IS \\'test comment\\'' + """ + quoted_col = self.quote_identifier(column_name) + # Escape single quotes by doubling them (PostgreSQL string literal syntax) + escaped_comment = comment.replace("'", "''") + return f"COMMENT ON COLUMN {full_table_name}.{quoted_col} IS '{escaped_comment}'" + + def enum_type_ddl(self, type_name: str, values: list[str]) -> str | None: + """ + Generate CREATE TYPE statement for PostgreSQL enum. + + Examples + -------- + >>> adapter.enum_type_ddl('status_type', ['active', 'inactive']) + 'CREATE TYPE "status_type" AS ENUM (\\'active\\', \\'inactive\\')' + """ + quoted_values = ", ".join(f"'{v}'" for v in values) + return f"CREATE TYPE {self.quote_identifier(type_name)} AS ENUM ({quoted_values})" + + def get_pending_enum_ddl(self, schema_name: str) -> list[str]: + """ + Get DDL statements for pending enum types and clear the pending list. + + PostgreSQL requires CREATE TYPE statements before using enum types in + column definitions. This method returns DDL for enum types accumulated + during type conversion and clears the pending list. + + Parameters + ---------- + schema_name : str + Schema name to qualify enum type names. + + Returns + ------- + list[str] + List of CREATE TYPE statements (if any pending). + """ + ddl_statements = [] + if hasattr(self, "_pending_enum_types") and self._pending_enum_types: + for type_name, values in self._pending_enum_types.items(): + # Generate CREATE TYPE with schema qualification + quoted_type = f"{self.quote_identifier(schema_name)}.{self.quote_identifier(type_name)}" + quoted_values = ", ".join(f"'{v}'" for v in values) + ddl_statements.append(f"CREATE TYPE {quoted_type} AS ENUM ({quoted_values})") + self._pending_enum_types = {} + return ddl_statements + + def job_metadata_columns(self) -> list[str]: + """ + Return PostgreSQL-specific job metadata column definitions. + + Examples + -------- + >>> adapter.job_metadata_columns() + ['"_job_start_time" timestamp DEFAULT NULL', + '"_job_duration" real DEFAULT NULL', + '"_job_version" varchar(64) DEFAULT \\'\\''] + """ + return [ + '"_job_start_time" timestamp DEFAULT NULL', + '"_job_duration" real DEFAULT NULL', + "\"_job_version\" varchar(64) DEFAULT ''", + ] + + # ========================================================================= + # Error Translation + # ========================================================================= + + def translate_error(self, error: Exception, query: str = "") -> Exception: + """ + Translate PostgreSQL error to DataJoint exception. + + Parameters + ---------- + error : Exception + PostgreSQL exception (typically psycopg2 error). + query : str, optional + SQL query that caused the error (for context). + + Returns + ------- + Exception + DataJoint exception or original error. + """ + if not hasattr(error, "pgcode"): + return error + + pgcode = error.pgcode + + # PostgreSQL error code mapping + # Reference: https://www.postgresql.org/docs/current/errcodes-appendix.html + match pgcode: + # Integrity constraint violations + case "23505": # unique_violation + return errors.DuplicateError(str(error)) + case "23503": # foreign_key_violation + return errors.IntegrityError(str(error)) + case "23502": # not_null_violation + return errors.MissingAttributeError(str(error)) + + # Syntax errors + case "42601": # syntax_error + return errors.QuerySyntaxError(str(error), "") + + # Undefined errors + case "42P01": # undefined_table + return errors.MissingTableError(str(error), "") + case "42703": # undefined_column + return errors.UnknownAttributeError(str(error)) + + # Connection errors + case "08006" | "08003" | "08000": # connection_failure + return errors.LostConnectionError(str(error)) + case "57P01": # admin_shutdown + return errors.LostConnectionError(str(error)) + + # Access errors + case "42501": # insufficient_privilege + return errors.AccessError("Insufficient privileges.", str(error), "") + + # All other errors pass through unchanged + case _: + return error + + # ========================================================================= + # Native Type Validation + # ========================================================================= + + def validate_native_type(self, type_str: str) -> bool: + """ + Check if a native PostgreSQL type string is valid. + + Parameters + ---------- + type_str : str + Type string to validate. + + Returns + ------- + bool + True if valid PostgreSQL type. + """ + type_lower = type_str.lower().strip() + + # PostgreSQL native types (simplified validation) + valid_types = { + # Integer types + "smallint", + "integer", + "int", + "bigint", + "smallserial", + "serial", + "bigserial", + # Floating point + "real", + "double precision", + "numeric", + "decimal", + # String types + "char", + "varchar", + "text", + # Binary + "bytea", + # Boolean + "boolean", + "bool", + # Temporal types + "date", + "time", + "timetz", + "timestamp", + "timestamptz", + "interval", + # UUID + "uuid", + # JSON + "json", + "jsonb", + # Network types + "inet", + "cidr", + "macaddr", + # Geometric types + "point", + "line", + "lseg", + "box", + "path", + "polygon", + "circle", + # Other + "money", + "xml", + } + + # Extract base type (before parentheses or brackets) + base_type = type_lower.split("(")[0].split("[")[0].strip() + + return base_type in valid_types + + # ========================================================================= + # PostgreSQL-Specific Enum Handling + # ========================================================================= + + def create_enum_type_sql( + self, + schema: str, + table: str, + column: str, + values: list[str], + ) -> str: + """ + Generate CREATE TYPE statement for PostgreSQL enum. + + Parameters + ---------- + schema : str + Schema name. + table : str + Table name. + column : str + Column name. + values : list[str] + Enum values. + + Returns + ------- + str + CREATE TYPE ... AS ENUM statement. + """ + type_name = f"{schema}_{table}_{column}_enum" + quoted_values = ", ".join(self.quote_string(v) for v in values) + return f"CREATE TYPE {self.quote_identifier(type_name)} AS ENUM ({quoted_values})" + + def drop_enum_type_sql(self, schema: str, table: str, column: str) -> str: + """ + Generate DROP TYPE statement for PostgreSQL enum. + + Parameters + ---------- + schema : str + Schema name. + table : str + Table name. + column : str + Column name. + + Returns + ------- + str + DROP TYPE statement. + """ + type_name = f"{schema}_{table}_{column}_enum" + return f"DROP TYPE IF EXISTS {self.quote_identifier(type_name)} CASCADE" + + def get_table_enum_types_sql(self, schema_name: str, table_name: str) -> str: + """ + Query to get enum types used by a table's columns. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + str + SQL query that returns enum type names (schema-qualified). + """ + return f""" + SELECT DISTINCT + n.nspname || '.' || t.typname as enum_type + FROM pg_catalog.pg_type t + JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + JOIN pg_catalog.pg_attribute a ON a.atttypid = t.oid + JOIN pg_catalog.pg_class c ON c.oid = a.attrelid + JOIN pg_catalog.pg_namespace cn ON cn.oid = c.relnamespace + WHERE t.typtype = 'e' + AND cn.nspname = {self.quote_string(schema_name)} + AND c.relname = {self.quote_string(table_name)} + """ + + def drop_enum_types_for_table(self, schema_name: str, table_name: str) -> list[str]: + """ + Generate DROP TYPE statements for all enum types used by a table. + + Parameters + ---------- + schema_name : str + Schema name. + table_name : str + Table name. + + Returns + ------- + list[str] + List of DROP TYPE IF EXISTS statements. + """ + # Returns list of DDL statements - caller should execute query first + # to get actual enum types, then call this with results + return [] # Placeholder - actual implementation requires query execution + + def drop_enum_type_ddl(self, enum_type_name: str) -> str: + """ + Generate DROP TYPE IF EXISTS statement for a PostgreSQL enum. + + Parameters + ---------- + enum_type_name : str + Fully qualified enum type name (schema.typename). + + Returns + ------- + str + DROP TYPE IF EXISTS statement with CASCADE. + """ + # Split schema.typename and quote each part + parts = enum_type_name.split(".") + if len(parts) == 2: + qualified_name = f"{self.quote_identifier(parts[0])}.{self.quote_identifier(parts[1])}" + else: + qualified_name = self.quote_identifier(enum_type_name) + return f"DROP TYPE IF EXISTS {qualified_name} CASCADE" diff --git a/src/datajoint/autopopulate.py b/src/datajoint/autopopulate.py index 02193f496..7660e43ec 100644 --- a/src/datajoint/autopopulate.py +++ b/src/datajoint/autopopulate.py @@ -11,9 +11,6 @@ import traceback from typing import TYPE_CHECKING, Any, Generator -import deepdiff -from tqdm import tqdm - from .errors import DataJointError, LostConnectionError from .expression import AndList, QueryExpression @@ -404,6 +401,8 @@ def _populate_direct( Computes keys directly from key_source, suitable for single-worker execution, development, and debugging. """ + from tqdm import tqdm + keys = (self._jobs_to_do(restrictions) - self).keys() logger.debug("Found %d keys to populate" % len(keys)) @@ -435,7 +434,9 @@ def _populate_direct( else: # spawn multiple processes self.connection.close() - del self.connection._conn.ctx # SSLContext is not pickleable + # Remove SSLContext if present (MySQL-specific, not pickleable) + if hasattr(self.connection._conn, "ctx"): + del self.connection._conn.ctx with ( mp.Pool(processes, _initialize_populate, (self, None, populate_kwargs)) as pool, tqdm(desc="Processes: ", total=nkeys) if display_progress else contextlib.nullcontext() as progress_bar, @@ -474,6 +475,8 @@ def _populate_distributed( Uses job table for multi-worker coordination, priority scheduling, and status tracking. """ + from tqdm import tqdm + from .settings import config # Define a signal handler for SIGTERM @@ -581,6 +584,8 @@ def _populate1( """ import time + import deepdiff + # use the legacy `_make_tuples` callback. make = self._make_tuples if hasattr(self, "_make_tuples") else self.make @@ -703,17 +708,26 @@ def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int] todo_sql = todo.make_sql() target_sql = self.make_sql() + # Get adapter for backend-specific quoting + adapter = self.connection.adapter + q = adapter.quote_identifier + + # Alias names for subqueries + ks_alias = q("$ks") + tgt_alias = q("$tgt") + # Build join condition on common attributes - join_cond = " AND ".join(f"`$ks`.`{attr}` = `$tgt`.`{attr}`" for attr in common_attrs) + join_cond = " AND ".join(f"{ks_alias}.{q(attr)} = {tgt_alias}.{q(attr)}" for attr in common_attrs) # Build DISTINCT key expression for counting unique jobs - # Use CONCAT for composite keys to create a single distinct value + # Use CONCAT_WS for composite keys (supported by both MySQL and PostgreSQL) if len(pk_attrs) == 1: - distinct_key = f"`$ks`.`{pk_attrs[0]}`" - null_check = f"`$tgt`.`{common_attrs[0]}`" + distinct_key = f"{ks_alias}.{q(pk_attrs[0])}" + null_check = f"{tgt_alias}.{q(common_attrs[0])}" else: - distinct_key = "CONCAT_WS('|', {})".format(", ".join(f"`$ks`.`{attr}`" for attr in pk_attrs)) - null_check = f"`$tgt`.`{common_attrs[0]}`" + key_cols = ", ".join(f"{ks_alias}.{q(attr)}" for attr in pk_attrs) + distinct_key = f"CONCAT_WS('|', {key_cols})" + null_check = f"{tgt_alias}.{q(common_attrs[0])}" # Single aggregation query: # - COUNT(DISTINCT key) gives total unique jobs in key_source @@ -722,8 +736,8 @@ def progress(self, *restrictions: Any, display: bool = False) -> tuple[int, int] SELECT COUNT(DISTINCT {distinct_key}) AS total, COUNT(DISTINCT CASE WHEN {null_check} IS NULL THEN {distinct_key} END) AS remaining - FROM ({todo_sql}) AS `$ks` - LEFT JOIN ({target_sql}) AS `$tgt` ON {join_cond} + FROM ({todo_sql}) AS {ks_alias} + LEFT JOIN ({target_sql}) AS {tgt_alias} ON {join_cond} """ result = self.connection.query(sql).fetchone() diff --git a/src/datajoint/blob.py b/src/datajoint/blob.py index d94417d6d..633f55b79 100644 --- a/src/datajoint/blob.py +++ b/src/datajoint/blob.py @@ -149,6 +149,9 @@ def squeeze(self, array: np.ndarray, convert_to_scalar: bool = True) -> np.ndarr return array.item() if array.ndim == 0 and convert_to_scalar else array def unpack(self, blob): + # PostgreSQL returns bytea as memoryview; convert to bytes for string operations + if isinstance(blob, memoryview): + blob = bytes(blob) self._blob = blob try: # decompress diff --git a/src/datajoint/builtin_codecs.py b/src/datajoint/builtin_codecs.py deleted file mode 100644 index c87ab4716..000000000 --- a/src/datajoint/builtin_codecs.py +++ /dev/null @@ -1,1286 +0,0 @@ -""" -Built-in DataJoint codecs. - -This module defines the standard codecs that ship with DataJoint. -These serve as both useful built-in codecs and as examples for users who -want to create their own custom codecs. - -Built-in Codecs: - - ````: Serialize Python objects (in-table storage) - - ````: Serialize Python objects (external with hash-addressed dedup) - - ````: File attachment (in-table storage) - - ````: File attachment (external with hash-addressed dedup) - - ````: Hash-addressed storage with MD5 deduplication (external only) - - ````: Schema-addressed storage for files/folders (external only) - - ````: Store numpy arrays as portable .npy files (external only) - - ````: Reference to existing file in store (external only) - -Example - Creating a Custom Codec: - Here's how to define your own codec, modeled after the built-in codecs:: - - import datajoint as dj - import networkx as nx - - class GraphCodec(dj.Codec): - '''Store NetworkX graphs as edge lists.''' - - name = "graph" # Use as in definitions - - def get_dtype(self, is_store: bool) -> str: - return "" # Compose with blob for serialization - - def encode(self, graph, *, key=None, store_name=None): - # Convert graph to a serializable format - return { - 'nodes': list(graph.nodes(data=True)), - 'edges': list(graph.edges(data=True)), - } - - def decode(self, stored, *, key=None): - # Reconstruct graph from stored format - G = nx.Graph() - G.add_nodes_from(stored['nodes']) - G.add_edges_from(stored['edges']) - return G - - def validate(self, value): - if not isinstance(value, nx.Graph): - raise TypeError(f"Expected nx.Graph, got {type(value).__name__}") - - # Now use in table definitions: - @schema - class Networks(dj.Manual): - definition = ''' - network_id : int - --- - topology : - ''' -""" - -from __future__ import annotations - -from typing import Any - -from .codecs import Codec -from .errors import DataJointError - - -# ============================================================================= -# Blob Codec - DataJoint's native serialization -# ============================================================================= - - -class BlobCodec(Codec): - """ - Serialize Python objects using DataJoint's blob format. - - The ```` codec handles serialization of arbitrary Python objects - including NumPy arrays, dictionaries, lists, datetime objects, and UUIDs. - - Supports both in-table and in-store storage: - - ````: Stored in database (bytes → LONGBLOB) - - ````: Stored in object store via ```` with deduplication - - ````: Stored in specific named store - - Format Features: - - Protocol headers (``mYm`` for MATLAB-compatible, ``dj0`` for Python-native) - - Optional zlib compression for data > 1KB - - Support for nested structures - - Example:: - - @schema - class ProcessedData(dj.Manual): - definition = ''' - data_id : int - --- - small_result : # in-table (in database) - large_result : # in-store (default store) - archive : # in-store (specific store) - ''' - - # Insert any serializable object - table.insert1({'data_id': 1, 'small_result': {'scores': [0.9, 0.8]}}) - """ - - name = "blob" - - def get_dtype(self, is_store: bool) -> str: - """Return bytes for in-table, for in-store storage.""" - return "" if is_store else "bytes" - - def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> bytes: - """Serialize a Python object to DataJoint's blob format.""" - from . import blob - - return blob.pack(value, compress=True) - - def decode(self, stored: bytes, *, key: dict | None = None) -> Any: - """Deserialize blob bytes back to a Python object.""" - from . import blob - - return blob.unpack(stored, squeeze=False) - - -# ============================================================================= -# Hash-Addressed Storage Codec -# ============================================================================= - - -class HashCodec(Codec): - """ - Hash-addressed storage with SHA256 deduplication. - - The ```` codec stores raw bytes using hash-addressed storage. - Data is identified by its SHA256 hash and stored in a hierarchical directory: - ``_hash/{hash[:2]}/{hash[2:4]}/{hash}`` - - The database column stores JSON metadata: ``{hash, store, size}``. - Duplicate content is automatically deduplicated across all tables. - - Deletion: Requires garbage collection via ``dj.gc.collect()``. - - External only - requires @ modifier. - - Example:: - - @schema - class RawContent(dj.Manual): - definition = ''' - content_id : int - --- - data : - ''' - - # Insert raw bytes - table.insert1({'content_id': 1, 'data': b'raw binary content'}) - - Note: - This codec accepts only ``bytes``. For Python objects, use ````. - Typically used indirectly via ```` or ```` rather than directly. - - See Also - -------- - datajoint.gc : Garbage collection for orphaned storage. - """ - - name = "hash" - - def get_dtype(self, is_store: bool) -> str: - """Hash storage is in-store only.""" - if not is_store: - raise DataJointError(" requires @ (in-store storage only)") - return "json" - - def encode(self, value: bytes, *, key: dict | None = None, store_name: str | None = None) -> dict: - """ - Store content and return metadata. - - Parameters - ---------- - value : bytes - Raw bytes to store. - key : dict, optional - Context dict with ``_schema`` for path isolation. - store_name : str, optional - Store to use. If None, uses default store. - - Returns - ------- - dict - Metadata dict: ``{hash, path, schema, store, size}``. - """ - from .hash_registry import put_hash - - schema_name = (key or {}).get("_schema", "unknown") - return put_hash(value, schema_name=schema_name, store_name=store_name) - - def decode(self, stored: dict, *, key: dict | None = None) -> bytes: - """ - Retrieve content using stored metadata. - - Parameters - ---------- - stored : dict - Metadata dict with ``'path'``, ``'hash'``, and optionally ``'store'``. - key : dict, optional - Context dict (unused - path is in metadata). - - Returns - ------- - bytes - Original bytes. - """ - from .hash_registry import get_hash - - return get_hash(stored) - - def validate(self, value: Any) -> None: - """Validate that value is bytes.""" - if not isinstance(value, bytes): - raise TypeError(f" expects bytes, got {type(value).__name__}") - - -# ============================================================================= -# Schema-Addressed Storage Base Class -# ============================================================================= - - -class SchemaCodec(Codec, register=False): - """ - Abstract base class for schema-addressed codecs. - - Schema-addressed storage is an OAS (Object-Augmented Schema) addressing - scheme where paths mirror the database schema structure: - ``{schema}/{table}/{pk}/{attribute}``. This creates a browsable - organization in object storage that reflects the schema design. - - Subclasses must implement: - - ``name``: Codec name for ```` syntax - - ``encode()``: Serialize and upload content - - ``decode()``: Create lazy reference from metadata - - ``validate()``: Validate input values - - Helper Methods: - - ``_extract_context()``: Parse key dict into schema/table/field/pk - - ``_build_path()``: Construct storage path from context - - ``_get_backend()``: Get storage backend by name - - Comparison with Hash-addressed: - - **Schema-addressed** (this): Path from schema structure, no dedup - - **Hash-addressed**: Path from content hash, automatic dedup - - Example:: - - class MyCodec(SchemaCodec): - name = "my" - - def encode(self, value, *, key=None, store_name=None): - schema, table, field, pk = self._extract_context(key) - path, _ = self._build_path(schema, table, field, pk, ext=".dat") - backend = self._get_backend(store_name) - backend.put_buffer(serialize(value), path) - return {"path": path, "store": store_name, ...} - - def decode(self, stored, *, key=None): - backend = self._get_backend(stored.get("store")) - return MyRef(stored, backend) - - See Also - -------- - HashCodec : Hash-addressed storage with content deduplication. - ObjectCodec : Schema-addressed storage for files/folders. - NpyCodec : Schema-addressed storage for numpy arrays. - """ - - def get_dtype(self, is_store: bool) -> str: - """ - Return storage dtype. Schema-addressed codecs require @ modifier. - - Parameters - ---------- - is_store : bool - Must be True for schema-addressed codecs. - - Returns - ------- - str - "json" for metadata storage. - - Raises - ------ - DataJointError - If is_store is False (@ modifier missing). - """ - if not is_store: - raise DataJointError(f"<{self.name}> requires @ (store only)") - return "json" - - def _extract_context(self, key: dict | None) -> tuple[str, str, str, dict]: - """ - Extract schema, table, field, and primary key from context dict. - - Parameters - ---------- - key : dict or None - Context dict with ``_schema``, ``_table``, ``_field``, - and primary key values. - - Returns - ------- - tuple[str, str, str, dict] - ``(schema, table, field, primary_key)`` - """ - key = dict(key) if key else {} - schema = key.pop("_schema", "unknown") - table = key.pop("_table", "unknown") - field = key.pop("_field", "data") - primary_key = {k: v for k, v in key.items() if not k.startswith("_")} - return schema, table, field, primary_key - - def _build_path( - self, - schema: str, - table: str, - field: str, - primary_key: dict, - ext: str | None = None, - store_name: str | None = None, - ) -> tuple[str, str]: - """ - Build schema-addressed storage path. - - Constructs a path that mirrors the database schema structure: - ``{schema}/{table}/{pk_values}/{field}{ext}`` - - Supports partitioning if configured in the store. - - Parameters - ---------- - schema : str - Schema name. - table : str - Table name. - field : str - Field/attribute name. - primary_key : dict - Primary key values. - ext : str, optional - File extension (e.g., ".npy", ".zarr"). - store_name : str, optional - Store name for retrieving partition configuration. - - Returns - ------- - tuple[str, str] - ``(path, token)`` where path is the storage path and token - is a unique identifier. - """ - from .storage import build_object_path - from . import config - - # Get store configuration for partition_pattern and token_length - spec = config.get_store_spec(store_name) - partition_pattern = spec.get("partition_pattern") - token_length = spec.get("token_length", 8) - - return build_object_path( - schema=schema, - table=table, - field=field, - primary_key=primary_key, - ext=ext, - partition_pattern=partition_pattern, - token_length=token_length, - ) - - def _get_backend(self, store_name: str | None = None): - """ - Get storage backend by name. - - Parameters - ---------- - store_name : str, optional - Store name. If None, returns default store. - - Returns - ------- - StorageBackend - Storage backend instance. - """ - from .hash_registry import get_store_backend - - return get_store_backend(store_name) - - -# ============================================================================= -# Object Codec (Schema-Addressed Files/Folders) -# ============================================================================= - - -class ObjectCodec(SchemaCodec): - """ - Schema-addressed storage for files and folders. - - The ```` codec provides managed file/folder storage using - schema-addressed paths: ``{schema}/{table}/{pk}/{field}/``. This creates - a browsable organization in object storage that mirrors the database schema. - - Unlike hash-addressed storage (````), each row has its own unique path - (no deduplication). Ideal for: - - - Zarr arrays (hierarchical chunked data) - - HDF5 files - - Complex multi-file outputs - - Any content that shouldn't be deduplicated - - Store only - requires @ modifier. - - Example:: - - @schema - class Analysis(dj.Computed): - definition = ''' - -> Recording - --- - results : - ''' - - def make(self, key): - # Store a file - self.insert1({**key, 'results': '/path/to/results.zarr'}) - - # Fetch returns ObjectRef for lazy access - ref = (Analysis & key).fetch1('results') - ref.path # Storage path - ref.read() # Read file content - ref.fsmap # For zarr.open(ref.fsmap) - - Storage Structure: - Objects are stored at:: - - {store_root}/{schema}/{table}/{pk}/{field}/ - - Deletion: Requires garbage collection via ``dj.gc.collect()``. - - Comparison with hash-addressed:: - - | Aspect | | | - |----------------|---------------------|---------------------| - | Addressing | Schema-addressed | Hash-addressed | - | Deduplication | No | Yes | - | Deletion | GC required | GC required | - | Use case | Zarr, HDF5 | Blobs, attachments | - - See Also - -------- - datajoint.gc : Garbage collection for orphaned storage. - SchemaCodec : Base class for schema-addressed codecs. - NpyCodec : Schema-addressed storage for numpy arrays. - HashCodec : Hash-addressed storage with deduplication. - """ - - name = "object" - - def encode( - self, - value: Any, - *, - key: dict | None = None, - store_name: str | None = None, - ) -> dict: - """ - Store content and return metadata. - - Parameters - ---------- - value : bytes, str, or Path - Content to store: bytes (raw data), or str/Path (file/folder to upload). - key : dict, optional - Context for path construction with keys ``_schema``, ``_table``, - ``_field``, plus primary key values. - store_name : str, optional - Store to use. If None, uses default store. - - Returns - ------- - dict - Metadata dict suitable for ``ObjectRef.from_json()``. - """ - from datetime import datetime, timezone - from pathlib import Path - - # Extract context using inherited helper - schema, table, field, primary_key = self._extract_context(key) - - # Check for pre-computed metadata (from staged insert) - if isinstance(value, dict) and "path" in value: - # Already encoded, pass through - return value - - # Determine content type and extension - is_dir = False - ext = None - size = None - item_count = None - - if isinstance(value, bytes): - content = value - size = len(content) - elif isinstance(value, tuple) and len(value) == 2: - # Tuple format: (extension, data) where data is bytes or file-like - ext, data = value - if hasattr(data, "read"): - content = data.read() - else: - content = data - size = len(content) - elif isinstance(value, (str, Path)): - source_path = Path(value) - if not source_path.exists(): - raise DataJointError(f"Source path not found: {source_path}") - is_dir = source_path.is_dir() - ext = source_path.suffix if not is_dir else None - if is_dir: - # For directories, we'll upload later - content = None - # Count items in directory - item_count = sum(1 for _ in source_path.rglob("*") if _.is_file()) - else: - content = source_path.read_bytes() - size = len(content) - else: - raise TypeError(f" expects bytes or path, got {type(value).__name__}") - - # Build storage path using inherited helper - path, token = self._build_path(schema, table, field, primary_key, ext=ext, store_name=store_name) - - # Get storage backend using inherited helper - backend = self._get_backend(store_name) - - # Upload content - if is_dir: - # Upload directory recursively - source_path = Path(value) - backend.put_folder(str(source_path), path) - # Compute size by summing all files - size = sum(f.stat().st_size for f in source_path.rglob("*") if f.is_file()) - else: - backend.put_buffer(content, path) - - # Build metadata - timestamp = datetime.now(timezone.utc) - metadata = { - "path": path, - "store": store_name, - "size": size, - "ext": ext, - "is_dir": is_dir, - "item_count": item_count, - "timestamp": timestamp.isoformat(), - } - - return metadata - - def decode(self, stored: dict, *, key: dict | None = None) -> Any: - """ - Create ObjectRef handle for lazy access. - - Parameters - ---------- - stored : dict - Metadata dict from database. - key : dict, optional - Primary key values (unused). - - Returns - ------- - ObjectRef - Handle for accessing the stored content. - """ - from .objectref import ObjectRef - - backend = self._get_backend(stored.get("store")) - return ObjectRef.from_json(stored, backend=backend) - - def validate(self, value: Any) -> None: - """Validate value is bytes, path, dict metadata, or (ext, data) tuple.""" - from pathlib import Path - - if isinstance(value, bytes): - return - if isinstance(value, (str, Path)): - # Could be a path or pre-encoded JSON string - return - if isinstance(value, tuple) and len(value) == 2: - # Tuple format: (extension, data) - return - if isinstance(value, dict) and "path" in value: - # Pre-computed metadata dict (from staged insert) - return - raise TypeError(f" expects bytes or path, got {type(value).__name__}") - - -# ============================================================================= -# File Attachment Codecs -# ============================================================================= - - -class AttachCodec(Codec): - """ - File attachment with filename preserved. - - Supports both in-table and in-store storage: - - ````: Stored in database (bytes → LONGBLOB) - - ````: Stored in object store via ```` with deduplication - - ````: Stored in specific named store - - The filename is preserved and the file is extracted to the configured - download path on fetch. - - Example:: - - @schema - class Documents(dj.Manual): - definition = ''' - doc_id : int - --- - config : # in-table (small file in DB) - dataset : # in-store (default store) - archive : # in-store (specific store) - ''' - - # Insert a file - table.insert1({'doc_id': 1, 'config': '/path/to/config.json'}) - - # Fetch extracts to download_path and returns local path - local_path = (table & 'doc_id=1').fetch1('config') - - Storage Format (internal): - The blob contains: ``filename\\0contents`` - - Filename (UTF-8 encoded) + null byte + raw file contents - """ - - name = "attach" - - def get_dtype(self, is_store: bool) -> str: - """Return bytes for in-table, for in-store storage.""" - return "" if is_store else "bytes" - - def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> bytes: - """ - Read file and encode as filename + contents. - - Parameters - ---------- - value : str or Path - Path to file. - key : dict, optional - Primary key values (unused). - store_name : str, optional - Unused for internal storage. - - Returns - ------- - bytes - Filename (UTF-8) + null byte + file contents. - """ - from pathlib import Path - - path = Path(value) - if not path.exists(): - raise FileNotFoundError(f"Attachment file not found: {path}") - if path.is_dir(): - raise IsADirectoryError(f" does not support directories: {path}") - - filename = path.name - contents = path.read_bytes() - return filename.encode("utf-8") + b"\x00" + contents - - def decode(self, stored: bytes, *, key: dict | None = None) -> str: - """ - Extract file to download path and return local path. - - Parameters - ---------- - stored : bytes - Blob containing filename + null + contents. - key : dict, optional - Primary key values (unused). - - Returns - ------- - str - Path to extracted file. - """ - from pathlib import Path - - from .settings import config - - # Split on first null byte - null_pos = stored.index(b"\x00") - filename = stored[:null_pos].decode("utf-8") - contents = stored[null_pos + 1 :] - - # Write to download path - download_path = Path(config.get("download_path", ".")) - download_path.mkdir(parents=True, exist_ok=True) - local_path = download_path / filename - - # Handle filename collision - if file exists with different content, add suffix - if local_path.exists(): - existing_contents = local_path.read_bytes() - if existing_contents != contents: - # Find unique filename - stem = local_path.stem - suffix = local_path.suffix - counter = 1 - while local_path.exists() and local_path.read_bytes() != contents: - local_path = download_path / f"{stem}_{counter}{suffix}" - counter += 1 - - # Only write if file doesn't exist or has different content - if not local_path.exists(): - local_path.write_bytes(contents) - - return str(local_path) - - def validate(self, value: Any) -> None: - """Validate that value is a valid file path.""" - from pathlib import Path - - if not isinstance(value, (str, Path)): - raise TypeError(f" expects a file path, got {type(value).__name__}") - - -# ============================================================================= -# Filepath Reference Codec -# ============================================================================= - - -class FilepathCodec(Codec): - """ - Reference to existing file in configured store. - - The ```` codec stores a reference to a file that already - exists in the storage backend. Unlike ```` or ````, no - file copying occurs - only the path is recorded. - - External only - requires @store. - - This codec gives users maximum freedom in organizing their files while - reusing DataJoint's store configuration. Files can be placed anywhere - in the store EXCEPT the reserved ``_hash/`` and ``_schema/`` sections - which are managed by DataJoint. - - This is useful when: - - Files are managed externally (e.g., by acquisition software) - - Files are too large to copy - - You want to reference shared datasets - - You need custom directory structures - - Example:: - - @schema - class Recordings(dj.Manual): - definition = ''' - recording_id : int - --- - raw_data : - ''' - - # Reference an existing file (no copy) - # Path is relative to store location - table.insert1({'recording_id': 1, 'raw_data': 'subject01/session001/data.bin'}) - - # Fetch returns ObjectRef for lazy access - ref = (table & 'recording_id=1').fetch1('raw_data') - ref.read() # Read file content - ref.download() # Download to local path - - Storage Format: - JSON metadata: ``{path, store, size, timestamp}`` - - Reserved Sections: - Paths cannot start with ``_hash/`` or ``_schema/`` - these are managed by DataJoint. - - Warning: - The file must exist in the store at the specified path. - DataJoint does not manage the lifecycle of referenced files. - """ - - name = "filepath" - - def get_dtype(self, is_store: bool) -> str: - """Filepath is external only.""" - if not is_store: - raise DataJointError( - " requires @ symbol. Use for default store " "or to specify store." - ) - return "json" - - def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> dict: - """ - Store path reference as JSON metadata. - - Parameters - ---------- - value : str - Relative path within the store. Cannot use reserved sections (_hash/, _schema/). - key : dict, optional - Primary key values (unused). - store_name : str, optional - Store where the file exists. - - Returns - ------- - dict - Metadata dict: ``{path, store}``. - - Raises - ------ - ValueError - If path uses reserved sections (_hash/ or _schema/). - FileNotFoundError - If file does not exist in the store. - """ - from datetime import datetime, timezone - - from . import config - from .hash_registry import get_store_backend - - path = str(value) - - # Get store spec to check prefix configuration - # Use filepath_default if no store specified (filepath is not part of OAS) - spec = config.get_store_spec(store_name, use_filepath_default=True) - - # Validate path doesn't use reserved sections (hash and schema) - path_normalized = path.lstrip("/") - reserved_prefixes = [] - - hash_prefix = spec.get("hash_prefix") - if hash_prefix: - reserved_prefixes.append(("hash_prefix", hash_prefix)) - - schema_prefix = spec.get("schema_prefix") - if schema_prefix: - reserved_prefixes.append(("schema_prefix", schema_prefix)) - - # Check if path starts with any reserved prefix - for prefix_name, prefix_value in reserved_prefixes: - prefix_normalized = prefix_value.strip("/") + "/" - if path_normalized.startswith(prefix_normalized): - raise ValueError( - f" cannot use reserved section '{prefix_value}' ({prefix_name}). " - f"This section is managed by DataJoint. " - f"Got path: {path}" - ) - - # If filepath_prefix is configured, enforce it - filepath_prefix = spec.get("filepath_prefix") - if filepath_prefix: - filepath_prefix_normalized = filepath_prefix.strip("/") + "/" - if not path_normalized.startswith(filepath_prefix_normalized): - raise ValueError(f" must use prefix '{filepath_prefix}' (filepath_prefix). " f"Got path: {path}") - - # Verify file exists - backend = get_store_backend(store_name) - if not backend.exists(path): - raise FileNotFoundError(f"File not found in store '{store_name or 'default'}': {path}") - - # Get file info - try: - size = backend.size(path) - except Exception: - size = None - - return { - "path": path, - "store": store_name, - "size": size, - "is_dir": False, - "timestamp": datetime.now(timezone.utc).isoformat(), - } - - def decode(self, stored: dict, *, key: dict | None = None) -> Any: - """ - Create ObjectRef handle for lazy access. - - Parameters - ---------- - stored : dict - Metadata dict with path and store. - key : dict, optional - Primary key values (unused). - - Returns - ------- - ObjectRef - Handle for accessing the file. - """ - from .objectref import ObjectRef - from .hash_registry import get_store_backend - - store_name = stored.get("store") - backend = get_store_backend(store_name) - return ObjectRef.from_json(stored, backend=backend) - - def validate(self, value: Any) -> None: - """Validate that value is a path string or Path object.""" - from pathlib import Path - - if not isinstance(value, (str, Path)): - raise TypeError(f" expects a path string or Path, got {type(value).__name__}") - - -# ============================================================================= -# NumPy Array Codec (.npy format) -# ============================================================================= - - -class NpyRef: - """ - Lazy reference to a numpy array stored as a .npy file. - - This class provides metadata access without I/O and transparent - integration with numpy operations via the ``__array__`` protocol. - - Attributes - ---------- - shape : tuple[int, ...] - Array shape (from metadata, no I/O). - dtype : numpy.dtype - Array dtype (from metadata, no I/O). - path : str - Storage path within the store. - store : str or None - Store name (None for default). - - Examples - -------- - Metadata access without download:: - - ref = (Recording & key).fetch1('waveform') - print(ref.shape) # (1000, 32) - no download - print(ref.dtype) # float64 - no download - - Explicit loading:: - - arr = ref.load() # Downloads and returns np.ndarray - - Transparent numpy integration:: - - # These all trigger automatic download via __array__ - result = ref + 1 - result = np.mean(ref) - result = ref[0:100] # Slicing works too - """ - - __slots__ = ("_meta", "_backend", "_cached") - - def __init__(self, metadata: dict, backend: Any): - """ - Initialize NpyRef from metadata and storage backend. - - Parameters - ---------- - metadata : dict - JSON metadata containing path, store, dtype, shape. - backend : StorageBackend - Storage backend for file operations. - """ - self._meta = metadata - self._backend = backend - self._cached = None - - @property - def shape(self) -> tuple: - """Array shape (no I/O required).""" - return tuple(self._meta["shape"]) - - @property - def dtype(self): - """Array dtype (no I/O required).""" - import numpy as np - - return np.dtype(self._meta["dtype"]) - - @property - def ndim(self) -> int: - """Number of dimensions (no I/O required).""" - return len(self._meta["shape"]) - - @property - def size(self) -> int: - """Total number of elements (no I/O required).""" - import math - - return math.prod(self._meta["shape"]) - - @property - def nbytes(self) -> int: - """Total bytes (estimated from shape and dtype, no I/O required).""" - return self.size * self.dtype.itemsize - - @property - def path(self) -> str: - """Storage path within the store.""" - return self._meta["path"] - - @property - def store(self) -> str | None: - """Store name (None for default store).""" - return self._meta.get("store") - - @property - def is_loaded(self) -> bool: - """True if array data has been downloaded and cached.""" - return self._cached is not None - - def load(self, mmap_mode=None): - """ - Download and return the array. - - Parameters - ---------- - mmap_mode : str, optional - Memory-map mode for lazy, random-access loading of large arrays: - - - ``'r'``: Read-only - - ``'r+'``: Read-write - - ``'c'``: Copy-on-write (changes not saved to disk) - - If None (default), loads entire array into memory. - - Returns - ------- - numpy.ndarray or numpy.memmap - The array data. Returns ``numpy.memmap`` if mmap_mode is specified. - - Notes - ----- - When ``mmap_mode`` is None, the array is cached after first load. - - For local filesystem stores, memory mapping accesses the file directly - with no download. For remote stores (S3, etc.), the file is downloaded - to a local cache (``{tempdir}/datajoint_mmap/``) before memory mapping. - - Examples - -------- - Standard loading:: - - arr = ref.load() # Loads entire array into memory - - Memory-mapped for random access to large arrays:: - - arr = ref.load(mmap_mode='r') - slice = arr[1000:2000] # Only reads the needed portion from disk - """ - import io - - import numpy as np - - if mmap_mode is None: - # Standard loading with caching - if self._cached is None: - buffer = self._backend.get_buffer(self.path) - self._cached = np.load(io.BytesIO(buffer), allow_pickle=False) - return self._cached - else: - # Memory-mapped loading - if self._backend.protocol == "file": - # Local filesystem - mmap directly, no download needed - local_path = self._backend._full_path(self.path) - return np.load(local_path, mmap_mode=mmap_mode, allow_pickle=False) - else: - # Remote storage - download to local cache first - import hashlib - import tempfile - from pathlib import Path - - path_hash = hashlib.md5(self.path.encode()).hexdigest()[:12] - cache_dir = Path(tempfile.gettempdir()) / "datajoint_mmap" - cache_dir.mkdir(exist_ok=True) - cache_path = cache_dir / f"{path_hash}.npy" - - if not cache_path.exists(): - buffer = self._backend.get_buffer(self.path) - cache_path.write_bytes(buffer) - - return np.load(str(cache_path), mmap_mode=mmap_mode, allow_pickle=False) - - def __array__(self, dtype=None): - """ - NumPy array protocol for transparent integration. - - This method is called automatically when the NpyRef is used - in numpy operations (arithmetic, ufuncs, etc.). - - Parameters - ---------- - dtype : numpy.dtype, optional - Desired output dtype. - - Returns - ------- - numpy.ndarray - The array data, optionally cast to dtype. - """ - arr = self.load() - if dtype is not None: - return arr.astype(dtype) - return arr - - def __getitem__(self, key): - """Support indexing/slicing by loading then indexing.""" - return self.load()[key] - - def __len__(self) -> int: - """Length of first dimension.""" - if not self._meta["shape"]: - raise TypeError("len() of 0-dimensional array") - return self._meta["shape"][0] - - def __repr__(self) -> str: - status = "loaded" if self.is_loaded else "not loaded" - return f"NpyRef(shape={self.shape}, dtype={self.dtype}, {status})" - - def __str__(self) -> str: - return repr(self) - - -class NpyCodec(SchemaCodec): - """ - Schema-addressed storage for numpy arrays as .npy files. - - The ```` codec stores numpy arrays as standard ``.npy`` files - using schema-addressed paths: ``{schema}/{table}/{pk}/{attribute}.npy``. - Arrays are fetched lazily via ``NpyRef``, which provides metadata access - without I/O and transparent numpy integration via ``__array__``. - - Store only - requires ``@`` modifier. - - Key Features: - - **Portable**: Standard .npy format readable by numpy, MATLAB, etc. - - **Lazy loading**: Metadata (shape, dtype) available without download - - **Transparent**: Use in numpy operations triggers automatic download - - **Safe bulk fetch**: Fetching many rows doesn't download until needed - - **Schema-addressed**: Browsable paths that mirror database structure - - Example:: - - @schema - class Recording(dj.Manual): - definition = ''' - recording_id : int - --- - waveform : # default store - spectrogram : # specific store - ''' - - # Insert - just pass the array - Recording.insert1({ - 'recording_id': 1, - 'waveform': np.random.randn(1000, 32), - }) - - # Fetch - returns NpyRef (lazy) - ref = (Recording & 'recording_id=1').fetch1('waveform') - ref.shape # (1000, 32) - no download - ref.dtype # float64 - no download - - # Use in numpy ops - downloads automatically - result = np.mean(ref, axis=0) - - # Or load explicitly - arr = ref.load() - - Storage Details: - - File format: NumPy .npy (version 1.0 or 2.0) - - Path: ``{schema}/{table}/{pk}/{attribute}.npy`` - - Database column: JSON with ``{path, store, dtype, shape}`` - - Deletion: Requires garbage collection via ``dj.gc.collect()``. - - See Also - -------- - datajoint.gc : Garbage collection for orphaned storage. - NpyRef : The lazy array reference returned on fetch. - SchemaCodec : Base class for schema-addressed codecs. - ObjectCodec : Schema-addressed storage for files/folders. - """ - - name = "npy" - - def validate(self, value: Any) -> None: - """ - Validate that value is a numpy array suitable for .npy storage. - - Parameters - ---------- - value : Any - Value to validate. - - Raises - ------ - DataJointError - If value is not a numpy array or has object dtype. - """ - import numpy as np - - if not isinstance(value, np.ndarray): - raise DataJointError(f" requires numpy.ndarray, got {type(value).__name__}") - if value.dtype == object: - raise DataJointError(" does not support object dtype arrays") - - def encode( - self, - value: Any, - *, - key: dict | None = None, - store_name: str | None = None, - ) -> dict: - """ - Serialize array to .npy and upload to storage. - - Parameters - ---------- - value : numpy.ndarray - Array to store. - key : dict, optional - Context dict with ``_schema``, ``_table``, ``_field``, - and primary key values for path construction. - store_name : str, optional - Target store. If None, uses default store. - - Returns - ------- - dict - JSON metadata: ``{path, store, dtype, shape}``. - """ - import io - - import numpy as np - - # Extract context using inherited helper - schema, table, field, primary_key = self._extract_context(key) - - # Build schema-addressed storage path - path, _ = self._build_path(schema, table, field, primary_key, ext=".npy", store_name=store_name) - - # Serialize to .npy format - buffer = io.BytesIO() - np.save(buffer, value, allow_pickle=False) - npy_bytes = buffer.getvalue() - - # Upload to storage using inherited helper - backend = self._get_backend(store_name) - backend.put_buffer(npy_bytes, path) - - # Return metadata (includes numpy-specific shape/dtype) - return { - "path": path, - "store": store_name, - "dtype": str(value.dtype), - "shape": list(value.shape), - } - - def decode(self, stored: dict, *, key: dict | None = None) -> NpyRef: - """ - Create lazy NpyRef from stored metadata. - - Parameters - ---------- - stored : dict - JSON metadata from database. - key : dict, optional - Primary key values (unused). - - Returns - ------- - NpyRef - Lazy array reference with metadata access and numpy integration. - """ - backend = self._get_backend(stored.get("store")) - return NpyRef(stored, backend) diff --git a/src/datajoint/builtin_codecs/__init__.py b/src/datajoint/builtin_codecs/__init__.py new file mode 100644 index 000000000..1f2dd2ec7 --- /dev/null +++ b/src/datajoint/builtin_codecs/__init__.py @@ -0,0 +1,77 @@ +""" +Built-in DataJoint codecs. + +This package defines the standard codecs that ship with DataJoint. +These serve as both useful built-in codecs and as examples for users who +want to create their own custom codecs. + +Built-in Codecs: + - ````: Serialize Python objects (in-table storage) + - ````: Serialize Python objects (in-store with hash-addressed dedup) + - ````: File attachment (in-table storage) + - ````: File attachment (in-store with hash-addressed dedup) + - ````: Hash-addressed storage with MD5 deduplication (store only) + - ````: Schema-addressed storage for files/folders (store only) + - ````: Store numpy arrays as portable .npy files (store only) + - ````: Reference to existing file in store (store only) + +Example - Creating a Custom Codec: + Here's how to define your own codec, modeled after the built-in codecs:: + + import datajoint as dj + import networkx as nx + + class GraphCodec(dj.Codec): + '''Store NetworkX graphs as edge lists.''' + + name = "graph" # Use as in definitions + + def get_dtype(self, is_store: bool) -> str: + return "" # Compose with blob for serialization + + def encode(self, graph, *, key=None, store_name=None): + # Convert graph to a serializable format + return { + 'nodes': list(graph.nodes(data=True)), + 'edges': list(graph.edges(data=True)), + } + + def decode(self, stored, *, key=None): + # Reconstruct graph from stored format + G = nx.Graph() + G.add_nodes_from(stored['nodes']) + G.add_edges_from(stored['edges']) + return G + + def validate(self, value): + if not isinstance(value, nx.Graph): + raise TypeError(f"Expected nx.Graph, got {type(value).__name__}") + + # Now use in table definitions: + @schema + class Networks(dj.Manual): + definition = ''' + network_id : int + --- + topology : + ''' +""" + +from .attach import AttachCodec +from .blob import BlobCodec +from .filepath import FilepathCodec +from .hash import HashCodec +from .npy import NpyCodec, NpyRef +from .object import ObjectCodec +from .schema import SchemaCodec + +__all__ = [ + "BlobCodec", + "HashCodec", + "SchemaCodec", + "ObjectCodec", + "AttachCodec", + "FilepathCodec", + "NpyCodec", + "NpyRef", +] diff --git a/src/datajoint/builtin_codecs/attach.py b/src/datajoint/builtin_codecs/attach.py new file mode 100644 index 000000000..f9a454b1a --- /dev/null +++ b/src/datajoint/builtin_codecs/attach.py @@ -0,0 +1,136 @@ +""" +File attachment codec with filename preservation. +""" + +from __future__ import annotations + +from typing import Any + +from ..codecs import Codec + + +class AttachCodec(Codec): + """ + File attachment with filename preserved. + + Supports both in-table and in-store storage: + - ````: Stored in database (bytes → LONGBLOB) + - ````: Stored in object store via ```` with deduplication + - ````: Stored in specific named store + + The filename is preserved and the file is extracted to the configured + download path on fetch. + + Example:: + + @schema + class Documents(dj.Manual): + definition = ''' + doc_id : int + --- + config : # in-table (small file in DB) + dataset : # in-store (default store) + archive : # in-store (specific store) + ''' + + # Insert a file + table.insert1({'doc_id': 1, 'config': '/path/to/config.json'}) + + # Fetch extracts to download_path and returns local path + local_path = (table & 'doc_id=1').fetch1('config') + + Storage Format (internal): + The blob contains: ``filename\\0contents`` + - Filename (UTF-8 encoded) + null byte + raw file contents + """ + + name = "attach" + + def get_dtype(self, is_store: bool) -> str: + """Return bytes for in-table, for in-store storage.""" + return "" if is_store else "bytes" + + def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> bytes: + """ + Read file and encode as filename + contents. + + Parameters + ---------- + value : str or Path + Path to file. + key : dict, optional + Primary key values (unused). + store_name : str, optional + Unused for internal storage. + + Returns + ------- + bytes + Filename (UTF-8) + null byte + file contents. + """ + from pathlib import Path + + path = Path(value) + if not path.exists(): + raise FileNotFoundError(f"Attachment file not found: {path}") + if path.is_dir(): + raise IsADirectoryError(f" does not support directories: {path}") + + filename = path.name + contents = path.read_bytes() + return filename.encode("utf-8") + b"\x00" + contents + + def decode(self, stored: bytes, *, key: dict | None = None) -> str: + """ + Extract file to download path and return local path. + + Parameters + ---------- + stored : bytes + Blob containing filename + null + contents. + key : dict, optional + Primary key values (unused). + + Returns + ------- + str + Path to extracted file. + """ + from pathlib import Path + + from ..settings import config + + # Split on first null byte + null_pos = stored.index(b"\x00") + filename = stored[:null_pos].decode("utf-8") + contents = stored[null_pos + 1 :] + + # Write to download path + download_path = Path(config.get("download_path", ".")) + download_path.mkdir(parents=True, exist_ok=True) + local_path = download_path / filename + + # Handle filename collision - if file exists with different content, add suffix + if local_path.exists(): + existing_contents = local_path.read_bytes() + if existing_contents != contents: + # Find unique filename + stem = local_path.stem + suffix = local_path.suffix + counter = 1 + while local_path.exists() and local_path.read_bytes() != contents: + local_path = download_path / f"{stem}_{counter}{suffix}" + counter += 1 + + # Only write if file doesn't exist or has different content + if not local_path.exists(): + local_path.write_bytes(contents) + + return str(local_path) + + def validate(self, value: Any) -> None: + """Validate that value is a valid file path.""" + from pathlib import Path + + if not isinstance(value, (str, Path)): + raise TypeError(f" expects a file path, got {type(value).__name__}") diff --git a/src/datajoint/builtin_codecs/blob.py b/src/datajoint/builtin_codecs/blob.py new file mode 100644 index 000000000..ff65161f4 --- /dev/null +++ b/src/datajoint/builtin_codecs/blob.py @@ -0,0 +1,61 @@ +""" +Blob codec for Python object serialization. +""" + +from __future__ import annotations + +from typing import Any + +from ..codecs import Codec + + +class BlobCodec(Codec): + """ + Serialize Python objects using DataJoint's blob format. + + The ```` codec handles serialization of arbitrary Python objects + including NumPy arrays, dictionaries, lists, datetime objects, and UUIDs. + + Supports both in-table and in-store storage: + - ````: Stored in database (bytes → LONGBLOB) + - ````: Stored in object store via ```` with deduplication + - ````: Stored in specific named store + + Format Features: + - Protocol headers (``mYm`` for MATLAB-compatible, ``dj0`` for Python-native) + - Optional zlib compression for data > 1KB + - Support for nested structures + + Example:: + + @schema + class ProcessedData(dj.Manual): + definition = ''' + data_id : int + --- + small_result : # in-table (in database) + large_result : # in-store (default store) + archive : # in-store (specific store) + ''' + + # Insert any serializable object + table.insert1({'data_id': 1, 'small_result': {'scores': [0.9, 0.8]}}) + """ + + name = "blob" + + def get_dtype(self, is_store: bool) -> str: + """Return bytes for in-table, for in-store storage.""" + return "" if is_store else "bytes" + + def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> bytes: + """Serialize a Python object to DataJoint's blob format.""" + from .. import blob + + return blob.pack(value, compress=True) + + def decode(self, stored: bytes, *, key: dict | None = None) -> Any: + """Deserialize blob bytes back to a Python object.""" + from .. import blob + + return blob.unpack(stored, squeeze=False) diff --git a/src/datajoint/builtin_codecs/filepath.py b/src/datajoint/builtin_codecs/filepath.py new file mode 100644 index 000000000..9c05b2385 --- /dev/null +++ b/src/datajoint/builtin_codecs/filepath.py @@ -0,0 +1,186 @@ +""" +Filepath reference codec for existing files in storage. +""" + +from __future__ import annotations + +from typing import Any + +from ..codecs import Codec +from ..errors import DataJointError + + +class FilepathCodec(Codec): + """ + Reference to existing file in configured store. + + The ```` codec stores a reference to a file that already + exists in the storage backend. Unlike ```` or ````, no + file copying occurs - only the path is recorded. + + Store only - requires @store. + + This codec gives users maximum freedom in organizing their files while + reusing DataJoint's store configuration. Files can be placed anywhere + in the store EXCEPT the reserved ``_hash/`` and ``_schema/`` sections + which are managed by DataJoint. + + This is useful when: + - Files are managed externally (e.g., by acquisition software) + - Files are too large to copy + - You want to reference shared datasets + - You need custom directory structures + + Example:: + + @schema + class Recordings(dj.Manual): + definition = ''' + recording_id : int + --- + raw_data : + ''' + + # Reference an existing file (no copy) + # Path is relative to store location + table.insert1({'recording_id': 1, 'raw_data': 'subject01/session001/data.bin'}) + + # Fetch returns ObjectRef for lazy access + ref = (table & 'recording_id=1').fetch1('raw_data') + ref.read() # Read file content + ref.download() # Download to local path + + Storage Format: + JSON metadata: ``{path, store, size, timestamp}`` + + Reserved Sections: + Paths cannot start with ``_hash/`` or ``_schema/`` - these are managed by DataJoint. + + Warning: + The file must exist in the store at the specified path. + DataJoint does not manage the lifecycle of referenced files. + """ + + name = "filepath" + + def get_dtype(self, is_store: bool) -> str: + """Filepath requires a store (use @store syntax).""" + if not is_store: + raise DataJointError( + " requires @ symbol. Use for default store or to specify store." + ) + return "json" + + def encode(self, value: Any, *, key: dict | None = None, store_name: str | None = None) -> dict: + """ + Store path reference as JSON metadata. + + Parameters + ---------- + value : str + Relative path within the store. Cannot use reserved sections (_hash/, _schema/). + key : dict, optional + Primary key values (unused). + store_name : str, optional + Store where the file exists. + + Returns + ------- + dict + Metadata dict: ``{path, store}``. + + Raises + ------ + ValueError + If path uses reserved sections (_hash/ or _schema/). + FileNotFoundError + If file does not exist in the store. + """ + from datetime import datetime, timezone + + from .. import config + from ..hash_registry import get_store_backend + + path = str(value) + + # Get store spec to check prefix configuration + # Use filepath_default if no store specified (filepath is not part of OAS) + spec = config.get_store_spec(store_name, use_filepath_default=True) + + # Validate path doesn't use reserved sections (hash and schema) + path_normalized = path.lstrip("/") + reserved_prefixes = [] + + hash_prefix = spec.get("hash_prefix") + if hash_prefix: + reserved_prefixes.append(("hash_prefix", hash_prefix)) + + schema_prefix = spec.get("schema_prefix") + if schema_prefix: + reserved_prefixes.append(("schema_prefix", schema_prefix)) + + # Check if path starts with any reserved prefix + for prefix_name, prefix_value in reserved_prefixes: + prefix_normalized = prefix_value.strip("/") + "/" + if path_normalized.startswith(prefix_normalized): + raise ValueError( + f" cannot use reserved section '{prefix_value}' ({prefix_name}). " + f"This section is managed by DataJoint. " + f"Got path: {path}" + ) + + # If filepath_prefix is configured, enforce it + filepath_prefix = spec.get("filepath_prefix") + if filepath_prefix: + filepath_prefix_normalized = filepath_prefix.strip("/") + "/" + if not path_normalized.startswith(filepath_prefix_normalized): + raise ValueError(f" must use prefix '{filepath_prefix}' (filepath_prefix). Got path: {path}") + + # Verify file exists + backend = get_store_backend(store_name) + if not backend.exists(path): + raise FileNotFoundError(f"File not found in store '{store_name or 'default'}': {path}") + + # Get file info + try: + size = backend.size(path) + except Exception: + size = None + + return { + "path": path, + "store": store_name, + "size": size, + "is_dir": False, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + + def decode(self, stored: dict, *, key: dict | None = None) -> Any: + """ + Create ObjectRef handle for lazy access. + + Parameters + ---------- + stored : dict + Metadata dict with path and store. + key : dict, optional + Primary key values (unused). + + Returns + ------- + ObjectRef + Handle for accessing the file. + """ + from ..objectref import ObjectRef + from ..hash_registry import get_store_backend + + store_name = stored.get("store") + backend = get_store_backend(store_name) + return ObjectRef.from_json(stored, backend=backend) + + def validate(self, value: Any) -> None: + """Validate that value is a path string or Path object.""" + from pathlib import Path + + if not isinstance(value, (str, Path)): + raise TypeError(f" expects a path string or Path, got {type(value).__name__}") diff --git a/src/datajoint/builtin_codecs/hash.py b/src/datajoint/builtin_codecs/hash.py new file mode 100644 index 000000000..676c1916f --- /dev/null +++ b/src/datajoint/builtin_codecs/hash.py @@ -0,0 +1,104 @@ +""" +Hash-addressed storage codec with SHA256 deduplication. +""" + +from __future__ import annotations + +from typing import Any + +from ..codecs import Codec +from ..errors import DataJointError + + +class HashCodec(Codec): + """ + Hash-addressed storage with SHA256 deduplication. + + The ```` codec stores raw bytes using hash-addressed storage. + Data is identified by its SHA256 hash and stored in a hierarchical directory: + ``_hash/{hash[:2]}/{hash[2:4]}/{hash}`` + + The database column stores JSON metadata: ``{hash, store, size}``. + Duplicate content is automatically deduplicated across all tables. + + Deletion: Requires garbage collection via ``dj.gc.collect()``. + + External only - requires @ modifier. + + Example:: + + @schema + class RawContent(dj.Manual): + definition = ''' + content_id : int + --- + data : + ''' + + # Insert raw bytes + table.insert1({'content_id': 1, 'data': b'raw binary content'}) + + Note: + This codec accepts only ``bytes``. For Python objects, use ````. + Typically used indirectly via ```` or ```` rather than directly. + + See Also + -------- + datajoint.gc : Garbage collection for orphaned storage. + """ + + name = "hash" + + def get_dtype(self, is_store: bool) -> str: + """Hash storage is in-store only.""" + if not is_store: + raise DataJointError(" requires @ (in-store storage only)") + return "json" + + def encode(self, value: bytes, *, key: dict | None = None, store_name: str | None = None) -> dict: + """ + Store content and return metadata. + + Parameters + ---------- + value : bytes + Raw bytes to store. + key : dict, optional + Context dict with ``_schema`` for path isolation. + store_name : str, optional + Store to use. If None, uses default store. + + Returns + ------- + dict + Metadata dict: ``{hash, path, schema, store, size}``. + """ + from ..hash_registry import put_hash + + schema_name = (key or {}).get("_schema", "unknown") + return put_hash(value, schema_name=schema_name, store_name=store_name) + + def decode(self, stored: dict, *, key: dict | None = None) -> bytes: + """ + Retrieve content using stored metadata. + + Parameters + ---------- + stored : dict + Metadata dict with ``'path'``, ``'hash'``, and optionally ``'store'``. + key : dict, optional + Context dict (unused - path is in metadata). + + Returns + ------- + bytes + Original bytes. + """ + from ..hash_registry import get_hash + + return get_hash(stored) + + def validate(self, value: Any) -> None: + """Validate that value is bytes.""" + if not isinstance(value, bytes): + raise TypeError(f" expects bytes, got {type(value).__name__}") diff --git a/src/datajoint/builtin_codecs/npy.py b/src/datajoint/builtin_codecs/npy.py new file mode 100644 index 000000000..51c5731ee --- /dev/null +++ b/src/datajoint/builtin_codecs/npy.py @@ -0,0 +1,377 @@ +""" +NumPy array codec using .npy format. +""" + +from __future__ import annotations + +from typing import Any + +from ..errors import DataJointError +from .schema import SchemaCodec + + +class NpyRef: + """ + Lazy reference to a numpy array stored as a .npy file. + + This class provides metadata access without I/O and transparent + integration with numpy operations via the ``__array__`` protocol. + + Attributes + ---------- + shape : tuple[int, ...] + Array shape (from metadata, no I/O). + dtype : numpy.dtype + Array dtype (from metadata, no I/O). + path : str + Storage path within the store. + store : str or None + Store name (None for default). + + Examples + -------- + Metadata access without download:: + + ref = (Recording & key).fetch1('waveform') + print(ref.shape) # (1000, 32) - no download + print(ref.dtype) # float64 - no download + + Explicit loading:: + + arr = ref.load() # Downloads and returns np.ndarray + + Transparent numpy integration:: + + # These all trigger automatic download via __array__ + result = ref + 1 + result = np.mean(ref) + result = ref[0:100] # Slicing works too + """ + + __slots__ = ("_meta", "_backend", "_cached") + + def __init__(self, metadata: dict, backend: Any): + """ + Initialize NpyRef from metadata and storage backend. + + Parameters + ---------- + metadata : dict + JSON metadata containing path, store, dtype, shape. + backend : StorageBackend + Storage backend for file operations. + """ + self._meta = metadata + self._backend = backend + self._cached = None + + @property + def shape(self) -> tuple: + """Array shape (no I/O required).""" + return tuple(self._meta["shape"]) + + @property + def dtype(self): + """Array dtype (no I/O required).""" + import numpy as np + + return np.dtype(self._meta["dtype"]) + + @property + def ndim(self) -> int: + """Number of dimensions (no I/O required).""" + return len(self._meta["shape"]) + + @property + def size(self) -> int: + """Total number of elements (no I/O required).""" + import math + + return math.prod(self._meta["shape"]) + + @property + def nbytes(self) -> int: + """Total bytes (estimated from shape and dtype, no I/O required).""" + return self.size * self.dtype.itemsize + + @property + def path(self) -> str: + """Storage path within the store.""" + return self._meta["path"] + + @property + def store(self) -> str | None: + """Store name (None for default store).""" + return self._meta.get("store") + + @property + def is_loaded(self) -> bool: + """True if array data has been downloaded and cached.""" + return self._cached is not None + + def load(self, mmap_mode=None): + """ + Download and return the array. + + Parameters + ---------- + mmap_mode : str, optional + Memory-map mode for lazy, random-access loading of large arrays: + + - ``'r'``: Read-only + - ``'r+'``: Read-write + - ``'c'``: Copy-on-write (changes not saved to disk) + + If None (default), loads entire array into memory. + + Returns + ------- + numpy.ndarray or numpy.memmap + The array data. Returns ``numpy.memmap`` if mmap_mode is specified. + + Notes + ----- + When ``mmap_mode`` is None, the array is cached after first load. + + For local filesystem stores, memory mapping accesses the file directly + with no download. For remote stores (S3, etc.), the file is downloaded + to a local cache (``{tempdir}/datajoint_mmap/``) before memory mapping. + + Examples + -------- + Standard loading:: + + arr = ref.load() # Loads entire array into memory + + Memory-mapped for random access to large arrays:: + + arr = ref.load(mmap_mode='r') + slice = arr[1000:2000] # Only reads the needed portion from disk + """ + import io + + import numpy as np + + if mmap_mode is None: + # Standard loading with caching + if self._cached is None: + buffer = self._backend.get_buffer(self.path) + self._cached = np.load(io.BytesIO(buffer), allow_pickle=False) + return self._cached + else: + # Memory-mapped loading + if self._backend.protocol == "file": + # Local filesystem - mmap directly, no download needed + local_path = self._backend._full_path(self.path) + return np.load(local_path, mmap_mode=mmap_mode, allow_pickle=False) + else: + # Remote storage - download to local cache first + import hashlib + import tempfile + from pathlib import Path + + path_hash = hashlib.md5(self.path.encode()).hexdigest()[:12] + cache_dir = Path(tempfile.gettempdir()) / "datajoint_mmap" + cache_dir.mkdir(exist_ok=True) + cache_path = cache_dir / f"{path_hash}.npy" + + if not cache_path.exists(): + buffer = self._backend.get_buffer(self.path) + cache_path.write_bytes(buffer) + + return np.load(str(cache_path), mmap_mode=mmap_mode, allow_pickle=False) + + def __array__(self, dtype=None): + """ + NumPy array protocol for transparent integration. + + This method is called automatically when the NpyRef is used + in numpy operations (arithmetic, ufuncs, etc.). + + Parameters + ---------- + dtype : numpy.dtype, optional + Desired output dtype. + + Returns + ------- + numpy.ndarray + The array data, optionally cast to dtype. + """ + arr = self.load() + if dtype is not None: + return arr.astype(dtype) + return arr + + def __getitem__(self, key): + """Support indexing/slicing by loading then indexing.""" + return self.load()[key] + + def __len__(self) -> int: + """Length of first dimension.""" + if not self._meta["shape"]: + raise TypeError("len() of 0-dimensional array") + return self._meta["shape"][0] + + def __repr__(self) -> str: + status = "loaded" if self.is_loaded else "not loaded" + return f"NpyRef(shape={self.shape}, dtype={self.dtype}, {status})" + + def __str__(self) -> str: + return repr(self) + + +class NpyCodec(SchemaCodec): + """ + Schema-addressed storage for numpy arrays as .npy files. + + The ```` codec stores numpy arrays as standard ``.npy`` files + using schema-addressed paths: ``{schema}/{table}/{pk}/{attribute}.npy``. + Arrays are fetched lazily via ``NpyRef``, which provides metadata access + without I/O and transparent numpy integration via ``__array__``. + + Store only - requires ``@`` modifier. + + Key Features: + - **Portable**: Standard .npy format readable by numpy, MATLAB, etc. + - **Lazy loading**: Metadata (shape, dtype) available without download + - **Transparent**: Use in numpy operations triggers automatic download + - **Safe bulk fetch**: Fetching many rows doesn't download until needed + - **Schema-addressed**: Browsable paths that mirror database structure + + Example:: + + @schema + class Recording(dj.Manual): + definition = ''' + recording_id : int + --- + waveform : # default store + spectrogram : # specific store + ''' + + # Insert - just pass the array + Recording.insert1({ + 'recording_id': 1, + 'waveform': np.random.randn(1000, 32), + }) + + # Fetch - returns NpyRef (lazy) + ref = (Recording & 'recording_id=1').fetch1('waveform') + ref.shape # (1000, 32) - no download + ref.dtype # float64 - no download + + # Use in numpy ops - downloads automatically + result = np.mean(ref, axis=0) + + # Or load explicitly + arr = ref.load() + + Storage Details: + - File format: NumPy .npy (version 1.0 or 2.0) + - Path: ``{schema}/{table}/{pk}/{attribute}.npy`` + - Database column: JSON with ``{path, store, dtype, shape}`` + + Deletion: Requires garbage collection via ``dj.gc.collect()``. + + See Also + -------- + datajoint.gc : Garbage collection for orphaned storage. + NpyRef : The lazy array reference returned on fetch. + SchemaCodec : Base class for schema-addressed codecs. + ObjectCodec : Schema-addressed storage for files/folders. + """ + + name = "npy" + + def validate(self, value: Any) -> None: + """ + Validate that value is a numpy array suitable for .npy storage. + + Parameters + ---------- + value : Any + Value to validate. + + Raises + ------ + DataJointError + If value is not a numpy array or has object dtype. + """ + import numpy as np + + if not isinstance(value, np.ndarray): + raise DataJointError(f" requires numpy.ndarray, got {type(value).__name__}") + if value.dtype == object: + raise DataJointError(" does not support object dtype arrays") + + def encode( + self, + value: Any, + *, + key: dict | None = None, + store_name: str | None = None, + ) -> dict: + """ + Serialize array to .npy and upload to storage. + + Parameters + ---------- + value : numpy.ndarray + Array to store. + key : dict, optional + Context dict with ``_schema``, ``_table``, ``_field``, + and primary key values for path construction. + store_name : str, optional + Target store. If None, uses default store. + + Returns + ------- + dict + JSON metadata: ``{path, store, dtype, shape}``. + """ + import io + + import numpy as np + + # Extract context using inherited helper + schema, table, field, primary_key = self._extract_context(key) + + # Build schema-addressed storage path + path, _ = self._build_path(schema, table, field, primary_key, ext=".npy", store_name=store_name) + + # Serialize to .npy format + buffer = io.BytesIO() + np.save(buffer, value, allow_pickle=False) + npy_bytes = buffer.getvalue() + + # Upload to storage using inherited helper + backend = self._get_backend(store_name) + backend.put_buffer(npy_bytes, path) + + # Return metadata (includes numpy-specific shape/dtype) + return { + "path": path, + "store": store_name, + "dtype": str(value.dtype), + "shape": list(value.shape), + } + + def decode(self, stored: dict, *, key: dict | None = None) -> NpyRef: + """ + Create lazy NpyRef from stored metadata. + + Parameters + ---------- + stored : dict + JSON metadata from database. + key : dict, optional + Primary key values (unused). + + Returns + ------- + NpyRef + Lazy array reference with metadata access and numpy integration. + """ + backend = self._get_backend(stored.get("store")) + return NpyRef(stored, backend) diff --git a/src/datajoint/builtin_codecs/object.py b/src/datajoint/builtin_codecs/object.py new file mode 100644 index 000000000..268651aea --- /dev/null +++ b/src/datajoint/builtin_codecs/object.py @@ -0,0 +1,213 @@ +""" +Schema-addressed storage for files and folders. +""" + +from __future__ import annotations + +from typing import Any + +from ..errors import DataJointError +from .schema import SchemaCodec + + +class ObjectCodec(SchemaCodec): + """ + Schema-addressed storage for files and folders. + + The ```` codec provides managed file/folder storage using + schema-addressed paths: ``{schema}/{table}/{pk}/{field}/``. This creates + a browsable organization in object storage that mirrors the database schema. + + Unlike hash-addressed storage (````), each row has its own unique path + (no deduplication). Ideal for: + + - Zarr arrays (hierarchical chunked data) + - HDF5 files + - Complex multi-file outputs + - Any content that shouldn't be deduplicated + + Store only - requires @ modifier. + + Example:: + + @schema + class Analysis(dj.Computed): + definition = ''' + -> Recording + --- + results : + ''' + + def make(self, key): + # Store a file + self.insert1({**key, 'results': '/path/to/results.zarr'}) + + # Fetch returns ObjectRef for lazy access + ref = (Analysis & key).fetch1('results') + ref.path # Storage path + ref.read() # Read file content + ref.fsmap # For zarr.open(ref.fsmap) + + Storage Structure: + Objects are stored at:: + + {store_root}/{schema}/{table}/{pk}/{field}/ + + Deletion: Requires garbage collection via ``dj.gc.collect()``. + + Comparison with hash-addressed:: + + | Aspect | | | + |----------------|---------------------|---------------------| + | Addressing | Schema-addressed | Hash-addressed | + | Deduplication | No | Yes | + | Deletion | GC required | GC required | + | Use case | Zarr, HDF5 | Blobs, attachments | + + See Also + -------- + datajoint.gc : Garbage collection for orphaned storage. + SchemaCodec : Base class for schema-addressed codecs. + NpyCodec : Schema-addressed storage for numpy arrays. + HashCodec : Hash-addressed storage with deduplication. + """ + + name = "object" + + def encode( + self, + value: Any, + *, + key: dict | None = None, + store_name: str | None = None, + ) -> dict: + """ + Store content and return metadata. + + Parameters + ---------- + value : bytes, str, or Path + Content to store: bytes (raw data), or str/Path (file/folder to upload). + key : dict, optional + Context for path construction with keys ``_schema``, ``_table``, + ``_field``, plus primary key values. + store_name : str, optional + Store to use. If None, uses default store. + + Returns + ------- + dict + Metadata dict suitable for ``ObjectRef.from_json()``. + """ + from datetime import datetime, timezone + from pathlib import Path + + # Extract context using inherited helper + schema, table, field, primary_key = self._extract_context(key) + + # Check for pre-computed metadata (from staged insert) + if isinstance(value, dict) and "path" in value: + # Already encoded, pass through + return value + + # Determine content type and extension + is_dir = False + ext = None + size = None + item_count = None + + if isinstance(value, bytes): + content = value + size = len(content) + elif isinstance(value, tuple) and len(value) == 2: + # Tuple format: (extension, data) where data is bytes or file-like + ext, data = value + if hasattr(data, "read"): + content = data.read() + else: + content = data + size = len(content) + elif isinstance(value, (str, Path)): + source_path = Path(value) + if not source_path.exists(): + raise DataJointError(f"Source path not found: {source_path}") + is_dir = source_path.is_dir() + ext = source_path.suffix if not is_dir else None + if is_dir: + # For directories, we'll upload later + content = None + # Count items in directory + item_count = sum(1 for _ in source_path.rglob("*") if _.is_file()) + else: + content = source_path.read_bytes() + size = len(content) + else: + raise TypeError(f" expects bytes or path, got {type(value).__name__}") + + # Build storage path using inherited helper + path, token = self._build_path(schema, table, field, primary_key, ext=ext, store_name=store_name) + + # Get storage backend using inherited helper + backend = self._get_backend(store_name) + + # Upload content + if is_dir: + # Upload directory recursively + source_path = Path(value) + backend.put_folder(str(source_path), path) + # Compute size by summing all files + size = sum(f.stat().st_size for f in source_path.rglob("*") if f.is_file()) + else: + backend.put_buffer(content, path) + + # Build metadata + timestamp = datetime.now(timezone.utc) + metadata = { + "path": path, + "store": store_name, + "size": size, + "ext": ext, + "is_dir": is_dir, + "item_count": item_count, + "timestamp": timestamp.isoformat(), + } + + return metadata + + def decode(self, stored: dict, *, key: dict | None = None) -> Any: + """ + Create ObjectRef handle for lazy access. + + Parameters + ---------- + stored : dict + Metadata dict from database. + key : dict, optional + Primary key values (unused). + + Returns + ------- + ObjectRef + Handle for accessing the stored content. + """ + from ..objectref import ObjectRef + + backend = self._get_backend(stored.get("store")) + return ObjectRef.from_json(stored, backend=backend) + + def validate(self, value: Any) -> None: + """Validate value is bytes, path, dict metadata, or (ext, data) tuple.""" + from pathlib import Path + + if isinstance(value, bytes): + return + if isinstance(value, (str, Path)): + # Could be a path or pre-encoded JSON string + return + if isinstance(value, tuple) and len(value) == 2: + # Tuple format: (extension, data) + return + if isinstance(value, dict) and "path" in value: + # Pre-computed metadata dict (from staged insert) + return + raise TypeError(f" expects bytes or path, got {type(value).__name__}") diff --git a/src/datajoint/builtin_codecs/schema.py b/src/datajoint/builtin_codecs/schema.py new file mode 100644 index 000000000..18bd62d00 --- /dev/null +++ b/src/datajoint/builtin_codecs/schema.py @@ -0,0 +1,175 @@ +""" +Schema-addressed storage base class. +""" + +from __future__ import annotations + +from ..codecs import Codec +from ..errors import DataJointError + + +class SchemaCodec(Codec, register=False): + """ + Abstract base class for schema-addressed codecs. + + Schema-addressed storage is an OAS (Object-Augmented Schema) addressing + scheme where paths mirror the database schema structure: + ``{schema}/{table}/{pk}/{attribute}``. This creates a browsable + organization in object storage that reflects the schema design. + + Subclasses must implement: + - ``name``: Codec name for ```` syntax + - ``encode()``: Serialize and upload content + - ``decode()``: Create lazy reference from metadata + - ``validate()``: Validate input values + + Helper Methods: + - ``_extract_context()``: Parse key dict into schema/table/field/pk + - ``_build_path()``: Construct storage path from context + - ``_get_backend()``: Get storage backend by name + + Comparison with Hash-addressed: + - **Schema-addressed** (this): Path from schema structure, no dedup + - **Hash-addressed**: Path from content hash, automatic dedup + + Example:: + + class MyCodec(SchemaCodec): + name = "my" + + def encode(self, value, *, key=None, store_name=None): + schema, table, field, pk = self._extract_context(key) + path, _ = self._build_path(schema, table, field, pk, ext=".dat") + backend = self._get_backend(store_name) + backend.put_buffer(serialize(value), path) + return {"path": path, "store": store_name, ...} + + def decode(self, stored, *, key=None): + backend = self._get_backend(stored.get("store")) + return MyRef(stored, backend) + + See Also + -------- + HashCodec : Hash-addressed storage with content deduplication. + ObjectCodec : Schema-addressed storage for files/folders. + NpyCodec : Schema-addressed storage for numpy arrays. + """ + + def get_dtype(self, is_store: bool) -> str: + """ + Return storage dtype. Schema-addressed codecs require @ modifier. + + Parameters + ---------- + is_store : bool + Must be True for schema-addressed codecs. + + Returns + ------- + str + "json" for metadata storage. + + Raises + ------ + DataJointError + If is_store is False (@ modifier missing). + """ + if not is_store: + raise DataJointError(f"<{self.name}> requires @ (store only)") + return "json" + + def _extract_context(self, key: dict | None) -> tuple[str, str, str, dict]: + """ + Extract schema, table, field, and primary key from context dict. + + Parameters + ---------- + key : dict or None + Context dict with ``_schema``, ``_table``, ``_field``, + and primary key values. + + Returns + ------- + tuple[str, str, str, dict] + ``(schema, table, field, primary_key)`` + """ + key = dict(key) if key else {} + schema = key.pop("_schema", "unknown") + table = key.pop("_table", "unknown") + field = key.pop("_field", "data") + primary_key = {k: v for k, v in key.items() if not k.startswith("_")} + return schema, table, field, primary_key + + def _build_path( + self, + schema: str, + table: str, + field: str, + primary_key: dict, + ext: str | None = None, + store_name: str | None = None, + ) -> tuple[str, str]: + """ + Build schema-addressed storage path. + + Constructs a path that mirrors the database schema structure: + ``{schema}/{table}/{pk_values}/{field}{ext}`` + + Supports partitioning if configured in the store. + + Parameters + ---------- + schema : str + Schema name. + table : str + Table name. + field : str + Field/attribute name. + primary_key : dict + Primary key values. + ext : str, optional + File extension (e.g., ".npy", ".zarr"). + store_name : str, optional + Store name for retrieving partition configuration. + + Returns + ------- + tuple[str, str] + ``(path, token)`` where path is the storage path and token + is a unique identifier. + """ + from ..storage import build_object_path + from .. import config + + # Get store configuration for partition_pattern and token_length + spec = config.get_store_spec(store_name) + partition_pattern = spec.get("partition_pattern") + token_length = spec.get("token_length", 8) + + return build_object_path( + schema=schema, + table=table, + field=field, + primary_key=primary_key, + ext=ext, + partition_pattern=partition_pattern, + token_length=token_length, + ) + + def _get_backend(self, store_name: str | None = None): + """ + Get storage backend by name. + + Parameters + ---------- + store_name : str, optional + Store name. If None, returns default store. + + Returns + ------- + StorageBackend + Storage backend instance. + """ + from ..hash_registry import get_store_backend + + return get_store_backend(store_name) diff --git a/src/datajoint/codecs.py b/src/datajoint/codecs.py index afa60321f..5c192d46e 100644 --- a/src/datajoint/codecs.py +++ b/src/datajoint/codecs.py @@ -544,7 +544,9 @@ def decode_attribute(attr, data, squeeze: bool = False): # Process the final storage type (what's in the database) if final_dtype.lower() == "json": - data = json.loads(data) + # psycopg2 auto-deserializes JSON to dict/list; only parse strings + if isinstance(data, str): + data = json.loads(data) elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"): pass # Blob data is already bytes elif final_dtype.lower() == "binary(16)": @@ -562,7 +564,10 @@ def decode_attribute(attr, data, squeeze: bool = False): # No codec - handle native types if attr.json: - return json.loads(data) + # psycopg2 auto-deserializes JSON to dict/list; only parse strings + if isinstance(data, str): + return json.loads(data) + return data if attr.uuid: import uuid as uuid_module diff --git a/src/datajoint/condition.py b/src/datajoint/condition.py index 9c6f933d1..0335d6adb 100644 --- a/src/datajoint/condition.py +++ b/src/datajoint/condition.py @@ -31,7 +31,7 @@ JSON_PATTERN = re.compile(r"^(?P\w+)(\.(?P[\w.*\[\]]+))?(:(?P[\w(,\s)]+))?$") -def translate_attribute(key: str) -> tuple[dict | None, str]: +def translate_attribute(key: str, adapter=None) -> tuple[dict | None, str]: """ Translate an attribute key, handling JSON path notation. @@ -39,6 +39,9 @@ def translate_attribute(key: str) -> tuple[dict | None, str]: ---------- key : str Attribute name, optionally with JSON path (e.g., ``"attr.path.field"``). + adapter : DatabaseAdapter, optional + Database adapter for backend-specific SQL generation. + If not provided, uses MySQL syntax for backward compatibility. Returns ------- @@ -53,9 +56,14 @@ def translate_attribute(key: str) -> tuple[dict | None, str]: if match["path"] is None: return match, match["attr"] else: - return match, "json_value(`{}`, _utf8mb4'$.{}'{})".format( - *[((f" returning {v}" if k == "type" else v) if v else "") for k, v in match.items()] - ) + # Use adapter's json_path_expr if available, otherwise fall back to MySQL syntax + if adapter is not None: + return match, adapter.json_path_expr(match["attr"], match["path"], match["type"]) + else: + # Legacy MySQL syntax for backward compatibility + return match, "json_value(`{}`, _utf8mb4'$.{}'{})".format( + *[((f" returning {v}" if k == "type" else v) if v else "") for k, v in match.items()] + ) class PromiscuousOperand: @@ -301,16 +309,21 @@ def make_condition( """ from .expression import Aggregation, QueryExpression, U + # Get adapter for backend-agnostic SQL generation + adapter = query_expression.connection.adapter + def prep_value(k, v): """prepare SQL condition""" - key_match, k = translate_attribute(k) - if key_match["path"] is None: - k = f"`{k}`" - if query_expression.heading[key_match["attr"]].json and key_match["path"] is not None and isinstance(v, dict): + key_match, k = translate_attribute(k, adapter) + is_json_path = key_match is not None and key_match.get("path") is not None + + if not is_json_path: + k = adapter.quote_identifier(k) + if is_json_path and isinstance(v, dict): return f"{k}='{json.dumps(v)}'" if v is None: return f"{k} IS NULL" - if query_expression.heading[key_match["attr"]].uuid: + if key_match is not None and query_expression.heading[key_match["attr"]].uuid: if not isinstance(v, uuid.UUID): try: v = uuid.UUID(v) @@ -327,10 +340,12 @@ def prep_value(k, v): list, ), ): - return f'{k}="{v}"' + # Use single quotes for string literals (works for both MySQL and PostgreSQL) + return f"{k}='{v}'" if isinstance(v, str): - v = v.replace("%", "%%").replace("\\", "\\\\") - return f'{k}="{v}"' + # Escape single quotes by doubling them, and escape % for driver + v = v.replace("'", "''").replace("%", "%%").replace("\\", "\\\\") + return f"{k}='{v}'" return f"{k}={v}" def combine_conditions(negate, conditions): @@ -410,10 +425,12 @@ def combine_conditions(negate, conditions): # without common attributes, any non-empty set matches everything (not negate if condition else negate) if not common_attributes - else "({fields}) {not_}in ({subquery})".format( - fields="`" + "`,`".join(common_attributes) + "`", - not_="not " if negate else "", - subquery=condition.make_sql(common_attributes), + else ( + "({fields}) {not_}in ({subquery})".format( + fields=", ".join(adapter.quote_identifier(a) for a in common_attributes), + not_="not " if negate else "", + subquery=condition.make_sql(common_attributes), + ) ) ) diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index 1445300ed..21b48e638 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -13,9 +13,8 @@ from contextlib import contextmanager from typing import Callable -import pymysql as client - from . import errors +from .adapters import get_adapter from .blob import pack, unpack from .dependencies import Dependencies from .settings import config @@ -28,7 +27,7 @@ cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config -def translate_query_error(client_error: Exception, query: str) -> Exception: +def translate_query_error(client_error: Exception, query: str, adapter) -> Exception: """ Translate client error to the corresponding DataJoint exception. @@ -38,6 +37,8 @@ def translate_query_error(client_error: Exception, query: str) -> Exception: The exception raised by the client interface. query : str SQL query with placeholders. + adapter : DatabaseAdapter + The database adapter instance. Returns ------- @@ -46,47 +47,7 @@ def translate_query_error(client_error: Exception, query: str) -> Exception: or the original error if no mapping exists. """ logger.debug("type: {}, args: {}".format(type(client_error), client_error.args)) - - err, *args = client_error.args - - match err: - # Loss of connection errors - case 0 | "(0, '')": - return errors.LostConnectionError("Server connection lost due to an interface error.", *args) - case 2006: - return errors.LostConnectionError("Connection timed out", *args) - case 2013: - return errors.LostConnectionError("Server connection lost", *args) - - # Access errors - case 1044 | 1142: - return errors.AccessError("Insufficient privileges.", args[0], query) - - # Integrity errors - case 1062: - return errors.DuplicateError(*args) - case 1217 | 1451 | 1452 | 3730: - # 1217: Cannot delete parent row (FK constraint) - # 1451: Cannot delete/update parent row (FK constraint) - # 1452: Cannot add/update child row (FK constraint) - # 3730: Cannot drop table referenced by FK constraint - return errors.IntegrityError(*args) - - # Syntax errors - case 1064: - return errors.QuerySyntaxError(args[0], query) - - # Existence errors - case 1146: - return errors.MissingTableError(args[0], query) - case 1364: - return errors.MissingAttributeError(*args) - case 1054: - return errors.UnknownAttributeError(*args) - - # All other errors pass through unchanged - case _: - return client_error + return adapter.translate_error(client_error, query) def conn( @@ -219,15 +180,29 @@ def __init__( port = config["database.port"] self.conn_info = dict(host=host, port=port, user=user, passwd=password) if use_tls is not False: - self.conn_info["ssl"] = use_tls if isinstance(use_tls, dict) else {"ssl": {}} + # use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config) + if isinstance(use_tls, dict): + self.conn_info["ssl"] = use_tls + elif use_tls is None: + # Auto-detect: try SSL, fallback to non-SSL if server doesn't support it + self.conn_info["ssl"] = True + else: + # use_tls=True: enable SSL with default settings + self.conn_info["ssl"] = True self.conn_info["ssl_input"] = use_tls self.init_fun = init_fun self._conn = None self._query_cache = None + self._is_closed = True # Mark as closed until connect() succeeds + + # Select adapter based on configured backend + backend = config["database.backend"] + self.adapter = get_adapter(backend) + self.connect() if self.is_connected: logger.info("DataJoint {version} connected to {user}@{host}:{port}".format(version=__version__, **self.conn_info)) - self.connection_id = self.query("SELECT connection_id()").fetchone()[0] + self.connection_id = self.adapter.get_connection_id(self._conn) else: raise errors.LostConnectionError("Connection failed {user}@{host}:{port}".format(**self.conn_info)) self._in_transaction = False @@ -246,26 +221,36 @@ def connect(self) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*deprecated.*") try: - self._conn = client.connect( + # Use adapter to create connection + self._conn = self.adapter.connect( + host=self.conn_info["host"], + port=self.conn_info["port"], + user=self.conn_info["user"], + password=self.conn_info["passwd"], init_command=self.init_fun, - sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", charset=config["connection.charset"], - **{k: v for k, v in self.conn_info.items() if k not in ["ssl_input"]}, + use_tls=self.conn_info.get("ssl"), ) - except client.err.InternalError: - self._conn = client.connect( - init_command=self.init_fun, - sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO," - "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY", - charset=config["connection.charset"], - **{ - k: v - for k, v in self.conn_info.items() - if not (k == "ssl_input" or k == "ssl" and self.conn_info["ssl_input"] is None) - }, - ) - self._conn.autocommit(True) + except Exception as ssl_error: + # If SSL fails, retry without SSL (if it was auto-detected) + if self.conn_info.get("ssl_input") is None: + logger.warning( + "SSL connection failed (%s). Falling back to non-SSL connection. " + "To require SSL, set use_tls=True explicitly.", + ssl_error, + ) + self._conn = self.adapter.connect( + host=self.conn_info["host"], + port=self.conn_info["port"], + user=self.conn_info["user"], + password=self.conn_info["passwd"], + init_command=self.init_fun, + charset=config["connection.charset"], + use_tls=False, # Explicitly disable SSL for fallback + ) + else: + raise + self._is_closed = False # Mark as connected after successful connection def set_query_cache(self, query_cache: str | None = None) -> None: """ @@ -293,7 +278,9 @@ def purge_query_cache(self) -> None: def close(self) -> None: """Close the database connection.""" - self._conn.close() + if self._conn is not None: + self._conn.close() + self._is_closed = True def __enter__(self) -> "Connection": """ @@ -355,7 +342,7 @@ def ping(self) -> None: Exception If the connection is closed. """ - self._conn.ping(reconnect=False) + self.adapter.ping(self._conn) @property def is_connected(self) -> bool: @@ -367,22 +354,24 @@ def is_connected(self) -> bool: bool True if connected. """ + if self._is_closed: + return False try: self.ping() except: + self._is_closed = True return False return True - @staticmethod - def _execute_query(cursor, query, args, suppress_warnings): + def _execute_query(self, cursor, query, args, suppress_warnings): try: with warnings.catch_warnings(): if suppress_warnings: # suppress all warnings arising from underlying SQL library warnings.simplefilter("ignore") cursor.execute(query, args) - except client.err.Error as err: - raise translate_query_error(err, query) + except Exception as err: + raise translate_query_error(err, query, self.adapter) def query( self, @@ -426,7 +415,8 @@ def query( if use_query_cache: if not config[cache_key]: raise errors.DataJointError(f"Provide filepath dj.config['{cache_key}'] when using query caching.") - hash_ = hashlib.md5((str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode() + pack(args)).hexdigest() + # Cache key is backend-specific (no identifier normalization needed) + hash_ = hashlib.md5((str(self._query_cache)).encode() + pack(args) + query.encode()).hexdigest() cache_path = pathlib.Path(config[cache_key]) / str(hash_) try: buffer = cache_path.read_bytes() @@ -438,20 +428,19 @@ def query( if reconnect is None: reconnect = config["database.reconnect"] logger.debug("Executing SQL:" + query[:query_log_max_length]) - cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor - cursor = self._conn.cursor(cursor=cursor_class) + cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) try: self._execute_query(cursor, query, args, suppress_warnings) except errors.LostConnectionError: if not reconnect: raise - logger.warning("Reconnecting to MySQL server.") + logger.warning("Reconnecting to database server.") self.connect() if self._in_transaction: self.cancel_transaction() raise errors.LostConnectionError("Connection was lost during a transaction.") logger.debug("Re-executing") - cursor = self._conn.cursor(cursor=cursor_class) + cursor = self.adapter.get_cursor(self._conn, as_dict=as_dict) self._execute_query(cursor, query, args, suppress_warnings) if use_query_cache: @@ -470,7 +459,7 @@ def get_user(self) -> str: str User name and host as ``'user@host'``. """ - return self.query("SELECT user()").fetchone()[0] + return self.query(f"SELECT {self.adapter.current_user_expr()}").fetchone()[0] # ---------- transaction processing @property @@ -497,19 +486,19 @@ def start_transaction(self) -> None: """ if self.in_transaction: raise errors.DataJointError("Nested connections are not supported.") - self.query("START TRANSACTION WITH CONSISTENT SNAPSHOT") + self.query(self.adapter.start_transaction_sql()) self._in_transaction = True logger.debug("Transaction started") def cancel_transaction(self) -> None: """Cancel the current transaction and roll back all changes.""" - self.query("ROLLBACK") + self.query(self.adapter.rollback_sql()) self._in_transaction = False logger.debug("Transaction cancelled. Rolling back ...") def commit_transaction(self) -> None: """Commit all changes and close the transaction.""" - self.query("COMMIT") + self.query(self.adapter.commit_sql()) self._in_transaction = False logger.debug("Transaction committed and closed.") diff --git a/src/datajoint/declare.py b/src/datajoint/declare.py index a640c444f..375daa07e 100644 --- a/src/datajoint/declare.py +++ b/src/datajoint/declare.py @@ -31,8 +31,8 @@ "bool": (r"bool$", "tinyint"), # UUID (stored as binary) "uuid": (r"uuid$", "binary(16)"), - # JSON - "json": (r"json$", None), # json passes through as-is + # JSON (matches both json and jsonb for PostgreSQL compatibility) + "json": (r"jsonb?$", None), # json/jsonb passes through as-is # Binary (bytes maps to longblob in MySQL, bytea in PostgreSQL) "bytes": (r"bytes$", "longblob"), # Temporal @@ -190,6 +190,7 @@ def compile_foreign_key( attr_sql: list[str], foreign_key_sql: list[str], index_sql: list[str], + adapter, fk_attribute_map: dict[str, tuple[str, str]] | None = None, ) -> None: """ @@ -212,6 +213,8 @@ def compile_foreign_key( SQL FOREIGN KEY constraints. Updated in place. index_sql : list[str] SQL INDEX declarations. Updated in place. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. fk_attribute_map : dict, optional Mapping of ``child_attr -> (parent_table, parent_attr)``. Updated in place. @@ -261,30 +264,57 @@ def compile_foreign_key( attributes.append(attr) if primary_key is not None: primary_key.append(attr) - attr_sql.append(ref.heading[attr].sql.replace("NOT NULL ", "", int(is_nullable))) + + # Build foreign key column definition using adapter + parent_attr = ref.heading[attr] + sql_type = parent_attr.sql_type + # For PostgreSQL enum types, qualify with schema name + # Enum type names start with "enum_" (generated hash-based names) + if sql_type.startswith("enum_") and adapter.backend == "postgresql": + sql_type = f"{adapter.quote_identifier(ref.database)}.{adapter.quote_identifier(sql_type)}" + col_def = adapter.format_column_definition( + name=attr, + sql_type=sql_type, + nullable=is_nullable, + default=None, + comment=parent_attr.sql_comment, + ) + attr_sql.append(col_def) + # Track FK attribute mapping for lineage: child_attr -> (parent_table, parent_attr) if fk_attribute_map is not None: parent_table = ref.support[0] # e.g., `schema`.`table` parent_attr = ref.heading[attr].original_name fk_attribute_map[attr] = (parent_table, parent_attr) - # declare the foreign key + # declare the foreign key using adapter for identifier quoting + fk_cols = ", ".join(adapter.quote_identifier(col) for col in ref.primary_key) + pk_cols = ", ".join(adapter.quote_identifier(ref.heading[name].original_name) for name in ref.primary_key) + + # Build referenced table name with proper quoting + # ref.support[0] may have cached quoting from a different backend + # Extract database and table name and rebuild with current adapter + parent_full_name = ref.support[0] + # Try to parse as database.table (with or without quotes) + parts = parent_full_name.replace('"', "").replace("`", "").split(".") + if len(parts) == 2: + ref_table_name = f"{adapter.quote_identifier(parts[0])}.{adapter.quote_identifier(parts[1])}" + else: + ref_table_name = adapter.quote_identifier(parts[0]) + foreign_key_sql.append( - "FOREIGN KEY (`{fk}`) REFERENCES {ref} (`{pk}`) ON UPDATE CASCADE ON DELETE RESTRICT".format( - fk="`,`".join(ref.primary_key), - pk="`,`".join(ref.heading[name].original_name for name in ref.primary_key), - ref=ref.support[0], - ) + f"FOREIGN KEY ({fk_cols}) REFERENCES {ref_table_name} ({pk_cols}) ON UPDATE CASCADE ON DELETE RESTRICT" ) # declare unique index if is_unique: - index_sql.append("UNIQUE INDEX ({attrs})".format(attrs=",".join("`%s`" % attr for attr in ref.primary_key))) + index_cols = ", ".join(adapter.quote_identifier(attr) for attr in ref.primary_key) + index_sql.append(f"UNIQUE INDEX ({index_cols})") def prepare_declare( - definition: str, context: dict -) -> tuple[str, list[str], list[str], list[str], list[str], list[str], dict[str, tuple[str, str]]]: + definition: str, context: dict, adapter +) -> tuple[str, list[str], list[str], list[str], list[str], list[str], dict[str, tuple[str, str]], dict[str, str]]: """ Parse a table definition into its components. @@ -294,11 +324,13 @@ def prepare_declare( DataJoint table definition string. context : dict Namespace for resolving foreign key references. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- tuple - Seven-element tuple containing: + Eight-element tuple containing: - table_comment : str - primary_key : list[str] @@ -307,6 +339,7 @@ def prepare_declare( - index_sql : list[str] - external_stores : list[str] - fk_attribute_map : dict[str, tuple[str, str]] + - column_comments : dict[str, str] - Column name to comment mapping """ # split definition into lines definition = re.split(r"\s*\n\s*", definition.strip()) @@ -322,11 +355,12 @@ def prepare_declare( index_sql = [] external_stores = [] fk_attribute_map = {} # child_attr -> (parent_table, parent_attr) + column_comments = {} # column_name -> comment (for PostgreSQL COMMENT ON) for line in definition: if not line or line.startswith("#"): # ignore additional comments pass - elif line.startswith("---") or line.startswith("___"): + elif line.startswith("---"): in_key = False # start parsing dependent attributes elif is_foreign_key(line): compile_foreign_key( @@ -337,12 +371,13 @@ def prepare_declare( attribute_sql, foreign_key_sql, index_sql, + adapter, fk_attribute_map, ) elif re.match(r"^(unique\s+)?index\s*.*$", line, re.I): # index - compile_index(line, index_sql) + compile_index(line, index_sql, adapter) else: - name, sql, store = compile_attribute(line, in_key, foreign_key_sql, context) + name, sql, store, comment = compile_attribute(line, in_key, foreign_key_sql, context, adapter) if store: external_stores.append(store) if in_key and name not in primary_key: @@ -350,6 +385,8 @@ def prepare_declare( if name not in attributes: attributes.append(name) attribute_sql.append(sql) + if comment: + column_comments[name] = comment return ( table_comment, @@ -359,40 +396,55 @@ def prepare_declare( index_sql, external_stores, fk_attribute_map, + column_comments, ) def declare( - full_table_name: str, definition: str, context: dict -) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]]]: + full_table_name: str, definition: str, context: dict, adapter +) -> tuple[str, list[str], list[str], dict[str, tuple[str, str]], list[str], list[str]]: r""" Parse a definition and generate SQL CREATE TABLE statement. Parameters ---------- full_table_name : str - Fully qualified table name (e.g., ```\`schema\`.\`table\```). + Fully qualified table name (e.g., ```\`schema\`.\`table\``` or ```"schema"."table"```). definition : str DataJoint table definition string. context : dict Namespace for resolving foreign key references. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- tuple - Four-element tuple: + Six-element tuple: - sql : str - SQL CREATE TABLE statement - external_stores : list[str] - External store names used - primary_key : list[str] - Primary key attribute names - fk_attribute_map : dict - FK attribute lineage mapping + - pre_ddl : list[str] - DDL statements to run BEFORE CREATE TABLE (e.g., CREATE TYPE) + - post_ddl : list[str] - DDL statements to run AFTER CREATE TABLE (e.g., COMMENT ON) Raises ------ DataJointError If table name exceeds max length or has no primary key. """ - table_name = full_table_name.strip("`").split(".")[1] + # Parse table name without assuming quote character + # Extract schema.table from quoted name using adapter + quote_char = adapter.quote_identifier("x")[0] # Get quote char from adapter + parts = full_table_name.split(".") + if len(parts) == 2: + schema_name = parts[0].strip(quote_char) + table_name = parts[1].strip(quote_char) + else: + schema_name = None + table_name = parts[0].strip(quote_char) + if len(table_name) > MAX_TABLE_NAME_LENGTH: raise DataJointError( "Table name `{name}` exceeds the max length of {max_length}".format( @@ -408,35 +460,87 @@ def declare( index_sql, external_stores, fk_attribute_map, - ) = prepare_declare(definition, context) + column_comments, + ) = prepare_declare(definition, context, adapter) # Add hidden job metadata for Computed/Imported tables (not parts) - # Note: table_name may still have backticks, strip them for prefix checking - clean_table_name = table_name.strip("`") if config.jobs.add_job_metadata: # Check if this is a Computed (__) or Imported (_) table, but not a Part (contains __ in middle) - is_computed = clean_table_name.startswith("__") and "__" not in clean_table_name[2:] - is_imported = clean_table_name.startswith("_") and not clean_table_name.startswith("__") + is_computed = table_name.startswith("__") and "__" not in table_name[2:] + is_imported = table_name.startswith("_") and not table_name.startswith("__") if is_computed or is_imported: - job_metadata_sql = [ - "`_job_start_time` datetime(3) DEFAULT NULL", - "`_job_duration` float DEFAULT NULL", - "`_job_version` varchar(64) DEFAULT ''", - ] + job_metadata_sql = adapter.job_metadata_columns() attribute_sql.extend(job_metadata_sql) if not primary_key: - raise DataJointError("Table must have a primary key") + # Singleton table: add hidden sentinel attribute + primary_key = ["_singleton"] + singleton_comment = ":bool:singleton primary key" + sql_type = adapter.core_type_to_sql("bool") + singleton_sql = adapter.format_column_definition( + name="_singleton", + sql_type=sql_type, + nullable=False, + default="NOT NULL DEFAULT TRUE", + comment=singleton_comment, + ) + attribute_sql.insert(0, singleton_sql) + column_comments["_singleton"] = singleton_comment + + pre_ddl = [] # DDL to run BEFORE CREATE TABLE (e.g., CREATE TYPE for enums) + post_ddl = [] # DDL to run AFTER CREATE TABLE (e.g., COMMENT ON) + # Get pending enum type DDL for PostgreSQL (must run before CREATE TABLE) + if schema_name and hasattr(adapter, "get_pending_enum_ddl"): + pre_ddl.extend(adapter.get_pending_enum_ddl(schema_name)) + + # Build PRIMARY KEY clause using adapter + pk_cols = ", ".join(adapter.quote_identifier(pk) for pk in primary_key) + pk_clause = f"PRIMARY KEY ({pk_cols})" + + # Handle indexes - inline for MySQL, separate CREATE INDEX for PostgreSQL + if adapter.supports_inline_indexes: + # MySQL: include indexes in CREATE TABLE + create_table_indexes = index_sql + else: + # PostgreSQL: convert to CREATE INDEX statements for post_ddl + create_table_indexes = [] + for idx_def in index_sql: + # Parse index definition: "unique index (cols)" or "index (cols)" + idx_match = re.match(r"(unique\s+)?index\s*\(([^)]+)\)", idx_def, re.I) + if idx_match: + is_unique = idx_match.group(1) is not None + # Extract column names (may be quoted or have expressions) + cols_str = idx_match.group(2) + # Simple split on comma - columns are already quoted + columns = [c.strip().strip('`"') for c in cols_str.split(",")] + # Generate CREATE INDEX DDL + create_idx_ddl = adapter.create_index_ddl(full_table_name, columns, unique=is_unique) + post_ddl.append(create_idx_ddl) + + # Assemble CREATE TABLE sql = ( - "CREATE TABLE IF NOT EXISTS %s (\n" % full_table_name - + ",\n".join(attribute_sql + ["PRIMARY KEY (`" + "`,`".join(primary_key) + "`)"] + foreign_key_sql + index_sql) - + '\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment + f"CREATE TABLE IF NOT EXISTS {full_table_name} (\n" + + ",\n".join(attribute_sql + [pk_clause] + foreign_key_sql + create_table_indexes) + + f"\n) {adapter.table_options_clause(table_comment)}" ) - return sql, external_stores, primary_key, fk_attribute_map + + # Add table-level comment DDL if needed (PostgreSQL) + table_comment_ddl = adapter.table_comment_ddl(full_table_name, table_comment) + if table_comment_ddl: + post_ddl.append(table_comment_ddl) + + # Add column-level comments DDL if needed (PostgreSQL) + # Column comments contain type specifications like ::user_comment + for col_name, comment in column_comments.items(): + col_comment_ddl = adapter.column_comment_ddl(full_table_name, col_name, comment) + if col_comment_ddl: + post_ddl.append(col_comment_ddl) + + return sql, external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl -def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str]) -> list[str]: +def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str], adapter) -> list[str]: """ Generate SQL ALTER commands for attribute changes. @@ -448,6 +552,8 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] Old attribute SQL declarations. primary_key : list[str] Primary key attribute names (cannot be altered). + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- @@ -459,8 +565,9 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] DataJointError If an attribute is renamed twice or renamed from non-existent attribute. """ - # parse attribute names - name_regexp = re.compile(r"^`(?P\w+)`") + # parse attribute names - use adapter's quote character + quote_char = re.escape(adapter.quote_identifier("x")[0]) + name_regexp = re.compile(rf"^{quote_char}(?P\w+){quote_char}") original_regexp = re.compile(r'COMMENT "{\s*(?P\w+)\s*}') matched = ((name_regexp.match(d), original_regexp.search(d)) for d in new) new_names = dict((d.group("name"), n and n.group("name")) for d, n in matched) @@ -486,7 +593,7 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] # dropping attributes to_drop = [n for n in old_names if n not in renamed and n not in new_names] - sql = ["DROP `%s`" % n for n in to_drop] + sql = [f"DROP {adapter.quote_identifier(n)}" for n in to_drop] old_names = [name for name in old_names if name not in to_drop] # add or change attributes in order @@ -503,25 +610,24 @@ def _make_attribute_alter(new: list[str], old: list[str], primary_key: list[str] if idx >= 1 and old_names[idx - 1] != (prev[1] or prev[0]): after = prev[0] if new_def not in old or after: - sql.append( - "{command} {new_def} {after}".format( - command=( - "ADD" - if (old_name or new_name) not in old_names - else "MODIFY" - if not old_name - else "CHANGE `%s`" % old_name - ), - new_def=new_def, - after="" if after is None else "AFTER `%s`" % after, - ) - ) + # Determine command type + if (old_name or new_name) not in old_names: + command = "ADD" + elif not old_name: + command = "MODIFY" + else: + command = f"CHANGE {adapter.quote_identifier(old_name)}" + + # Build after clause + after_clause = "" if after is None else f"AFTER {adapter.quote_identifier(after)}" + + sql.append(f"{command} {new_def} {after_clause}") prev = new_name, old_name return sql -def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str], list[str]]: +def alter(definition: str, old_definition: str, context: dict, adapter) -> tuple[list[str], list[str]]: """ Generate SQL ALTER commands for table definition changes. @@ -533,6 +639,8 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str Current table definition. context : dict Namespace for resolving foreign key references. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- @@ -555,7 +663,8 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str index_sql, external_stores, _fk_attribute_map, - ) = prepare_declare(definition, context) + _column_comments, + ) = prepare_declare(definition, context, adapter) ( table_comment_, primary_key_, @@ -564,7 +673,8 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str index_sql_, external_stores_, _fk_attribute_map_, - ) = prepare_declare(old_definition, context) + _column_comments_, + ) = prepare_declare(old_definition, context, adapter) # analyze differences between declarations sql = list() @@ -575,9 +685,12 @@ def alter(definition: str, old_definition: str, context: dict) -> tuple[list[str if index_sql != index_sql_: raise NotImplementedError("table.alter cannot alter indexes (yet)") if attribute_sql != attribute_sql_: - sql.extend(_make_attribute_alter(attribute_sql, attribute_sql_, primary_key)) + sql.extend(_make_attribute_alter(attribute_sql, attribute_sql_, primary_key, adapter)) if table_comment != table_comment_: - sql.append('COMMENT="%s"' % table_comment) + # For MySQL: COMMENT="new comment" + # For PostgreSQL: would need COMMENT ON TABLE, but that's not an ALTER TABLE clause + # Keep MySQL syntax for now (ALTER TABLE ... COMMENT="...") + sql.append(f'COMMENT="{table_comment}"') return sql, [e for e in external_stores if e not in external_stores_] @@ -620,7 +733,7 @@ def _parse_index_args(args: str) -> list[str]: return [arg for arg in result if arg] # Filter empty strings -def compile_index(line: str, index_sql: list[str]) -> None: +def compile_index(line: str, index_sql: list[str], adapter) -> None: """ Parse an index declaration and append SQL to index_sql. @@ -631,6 +744,8 @@ def compile_index(line: str, index_sql: list[str]) -> None: ``"unique index(attr)"``). index_sql : list[str] List of index SQL declarations. Updated in place. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Raises ------ @@ -639,11 +754,11 @@ def compile_index(line: str, index_sql: list[str]) -> None: """ def format_attribute(attr): - match, attr = translate_attribute(attr) + match, attr = translate_attribute(attr, adapter) if match is None: return attr if match["path"] is None: - return f"`{attr}`" + return adapter.quote_identifier(attr) return f"({attr})" match = re.match(r"(?Punique\s+)?index\s*\(\s*(?P.*)\)", line, re.I) @@ -660,7 +775,7 @@ def format_attribute(attr): ) -def substitute_special_type(match: dict, category: str, foreign_key_sql: list[str], context: dict) -> None: +def substitute_special_type(match: dict, category: str, foreign_key_sql: list[str], context: dict, adapter) -> None: """ Substitute special types with their native SQL equivalents. @@ -679,6 +794,8 @@ def substitute_special_type(match: dict, category: str, foreign_key_sql: list[st Foreign key declarations (unused, kept for API compatibility). context : dict Namespace for codec lookup (unused, kept for API compatibility). + adapter : DatabaseAdapter + Database adapter for backend-specific type mapping. """ if category == "CODEC": # Codec - resolve to underlying dtype @@ -699,11 +816,11 @@ def substitute_special_type(match: dict, category: str, foreign_key_sql: list[st # Recursively resolve if dtype is also a special type category = match_type(match["type"]) if category in SPECIAL_TYPES: - substitute_special_type(match, category, foreign_key_sql, context) + substitute_special_type(match, category, foreign_key_sql, context, adapter) elif category in CORE_TYPE_NAMES: - # Core DataJoint type - substitute with native SQL type if mapping exists - core_name = category.lower() - sql_type = CORE_TYPE_SQL.get(core_name) + # Core DataJoint type - substitute with native SQL type using adapter + # Pass the full type string (e.g., "varchar(255)") not just category name + sql_type = adapter.core_type_to_sql(match["type"]) if sql_type is not None: match["type"] = sql_type # else: type passes through as-is (json, date, datetime, char, varchar, enum) @@ -711,7 +828,9 @@ def substitute_special_type(match: dict, category: str, foreign_key_sql: list[st raise DataJointError(f"Unknown special type: {category}") -def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], context: dict) -> tuple[str, str, str | None]: +def compile_attribute( + line: str, in_key: bool, foreign_key_sql: list[str], context: dict, adapter +) -> tuple[str, str, str | None, str | None]: """ Convert an attribute definition from DataJoint format to SQL. @@ -725,15 +844,18 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte Foreign key declarations (passed to type substitution). context : dict Namespace for codec lookup. + adapter : DatabaseAdapter + Database adapter for backend-specific SQL generation. Returns ------- tuple - Three-element tuple: + Four-element tuple: - name : str - Attribute name - sql : str - SQL column declaration - store : str or None - External store name if applicable + - comment : str or None - Column comment (for PostgreSQL COMMENT ON) Raises ------ @@ -760,8 +882,22 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte match["default"] = "DEFAULT NULL" # nullable attributes default to null else: if match["default"]: - quote = match["default"].split("(")[0].upper() not in CONSTANT_LITERALS and match["default"][0] not in "\"'" - match["default"] = "NOT NULL DEFAULT " + ('"%s"' if quote else "%s") % match["default"] + default_val = match["default"] + base_val = default_val.split("(")[0].upper() + + if base_val in CONSTANT_LITERALS: + # SQL constants like NULL, CURRENT_TIMESTAMP - use as-is + match["default"] = f"NOT NULL DEFAULT {default_val}" + elif default_val.startswith('"') and default_val.endswith('"'): + # Double-quoted string - convert to single quotes for PostgreSQL + inner = default_val[1:-1].replace("'", "''") # Escape single quotes + match["default"] = f"NOT NULL DEFAULT '{inner}'" + elif default_val.startswith("'"): + # Already single-quoted - use as-is + match["default"] = f"NOT NULL DEFAULT {default_val}" + else: + # Unquoted value - wrap in single quotes + match["default"] = f"NOT NULL DEFAULT '{default_val}'" else: match["default"] = "NOT NULL" @@ -775,7 +911,7 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte if category in SPECIAL_TYPES: # Core types and Codecs are recorded in comment for reconstruction match["comment"] = ":{type}:{comment}".format(**match) - substitute_special_type(match, category, foreign_key_sql, context) + substitute_special_type(match, category, foreign_key_sql, context, adapter) elif category in NATIVE_TYPES: # Native type - warn user logger.warning( @@ -789,5 +925,12 @@ def compile_attribute(line: str, in_key: bool, foreign_key_sql: list[str], conte if ("blob" in final_type) and match["default"] not in {"DEFAULT NULL", "NOT NULL"}: raise DataJointError("The default value for blob attributes can only be NULL in:\n{line}".format(line=line)) - sql = ("`{name}` {type} {default}" + (' COMMENT "{comment}"' if match["comment"] else "")).format(**match) - return match["name"], sql, match.get("store") + # Use adapter to format column definition + sql = adapter.format_column_definition( + name=match["name"], + sql_type=match["type"], + nullable=match["nullable"], + default=match["default"] if match["default"] else None, + comment=match["comment"] if match["comment"] else None, + ) + return match["name"], sql, match.get("store"), match["comment"] if match["comment"] else None diff --git a/src/datajoint/dependencies.py b/src/datajoint/dependencies.py index 621011426..83162a112 100644 --- a/src/datajoint/dependencies.py +++ b/src/datajoint/dependencies.py @@ -31,8 +31,14 @@ def extract_master(part_table: str) -> str | None: str or None Master table name if part_table is a part table, None otherwise. """ - match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) - return match["master"] + "`" if match else None + # Match both MySQL backticks and PostgreSQL double quotes + # MySQL: `schema`.`master__part` + # PostgreSQL: "schema"."master__part" + match = re.match(r'(?P(?P[`"])[\w]+(?P=q)\.(?P=q)#?[\w]+)__[\w]+(?P=q)', part_table) + if match: + q = match["q"] + return match["master"] + q + return None def topo_sort(graph: nx.DiGraph) -> list[str]: @@ -131,6 +137,7 @@ def __init__(self, connection=None) -> None: def clear(self) -> None: """Clear the graph and reset loaded state.""" self._loaded = False + self._node_alias_count = itertools.count() # reset alias IDs for consistency super().clear() def load(self, force: bool = True) -> None: @@ -151,39 +158,105 @@ def load(self, force: bool = True) -> None: self.clear() - # load primary key info - keys = self._conn.query( - """ - SELECT - concat('`', table_schema, '`.`', table_name, '`') as tab, column_name + # Get adapter for backend-specific SQL generation + adapter = self._conn.adapter + + # Build schema list for IN clause + schemas_list = ", ".join(adapter.quote_string(s) for s in self._conn.schemas) + + # Backend-specific queries for primary keys and foreign keys + # Note: Both PyMySQL and psycopg2 use %s placeholders, so escape % as %% + like_pattern = "'~%%'" + + if adapter.backend == "mysql": + # MySQL: use concat() and MySQL-specific information_schema columns + tab_expr = "concat('`', table_schema, '`.`', table_name, '`')" + + # load primary key info (MySQL uses constraint_name='PRIMARY') + keys = self._conn.query( + f""" + SELECT {tab_expr} as tab, column_name FROM information_schema.key_column_usage - WHERE table_name not LIKE "~%%" AND table_schema in ('{schemas}') AND constraint_name="PRIMARY" - """.format(schemas="','".join(self._conn.schemas)) - ) - pks = defaultdict(set) - for key in keys: - pks[key[0]].add(key[1]) + WHERE table_name NOT LIKE {like_pattern} + AND table_schema in ({schemas_list}) + AND constraint_name='PRIMARY' + """ + ) + pks = defaultdict(set) + for key in keys: + pks[key[0]].add(key[1]) + + # load foreign keys (MySQL has referenced_* columns) + ref_tab_expr = "concat('`', referenced_table_schema, '`.`', referenced_table_name, '`')" + fk_keys = self._conn.query( + f""" + SELECT constraint_name, + {tab_expr} as referencing_table, + {ref_tab_expr} as referenced_table, + column_name, referenced_column_name + FROM information_schema.key_column_usage + WHERE referenced_table_name NOT LIKE {like_pattern} + AND (referenced_table_schema in ({schemas_list}) + OR referenced_table_schema is not NULL AND table_schema in ({schemas_list})) + """, + as_dict=True, + ) + else: + # PostgreSQL: use || concatenation and different query structure + tab_expr = "'\"' || kcu.table_schema || '\".\"' || kcu.table_name || '\"'" + + # load primary key info (PostgreSQL uses constraint_type='PRIMARY KEY') + keys = self._conn.query( + f""" + SELECT {tab_expr} as tab, kcu.column_name + FROM information_schema.key_column_usage kcu + JOIN information_schema.table_constraints tc + ON kcu.constraint_name = tc.constraint_name + AND kcu.table_schema = tc.table_schema + WHERE kcu.table_name NOT LIKE {like_pattern} + AND kcu.table_schema in ({schemas_list}) + AND tc.constraint_type = 'PRIMARY KEY' + """ + ) + pks = defaultdict(set) + for key in keys: + pks[key[0]].add(key[1]) + + # load foreign keys using pg_constraint system catalogs + # The information_schema approach creates a Cartesian product for composite FKs + # because constraint_column_usage doesn't have ordinal_position. + # Using pg_constraint with unnest(conkey, confkey) WITH ORDINALITY gives correct mapping. + fk_keys = self._conn.query( + f""" + SELECT + c.conname as constraint_name, + '"' || ns1.nspname || '"."' || cl1.relname || '"' as referencing_table, + '"' || ns2.nspname || '"."' || cl2.relname || '"' as referenced_table, + a1.attname as column_name, + a2.attname as referenced_column_name + FROM pg_constraint c + JOIN pg_class cl1 ON c.conrelid = cl1.oid + JOIN pg_namespace ns1 ON cl1.relnamespace = ns1.oid + JOIN pg_class cl2 ON c.confrelid = cl2.oid + JOIN pg_namespace ns2 ON cl2.relnamespace = ns2.oid + CROSS JOIN LATERAL unnest(c.conkey, c.confkey) WITH ORDINALITY AS cols(conkey, confkey, ord) + JOIN pg_attribute a1 ON a1.attrelid = cl1.oid AND a1.attnum = cols.conkey + JOIN pg_attribute a2 ON a2.attrelid = cl2.oid AND a2.attnum = cols.confkey + WHERE c.contype = 'f' + AND cl1.relname NOT LIKE {like_pattern} + AND (ns2.nspname in ({schemas_list}) + OR ns1.nspname in ({schemas_list})) + ORDER BY c.conname, cols.ord + """, + as_dict=True, + ) # add nodes to the graph for n, pk in pks.items(): self.add_node(n, primary_key=pk) - # load foreign keys - keys = ( - {k.lower(): v for k, v in elem.items()} - for elem in self._conn.query( - """ - SELECT constraint_name, - concat('`', table_schema, '`.`', table_name, '`') as referencing_table, - concat('`', referenced_table_schema, '`.`', referenced_table_name, '`') as referenced_table, - column_name, referenced_column_name - FROM information_schema.key_column_usage - WHERE referenced_table_name NOT LIKE "~%%" AND (referenced_table_schema in ('{schemas}') OR - referenced_table_schema is not NULL AND table_schema in ('{schemas}')) - """.format(schemas="','".join(self._conn.schemas)), - as_dict=True, - ) - ) + # Process foreign keys (same for both backends) + keys = ({k.lower(): v for k, v in elem.items()} for elem in fk_keys) fks = defaultdict(lambda: dict(attr_map=dict())) for key in keys: d = fks[ diff --git a/src/datajoint/diagram.py b/src/datajoint/diagram.py index c52340f46..48e18fd0d 100644 --- a/src/datajoint/diagram.py +++ b/src/datajoint/diagram.py @@ -16,6 +16,7 @@ from .dependencies import topo_sort from .errors import DataJointError +from .settings import config from .table import Table, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _AliasNode, _get_tier @@ -90,12 +91,19 @@ class Diagram(nx.DiGraph): ----- ``diagram + 1 - 1`` may differ from ``diagram - 1 + 1``. Only tables loaded in the connection are displayed. + + Layout direction is controlled via ``dj.config.display.diagram_direction`` + (default ``"TB"``). Use ``dj.config.override()`` to change temporarily:: + + with dj.config.override(display_diagram_direction="LR"): + dj.Diagram(schema).draw() """ def __init__(self, source, context=None) -> None: if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) + self._expanded_nodes = set(source._expanded_nodes) self.context = source.context super().__init__(source) return @@ -134,8 +142,11 @@ def __init__(self, source, context=None) -> None: except AttributeError: raise DataJointError("Cannot plot Diagram for %s" % repr(source)) for node in self: - if node.startswith("`%s`" % database): + # Handle both MySQL backticks and PostgreSQL double quotes + if node.startswith("`%s`" % database) or node.startswith('"%s"' % database): self.nodes_to_show.add(node) + # All nodes start as expanded + self._expanded_nodes = set(self.nodes_to_show) @classmethod def from_sequence(cls, sequence) -> "Diagram": @@ -173,6 +184,34 @@ def is_part(part, master): self.nodes_to_show.update(n for n in self.nodes() if any(is_part(n, m) for m in self.nodes_to_show)) return self + def collapse(self) -> "Diagram": + """ + Mark all nodes in this diagram as collapsed. + + Collapsed nodes are shown as a single node per schema. When combined + with other diagrams using ``+``, expanded nodes win: if a node is + expanded in either operand, it remains expanded in the result. + + Returns + ------- + Diagram + A copy of this diagram with all nodes collapsed. + + Examples + -------- + >>> # Show schema1 expanded, schema2 collapsed into single nodes + >>> dj.Diagram(schema1) + dj.Diagram(schema2).collapse() + + >>> # Collapse all three schemas together + >>> (dj.Diagram(schema1) + dj.Diagram(schema2) + dj.Diagram(schema3)).collapse() + + >>> # Expand one table from collapsed schema + >>> dj.Diagram(schema).collapse() + dj.Diagram(SingleTable) + """ + result = Diagram(self) + result._expanded_nodes = set() # All nodes collapsed + return result + def __add__(self, arg) -> "Diagram": """ Union or downstream expansion. @@ -187,21 +226,31 @@ def __add__(self, arg) -> "Diagram": Diagram Combined or expanded diagram. """ - self = Diagram(self) # copy + result = Diagram(self) # copy try: - self.nodes_to_show.update(arg.nodes_to_show) + # Merge nodes and edges from the other diagram + result.add_nodes_from(arg.nodes(data=True)) + result.add_edges_from(arg.edges(data=True)) + result.nodes_to_show.update(arg.nodes_to_show) + # Merge contexts for class name lookups + result.context = {**result.context, **arg.context} + # Expanded wins: union of expanded nodes from both operands + result._expanded_nodes = self._expanded_nodes | arg._expanded_nodes except AttributeError: try: - self.nodes_to_show.add(arg.full_table_name) + result.nodes_to_show.add(arg.full_table_name) + result._expanded_nodes.add(arg.full_table_name) except AttributeError: for i in range(arg): - new = nx.algorithms.boundary.node_boundary(self, self.nodes_to_show) + new = nx.algorithms.boundary.node_boundary(result, result.nodes_to_show) if not new: break # add nodes referenced by aliased nodes - new.update(nx.algorithms.boundary.node_boundary(self, (a for a in new if a.isdigit()))) - self.nodes_to_show.update(new) - return self + new.update(nx.algorithms.boundary.node_boundary(result, (a for a in new if a.isdigit()))) + result.nodes_to_show.update(new) + # New nodes from expansion are expanded + result._expanded_nodes = result._expanded_nodes | result.nodes_to_show + return result def __sub__(self, arg) -> "Diagram": """ @@ -274,7 +323,9 @@ def _make_graph(self) -> nx.DiGraph: """ # mark "distinguished" tables, i.e. those that introduce new primary key # attributes - for name in self.nodes_to_show: + # Filter nodes_to_show to only include nodes that exist in the graph + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + for name in valid_nodes: foreign_attributes = set( attr for p in self.in_edges(name, data=True) for attr in p[2]["attr_map"] if p[2]["primary"] ) @@ -282,21 +333,210 @@ def _make_graph(self) -> nx.DiGraph: "primary_key" in self.nodes[name] and foreign_attributes < self.nodes[name]["primary_key"] ) # include aliased nodes that are sandwiched between two displayed nodes - gaps = set(nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)).intersection( - nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), self.nodes_to_show) + gaps = set(nx.algorithms.boundary.node_boundary(self, valid_nodes)).intersection( + nx.algorithms.boundary.node_boundary(nx.DiGraph(self).reverse(), valid_nodes) ) - nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit) + nodes = valid_nodes.union(a for a in gaps if a.isdigit()) # construct subgraph and rename nodes to class names graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes)) nx.set_node_attributes(graph, name="node_type", values={n: _get_tier(n) for n in graph}) # relabel nodes to class names mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()} - new_names = [mapping.values()] + new_names = list(mapping.values()) if len(new_names) > len(set(new_names)): raise DataJointError("Some classes have identical names. The Diagram cannot be plotted.") nx.relabel_nodes(graph, mapping, copy=False) return graph + def _apply_collapse(self, graph: nx.DiGraph) -> tuple[nx.DiGraph, dict[str, str]]: + """ + Apply collapse logic to the graph. + + Nodes in nodes_to_show but not in _expanded_nodes are collapsed into + single schema nodes. + + Parameters + ---------- + graph : nx.DiGraph + The graph from _make_graph(). + + Returns + ------- + tuple[nx.DiGraph, dict[str, str]] + Modified graph and mapping of collapsed schema labels to their table count. + """ + # Filter to valid nodes (those that exist in the underlying graph) + valid_nodes = self.nodes_to_show.intersection(set(self.nodes())) + valid_expanded = self._expanded_nodes.intersection(set(self.nodes())) + + # If all nodes are expanded, no collapse needed + if valid_expanded >= valid_nodes: + return graph, {} + + # Map full_table_names to class_names + full_to_class = {node: lookup_class_name(node, self.context) or node for node in valid_nodes} + class_to_full = {v: k for k, v in full_to_class.items()} + + # Identify expanded class names + expanded_class_names = {full_to_class.get(node, node) for node in valid_expanded} + + # Identify nodes to collapse (class names) + nodes_to_collapse = set(graph.nodes()) - expanded_class_names + + if not nodes_to_collapse: + return graph, {} + + # Group collapsed nodes by schema + collapsed_by_schema = {} # schema_name -> list of class_names + for class_name in nodes_to_collapse: + full_name = class_to_full.get(class_name) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] + if schema_name not in collapsed_by_schema: + collapsed_by_schema[schema_name] = [] + collapsed_by_schema[schema_name].append(class_name) + + if not collapsed_by_schema: + return graph, {} + + # Determine labels for collapsed schemas + schema_modules = {} + for schema_name, class_names in collapsed_by_schema.items(): + schema_modules[schema_name] = set() + for class_name in class_names: + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Collect module names for ALL schemas in the diagram (not just collapsed) + all_schema_modules = {} # schema_name -> module_name + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + db_schema = parts[1] + cls = self._resolve_class(node) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + all_schema_modules[db_schema] = module_name + + # Check which module names are shared by multiple schemas + module_to_schemas = {} + for db_schema, module_name in all_schema_modules.items(): + if module_name not in module_to_schemas: + module_to_schemas[module_name] = [] + module_to_schemas[module_name].append(db_schema) + + ambiguous_modules = {m for m, schemas in module_to_schemas.items() if len(schemas) > 1} + + # Determine labels for collapsed schemas + collapsed_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + module_name = next(iter(modules)) + # Use database schema name if module is ambiguous + if module_name in ambiguous_modules: + label = schema_name + else: + label = module_name + else: + label = schema_name + collapsed_labels[schema_name] = label + + # Build counts using final labels + collapsed_counts = {} # label -> count of tables + for schema_name, class_names in collapsed_by_schema.items(): + label = collapsed_labels[schema_name] + collapsed_counts[label] = len(class_names) + + # Create new graph with collapsed nodes + new_graph = nx.DiGraph() + + # Map old node names to new names (collapsed nodes -> schema label) + node_mapping = {} + for node in graph.nodes(): + full_name = class_to_full.get(node) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2 and node in nodes_to_collapse: + schema_name = parts[1] + node_mapping[node] = collapsed_labels[schema_name] + else: + node_mapping[node] = node + else: + # Alias nodes - check if they should be collapsed + # An alias node should be collapsed if ALL its neighbors are collapsed + neighbors = set(graph.predecessors(node)) | set(graph.successors(node)) + if neighbors and neighbors <= nodes_to_collapse: + # Get schema from first neighbor + neighbor = next(iter(neighbors)) + full_name = class_to_full.get(neighbor) + if full_name: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] + node_mapping[node] = collapsed_labels[schema_name] + continue + node_mapping[node] = node + + # Build reverse mapping: label -> schema_name + label_to_schema = {label: schema for schema, label in collapsed_labels.items()} + + # Add nodes + added_collapsed = set() + for old_node, new_node in node_mapping.items(): + if new_node in collapsed_counts: + # This is a collapsed schema node + if new_node not in added_collapsed: + schema_name = label_to_schema.get(new_node, new_node) + new_graph.add_node( + new_node, + node_type=None, + collapsed=True, + table_count=collapsed_counts[new_node], + schema_name=schema_name, + ) + added_collapsed.add(new_node) + else: + new_graph.add_node(new_node, **graph.nodes[old_node]) + + # Add edges (avoiding self-loops and duplicates) + for src, dest, data in graph.edges(data=True): + new_src = node_mapping[src] + new_dest = node_mapping[dest] + if new_src != new_dest and not new_graph.has_edge(new_src, new_dest): + new_graph.add_edge(new_src, new_dest, **data) + + return new_graph, collapsed_counts + + def _resolve_class(self, name: str): + """ + Safely resolve a table class from a dotted name without eval(). + + Parameters + ---------- + name : str + Dotted class name like "MyTable" or "Module.MyTable". + + Returns + ------- + type or None + The table class if found, otherwise None. + """ + parts = name.split(".") + obj = self.context.get(parts[0]) + for part in parts[1:]: + if obj is None: + return None + obj = getattr(obj, part, None) + if obj is not None and isinstance(obj, type) and issubclass(obj, Table): + return obj + return None + @staticmethod def _encapsulate_edge_attributes(graph: nx.DiGraph) -> None: """ @@ -330,8 +570,78 @@ def _encapsulate_node_names(graph: nx.DiGraph) -> None: ) def make_dot(self): + """ + Generate a pydot graph object. + + Returns + ------- + pydot.Dot + The graph object ready for rendering. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. + """ + direction = config.display.diagram_direction graph = self._make_graph() - graph.nodes() + + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + + # Build schema mapping: class_name -> schema_name + # Group by database schema, label with Python module name if 1:1 mapping + schema_map = {} # class_name -> schema_name + schema_modules = {} # schema_name -> set of module names + + for full_name in self.nodes_to_show: + # Extract schema from full table name like `schema`.`table` or "schema"."table" + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] # schema is between first pair of backticks + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} # schema_name -> label + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name + + # Disambiguate labels if multiple schemas share the same module name + # (e.g., all defined in __main__ in a notebook) + label_counts = {} + for label in cluster_labels.values(): + label_counts[label] = label_counts.get(label, 0) + 1 + + for schema_name, label in cluster_labels.items(): + if label_counts[label] > 1: + # Multiple schemas share this module name - add schema name + cluster_labels[schema_name] = f"{label} ({schema_name})" + + # Assign alias nodes (orange dots) to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + # Find the child (successor) - the table that declares the renamed FK + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + + # Assign collapsed nodes to their schema so they appear in the cluster + for node, data in graph.nodes(data=True): + if data.get("collapsed") and data.get("schema_name"): + schema_map[node] = data["schema_name"] scale = 1.2 # scaling factor for fonts and boxes label_props = { # http://matplotlib.org/examples/color/named_colors.html @@ -372,8 +682,8 @@ def make_dot(self): color="#FF000020", fontcolor="#7F0000A0", fontsize=round(scale * 10), - size=0.3 * scale, - fixed=True, + size=0.4 * scale, + fixed=False, ), Imported: dict( shape="ellipse", @@ -385,18 +695,33 @@ def make_dot(self): ), Part: dict( shape="plaintext", - color="#0000000", + color="#00000000", fontcolor="black", fontsize=round(scale * 8), size=0.1 * scale, fixed=False, ), + "collapsed": dict( + shape="box3d", + color="#80808060", + fontcolor="#404040", + fontsize=round(scale * 10), + size=0.5 * scale, + fixed=False, + ), } - node_props = {node: label_props[d["node_type"]] for node, d in dict(graph.nodes(data=True)).items()} + # Build node_props, handling collapsed nodes specially + node_props = {} + for node, d in graph.nodes(data=True): + if d.get("collapsed"): + node_props[node] = label_props["collapsed"] + else: + node_props[node] = label_props[d["node_type"]] self._encapsulate_node_names(graph) self._encapsulate_edge_attributes(graph) dot = nx.drawing.nx_pydot.to_pydot(graph) + dot.set_rankdir(direction) for node in dot.get_nodes(): node.set_shape("circle") name = node.get_name().strip('"') @@ -408,17 +733,36 @@ def make_dot(self): node.set_fixedsize("shape" if props["fixed"] else False) node.set_width(props["size"]) node.set_height(props["size"]) - if name.split(".")[0] in self.context: - cls = eval(name, self.context) - assert issubclass(cls, Table) - description = cls().describe(context=self.context).split("\n") - description = ( - ("-" * 30 if q.startswith("---") else (q.replace("->", "→") if "->" in q else q.split(":")[0])) - for q in description - if not q.startswith("#") - ) - node.set_tooltip(" ".join(description)) - node.set_label("<" + name + ">" if node.get("distinguished") == "True" else name) + + # Handle collapsed nodes specially + node_data = graph.nodes.get(f'"{name}"', {}) + if node_data.get("collapsed"): + table_count = node_data.get("table_count", 0) + label = f"({table_count} tables)" if table_count != 1 else "(1 table)" + node.set_label(label) + node.set_tooltip(f"Collapsed schema: {table_count} tables") + else: + cls = self._resolve_class(name) + if cls is not None: + description = cls().describe(context=self.context).split("\n") + description = ( + ( + "-" * 30 + if q.startswith("---") + else (q.replace("->", "→") if "->" in q else q.split(":")[0]) + ) + for q in description + if not q.startswith("#") + ) + node.set_tooltip(" ".join(description)) + # Strip module prefix from label if it matches the cluster label + display_name = name + schema_name = schema_map.get(name) + if schema_name and "." in name: + prefix = name.rsplit(".", 1)[0] + if prefix == cluster_labels.get(schema_name): + display_name = name.rsplit(".", 1)[1] + node.set_label("<" + display_name + ">" if node.get("distinguished") == "True" else display_name) node.set_color(props["color"]) node.set_style("filled") @@ -430,11 +774,41 @@ def make_dot(self): if props is None: raise DataJointError("Could not find edge with source '{}' and destination '{}'".format(src, dest)) edge.set_color("#00000040") - edge.set_style("solid" if props["primary"] else "dashed") - master_part = graph.nodes[dest]["node_type"] is Part and dest.startswith(src + ".") + edge.set_style("solid" if props.get("primary") else "dashed") + dest_node_type = graph.nodes[dest].get("node_type") + master_part = dest_node_type is Part and dest.startswith(src + ".") edge.set_weight(3 if master_part else 1) edge.set_arrowhead("none") - edge.set_penwidth(0.75 if props["multi"] else 2) + edge.set_penwidth(0.75 if props.get("multi") else 2) + + # Group nodes into schema clusters (always on) + if schema_map: + import pydot + + # Group nodes by schema + schemas = {} + for node in list(dot.get_nodes()): + name = node.get_name().strip('"') + schema_name = schema_map.get(name) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append(node) + + # Create clusters for each schema + # Use Python module name if 1:1 mapping, otherwise database schema name + for schema_name, nodes in schemas.items(): + label = cluster_labels.get(schema_name, schema_name) + cluster = pydot.Cluster( + f"cluster_{schema_name}", + label=label, + style="dashed", + color="gray", + fontcolor="gray", + ) + for node in nodes: + cluster.add_node(node) + dot.add_subgraph(cluster) return dot @@ -452,6 +826,159 @@ def make_image(self): else: raise DataJointError("pyplot was not imported") + def make_mermaid(self) -> str: + """ + Generate Mermaid diagram syntax. + + Produces a flowchart in Mermaid syntax that can be rendered in + Markdown documentation, GitHub, or https://mermaid.live. + + Returns + ------- + str + Mermaid flowchart syntax. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema using Mermaid subgraphs, with the Python + module name shown as the group label when available. + + Examples + -------- + >>> print(dj.Diagram(schema).make_mermaid()) + flowchart TB + subgraph my_pipeline + Mouse[Mouse]:::manual + Session[Session]:::manual + Neuron([Neuron]):::computed + end + Mouse --> Session + Session --> Neuron + """ + graph = self._make_graph() + direction = config.display.diagram_direction + + # Apply collapse logic if needed + graph, collapsed_counts = self._apply_collapse(graph) + + # Build schema mapping for grouping + schema_map = {} # class_name -> schema_name + schema_modules = {} # schema_name -> set of module names + + for full_name in self.nodes_to_show: + parts = full_name.replace('"', "`").split("`") + if len(parts) >= 2: + schema_name = parts[1] + class_name = lookup_class_name(full_name, self.context) or full_name + schema_map[class_name] = schema_name + + # Collect all module names for this schema + if schema_name not in schema_modules: + schema_modules[schema_name] = set() + cls = self._resolve_class(class_name) + if cls is not None and hasattr(cls, "__module__"): + module_name = cls.__module__.split(".")[-1] + schema_modules[schema_name].add(module_name) + + # Determine cluster labels: use module name if 1:1, else database schema name + cluster_labels = {} + for schema_name, modules in schema_modules.items(): + if len(modules) == 1: + cluster_labels[schema_name] = next(iter(modules)) + else: + cluster_labels[schema_name] = schema_name + + # Assign alias nodes to the same schema as their child table + for node, data in graph.nodes(data=True): + if data.get("node_type") is _AliasNode: + successors = list(graph.successors(node)) + if successors and successors[0] in schema_map: + schema_map[node] = schema_map[successors[0]] + + lines = [f"flowchart {direction}"] + + # Define class styles matching Graphviz colors + lines.append(" classDef manual fill:#90EE90,stroke:#006400") + lines.append(" classDef lookup fill:#D3D3D3,stroke:#696969") + lines.append(" classDef computed fill:#FFB6C1,stroke:#8B0000") + lines.append(" classDef imported fill:#ADD8E6,stroke:#00008B") + lines.append(" classDef part fill:#FFFFFF,stroke:#000000") + lines.append(" classDef collapsed fill:#808080,stroke:#404040") + lines.append("") + + # Shape mapping: Manual=box, Computed/Imported=stadium, Lookup/Part=box + shape_map = { + Manual: ("[", "]"), # box + Lookup: ("[", "]"), # box + Computed: ("([", "])"), # stadium/pill + Imported: ("([", "])"), # stadium/pill + Part: ("[", "]"), # box + _AliasNode: ("((", "))"), # circle + None: ("((", "))"), # circle + } + + tier_class = { + Manual: "manual", + Lookup: "lookup", + Computed: "computed", + Imported: "imported", + Part: "part", + _AliasNode: "", + None: "", + } + + # Group nodes by schema into subgraphs (including collapsed nodes) + schemas = {} + for node, data in graph.nodes(data=True): + if data.get("collapsed"): + # Collapsed nodes use their schema_name attribute + schema_name = data.get("schema_name") + else: + schema_name = schema_map.get(node) + if schema_name: + if schema_name not in schemas: + schemas[schema_name] = [] + schemas[schema_name].append((node, data)) + + # Add nodes grouped by schema subgraphs + for schema_name, nodes in schemas.items(): + label = cluster_labels.get(schema_name, schema_name) + lines.append(f" subgraph {label}") + for node, data in nodes: + safe_id = node.replace(".", "_").replace(" ", "_") + if data.get("collapsed"): + # Collapsed node - show only table count + table_count = data.get("table_count", 0) + count_text = f"{table_count} tables" if table_count != 1 else "1 table" + lines.append(f' {safe_id}[["({count_text})"]]:::collapsed') + else: + # Regular node + tier = data.get("node_type") + left, right = shape_map.get(tier, ("[", "]")) + cls = tier_class.get(tier, "") + # Strip module prefix from display name if it matches the cluster label + display_name = node + if "." in node: + prefix = node.rsplit(".", 1)[0] + if prefix == label: + display_name = node.rsplit(".", 1)[1] + class_suffix = f":::{cls}" if cls else "" + lines.append(f" {safe_id}{left}{display_name}{right}{class_suffix}") + lines.append(" end") + + lines.append("") + + # Add edges + for src, dest, data in graph.edges(data=True): + safe_src = src.replace(".", "_").replace(" ", "_") + safe_dest = dest.replace(".", "_").replace(" ", "_") + # Solid arrow for primary FK, dotted for non-primary + style = "-->" if data.get("primary") else "-.->" + lines.append(f" {safe_src} {style} {safe_dest}") + + return "\n".join(lines) + def _repr_svg_(self): return self.make_svg()._repr_svg_() @@ -472,24 +999,38 @@ def save(self, filename: str, format: str | None = None) -> None: filename : str Output filename. format : str, optional - File format (``'png'`` or ``'svg'``). Inferred from extension if None. + File format (``'png'``, ``'svg'``, or ``'mermaid'``). + Inferred from extension if None. Raises ------ DataJointError If format is unsupported. + + Notes + ----- + Layout direction is controlled via ``dj.config.display.diagram_direction``. + Tables are grouped by schema, with the Python module name shown as the + group label when available. """ if format is None: if filename.lower().endswith(".png"): format = "png" elif filename.lower().endswith(".svg"): format = "svg" + elif filename.lower().endswith((".mmd", ".mermaid")): + format = "mermaid" + if format is None: + raise DataJointError("Could not infer format from filename. Specify format explicitly.") if format.lower() == "png": with open(filename, "wb") as f: f.write(self.make_png().getbuffer().tobytes()) elif format.lower() == "svg": with open(filename, "w") as f: f.write(self.make_svg().data) + elif format.lower() == "mermaid": + with open(filename, "w") as f: + f.write(self.make_mermaid()) else: raise DataJointError("Unsupported file format") diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index a9a7ddfe7..667479cdd 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -104,9 +104,10 @@ def primary_key(self): _subquery_alias_count = count() # count for alias names used in the FROM clause def from_clause(self): + adapter = self.connection.adapter support = ( ( - "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count) + "({}) as {}".format(src.make_sql(), adapter.quote_identifier(f"${next(self._subquery_alias_count):x}")) if isinstance(src, QueryExpression) else src ) @@ -116,7 +117,8 @@ def from_clause(self): for s, (is_left, using_attrs) in zip(support, self._joins): left_kw = "LEFT " if is_left else "" if using_attrs: - using = "USING ({})".format(", ".join(f"`{a}`" for a in using_attrs)) + quoted_attrs = ", ".join(adapter.quote_identifier(a) for a in using_attrs) + using = f"USING ({quoted_attrs})" clause += f" {left_kw}JOIN {s} {using}" else: # Cross join (no common non-hidden attributes) @@ -134,7 +136,8 @@ def sorting_clauses(self): return "" # Default to KEY ordering if order_by is None (inherit with no existing order) order_by = self._top.order_by if self._top.order_by is not None else ["KEY"] - clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, order_by))) + adapter = self.connection.adapter + clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, order_by), adapter)) if clause: clause = f" ORDER BY {clause}" if self._top.limit is not None: @@ -146,11 +149,19 @@ def make_sql(self, fields=None): """ Make the SQL SELECT statement. - :param fields: used to explicitly set the select attributes + Parameters + ---------- + fields : list, optional + Used to explicitly set the select attributes. + + Returns + ------- + str + The SQL SELECT statement. """ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", - fields=self.heading.as_sql(fields or self.heading.names), + fields=self.heading.as_sql(fields or self.heading.names, adapter=self.connection.adapter), from_=self.from_clause(), where=self.where_clause(), sorting=self.sorting_clauses(), @@ -169,23 +180,17 @@ def restrict(self, restriction, semantic_check=True): """ Produces a new expression with the new restriction applied. - :param restriction: a sequence or an array (treated as OR list), another QueryExpression, - an SQL condition string, or an AndList. - :param semantic_check: If True (default), use semantic matching - only match on - homologous namesakes and error on non-homologous namesakes. - If False, use natural matching on all namesakes (no lineage checking). - :return: A new QueryExpression with the restriction applied. - - rel.restrict(restriction) is equivalent to rel & restriction. - rel.restrict(Not(restriction)) is equivalent to rel - restriction + ``rel.restrict(restriction)`` is equivalent to ``rel & restriction``. + ``rel.restrict(Not(restriction))`` is equivalent to ``rel - restriction``. The primary key of the result is unaffected. - Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b)) + Successive restrictions are combined as logical AND: ``r & a & b`` is equivalent to + ``r & AndList((a, b))``. Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists - (logical disjunction of conditions) + (logical disjunction of conditions). Inverse restriction is accomplished by either using the subtraction operator or the Not class. - The expressions in each row equivalent: + The expressions in each row are equivalent: rel & True rel rel & False the empty entity set @@ -207,14 +212,31 @@ def restrict(self, restriction, semantic_check=True): rel - None rel rel - any_empty_entity_set rel - When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least - one element in arg (hence arg is treated as an OrList). - Conversely, rel - arg restricts rel to elements that do not match any elements in arg. - Two elements match when their common attributes have equal values or when they have no common attributes. - All shared attributes must be in the primary key of either rel or arg or both or an error will be raised. + When arg is another QueryExpression, the restriction ``rel & arg`` restricts rel to elements + that match at least one element in arg (hence arg is treated as an OrList). + Conversely, ``rel - arg`` restricts rel to elements that do not match any elements in arg. + Two elements match when their common attributes have equal values or when they have no + common attributes. + All shared attributes must be in the primary key of either rel or arg or both or an error + will be raised. + + QueryExpression.restrict is the only access point that modifies restrictions. All other + operators must ultimately call restrict(). + + Parameters + ---------- + restriction : QueryExpression, AndList, str, dict, list, or array-like + A sequence or an array (treated as OR list), another QueryExpression, + an SQL condition string, or an AndList. + semantic_check : bool, optional + If True (default), use semantic matching - only match on homologous namesakes + and error on non-homologous namesakes. + If False, use natural matching on all namesakes (no lineage checking). - QueryExpression.restrict is the only access point that modifies restrictions. All other operators must - ultimately call restrict() + Returns + ------- + QueryExpression + A new QueryExpression with the restriction applied. """ attributes = set() if isinstance(restriction, Top): @@ -261,8 +283,13 @@ def restrict_in_place(self, restriction): def __and__(self, restriction): """ Restriction operator e.g. ``q1 & q2``. - :return: a restricted copy of the input argument + See QueryExpression.restrict for more detail. + + Returns + ------- + QueryExpression + A restricted copy of the input argument. """ return self.restrict(restriction) @@ -276,16 +303,26 @@ def __xor__(self, restriction): def __sub__(self, restriction): """ Inverted restriction e.g. ``q1 - q2``. - :return: a restricted copy of the input argument + See QueryExpression.restrict for more detail. + + Returns + ------- + QueryExpression + A restricted copy of the input argument. """ return self.restrict(Not(restriction)) def __neg__(self): """ Convert between restriction and inverted restriction e.g. ``-q1``. - :return: target restriction + See QueryExpression.restrict for more detail. + + Returns + ------- + QueryExpression or Not + The target restriction. """ if isinstance(self, Not): return self.restriction @@ -308,18 +345,28 @@ def join(self, other, semantic_check=True, left=False, allow_nullable_pk=False): """ Create the joined QueryExpression. - :param other: QueryExpression to join with - :param semantic_check: If True (default), use semantic matching - only match on - homologous namesakes (same lineage) and error on non-homologous namesakes. + ``a * b`` is short for ``a.join(b)``. + + Parameters + ---------- + other : QueryExpression + QueryExpression to join with. + semantic_check : bool, optional + If True (default), use semantic matching - only match on homologous namesakes + (same lineage) and error on non-homologous namesakes. If False, use natural join on all namesakes (no lineage checking). - :param left: If True, perform a left join (retain all rows from self) - :param allow_nullable_pk: If True, bypass the left join constraint that requires - self to determine other. When bypassed, the result PK is the union of both - operands' PKs, and PK attributes from the right operand could be NULL. - Used internally by aggregation when exclude_nonmatching=False. - :return: The joined QueryExpression - - a * b is short for a.join(b) + left : bool, optional + If True, perform a left join (retain all rows from self). Default False. + allow_nullable_pk : bool, optional + If True, bypass the left join constraint that requires self to determine other. + When bypassed, the result PK is the union of both operands' PKs, and PK + attributes from the right operand could be NULL. + Used internally by aggregation when exclude_nonmatching=False. Default False. + + Returns + ------- + QueryExpression + The joined QueryExpression. """ # Joining with U is no longer supported if isinstance(other, U): @@ -411,21 +458,36 @@ def extend(self, other, semantic_check=True): extend is closer to projection—it adds new attributes to existing entities without changing which entities are in the result. - Example: - # Session determines Trial (session_id is in Trial's PK) - # But Trial does NOT determine Session (trial_num not in Session) + Examples + -------- + Session determines Trial (session_id is in Trial's PK), but Trial does NOT + determine Session (trial_num not in Session). + + Valid: extend trials with session info:: - # Valid: extend trials with session info Trial.extend(Session) # Adds 'date' from Session to each Trial - # Invalid: Session cannot extend to Trial + Invalid: Session cannot extend to Trial:: + Session.extend(Trial) # Error: trial_num not in Session - :param other: QueryExpression whose attributes will extend self - :param semantic_check: If True (default), require homologous namesakes. + Parameters + ---------- + other : QueryExpression + QueryExpression whose attributes will extend self. + semantic_check : bool, optional + If True (default), require homologous namesakes. If False, match on all namesakes without lineage checking. - :return: Extended QueryExpression with self's PK and combined attributes - :raises DataJointError: If self does not determine other + + Returns + ------- + QueryExpression + Extended QueryExpression with self's PK and combined attributes. + + Raises + ------ + DataJointError + If self does not determine other. """ return self.join(other, semantic_check=semantic_check, left=True) @@ -437,24 +499,40 @@ def proj(self, *attributes, **named_attributes): """ Projection operator. - :param attributes: attributes to be included in the result. (The primary key is already included). - :param named_attributes: new attributes computed or renamed from existing attributes. - :return: the projected expression. Primary key attributes cannot be excluded but may be renamed. - If the attribute list contains an Ellipsis ..., then all secondary attributes are included too - Prefixing an attribute name with a dash '-attr' removes the attribute from the list if present. - Keyword arguments can be used to rename attributes as in name='attr', duplicate them as in name='(attr)', or - self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self) - self.proj() -- include only primary key - self.proj('attr1', 'attr2') -- include primary key and attributes attr1 and attr2 - self.proj(..., '-attr1', '-attr2') -- include all attributes except attr1 and attr2 - self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1 - self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup' - self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax) - from other attributes available before the projection. + If the attribute list contains an Ellipsis ``...``, then all secondary attributes + are included too. + Prefixing an attribute name with a dash ``-attr`` removes the attribute from the list + if present. + Keyword arguments can be used to rename attributes as in ``name='attr'``, duplicate + them as in ``name='(attr)'``, or compute new attributes. + + - ``self.proj(...)`` or ``self.proj(Ellipsis)`` -- include all attributes (return self) + - ``self.proj()`` -- include only primary key + - ``self.proj('attr1', 'attr2')`` -- include primary key and attributes attr1 and attr2 + - ``self.proj(..., '-attr1', '-attr2')`` -- include all attributes except attr1 and attr2 + - ``self.proj(name1='attr1')`` -- include primary key and 'attr1' renamed as name1 + - ``self.proj('attr1', dup='(attr1)')`` -- include primary key and attr1 twice, with + the duplicate 'dup' + - ``self.proj(k='abs(attr1)')`` adds the new attribute k with the value computed as an + expression (SQL syntax) from other attributes available before the projection. + Each attribute name can only be used once. + + Parameters + ---------- + *attributes : str + Attributes to be included in the result. The primary key is already included. + **named_attributes : str + New attributes computed or renamed from existing attributes. + + Returns + ------- + QueryExpression + The projected expression. """ - named_attributes = {k: translate_attribute(v)[1] for k, v in named_attributes.items()} + adapter = self.connection.adapter if hasattr(self, "connection") and self.connection else None + named_attributes = {k: translate_attribute(v, adapter)[1] for k, v in named_attributes.items()} # new attributes in parentheses are included again with the new name without removing original duplication_pattern = re.compile(rf"^\s*\(\s*(?!{'|'.join(CONSTANT_LITERALS)})(?P[a-zA-Z_]\w*)\s*\)\s*$") # attributes without parentheses renamed @@ -552,21 +630,34 @@ def aggr(self, group, *attributes, exclude_nonmatching=False, **named_attributes """ Aggregation/grouping operation, similar to proj but with computations over a grouped relation. - By default, keeps all rows from self (like proj). Use exclude_nonmatching=True to + By default, keeps all rows from self (like proj). Use ``exclude_nonmatching=True`` to keep only rows that have matches in group. - :param group: The query expression to be aggregated. - :param exclude_nonmatching: If True, exclude rows from self that have no matching - entries in group (INNER JOIN). Default False keeps all rows (LEFT JOIN). - :param named_attributes: computations of the form new_attribute="sql expression on attributes of group" - :return: The derived query expression + Parameters + ---------- + group : QueryExpression + The query expression to be aggregated. + *attributes : str + Attributes to include in the result. + exclude_nonmatching : bool, optional + If True, exclude rows from self that have no matching entries in group + (INNER JOIN). Default False keeps all rows (LEFT JOIN). + **named_attributes : str + Computations of the form ``new_attribute="sql expression on attributes of group"``. + + Returns + ------- + QueryExpression + The derived query expression. - Example:: + Examples + -------- + Count sessions per subject (keeps all subjects, even those with 0 sessions):: - # Count sessions per subject (keeps all subjects, even those with 0 sessions) Subject.aggr(Session, n="count(*)") - # Count sessions per subject (only subjects with at least one session) + Count sessions per subject (only subjects with at least one session):: + Subject.aggr(Session, n="count(*)", exclude_nonmatching=True) """ if Ellipsis in attributes: @@ -661,12 +752,26 @@ def fetch1(self, *attrs, squeeze=False): If no attributes are specified, returns the result as a dict. If attributes are specified, returns the corresponding values as a tuple. - :param attrs: attribute names to fetch (if empty, fetch all as dict) - :param squeeze: if True, remove extra dimensions from arrays - :return: dict (no attrs) or tuple/value (with attrs) - :raises DataJointError: if not exactly one row in result + Parameters + ---------- + *attrs : str + Attribute names to fetch. If empty, fetch all as dict. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + dict or tuple or value + Dict (no attrs) or tuple/value (with attrs). + + Raises + ------ + DataJointError + If not exactly one row in result. - Examples:: + Examples + -------- + :: d = table.fetch1() # returns dict with all attributes a, b = table.fetch1('a', 'b') # returns tuple of attribute values @@ -730,17 +835,27 @@ def to_dicts(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a list of dictionaries. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: list of dictionaries, one per row - For object storage types (attachments, filepaths), files are downloaded - to config["download_path"]. Use config.override() to change:: + to ``config["download_path"]``. Use ``config.override()`` to change:: with dj.config.override(download_path="/data"): data = table.to_dicts() + + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + list[dict] + List of dictionaries, one per row. """ expr = self._apply_top(order_by, limit, offset) cursor = expr.cursor(as_dict=True) @@ -751,11 +866,21 @@ def to_pandas(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a pandas DataFrame with primary key as index. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: pandas DataFrame with primary key columns as index + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + pandas.DataFrame + DataFrame with primary key columns as index. """ dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) df = pandas.DataFrame(dicts) @@ -767,13 +892,23 @@ def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a polars DataFrame. - Requires polars: pip install datajoint[polars] + Requires polars: ``pip install datajoint[polars]`` + + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: polars DataFrame + Returns + ------- + polars.DataFrame + Polars DataFrame. """ try: import polars @@ -786,13 +921,23 @@ def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a PyArrow Table. - Requires pyarrow: pip install datajoint[arrow] + Requires pyarrow: ``pip install datajoint[arrow]`` + + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: pyarrow Table + Returns + ------- + pyarrow.Table + PyArrow Table. """ try: import pyarrow @@ -810,24 +955,39 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset If no attrs specified, returns a numpy structured array (recarray) of all columns. If attrs specified, returns a tuple of numpy arrays (one per attribute). - :param attrs: attribute names to fetch (if empty, fetch all) - :param include_key: if True and attrs specified, prepend primary keys as list of dicts - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: numpy recarray (no attrs) or tuple of arrays (with attrs). - With include_key=True: (keys, *arrays) where keys is list[dict] + Parameters + ---------- + *attrs : str + Attribute names to fetch. If empty, fetch all. + include_key : bool, optional + If True and attrs specified, prepend primary keys as list of dicts. Default False. + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. - Examples:: + Returns + ------- + np.recarray or tuple of np.ndarray + Numpy recarray (no attrs) or tuple of arrays (with attrs). + With ``include_key=True``: ``(keys, *arrays)`` where keys is ``list[dict]``. + + Examples + -------- + Fetch as structured array:: - # Fetch as structured array data = table.to_arrays() - # Fetch specific columns as separate arrays + Fetch specific columns as separate arrays:: + a, b = table.to_arrays('a', 'b') - # Fetch with primary keys for later restrictions + Fetch with primary keys for later restrictions:: + keys, a, b = table.to_arrays('a', 'b', include_key=True) # keys = [{'id': 1}, {'id': 2}, ...] # same format as table.keys() """ @@ -897,10 +1057,19 @@ def keys(self, order_by=None, limit=None, offset=None): """ Fetch primary key values as a list of dictionaries. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :return: list of dictionaries containing only primary key columns + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + + Returns + ------- + list[dict] + List of dictionaries containing only primary key columns. """ return self.proj().to_dicts(order_by=order_by, limit=limit, offset=offset) @@ -908,8 +1077,15 @@ def head(self, limit=25): """ Preview the first few entries from query expression. - :param limit: number of entries (default 25) - :return: list of dictionaries + Parameters + ---------- + limit : int, optional + Number of entries. Default 25. + + Returns + ------- + list[dict] + List of dictionaries. """ return self.to_dicts(order_by="KEY", limit=limit) @@ -917,33 +1093,57 @@ def tail(self, limit=25): """ Preview the last few entries from query expression. - :param limit: number of entries (default 25) - :return: list of dictionaries + Parameters + ---------- + limit : int, optional + Number of entries. Default 25. + + Returns + ------- + list[dict] + List of dictionaries. """ return list(reversed(self.to_dicts(order_by="KEY DESC", limit=limit))) def __len__(self): - """:return: number of elements in the result set e.g. ``len(q1)``.""" + """ + Return number of elements in the result set e.g. ``len(q1)``. + + Returns + ------- + int + Number of elements in the result set. + """ result = self.make_subquery() if self._top else copy.copy(self) has_left_join = any(is_left for is_left, _ in result._joins) - return result.connection.query( - "SELECT {select_} FROM {from_}{where}".format( - select_=( - "count(*)" - if has_left_join - else "count(DISTINCT {fields})".format( - fields=result.heading.as_sql(result.primary_key, include_aliases=False) - ) - ), - from_=result.from_clause(), - where=result.where_clause(), + + # Build COUNT query - PostgreSQL requires different syntax for multi-column DISTINCT + adapter = result.connection.adapter + if has_left_join or len(result.primary_key) > 1: + # Use subquery with DISTINCT for multi-column primary keys (backend-agnostic) + fields = result.heading.as_sql(result.primary_key, include_aliases=False, adapter=adapter) + query = ( + f"SELECT count(*) FROM (" + f"SELECT DISTINCT {fields} FROM {result.from_clause()}{result.where_clause()}" + f") AS distinct_count" ) - ).fetchone()[0] + else: + # Single column - can use count(DISTINCT col) directly + fields = result.heading.as_sql(result.primary_key, include_aliases=False, adapter=adapter) + query = f"SELECT count(DISTINCT {fields}) FROM {result.from_clause()}{result.where_clause()}" + + return result.connection.query(query).fetchone()[0] def __bool__(self): """ - :return: True if the result is not empty. Equivalent to len(self) > 0 but often - faster e.g. ``bool(q1)``. + Check if the result is not empty. + + Equivalent to ``len(self) > 0`` but often faster e.g. ``bool(q1)``. + + Returns + ------- + bool + True if the result is not empty. """ return bool( self.connection.query( @@ -953,12 +1153,20 @@ def __bool__(self): def __contains__(self, item): """ - returns True if the restriction in item matches any entries in self - e.g. ``restriction in q1``. + Check if the restriction in item matches any entries in self. - :param item: any restriction - (item in query_expression) is equivalent to bool(query_expression & item) but may be - executed more efficiently. + ``(item in query_expression)`` is equivalent to ``bool(query_expression & item)`` + but may be executed more efficiently. + + Parameters + ---------- + item : any + Any restriction. + + Returns + ------- + bool + True if the restriction matches any entries e.g. ``restriction in q1``. """ return bool(self & item) # May be optimized e.g. using an EXISTS query @@ -980,8 +1188,15 @@ def cursor(self, as_dict=False): """ Execute the query and return a database cursor. - :param as_dict: if True, rows are returned as dictionaries - :return: database query cursor + Parameters + ---------- + as_dict : bool, optional + If True, rows are returned as dictionaries. Default False. + + Returns + ------- + cursor + Database query cursor. """ sql = self.make_sql() logger.debug(sql) @@ -989,20 +1204,42 @@ def cursor(self, as_dict=False): def __repr__(self): """ - returns the string representation of a QueryExpression object e.g. ``str(q1)``. + Return the string representation of a QueryExpression object e.g. ``str(q1)``. - :param self: A query expression - :type self: :class:`QueryExpression` - :rtype: str + Returns + ------- + str + String representation of the QueryExpression. """ return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview() def preview(self, limit=None, width=None): - """:return: a string of preview of the contents of the query.""" + """ + Return a string preview of the contents of the query. + + Parameters + ---------- + limit : int, optional + Maximum number of rows to preview. + width : int, optional + Maximum width of the preview output. + + Returns + ------- + str + A string preview of the contents of the query. + """ return preview(self, limit, width) def _repr_html_(self): - """:return: HTML to display table in Jupyter notebook.""" + """ + Return HTML to display table in Jupyter notebook. + + Returns + ------- + str + HTML to display table in Jupyter notebook. + """ return repr_html(self) @@ -1024,9 +1261,19 @@ def create(cls, groupby, group, keep_all_rows=False): """ Create an aggregation expression. - :param groupby: The expression to GROUP BY (determines the result's primary key) - :param group: The expression to aggregate over - :param keep_all_rows: If True, use left join to keep all rows from groupby + Parameters + ---------- + groupby : QueryExpression + The expression to GROUP BY (determines the result's primary key). + group : QueryExpression + The expression to aggregate over. + keep_all_rows : bool, optional + If True, use left join to keep all rows from groupby. Default False. + + Returns + ------- + Aggregation + The aggregation expression. """ if inspect.isclass(group) and issubclass(group, QueryExpression): group = group() # instantiate if a class @@ -1062,31 +1309,49 @@ def where_clause(self): return "" if not self._left_restrict else " WHERE (%s)" % ")AND(".join(str(s) for s in self._left_restrict) def make_sql(self, fields=None): - fields = self.heading.as_sql(fields or self.heading.names) + adapter = self.connection.adapter + fields = self.heading.as_sql(fields or self.heading.names, adapter=adapter) assert self._grouping_attributes or not self.restriction distinct = set(self.heading.names) == set(self.primary_key) - return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format( - distinct="DISTINCT " if distinct else "", - fields=fields, - from_=self.from_clause(), - where=self.where_clause(), - group_by=( - "" - if not self.primary_key - else ( - " GROUP BY `%s`" % "`,`".join(self._grouping_attributes) - + ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction)) - ) - ), - sorting=self.sorting_clauses(), - ) - def __len__(self): - return self.connection.query( - "SELECT count(1) FROM ({subquery}) `${alias:x}`".format( - subquery=self.make_sql(), alias=next(self._subquery_alias_count) + # PostgreSQL doesn't allow column aliases in HAVING clause (SQL standard). + # For PostgreSQL with restrictions, wrap aggregation in subquery and use WHERE. + use_subquery_for_having = adapter.backend == "postgresql" and self.restriction and self._grouping_attributes + + if use_subquery_for_having: + # Generate inner query without HAVING + inner_sql = "SELECT {distinct}{fields} FROM {from_}{where}{group_by}".format( + distinct="DISTINCT " if distinct else "", + fields=fields, + from_=self.from_clause(), + where=self.where_clause(), + group_by=" GROUP BY {}".format(", ".join(adapter.quote_identifier(col) for col in self._grouping_attributes)), + ) + # Wrap in subquery with WHERE for the HAVING conditions + subquery_alias = adapter.quote_identifier(f"_aggr{next(self._subquery_alias_count)}") + outer_where = " WHERE (%s)" % ")AND(".join(self.restriction) + return f"SELECT * FROM ({inner_sql}) AS {subquery_alias}{outer_where}{self.sorting_clauses()}" + else: + # MySQL path: use HAVING directly + return "SELECT {distinct}{fields} FROM {from_}{where}{group_by}{sorting}".format( + distinct="DISTINCT " if distinct else "", + fields=fields, + from_=self.from_clause(), + where=self.where_clause(), + group_by=( + "" + if not self.primary_key + else ( + " GROUP BY {}".format(", ".join(adapter.quote_identifier(col) for col in self._grouping_attributes)) + + ("" if not self.restriction else " HAVING (%s)" % ")AND(".join(self.restriction)) + ) + ), + sorting=self.sorting_clauses(), ) - ).fetchone()[0] + + def __len__(self): + alias = self.connection.adapter.quote_identifier(f"${next(self._subquery_alias_count):x}") + return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0] def __bool__(self): return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0]) @@ -1122,12 +1387,11 @@ def make_sql(self): if not arg1.heading.secondary_attributes and not arg2.heading.secondary_attributes: # no secondary attributes: use UNION DISTINCT fields = arg1.primary_key - return "SELECT * FROM (({sql1}) UNION ({sql2})) as `_u{alias}{sorting}`".format( - sql1=(arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields)), - sql2=(arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields)), - alias=next(self.__count), - sorting=self.sorting_clauses(), - ) + alias_name = f"_u{next(self.__count)}{self.sorting_clauses()}" + alias_quoted = self.connection.adapter.quote_identifier(alias_name) + sql1 = arg1.make_sql() if isinstance(arg1, Union) else arg1.make_sql(fields) + sql2 = arg2.make_sql() if isinstance(arg2, Union) else arg2.make_sql(fields) + return f"SELECT * FROM (({sql1}) UNION ({sql2})) as {alias_quoted}" # with secondary attributes, use union of left join with anti-restriction fields = self.heading.names sql1 = arg1.join(arg2, left=True).make_sql(fields) @@ -1143,12 +1407,8 @@ def where_clause(self): raise NotImplementedError("Union does not use a WHERE clause") def __len__(self): - return self.connection.query( - "SELECT count(1) FROM ({subquery}) `${alias:x}`".format( - subquery=self.make_sql(), - alias=next(QueryExpression._subquery_alias_count), - ) - ).fetchone()[0] + alias = self.connection.adapter.quote_identifier(f"${next(QueryExpression._subquery_alias_count):x}") + return self.connection.query(f"SELECT count(1) FROM ({self.make_sql()}) {alias}").fetchone()[0] def __bool__(self): return bool(self.connection.query("SELECT EXISTS({sql})".format(sql=self.make_sql())).fetchone()[0]) @@ -1239,14 +1499,24 @@ def __sub__(self, other): def aggr(self, group, **named_attributes): """ - Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression") - has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`. + Aggregation of the type ``U('attr1','attr2').aggr(group, computation="QueryExpression")``. + + Has the primary key ``('attr1','attr2')`` and performs aggregation computations for all + matching elements of ``group``. - Note: exclude_nonmatching is always True for dj.U (cannot keep all rows from infinite set). + Note: ``exclude_nonmatching`` is always True for dj.U (cannot keep all rows from infinite set). - :param group: The query expression to be aggregated. - :param named_attributes: computations of the form new_attribute="sql expression on attributes of group" - :return: The derived query expression + Parameters + ---------- + group : QueryExpression + The query expression to be aggregated. + **named_attributes : str + Computations of the form ``new_attribute="sql expression on attributes of group"``. + + Returns + ------- + QueryExpression + The derived query expression. """ if named_attributes.pop("exclude_nonmatching", True) is False: raise DataJointError("Cannot set exclude_nonmatching=False when aggregating on a universal set.") @@ -1277,9 +1547,19 @@ def aggr(self, group, **named_attributes): def _flatten_attribute_list(primary_key, attrs): """ - :param primary_key: list of attributes in primary key - :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" - :return: generator of attributes where "KEY" is replaced with its component attributes + Flatten an attribute list, replacing "KEY" with primary key attributes. + + Parameters + ---------- + primary_key : list + List of attributes in primary key. + attrs : list + List of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC". + + Yields + ------ + str + Attributes where "KEY" is replaced with its component attributes. """ for a in attrs: if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): @@ -1292,6 +1572,14 @@ def _flatten_attribute_list(primary_key, attrs): yield a -def _wrap_attributes(attr): - for entry in attr: # wrap attribute names in backquotes - yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE) +def _wrap_attributes(attr, adapter): + """Wrap attribute names with database-specific quotes.""" + for entry in attr: + # Replace word boundaries (not 'asc' or 'desc') with quoted version + def quote_match(match): + word = match.group(1) + if word.lower() not in ("asc", "desc"): + return adapter.quote_identifier(word) + return word + + yield re.sub(r"\b((?!asc|desc)\w+)\b", quote_match, entry, flags=re.IGNORECASE) diff --git a/src/datajoint/heading.py b/src/datajoint/heading.py index c8486021a..4d7f0c62a 100644 --- a/src/datajoint/heading.py +++ b/src/datajoint/heading.py @@ -133,7 +133,7 @@ def sql_comment(self) -> str: Comment with optional ``:uuid:`` prefix. """ # UUID info is stored in the comment for reconstruction - return (":uuid:" if self.uuid else "") + self.comment + return (":uuid:" if self.uuid else "") + (self.comment or "") @property def sql(self) -> str: @@ -164,8 +164,9 @@ def original_name(self) -> str: """ if self.attribute_expression is None: return self.name - assert self.attribute_expression.startswith("`") - return self.attribute_expression.strip("`") + # Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ") + assert self.attribute_expression.startswith(("`", '"')) + return self.attribute_expression.strip('`"') class Heading: @@ -290,7 +291,9 @@ def __repr__(self) -> str: in_key = True ret = "" if self._table_status is not None: - ret += "# " + self.table_status["comment"] + "\n" + comment = self.table_status.get("comment", "") + if comment: + ret += "# " + comment + "\n" for v in self.attributes.values(): if in_key and not v.in_key: ret += "---\n" @@ -319,7 +322,7 @@ def as_dtype(self) -> np.dtype: """ return np.dtype(dict(names=self.names, formats=[v.dtype for v in self.attributes.values()])) - def as_sql(self, fields: list[str], include_aliases: bool = True) -> str: + def as_sql(self, fields: list[str], include_aliases: bool = True, adapter=None) -> str: """ Generate SQL SELECT clause for specified fields. @@ -329,20 +332,37 @@ def as_sql(self, fields: list[str], include_aliases: bool = True) -> str: Attribute names to include. include_aliases : bool, optional Include AS clauses for computed attributes. Default True. + adapter : DatabaseAdapter, optional + Database adapter for identifier quoting. If not provided, attempts + to get from table_info connection. Returns ------- str Comma-separated SQL field list. """ - return ",".join( - ( - "`%s`" % name - if self.attributes[name].attribute_expression is None - else self.attributes[name].attribute_expression + (" as `%s`" % name if include_aliases else "") - ) - for name in fields - ) + # Get adapter for proper identifier quoting + if adapter is None and self.table_info and "conn" in self.table_info and self.table_info["conn"]: + adapter = self.table_info["conn"].adapter + + def quote(name): + # Use adapter if available, otherwise use ANSI SQL double quotes (not backticks) + return adapter.quote_identifier(name) if adapter else f'"{name}"' + + def render_field(name): + attr = self.attributes[name] + if attr.attribute_expression is None: + return quote(name) + else: + # Translate expression for backend compatibility (e.g., GROUP_CONCAT ↔ STRING_AGG) + expr = attr.attribute_expression + if adapter: + expr = adapter.translate_expression(expr) + if include_aliases: + return f"{expr} as {quote(name)}" + return expr + + return ",".join(render_field(name) for name in fields) def __iter__(self): return iter(self.attributes) @@ -350,38 +370,42 @@ def __iter__(self): def _init_from_database(self) -> None: """Initialize heading from an existing database table.""" conn, database, table_name, context = (self.table_info[k] for k in ("conn", "database", "table_name", "context")) + adapter = conn.adapter + + # Get table metadata info = conn.query( - 'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format(table_name=table_name, database=database), + adapter.get_table_info_sql(database, table_name), as_dict=True, ).fetchone() if info is None: - raise DataJointError( - "The table `{database}`.`{table_name}` is not defined.".format(table_name=table_name, database=database) - ) + raise DataJointError(f"The table {database}.{table_name} is not defined.") + # Normalize table_comment to comment for backward compatibility self._table_status = {k.lower(): v for k, v in info.items()} + if "table_comment" in self._table_status: + self._table_status["comment"] = self._table_status["table_comment"] + + # Get column information cur = conn.query( - "SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`".format(table_name=table_name, database=database), + adapter.get_columns_sql(database, table_name), as_dict=True, ) - attributes = cur.fetchall() - - rename_map = { - "Field": "name", - "Type": "type", - "Null": "nullable", - "Default": "default", - "Key": "in_key", - "Comment": "comment", - } + # Parse columns using adapter-specific parser + raw_attributes = cur.fetchall() + attributes = [adapter.parse_column_info(row) for row in raw_attributes] - fields_to_drop = ("Privileges", "Collation") + # Get primary key information and mark primary key columns + pk_query = conn.query( + adapter.get_primary_key_sql(database, table_name), + as_dict=True, + ) + pk_columns = {row["column_name"] for row in pk_query.fetchall()} + for attr in attributes: + if attr["name"] in pk_columns: + attr["key"] = "PRI" - # rename and drop attributes - attributes = [ - {rename_map[k] if k in rename_map else k: v for k, v in x.items() if k not in fields_to_drop} for x in attributes - ] numeric_types = { + # MySQL types ("float", False): np.float64, ("float", True): np.float64, ("double", False): np.float64, @@ -396,6 +420,13 @@ def _init_from_database(self) -> None: ("int", True): np.int64, ("bigint", False): np.int64, ("bigint", True): np.uint64, + # PostgreSQL types + ("integer", False): np.int64, + ("integer", True): np.int64, + ("real", False): np.float64, + ("real", True): np.float64, + ("double precision", False): np.float64, + ("double precision", True): np.float64, } sql_literals = ["CURRENT_TIMESTAMP"] @@ -403,9 +434,9 @@ def _init_from_database(self) -> None: # additional attribute properties for attr in attributes: attr.update( - in_key=(attr["in_key"] == "PRI"), - nullable=attr["nullable"] == "YES", - autoincrement=bool(re.search(r"auto_increment", attr["Extra"], flags=re.I)), + in_key=(attr["key"] == "PRI"), + nullable=attr["nullable"], # Already boolean from parse_column_info + autoincrement=bool(re.search(r"auto_increment", attr["extra"], flags=re.I)), numeric=any(TYPE_PATTERN[t].match(attr["type"]) for t in ("DECIMAL", "INTEGER", "FLOAT")), string=any(TYPE_PATTERN[t].match(attr["type"]) for t in ("ENUM", "TEMPORAL", "STRING")), is_blob=any(TYPE_PATTERN[t].match(attr["type"]) for t in ("BYTES", "NATIVE_BLOB")), @@ -421,10 +452,12 @@ def _init_from_database(self) -> None: if any(TYPE_PATTERN[t].match(attr["type"]) for t in ("INTEGER", "FLOAT")): attr["type"] = re.sub(r"\(\d+\)", "", attr["type"], count=1) # strip size off integers and floats attr["unsupported"] = not any((attr["is_blob"], attr["numeric"], attr["numeric"])) - attr.pop("Extra") + attr.pop("extra") + attr.pop("key") # process custom DataJoint types stored in comment - special = re.match(r":(?P[^:]+):(?P.*)", attr["comment"]) + comment = attr["comment"] or "" # Handle None for PostgreSQL + special = re.match(r":(?P[^:]+):(?P.*)", comment) if special: special = special.groupdict() attr["comment"] = special["comment"] # Always update the comment @@ -519,35 +552,59 @@ def _init_from_database(self) -> None: # Read and tabulate secondary indexes keys = defaultdict(dict) for item in conn.query( - "SHOW KEYS FROM `{db}`.`{tab}`".format(db=database, tab=table_name), + adapter.get_indexes_sql(database, table_name), as_dict=True, ): - if item["Key_name"] != "PRIMARY": - keys[item["Key_name"]][item["Seq_in_index"]] = dict( - column=item["Column_name"] or f"({item['Expression']})".replace(r"\'", "'"), - unique=(item["Non_unique"] == 0), - nullable=item["Null"].lower() == "yes", - ) + # Note: adapter.get_indexes_sql() already filters out PRIMARY key + # MySQL/PostgreSQL adapters return: index_name, column_name, non_unique + index_name = item.get("index_name") or item.get("Key_name") + seq = item.get("seq_in_index") or item.get("Seq_in_index") or len(keys[index_name]) + 1 + column = item.get("column_name") or item.get("Column_name") + # MySQL EXPRESSION column stores escaped single quotes - unescape them + if column: + column = column.replace("\\'", "'") + non_unique = item.get("non_unique") or item.get("Non_unique") + nullable = item.get("nullable") or (item.get("Null", "NO").lower() == "yes") + + keys[index_name][seq] = dict( + column=column, + unique=(non_unique == 0 or not non_unique), + nullable=nullable, + ) self.indexes = { - tuple(item[k]["column"] for k in sorted(item.keys())): dict( + tuple(item[k]["column"] for k in sorted(item.keys()) if item[k]["column"] is not None): dict( unique=item[1]["unique"], nullable=any(v["nullable"] for v in item.values()), ) for item in keys.values() + if any(item[k]["column"] is not None for k in item.keys()) } def select(self, select_list, rename_map=None, compute_map=None): """ - derive a new heading by selecting, renaming, or computing attributes. - In relational algebra these operators are known as project, rename, and extend. + Derive a new heading by selecting, renaming, or computing attributes. - :param select_list: the full list of existing attributes to include - :param rename_map: dictionary of renamed attributes: keys=new names, values=old names - :param compute_map: a direction of computed attributes + In relational algebra these operators are known as project, rename, and extend. This low-level method performs no error checking. + + Parameters + ---------- + select_list : list + The full list of existing attributes to include. + rename_map : dict, optional + Dictionary of renamed attributes: keys=new names, values=old names. + compute_map : dict, optional + A dictionary of computed attributes. + + Returns + ------- + Heading + New heading with selected, renamed, and computed attributes. """ rename_map = rename_map or {} compute_map = compute_map or {} + # Get adapter for proper identifier quoting + adapter = self.table_info["conn"].adapter if self.table_info else None copy_attrs = list() for name in self.attributes: if name in select_list: @@ -557,7 +614,7 @@ def select(self, select_list, rename_map=None, compute_map=None): dict( self.attributes[old_name].todict(), name=new_name, - attribute_expression="`%s`" % old_name, + attribute_expression=(adapter.quote_identifier(old_name) if adapter else f"`{old_name}`"), ) for new_name, old_name in rename_map.items() if old_name == name @@ -567,7 +624,10 @@ def select(self, select_list, rename_map=None, compute_map=None): dict(default_attribute_properties, name=new_name, attribute_expression=expr) for new_name, expr in compute_map.items() ) - return Heading(chain(copy_attrs, compute_attrs), lineage_available=self._lineage_available) + # Inherit table_info so the new heading has access to the adapter + new_heading = Heading(chain(copy_attrs, compute_attrs), lineage_available=self._lineage_available) + new_heading.table_info = self.table_info + return new_heading def _join_dependent(self, dependent): """Build attribute list when self → dependent: PK = PK(self), self's attrs first.""" @@ -582,16 +642,27 @@ def join(self, other, nullable_pk=False): Join two headings into a new one. The primary key of the result depends on functional dependencies: - - A → B: PK = PK(A), A's attributes first - - B → A (not A → B): PK = PK(B), B's attributes first - - Both: PK = PK(A), left operand takes precedence - - Neither: PK = PK(A) ∪ PK(B), A's PK first then B's new PK attrs - :param nullable_pk: If True, skip PK optimization and use combined PK from both - operands. Used for left joins that bypass the A → B constraint, where the - right operand's PK attributes could be NULL. + - A -> B: PK = PK(A), A's attributes first + - B -> A (not A -> B): PK = PK(B), B's attributes first + - Both: PK = PK(A), left operand takes precedence + - Neither: PK = PK(A) | PK(B), A's PK first then B's new PK attrs It assumes that self and other are headings that share no common dependent attributes. + + Parameters + ---------- + other : Heading + The other heading to join with. + nullable_pk : bool, optional + If True, skip PK optimization and use combined PK from both + operands. Used for left joins that bypass the A -> B constraint, where the + right operand's PK attributes could be NULL. Default False. + + Returns + ------- + Heading + New heading resulting from the join. """ if nullable_pk: a_determines_b = b_determines_a = False diff --git a/src/datajoint/jobs.py b/src/datajoint/jobs.py index 97cfafb15..4deff4804 100644 --- a/src/datajoint/jobs.py +++ b/src/datajoint/jobs.py @@ -250,7 +250,7 @@ def pending(self) -> "Job": Job Restricted query with ``status='pending'``. """ - return self & 'status="pending"' + return self & "status='pending'" @property def reserved(self) -> "Job": @@ -262,7 +262,7 @@ def reserved(self) -> "Job": Job Restricted query with ``status='reserved'``. """ - return self & 'status="reserved"' + return self & "status='reserved'" @property def errors(self) -> "Job": @@ -274,7 +274,7 @@ def errors(self) -> "Job": Job Restricted query with ``status='error'``. """ - return self & 'status="error"' + return self & "status='error'" @property def ignored(self) -> "Job": @@ -286,7 +286,7 @@ def ignored(self) -> "Job": Job Restricted query with ``status='ignore'``. """ - return self & 'status="ignore"' + return self & "status='ignore'" @property def completed(self) -> "Job": @@ -298,7 +298,7 @@ def completed(self) -> "Job": Job Restricted query with ``status='success'``. """ - return self & 'status="success"' + return self & "status='success'" # ------------------------------------------------------------------------- # Core job management methods @@ -377,7 +377,8 @@ def refresh( if new_key_list: # Use server time for scheduling (CURRENT_TIMESTAMP(3) matches datetime(3) precision) - scheduled_time = self.connection.query(f"SELECT CURRENT_TIMESTAMP(3) + INTERVAL {delay} SECOND").fetchone()[0] + interval_expr = self.adapter.interval_expr(delay, "second") + scheduled_time = self.connection.query(f"SELECT CURRENT_TIMESTAMP(3) + {interval_expr}").fetchone()[0] for key in new_key_list: job_entry = { @@ -405,7 +406,8 @@ def refresh( # 3. Remove stale jobs (not ignore status) - use server CURRENT_TIMESTAMP for consistent timing if stale_timeout > 0: - old_jobs = self & f"created_time < CURRENT_TIMESTAMP - INTERVAL {stale_timeout} SECOND" & 'status != "ignore"' + stale_interval = self.adapter.interval_expr(stale_timeout, "second") + old_jobs = self & f"created_time < CURRENT_TIMESTAMP - {stale_interval}" & "status != 'ignore'" for key in old_jobs.keys(): # Check if key still in key_source @@ -415,7 +417,8 @@ def refresh( # 4. Handle orphaned reserved jobs - use server CURRENT_TIMESTAMP for consistent timing if orphan_timeout is not None and orphan_timeout > 0: - orphaned_jobs = self.reserved & f"reserved_time < CURRENT_TIMESTAMP - INTERVAL {orphan_timeout} SECOND" + orphan_interval = self.adapter.interval_expr(orphan_timeout, "second") + orphaned_jobs = self.reserved & f"reserved_time < CURRENT_TIMESTAMP - {orphan_interval}" for key in orphaned_jobs.keys(): (self & key).delete_quick() @@ -442,7 +445,7 @@ def reserve(self, key: dict) -> bool: True if reservation successful, False if job not available. """ # Check if job is pending and scheduled (use CURRENT_TIMESTAMP(3) for datetime(3) precision) - job = (self & key & 'status="pending"' & "scheduled_time <= CURRENT_TIMESTAMP(3)").to_dicts() + job = (self & key & "status='pending'" & "scheduled_time <= CURRENT_TIMESTAMP(3)").to_dicts() if not job: return False diff --git a/src/datajoint/lineage.py b/src/datajoint/lineage.py index d40ed8dd8..bb911a876 100644 --- a/src/datajoint/lineage.py +++ b/src/datajoint/lineage.py @@ -38,17 +38,30 @@ def ensure_lineage_table(connection, database): database : str The schema/database name. """ - connection.query( - """ - CREATE TABLE IF NOT EXISTS `{database}`.`~lineage` ( - table_name VARCHAR(64) NOT NULL COMMENT 'table name within the schema', - attribute_name VARCHAR(64) NOT NULL COMMENT 'attribute name', - lineage VARCHAR(255) NOT NULL COMMENT 'origin: schema.table.attribute', - PRIMARY KEY (table_name, attribute_name) - ) ENGINE=InnoDB - """.format(database=database) + adapter = connection.adapter + + # Build fully qualified table name + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + + # Build column definitions using adapter + columns = [ + adapter.format_column_definition("table_name", "VARCHAR(64)", nullable=False, comment="table name within the schema"), + adapter.format_column_definition("attribute_name", "VARCHAR(64)", nullable=False, comment="attribute name"), + adapter.format_column_definition("lineage", "VARCHAR(255)", nullable=False, comment="origin: schema.table.attribute"), + ] + + # Build PRIMARY KEY using adapter + pk_cols = adapter.quote_identifier("table_name") + ", " + adapter.quote_identifier("attribute_name") + pk_clause = f"PRIMARY KEY ({pk_cols})" + + sql = ( + f"CREATE TABLE IF NOT EXISTS {lineage_table} (\n" + + ",\n".join(columns + [pk_clause]) + + f"\n) {adapter.table_options_clause()}" ) + connection.query(sql) + def lineage_table_exists(connection, database): """ @@ -99,11 +112,14 @@ def get_lineage(connection, database, table_name, attribute_name): if not lineage_table_exists(connection, database): return None + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + result = connection.query( - """ - SELECT lineage FROM `{database}`.`~lineage` + f""" + SELECT lineage FROM {lineage_table} WHERE table_name = %s AND attribute_name = %s - """.format(database=database), + """, args=(table_name, attribute_name), ).fetchone() return result[0] if result else None @@ -130,11 +146,14 @@ def get_table_lineages(connection, database, table_name): if not lineage_table_exists(connection, database): return {} + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + results = connection.query( - """ - SELECT attribute_name, lineage FROM `{database}`.`~lineage` + f""" + SELECT attribute_name, lineage FROM {lineage_table} WHERE table_name = %s - """.format(database=database), + """, args=(table_name,), ).fetchall() return {row[0]: row[1] for row in results} @@ -159,10 +178,13 @@ def get_schema_lineages(connection, database): if not lineage_table_exists(connection, database): return {} + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + results = connection.query( - """ - SELECT table_name, attribute_name, lineage FROM `{database}`.`~lineage` - """.format(database=database), + f""" + SELECT table_name, attribute_name, lineage FROM {lineage_table} + """, ).fetchall() return {f"{database}.{table}.{attr}": lineage for table, attr, lineage in results} @@ -184,18 +206,25 @@ def insert_lineages(connection, database, entries): if not entries: return ensure_lineage_table(connection, database) - # Build a single INSERT statement with multiple values for atomicity - placeholders = ", ".join(["(%s, %s, %s)"] * len(entries)) + + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + + # Build backend-agnostic upsert statement + columns = ["table_name", "attribute_name", "lineage"] + primary_key = ["table_name", "attribute_name"] + + sql = adapter.upsert_on_duplicate_sql( + lineage_table, + columns, + primary_key, + len(entries), + ) + # Flatten the entries into a single args tuple args = tuple(val for entry in entries for val in entry) - connection.query( - """ - INSERT INTO `{database}`.`~lineage` (table_name, attribute_name, lineage) - VALUES {placeholders} - ON DUPLICATE KEY UPDATE lineage = VALUES(lineage) - """.format(database=database, placeholders=placeholders), - args=args, - ) + + connection.query(sql, args=args) def delete_table_lineages(connection, database, table_name): @@ -213,11 +242,15 @@ def delete_table_lineages(connection, database, table_name): """ if not lineage_table_exists(connection, database): return + + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + connection.query( - """ - DELETE FROM `{database}`.`~lineage` + f""" + DELETE FROM {lineage_table} WHERE table_name = %s - """.format(database=database), + """, args=(table_name,), ) @@ -251,8 +284,11 @@ def rebuild_schema_lineage(connection, database): # Ensure the lineage table exists ensure_lineage_table(connection, database) + adapter = connection.adapter + lineage_table = f"{adapter.quote_identifier(database)}.{adapter.quote_identifier('~lineage')}" + # Clear all existing lineage entries for this schema - connection.query(f"DELETE FROM `{database}`.`~lineage`") + connection.query(f"DELETE FROM {lineage_table}") # Get all tables in the schema (excluding hidden tables) tables_result = connection.query( diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 1878870df..2955fd67d 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -192,7 +192,8 @@ def activate( # create database logger.debug("Creating schema `{name}`.".format(name=schema_name)) try: - self.connection.query("CREATE DATABASE `{name}`".format(name=schema_name)) + create_sql = self.connection.adapter.create_schema_sql(schema_name) + self.connection.query(create_sql) except AccessError: raise DataJointError( "Schema `{name}` does not exist and could not be created. Check permissions.".format(name=schema_name) @@ -415,7 +416,8 @@ def drop(self, prompt: bool | None = None) -> None: elif not prompt or user_choice("Proceed to delete entire schema `%s`?" % self.database, default="no") == "yes": logger.debug("Dropping `{database}`.".format(database=self.database)) try: - self.connection.query("DROP DATABASE `{database}`".format(database=self.database)) + drop_sql = self.connection.adapter.drop_schema_sql(self.database) + self.connection.query(drop_sql) logger.debug("Schema `{database}` was dropped successfully.".format(database=self.database)) except AccessError: raise AccessError( @@ -517,13 +519,17 @@ def jobs(self) -> list[Job]: jobs_list = [] # Get all existing job tables (~~prefix) - # Note: %% escapes the % in pymysql - result = self.connection.query(f"SHOW TABLES IN `{self.database}` LIKE '~~%%'").fetchall() + # Note: %% escapes the % in pymysql/psycopg2 + adapter = self.connection.adapter + sql = adapter.list_tables_sql(self.database, pattern="~~%%") + result = self.connection.query(sql).fetchall() existing_job_tables = {row[0] for row in result} # Iterate over auto-populated tables and check if their job table exists for table_name in self.list_tables(): - table = FreeTable(self.connection, f"`{self.database}`.`{table_name}`") + adapter = self.connection.adapter + full_name = f"{adapter.quote_identifier(self.database)}." f"{adapter.quote_identifier(table_name)}" + table = FreeTable(self.connection, full_name) tier = _get_tier(table.full_table_name) if tier in (Computed, Imported): # Compute expected job table name: ~~base_name @@ -696,7 +702,8 @@ def get_table(self, name: str) -> FreeTable: if table_name is None: raise DataJointError(f"Table `{name}` does not exist in schema `{self.database}`.") - full_name = f"`{self.database}`.`{table_name}`" + adapter = self.connection.adapter + full_name = f"{adapter.quote_identifier(self.database)}.{adapter.quote_identifier(table_name)}" return FreeTable(self.connection, full_name) def __getitem__(self, name: str) -> FreeTable: @@ -894,7 +901,7 @@ def virtual_schema( -------- >>> lab = dj.virtual_schema('my_lab') >>> lab.Subject.fetch() - >>> lab.Session & 'subject_id="M001"' + >>> lab.Session & "subject_id='M001'" See Also -------- diff --git a/src/datajoint/settings.py b/src/datajoint/settings.py index 0338555ed..e373ca38f 100644 --- a/src/datajoint/settings.py +++ b/src/datajoint/settings.py @@ -15,6 +15,10 @@ >>> import datajoint as dj >>> dj.config.database.host 'localhost' +>>> dj.config.database.backend +'mysql' +>>> dj.config.database.port # Auto-detects: 3306 for MySQL, 5432 for PostgreSQL +3306 >>> with dj.config.override(safemode=False): ... # dangerous operations here ... pass @@ -43,7 +47,7 @@ from pathlib import Path from typing import Any, Iterator, Literal -from pydantic import Field, SecretStr, field_validator +from pydantic import Field, SecretStr, field_validator, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict from .errors import DataJointError @@ -59,10 +63,12 @@ "database.host": "DJ_HOST", "database.user": "DJ_USER", "database.password": "DJ_PASS", + "database.backend": "DJ_BACKEND", "database.port": "DJ_PORT", "database.database_prefix": "DJ_DATABASE_PREFIX", "database.create_tables": "DJ_CREATE_TABLES", "loglevel": "DJ_LOG_LEVEL", + "display.diagram_direction": "DJ_DIAGRAM_DIRECTION", } Role = Enum("Role", "manual lookup imported computed job") @@ -184,9 +190,14 @@ class DatabaseSettings(BaseSettings): host: str = Field(default="localhost", validation_alias="DJ_HOST") user: str | None = Field(default=None, validation_alias="DJ_USER") password: SecretStr | None = Field(default=None, validation_alias="DJ_PASS") - port: int = Field(default=3306, validation_alias="DJ_PORT") + backend: Literal["mysql", "postgresql"] = Field( + default="mysql", + validation_alias="DJ_BACKEND", + description="Database backend: 'mysql' or 'postgresql'", + ) + port: int | None = Field(default=None, validation_alias="DJ_PORT") reconnect: bool = True - use_tls: bool | None = None + use_tls: bool | None = Field(default=None, validation_alias="DJ_USE_TLS") database_prefix: str = Field( default="", validation_alias="DJ_DATABASE_PREFIX", @@ -200,6 +211,13 @@ class DatabaseSettings(BaseSettings): "Set to False for production mode to prevent automatic table creation.", ) + @model_validator(mode="after") + def set_default_port_from_backend(self) -> "DatabaseSettings": + """Set default port based on backend if not explicitly provided.""" + if self.port is None: + self.port = 5432 if self.backend == "postgresql" else 3306 + return self + class ConnectionSettings(BaseSettings): """Connection behavior settings.""" @@ -218,6 +236,11 @@ class DisplaySettings(BaseSettings): limit: int = 12 width: int = 14 show_tuple_count: bool = True + diagram_direction: Literal["TB", "LR"] = Field( + default="LR", + validation_alias="DJ_DIAGRAM_DIRECTION", + description="Default diagram layout direction: 'TB' (top-to-bottom) or 'LR' (left-to-right)", + ) class StoresSettings(BaseSettings): diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 950ab513f..59279489e 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -4,7 +4,6 @@ import itertools import json import logging -import re import uuid import warnings from dataclasses import dataclass, field @@ -30,24 +29,8 @@ logger = logging.getLogger(__name__.split(".")[0]) -foreign_key_error_regexp = re.compile( - r"[\w\s:]*\((?P`[^`]+`.`[^`]+`), " - r"CONSTRAINT (?P`[^`]+`) " - r"(FOREIGN KEY \((?P[^)]+)\) " - r"REFERENCES (?P`[^`]+`(\.`[^`]+`)?) \((?P[^)]+)\)[\s\w]+\))?" -) - -constraint_info_query = " ".join( - """ - SELECT - COLUMN_NAME as fk_attrs, - CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent, - REFERENCED_COLUMN_NAME as pk_attrs - FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE - WHERE - CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s; - """.split() -) +# Note: Foreign key error parsing is now handled by adapter methods +# Legacy regexp and query kept for reference but no longer used @dataclass @@ -146,7 +129,10 @@ def declare(self, context=None): """ Declare the table in the schema based on self.definition. - :param context: the context for foreign key resolution. If None, foreign keys are + Parameters + ---------- + context : dict, optional + The context for foreign key resolution. If None, foreign keys are not allowed. """ if self.connection.in_transaction: @@ -166,14 +152,26 @@ def declare(self, context=None): f"Table class name `{self.class_name}` is invalid. " "Class names must be in CamelCase, starting with a capital letter." ) - sql, _external_stores, primary_key, fk_attribute_map = declare(self.full_table_name, self.definition, context) + sql, _external_stores, primary_key, fk_attribute_map, pre_ddl, post_ddl = declare( + self.full_table_name, self.definition, context, self.connection.adapter + ) # Call declaration hook for validation (subclasses like AutoPopulate can override) self._declare_check(primary_key, fk_attribute_map) sql = sql.format(database=self.database) try: + # Execute pre-DDL statements (e.g., CREATE TYPE for PostgreSQL enums) + for ddl in pre_ddl: + try: + self.connection.query(ddl.format(database=self.database)) + except Exception: + # Ignore errors (type may already exist) + pass self.connection.query(sql) + # Execute post-DDL statements (e.g., COMMENT ON for PostgreSQL) + for ddl in post_ddl: + self.connection.query(ddl.format(database=self.database)) except AccessError: # Only suppress if table already exists (idempotent declaration) # Otherwise raise - user needs to know about permission issues @@ -195,8 +193,12 @@ def _declare_check(self, primary_key, fk_attribute_map): Called before the table is created in the database. Override this method to add validation logic (e.g., AutoPopulate validates FK-only primary keys). - :param primary_key: list of primary key attribute names - :param fk_attribute_map: dict mapping child_attr -> (parent_table, parent_attr) + Parameters + ---------- + primary_key : list + List of primary key attribute names. + fk_attribute_map : dict + Dict mapping child_attr -> (parent_table, parent_attr). """ pass # Default: no validation @@ -208,8 +210,12 @@ def _populate_lineage(self, primary_key, fk_attribute_map): - All FK attributes (traced to their origin) - Native primary key attributes (lineage = self) - :param primary_key: list of primary key attribute names - :param fk_attribute_map: dict mapping child_attr -> (parent_table, parent_attr) + Parameters + ---------- + primary_key : list + List of primary key attribute names. + fk_attribute_map : dict + Dict mapping child_attr -> (parent_table, parent_attr). """ from .lineage import ( ensure_lineage_table, @@ -228,8 +234,8 @@ def _populate_lineage(self, primary_key, fk_attribute_map): # FK attributes: copy lineage from parent (whether in PK or not) for attr, (parent_table, parent_attr) in fk_attribute_map.items(): - # Parse parent table name: `schema`.`table` -> (schema, table) - parent_clean = parent_table.replace("`", "") + # Parse parent table name: `schema`.`table` or "schema"."table" -> (schema, table) + parent_clean = parent_table.replace("`", "").replace('"', "") if "." in parent_clean: parent_db, parent_tbl = parent_clean.split(".", 1) else: @@ -273,7 +279,7 @@ def alter(self, prompt=True, context=None): context = dict(frame.f_globals, **frame.f_locals) del frame old_definition = self.describe(context=context) - sql, _external_stores = alter(self.definition, old_definition, context) + sql, _external_stores = alter(self.definition, old_definition, context, self.connection.adapter) if not sql: if prompt: logger.warning("Nothing to alter.") @@ -293,26 +299,51 @@ def alter(self, prompt=True, context=None): def from_clause(self): """ - :return: the FROM clause of SQL SELECT statements. + Return the FROM clause of SQL SELECT statements. + + Returns + ------- + str + The full table name for use in SQL FROM clauses. """ return self.full_table_name def get_select_fields(self, select_fields=None): """ - :return: the selected attributes from the SQL SELECT statement. + Return the selected attributes from the SQL SELECT statement. + + Parameters + ---------- + select_fields : list, optional + List of attribute names to select. If None, selects all attributes. + + Returns + ------- + str + The SQL field selection string. """ return "*" if select_fields is None else self.heading.project(select_fields).as_sql def parents(self, primary=None, as_objects=False, foreign_key_info=False): """ - - :param primary: if None, then all parents are returned. If True, then only foreign keys composed of - primary key attributes are considered. If False, return foreign keys including at least one - secondary attribute. - :param as_objects: if False, return table names. If True, return table objects. - :param foreign_key_info: if True, each element in result also includes foreign key info. - :return: list of parents as table names or table objects - with (optional) foreign key information. + Return the list of parent tables. + + Parameters + ---------- + primary : bool, optional + If None, then all parents are returned. If True, then only foreign keys + composed of primary key attributes are considered. If False, return + foreign keys including at least one secondary attribute. + as_objects : bool, optional + If False, return table names. If True, return table objects. + foreign_key_info : bool, optional + If True, each element in result also includes foreign key info. + + Returns + ------- + list + List of parents as table names or table objects with (optional) foreign + key information. """ get_edge = self.connection.dependencies.parents nodes = [ @@ -327,13 +358,24 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False): def children(self, primary=None, as_objects=False, foreign_key_info=False): """ - :param primary: if None, then all children are returned. If True, then only foreign keys composed of - primary key attributes are considered. If False, return foreign keys including at least one - secondary attribute. - :param as_objects: if False, return table names. If True, return table objects. - :param foreign_key_info: if True, each element in result also includes foreign key info. - :return: list of children as table names or table objects - with (optional) foreign key information. + Return the list of child tables. + + Parameters + ---------- + primary : bool, optional + If None, then all children are returned. If True, then only foreign keys + composed of primary key attributes are considered. If False, return + foreign keys including at least one secondary attribute. + as_objects : bool, optional + If False, return table names. If True, return table objects. + foreign_key_info : bool, optional + If True, each element in result also includes foreign key info. + + Returns + ------- + list + List of children as table names or table objects with (optional) foreign + key information. """ get_edge = self.connection.dependencies.children nodes = [ @@ -348,8 +390,18 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False): def descendants(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. - :return: list of tables descendants in topological order. + Return list of descendant tables in topological order. + + Parameters + ---------- + as_objects : bool, optional + If False (default), return a list of table names. If True, return a + list of table objects. + + Returns + ------- + list + List of descendant tables in topological order. """ return [ FreeTable(self.connection, node) if as_objects else node @@ -359,8 +411,18 @@ def descendants(self, as_objects=False): def ancestors(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. - :return: list of tables ancestors in topological order. + Return list of ancestor tables in topological order. + + Parameters + ---------- + as_objects : bool, optional + If False (default), return a list of table names. If True, return a + list of table objects. + + Returns + ------- + list + List of ancestor tables in topological order. """ return [ FreeTable(self.connection, node) if as_objects else node @@ -370,9 +432,18 @@ def ancestors(self, as_objects=False): def parts(self, as_objects=False): """ - return part tables either as entries in a dict with foreign key information or a list of objects - - :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects. + Return part tables for this master table. + + Parameters + ---------- + as_objects : bool, optional + If False (default), the output is a list of full table names. If True, + return table objects. + + Returns + ------- + list + List of part table names or table objects. """ self.connection.dependencies.load(force=False) nodes = [ @@ -385,43 +456,58 @@ def parts(self, as_objects=False): @property def is_declared(self): """ - :return: True is the table is declared in the schema. + Check if the table is declared in the schema. + + Returns + ------- + bool + True if the table is declared in the schema. """ - return ( - self.connection.query( - 'SHOW TABLES in `{database}` LIKE "{table_name}"'.format(database=self.database, table_name=self.table_name) - ).rowcount - > 0 - ) + query = self.connection.adapter.get_table_info_sql(self.database, self.table_name) + return self.connection.query(query).rowcount > 0 @property def full_table_name(self): """ - :return: full table name in the schema + Return the full table name in the schema. + + Returns + ------- + str + Full table name in the format `database`.`table_name`. """ if self.database is None or self.table_name is None: raise DataJointError( f"Class {self.__class__.__name__} is not associated with a schema. " "Apply a schema decorator or use schema() to bind it." ) - return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) + return f"{self.adapter.quote_identifier(self.database)}.{self.adapter.quote_identifier(self.table_name)}" + + @property + def adapter(self): + """Database adapter for backend-agnostic SQL generation.""" + return self.connection.adapter def update1(self, row): """ - ``update1`` updates one existing entry in the table. + Update one existing entry in the table. + Caution: In DataJoint the primary modes for data manipulation is to ``insert`` and ``delete`` entire records since referential integrity works on the level of records, not fields. Therefore, updates are reserved for corrective operations outside of main workflow. Use UPDATE methods sparingly with full awareness of potential violations of assumptions. - :param row: a ``dict`` containing the primary key values and the attributes to update. - Setting an attribute value to None will reset it to the default value (if any). - The primary key attributes must always be provided. - Examples: + Parameters + ---------- + row : dict + A dict containing the primary key values and the attributes to update. + Setting an attribute value to None will reset it to the default value (if any). + Examples + -------- >>> table.update1({'id': 1, 'value': 3}) # update value in record with id=1 >>> table.update1({'id': 1, 'value': None}) # reset value to default """ @@ -441,9 +527,10 @@ def update1(self, row): raise DataJointError("Update can only be applied to one existing entry.") # UPDATE query row = [self.__make_placeholder(k, v) for k, v in row.items() if k not in self.primary_key] + assignments = ",".join(f"{self.adapter.quote_identifier(r[0])}={r[1]}" for r in row) query = "UPDATE {table} SET {assignments} WHERE {where}".format( table=self.full_table_name, - assignments=",".join("`%s`=%s" % r[:2] for r in row), + assignments=assignments, where=make_condition(self, key, set()), ) self.connection.query(query, args=list(r[2] for r in row if r[2] is not None)) @@ -452,11 +539,6 @@ def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult: """ Validate rows without inserting them. - :param rows: Same format as insert() - iterable of dicts, tuples, numpy records, - or a pandas DataFrame. - :param ignore_extra_fields: If True, ignore fields not in the table heading. - :return: ValidationResult with is_valid, errors list, and rows_checked count. - Validates: - Field existence (all fields must be in table heading) - Row format (correct number of attributes for positional inserts) @@ -470,13 +552,26 @@ def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult: - Unique constraints (other than PK) - Custom MySQL constraints - Example:: - - result = table.validate(rows) - if result: - table.insert(rows) - else: - print(result.summary()) + Parameters + ---------- + rows : iterable + Same format as insert() - iterable of dicts, tuples, numpy records, + or a pandas DataFrame. + ignore_extra_fields : bool, optional + If True, ignore fields not in the table heading. + + Returns + ------- + ValidationResult + Result with is_valid, errors list, and rows_checked count. + + Examples + -------- + >>> result = table.validate(rows) + >>> if result: + ... table.insert(rows) + ... else: + ... print(result.summary()) """ errors = [] @@ -587,10 +682,21 @@ def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult: def insert1(self, row, **kwargs): """ - Insert one data record into the table. For ``kwargs``, see ``insert()``. + Insert one data record into the table. - :param row: a numpy record, a dict-like object, or an ordered sequence to be inserted + For ``kwargs``, see ``insert()``. + + Parameters + ---------- + row : numpy.void, dict, or sequence + A numpy record, a dict-like object, or an ordered sequence to be inserted as one row. + **kwargs + Additional arguments passed to ``insert()``. + + See Also + -------- + insert : Insert multiple data records. """ self.insert((row,), **kwargs) @@ -635,27 +741,36 @@ def insert( """ Insert a collection of rows. - :param rows: Either (a) an iterable where an element is a numpy record, a - dict-like object, a pandas.DataFrame, a polars.DataFrame, a pyarrow.Table, - a sequence, or a query expression with the same heading as self, or + Parameters + ---------- + rows : iterable or pathlib.Path + Either (a) an iterable where an element is a numpy record, a dict-like + object, a pandas.DataFrame, a polars.DataFrame, a pyarrow.Table, a + sequence, or a query expression with the same heading as self, or (b) a pathlib.Path object specifying a path relative to the current directory with a CSV file, the contents of which will be inserted. - :param replace: If True, replaces the existing tuple. - :param skip_duplicates: If True, silently skip duplicate inserts. - :param ignore_extra_fields: If False, fields that are not in the heading raise error. - :param allow_direct_insert: Only applies in auto-populated tables. If False (default), - insert may only be called from inside the make callback. - :param chunk_size: If set, insert rows in batches of this size. Useful for very - large inserts to avoid memory issues. Each chunk is a separate transaction. - - Example: - - >>> Table.insert([ - >>> dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"), - >>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")]) - - # Large insert with chunking - >>> Table.insert(large_dataset, chunk_size=10000) + replace : bool, optional + If True, replaces the existing tuple. + skip_duplicates : bool, optional + If True, silently skip duplicate inserts. + ignore_extra_fields : bool, optional + If False (default), fields that are not in the heading raise error. + allow_direct_insert : bool, optional + Only applies in auto-populated tables. If False (default), insert may + only be called from inside the make callback. + chunk_size : int, optional + If set, insert rows in batches of this size. Useful for very large + inserts to avoid memory issues. Each chunk is a separate transaction. + + Examples + -------- + >>> Table.insert([ + ... dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"), + ... dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")]) + + Large insert with chunking: + + >>> Table.insert(large_dataset, chunk_size=10000) """ if isinstance(rows, pandas.DataFrame): # drop 'extra' synthetic index for 1-field index case - @@ -697,17 +812,16 @@ def insert( except StopIteration: pass fields = list(name for name in rows.heading if name in self.heading) - query = "{command} INTO {table} ({fields}) {select}{duplicate}".format( - command="REPLACE" if replace else "INSERT", - fields="`" + "`,`".join(fields) + "`", - table=self.full_table_name, - select=rows.make_sql(fields), - duplicate=( - " ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`".format(table=self.full_table_name, pk=self.primary_key[0]) - if skip_duplicates - else "" - ), - ) + quoted_fields = ",".join(self.adapter.quote_identifier(f) for f in fields) + + # Duplicate handling (backend-agnostic) + if skip_duplicates: + duplicate = self.adapter.skip_duplicates_clause(self.full_table_name, self.primary_key) + else: + duplicate = "" + + command = "REPLACE" if replace else "INSERT" + query = f"{command} INTO {self.full_table_name} ({quoted_fields}) {rows.make_sql(fields)}{duplicate}" self.connection.query(query) return @@ -728,10 +842,16 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields): """ Internal helper to insert a batch of rows. - :param rows: Iterable of rows to insert - :param replace: If True, use REPLACE instead of INSERT - :param skip_duplicates: If True, use ON DUPLICATE KEY UPDATE - :param ignore_extra_fields: If True, ignore unknown fields + Parameters + ---------- + rows : iterable + Iterable of rows to insert. + replace : bool + If True, use REPLACE instead of INSERT. + skip_duplicates : bool + If True, use ON DUPLICATE KEY UPDATE. + ignore_extra_fields : bool + If True, ignore unknown fields. """ # collects the field list from first row (passed by reference) field_list = [] @@ -739,16 +859,20 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields): if rows: try: # Handle empty field_list (all-defaults insert) - fields_clause = f"(`{'`,`'.join(field_list)}`)" if field_list else "()" - query = "{command} INTO {destination}{fields} VALUES {placeholders}{duplicate}".format( - command="REPLACE" if replace else "INSERT", - destination=self.from_clause(), - fields=fields_clause, - placeholders=",".join("(" + ",".join(row["placeholders"]) + ")" for row in rows), - duplicate=( - " ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`".format(pk=self.primary_key[0]) if skip_duplicates else "" - ), - ) + if field_list: + fields_clause = f"({','.join(self.adapter.quote_identifier(f) for f in field_list)})" + else: + fields_clause = "()" + + # Build duplicate clause (backend-agnostic) + if skip_duplicates: + duplicate = self.adapter.skip_duplicates_clause(self.full_table_name, self.primary_key) + else: + duplicate = "" + + command = "REPLACE" if replace else "INSERT" + placeholders = ",".join("(" + ",".join(row["placeholders"]) + ")" for row in rows) + query = f"{command} INTO {self.from_clause()}{fields_clause} VALUES {placeholders}{duplicate}" self.connection.query( query, args=list(itertools.chain.from_iterable((v for v in r["values"] if v is not None) for r in rows)), @@ -766,26 +890,34 @@ def insert_dataframe(self, df, index_as_pk=None, **insert_kwargs): (which sets primary key as index) can be modified and re-inserted using insert_dataframe() without manual index manipulation. - :param df: pandas DataFrame to insert - :param index_as_pk: How to handle DataFrame index: + Parameters + ---------- + df : pandas.DataFrame + DataFrame to insert. + index_as_pk : bool, optional + How to handle DataFrame index: + - None (default): Auto-detect. Use index as primary key if index names match primary_key columns. Drop if unnamed RangeIndex. - True: Treat index as primary key columns. Raises if index names don't match table primary key. - False: Ignore index entirely (drop it). - :param **insert_kwargs: Passed to insert() - replace, skip_duplicates, - ignore_extra_fields, allow_direct_insert, chunk_size + **insert_kwargs + Passed to insert() - replace, skip_duplicates, ignore_extra_fields, + allow_direct_insert, chunk_size. - Example:: + Examples + -------- + Round-trip with to_pandas(): - # Round-trip with to_pandas() - df = table.to_pandas() # PK becomes index - df['value'] = df['value'] * 2 # Modify data - table.insert_dataframe(df) # Auto-detects index as PK + >>> df = table.to_pandas() # PK becomes index + >>> df['value'] = df['value'] * 2 # Modify data + >>> table.insert_dataframe(df) # Auto-detects index as PK - # Explicit control - table.insert_dataframe(df, index_as_pk=True) # Use index - table.insert_dataframe(df, index_as_pk=False) # Ignore index + Explicit control: + + >>> table.insert_dataframe(df, index_as_pk=True) # Use index + >>> table.insert_dataframe(df, index_as_pk=False) # Ignore index """ if not isinstance(df, pandas.DataFrame): raise DataJointError("insert_dataframe requires a pandas DataFrame") @@ -839,8 +971,9 @@ def delete_quick(self, get_count=False): If this table has populated dependent tables, this will fail. """ query = "DELETE FROM " + self.full_table_name + self.where_clause() - self.connection.query(query) - count = self.connection.query("SELECT ROW_COUNT()").fetchone()[0] if get_count else None + cursor = self.connection.query(query) + # Use cursor.rowcount (DB-API 2.0 standard, works for both MySQL and PostgreSQL) + count = cursor.rowcount if get_count else None return count def delete( @@ -881,44 +1014,72 @@ def cascade(table): """service function to perform cascading deletes recursively.""" max_attempts = 50 for _ in range(max_attempts): + # Set savepoint before delete attempt (for PostgreSQL transaction handling) + savepoint_name = f"cascade_delete_{id(table)}" + if transaction: + table.connection.query(f"SAVEPOINT {savepoint_name}") + try: delete_count = table.delete_quick(get_count=True) except IntegrityError as error: - match = foreign_key_error_regexp.match(error.args[0]) + # Rollback to savepoint so we can continue querying (PostgreSQL requirement) + if transaction: + table.connection.query(f"ROLLBACK TO SAVEPOINT {savepoint_name}") + # Use adapter to parse FK error message + match = table.connection.adapter.parse_foreign_key_error(error.args[0]) if match is None: raise DataJointError( - "Cascading deletes failed because the error message is missing foreign key information." + "Cascading deletes failed because the error message is missing foreign key information. " "Make sure you have REFERENCES privilege to all dependent tables." ) from None - match = match.groupdict() - # if schema name missing, use table - if "`.`" not in match["child"]: - match["child"] = "{}.{}".format(table.full_table_name.split(".")[0], match["child"]) - if match["pk_attrs"] is not None: # fully matched, adjusting the keys - match["fk_attrs"] = [k.strip("`") for k in match["fk_attrs"].split(",")] - match["pk_attrs"] = [k.strip("`") for k in match["pk_attrs"].split(",")] - else: # only partially matched, querying with constraint to determine keys - match["fk_attrs"], match["parent"], match["pk_attrs"] = list( - map( - list, - zip( - *table.connection.query( - constraint_info_query, - args=( - match["name"].strip("`"), - *[_.strip("`") for _ in match["child"].split("`.`")], - ), - ).fetchall() - ), - ) + + # Strip quotes from parsed values for backend-agnostic processing + quote_chars = ("`", '"') + + def strip_quotes(s): + if s and any(s.startswith(q) for q in quote_chars): + return s.strip('`"') + return s + + # Extract schema and table name from child (work with unquoted names) + child_table_raw = strip_quotes(match["child"]) + if "." in child_table_raw: + child_parts = child_table_raw.split(".") + child_schema = strip_quotes(child_parts[0]) + child_table_name = strip_quotes(child_parts[1]) + else: + # Add schema from current table + schema_parts = table.full_table_name.split(".") + child_schema = strip_quotes(schema_parts[0]) + child_table_name = child_table_raw + + # If FK/PK attributes not in error message, query information_schema + if match["fk_attrs"] is None or match["pk_attrs"] is None: + constraint_query = table.connection.adapter.get_constraint_info_sql( + strip_quotes(match["name"]), + child_schema, + child_table_name, ) - match["parent"] = match["parent"][0] + + results = table.connection.query( + constraint_query, + args=(strip_quotes(match["name"]), child_schema, child_table_name), + ).fetchall() + if results: + match["fk_attrs"], match["parent"], match["pk_attrs"] = list(map(list, zip(*results))) + match["parent"] = match["parent"][0] # All rows have same parent + + # Build properly quoted full table name for FreeTable + child_full_name = ( + f"{table.connection.adapter.quote_identifier(child_schema)}." + f"{table.connection.adapter.quote_identifier(child_table_name)}" + ) # Restrict child by table if # 1. if table's restriction attributes are not in child's primary key # 2. if child renames any attributes # Otherwise restrict child by table's restriction. - child = FreeTable(table.connection, match["child"]) + child = FreeTable(table.connection, child_full_name) if set(table.restriction_attributes) <= set(child.primary_key) and match["fk_attrs"] == match["pk_attrs"]: child._restriction = table._restriction child._restriction_attributes = table.restriction_attributes @@ -927,7 +1088,7 @@ def cascade(table): else: child &= table.proj() - master_name = get_master(child.full_table_name) + master_name = get_master(child.full_table_name, table.connection.adapter) if ( part_integrity == "cascade" and master_name @@ -948,6 +1109,9 @@ def cascade(table): else: cascade(child) else: + # Successful delete - release savepoint + if transaction: + table.connection.query(f"RELEASE SAVEPOINT {savepoint_name}") deleted.add(table.full_table_name) logger.info("Deleting {count} rows from {table}".format(count=delete_count, table=table.full_table_name)) break @@ -980,7 +1144,7 @@ def cascade(table): if part_integrity == "enforce": # Avoid deleting from part before master (See issue #151) for part in deleted: - master = get_master(part) + master = get_master(part, self.connection.adapter) if master and master not in deleted: if transaction: self.connection.cancel_transaction() @@ -1023,9 +1187,31 @@ def drop_quick(self): delete_table_lineages(self.connection, self.database, self.table_name) + # For PostgreSQL, get enum types used by this table before dropping + # (we need to query this before the table is dropped) + enum_types_to_drop = [] + adapter = self.connection.adapter + if hasattr(adapter, "get_table_enum_types_sql"): + try: + enum_query = adapter.get_table_enum_types_sql(self.database, self.table_name) + result = self.connection.query(enum_query) + enum_types_to_drop = [row[0] for row in result.fetchall()] + except Exception: + pass # Ignore errors - enum cleanup is best-effort + query = "DROP TABLE %s" % self.full_table_name self.connection.query(query) logger.info("Dropped table %s" % self.full_table_name) + + # For PostgreSQL, clean up enum types after dropping the table + if enum_types_to_drop and hasattr(adapter, "drop_enum_type_ddl"): + for enum_type in enum_types_to_drop: + try: + drop_ddl = adapter.drop_enum_type_ddl(enum_type) + self.connection.query(drop_ddl) + logger.debug("Dropped enum type %s" % enum_type) + except Exception: + pass # Ignore errors - type may be used by other tables else: logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) @@ -1049,7 +1235,7 @@ def drop(self, prompt: bool | None = None): # avoid dropping part tables without their masters: See issue #374 for part in tables: - master = get_master(part) + master = get_master(part, self.connection.adapter) if master and master not in tables: raise DataJointError( "Attempt to drop part table {part} before dropping its master. Drop {master} first.".format( @@ -1069,7 +1255,12 @@ def drop(self, prompt: bool | None = None): @property def size_on_disk(self): """ - :return: size of data and indices in bytes on the storage device + Return the size of data and indices in bytes on the storage device. + + Returns + ------- + int + Size of data and indices in bytes. """ ret = self.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format(database=self.database, table=self.table_name), @@ -1079,7 +1270,20 @@ def size_on_disk(self): def describe(self, context=None, printout=False): """ - :return: the definition string for the query using DataJoint DDL. + Return the definition string for the query using DataJoint DDL. + + Parameters + ---------- + context : dict, optional + The context for foreign key resolution. If None, uses the caller's + local and global namespace. + printout : bool, optional + If True, also log the definition string. + + Returns + ------- + str + The definition string for the table in DataJoint DDL format. """ if context is None: frame = inspect.currentframe().f_back @@ -1092,7 +1296,7 @@ def describe(self, context=None, printout=False): definition = "# " + self.heading.table_status["comment"] + "\n" if self.heading.table_status["comment"] else "" attributes_thus_far = set() attributes_declared = set() - indexes = self.heading.indexes.copy() + indexes = self.heading.indexes.copy() if self.heading.indexes else {} for attr in self.heading.attributes.values(): if in_key and not attr.in_key: definition += "---\n" @@ -1157,9 +1361,11 @@ def describe(self, context=None, printout=False): # --- private helper functions ---- def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): """ - For a given attribute `name` with `value`, return its processed value or value placeholder - as a string to be included in the query and the value, if any, to be submitted for - processing by mysql API. + Return processed value or placeholder for an attribute. + + For a given attribute `name` with `value`, return its processed value or + value placeholder as a string to be included in the query and the value, + if any, to be submitted for processing by mysql API. In the simplified type system: - Codecs handle all custom encoding via type chains @@ -1168,10 +1374,22 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): - Blob values pass through as bytes - Numeric values are stringified - :param name: name of attribute to be inserted - :param value: value of attribute to be inserted - :param ignore_extra_fields: if True, return None for unknown fields - :param row: the full row dict (unused in simplified model) + Parameters + ---------- + name : str + Name of attribute to be inserted. + value : any + Value of attribute to be inserted. + ignore_extra_fields : bool, optional + If True, return None for unknown fields. + row : dict, optional + The full row dict (used for context in codec encoding). + + Returns + ------- + tuple or None + A tuple of (name, placeholder, value) or None if the field should be + ignored. """ if ignore_extra_fields and name not in self.heading: return None @@ -1239,17 +1457,31 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): def __make_row_to_insert(self, row, field_list, ignore_extra_fields): """ - Helper function for insert and update - - :param row: A tuple to insert - :return: a dict with fields 'names', 'placeholders', 'values' + Helper function for insert and update. + + Parameters + ---------- + row : tuple, dict, or numpy.void + A row to insert. + field_list : list + List to be populated with field names from the first row. + ignore_extra_fields : bool + If True, ignore fields not in the heading. + + Returns + ------- + dict + A dict with fields 'names', 'placeholders', 'values'. """ def check_fields(fields): """ - Validates that all items in `fields` are valid attributes in the heading + Validate that all items in `fields` are valid attributes in the heading. - :param fields: field names of a tuple + Parameters + ---------- + fields : list + Field names of a tuple. """ if not field_list: if not ignore_extra_fields: @@ -1329,12 +1561,24 @@ def check_fields(fields): def lookup_class_name(name, context, depth=3): """ - given a table name in the form `schema_name`.`table_name`, find its class in the context. - - :param name: `schema_name`.`table_name` - :param context: dictionary representing the namespace - :param depth: search depth into imported modules, helps avoid infinite recursion. - :return: class name found in the context or None if not found + Find a table's class in the context given its full table name. + + Given a table name in the form `schema_name`.`table_name`, find its class in + the context. + + Parameters + ---------- + name : str + Full table name in format `schema_name`.`table_name`. + context : dict + Dictionary representing the namespace. + depth : int, optional + Search depth into imported modules, helps avoid infinite recursion. + + Returns + ------- + str or None + Class name found in the context or None if not found. """ # breadth-first search nodes = [dict(context=context, context_name="", depth=depth)] @@ -1370,15 +1614,21 @@ def lookup_class_name(name, context, depth=3): class FreeTable(Table): """ - A base table without a dedicated class. Each instance is associated with a table - specified by full_table_name. + A base table without a dedicated class. + + Each instance is associated with a table specified by full_table_name. - :param conn: a dj.Connection object - :param full_table_name: in format `database`.`table_name` + Parameters + ---------- + conn : datajoint.Connection + A DataJoint connection object. + full_table_name : str + Full table name in format `database`.`table_name`. """ def __init__(self, conn, full_table_name): - self.database, self._table_name = (s.strip("`") for s in full_table_name.split(".")) + # Backend-agnostic quote stripping (MySQL uses `, PostgreSQL uses ") + self.database, self._table_name = (s.strip('`"') for s in full_table_name.split(".")) self._connection = conn self._support = [full_table_name] self._heading = Heading( @@ -1391,4 +1641,4 @@ def __init__(self, conn, full_table_name): ) def __repr__(self): - return "FreeTable(`%s`.`%s`)\n" % (self.database, self._table_name) + super().__repr__() + return f"FreeTable({self.full_table_name})\n" + super().__repr__() diff --git a/src/datajoint/user_tables.py b/src/datajoint/user_tables.py index f85273a1e..4c2ba8d4c 100644 --- a/src/datajoint/user_tables.py +++ b/src/datajoint/user_tables.py @@ -103,10 +103,11 @@ def table_name(cls): @property def full_table_name(cls): - """The fully qualified table name (`database`.`table`).""" + """The fully qualified table name (quoted per backend).""" if cls.database is None: return None - return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + adapter = cls._connection.adapter + return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}" class UserTable(Table, metaclass=TableMeta): @@ -182,10 +183,11 @@ def table_name(cls): @property def full_table_name(cls): - """The fully qualified table name (`database`.`table`).""" + """The fully qualified table name (quoted per backend).""" if cls.database is None or cls.table_name is None: return None - return r"`{0:s}`.`{1:s}`".format(cls.database, cls.table_name) + adapter = cls._connection.adapter + return f"{adapter.quote_identifier(cls.database)}.{adapter.quote_identifier(cls.table_name)}" @property def master(cls): @@ -275,10 +277,16 @@ class _AliasNode: def _get_tier(table_name): """given the table name, return the user table class.""" - if not table_name.startswith("`"): - return _AliasNode + # Handle both MySQL backticks and PostgreSQL double quotes + if table_name.startswith("`"): + # MySQL format: `schema`.`table_name` + extracted_name = table_name.split("`")[-2] + elif table_name.startswith('"'): + # PostgreSQL format: "schema"."table_name" + extracted_name = table_name.split('"')[-2] else: - try: - return next(tier for tier in user_table_classes if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])) - except StopIteration: - return None + return _AliasNode + try: + return next(tier for tier in user_table_classes if re.fullmatch(tier.tier_regexp, extracted_name)) + except StopIteration: + return None diff --git a/src/datajoint/utils.py b/src/datajoint/utils.py index d7bf9ac6d..0441af354 100644 --- a/src/datajoint/utils.py +++ b/src/datajoint/utils.py @@ -10,12 +10,23 @@ def user_choice(prompt, choices=("yes", "no"), default=None): """ - Prompts the user for confirmation. The default value, if any, is capitalized. + Prompt the user for confirmation. - :param prompt: Information to display to the user. - :param choices: an iterable of possible choices. - :param default: default choice - :return: the user's choice + The default value, if any, is capitalized. + + Parameters + ---------- + prompt : str + Information to display to the user. + choices : tuple, optional + An iterable of possible choices. Default ("yes", "no"). + default : str, optional + Default choice. Default None. + + Returns + ------- + str + The user's choice. """ assert default is None or default in choices choice_list = ", ".join((choice.title() if choice == default else choice for choice in choices)) @@ -26,46 +37,88 @@ def user_choice(prompt, choices=("yes", "no"), default=None): return response -def get_master(full_table_name: str) -> str: +def get_master(full_table_name: str, adapter=None) -> str: """ + Get the master table name from a part table name. + If the table name is that of a part table, then return what the master table name would be. This follows DataJoint's table naming convention where a master and a part must be in the same schema and the part table is prefixed with the master table name + ``__``. - Example: - `ephys`.`session` -- master - `ephys`.`session__recording` -- part - - :param full_table_name: Full table name including part. - :type full_table_name: str - :return: Supposed master full table name or empty string if not a part table name. - :rtype: str - """ - match = re.match(r"(?P`\w+`.`\w+)__(?P\w+)`", full_table_name) - return match["master"] + "`" if match else "" + Parameters + ---------- + full_table_name : str + Full table name including part. + adapter : DatabaseAdapter, optional + Database adapter for backend-specific parsing. Default None. + + Returns + ------- + str + Supposed master full table name or empty string if not a part table name. + + Examples + -------- + >>> get_master('`ephys`.`session__recording`') # MySQL part table + '`ephys`.`session`' + >>> get_master('"ephys"."session__recording"') # PostgreSQL part table + '"ephys"."session"' + >>> get_master('`ephys`.`session`') # Not a part table + '' + """ + if adapter is not None: + result = adapter.get_master_table_name(full_table_name) + return result if result else "" + + # Fallback: handle both MySQL backticks and PostgreSQL double quotes + match = re.match(r'(?P(?P[`"])[\w]+(?P=q)\.(?P=q)[\w]+)__[\w]+(?P=q)', full_table_name) + if match: + return match["master"] + match["q"] + return "" def is_camel_case(s): """ Check if a string is in CamelCase notation. - :param s: string to check - :returns: True if the string is in CamelCase notation, False otherwise - Example: - >>> is_camel_case("TableName") # returns True - >>> is_camel_case("table_name") # returns False + Parameters + ---------- + s : str + String to check. + + Returns + ------- + bool + True if the string is in CamelCase notation, False otherwise. + + Examples + -------- + >>> is_camel_case("TableName") + True + >>> is_camel_case("table_name") + False """ return bool(re.match(r"^[A-Z][A-Za-z0-9]*$", s)) def to_camel_case(s): """ - Convert names with under score (_) separation into camel case names. + Convert names with underscore (_) separation into camel case names. + + Parameters + ---------- + s : str + String in under_score notation. - :param s: string in under_score notation - :returns: string in CamelCase notation - Example: - >>> to_camel_case("table_name") # returns "TableName" + Returns + ------- + str + String in CamelCase notation. + + Examples + -------- + >>> to_camel_case("table_name") + 'TableName' """ def to_upper(match): @@ -76,12 +129,27 @@ def to_upper(match): def from_camel_case(s): """ - Convert names in camel case into underscore (_) separated names + Convert names in camel case into underscore (_) separated names. + + Parameters + ---------- + s : str + String in CamelCase notation. - :param s: string in CamelCase notation - :returns: string in under_score notation - Example: - >>> from_camel_case("TableName") # yields "table_name" + Returns + ------- + str + String in under_score notation. + + Raises + ------ + DataJointError + If the string is not in valid CamelCase notation. + + Examples + -------- + >>> from_camel_case("TableName") + 'table_name' """ def convert(match): @@ -102,10 +170,17 @@ def convert(match): def safe_write(filepath, blob): """ - A two-step write. + Write data to a file using a two-step process. + + Writes to a temporary file first, then renames to the final path. + This ensures atomic writes and prevents partial file corruption. - :param filename: full path - :param blob: binary data + Parameters + ---------- + filepath : str or Path + Full path to the destination file. + blob : bytes + Binary data to write. """ filepath = Path(filepath) if not filepath.is_file(): @@ -117,7 +192,19 @@ def safe_write(filepath, blob): def safe_copy(src, dest, overwrite=False): """ - Copy the contents of src file into dest file as a two-step process. Skip if dest exists already + Copy the contents of src file into dest file as a two-step process. + + Copies to a temporary file first, then renames to the final path. + Skips if dest exists already (unless overwrite is True). + + Parameters + ---------- + src : str or Path + Source file path. + dest : str or Path + Destination file path. + overwrite : bool, optional + If True, overwrite existing destination file. Default False. """ src, dest = Path(src), Path(dest) if not (dest.exists() and src.samefile(dest)) and (overwrite or not dest.is_file()): @@ -125,24 +212,3 @@ def safe_copy(src, dest, overwrite=False): temp_file = dest.with_suffix(dest.suffix + ".copying") shutil.copyfile(str(src), str(temp_file)) temp_file.rename(dest) - - -def parse_sql(filepath): - """ - yield SQL statements from an SQL file - """ - delimiter = ";" - statement = [] - with Path(filepath).open("rt") as f: - for line in f: - line = line.strip() - if not line.startswith("--") and len(line) > 1: - if line.startswith("delimiter"): - delimiter = line.split()[1] - else: - statement.append(line) - if line.endswith(delimiter): - yield " ".join(statement) - statement = [] - if statement: - yield " ".join(statement) diff --git a/src/datajoint/version.py b/src/datajoint/version.py index 551aaff5a..6d2daf6d7 100644 --- a/src/datajoint/version.py +++ b/src/datajoint/version.py @@ -1,4 +1,4 @@ # version bump auto managed by Github Actions: # label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit) # manually set this version will be eventually overwritten by the above actions -__version__ = "2.0.1" +__version__ = "2.1.0a9" diff --git a/tests/conftest.py b/tests/conftest.py index dc2eb73b6..4d6adf09c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,6 +66,12 @@ def pytest_collection_modifyitems(config, items): "stores_config", "mock_stores", } + # Tests that use these fixtures are backend-parameterized + backend_fixtures = { + "backend", + "db_creds_by_backend", + "connection_by_backend", + } for item in items: # Get all fixtures this test uses (directly or indirectly) @@ -80,6 +86,13 @@ def pytest_collection_modifyitems(config, items): if fixturenames & minio_fixtures: item.add_marker(pytest.mark.requires_minio) + # Auto-mark backend-parameterized tests + if fixturenames & backend_fixtures: + # Test will run for both backends - add all backend markers + item.add_marker(pytest.mark.mysql) + item.add_marker(pytest.mark.postgresql) + item.add_marker(pytest.mark.backend_agnostic) + # ============================================================================= # Container Fixtures - Auto-start MySQL and MinIO via testcontainers @@ -101,7 +114,7 @@ def mysql_container(): from testcontainers.mysql import MySqlContainer container = MySqlContainer( - image="mysql:8.0", + image="datajoint/mysql:8.0", # Use datajoint image which has SSL configured username="root", password="password", dbname="test", @@ -118,6 +131,35 @@ def mysql_container(): logger.info("MySQL container stopped") +@pytest.fixture(scope="session") +def postgres_container(): + """Start PostgreSQL container for the test session (or use external).""" + if USE_EXTERNAL_CONTAINERS: + # Use external container - return None, credentials come from env + logger.info("Using external PostgreSQL container") + yield None + return + + from testcontainers.postgres import PostgresContainer + + container = PostgresContainer( + image="postgres:15", + username="postgres", + password="password", + dbname="test", + ) + container.start() + + host = container.get_container_host_ip() + port = container.get_exposed_port(5432) + logger.info(f"PostgreSQL container started at {host}:{port}") + + yield container + + container.stop() + logger.info("PostgreSQL container stopped") + + @pytest.fixture(scope="session") def minio_container(): """Start MinIO container for the test session (or use external).""" @@ -225,6 +267,107 @@ def s3_creds(minio_container) -> Dict: ) +# ============================================================================= +# Backend-Parameterized Fixtures +# ============================================================================= + + +@pytest.fixture(scope="session", params=["mysql", "postgresql"]) +def backend(request): + """Parameterize tests to run against both backends.""" + return request.param + + +@pytest.fixture(scope="session") +def db_creds_by_backend(backend, mysql_container, postgres_container): + """Get root database credentials for the specified backend.""" + if backend == "mysql": + if mysql_container is not None: + host = mysql_container.get_container_host_ip() + port = mysql_container.get_exposed_port(3306) + return { + "backend": "mysql", + "host": f"{host}:{port}", + "user": "root", + "password": "password", + } + else: + # External MySQL container + host = os.environ.get("DJ_HOST", "localhost") + port = os.environ.get("DJ_PORT", "3306") + return { + "backend": "mysql", + "host": f"{host}:{port}" if port else host, + "user": os.environ.get("DJ_USER", "root"), + "password": os.environ.get("DJ_PASS", "password"), + } + + elif backend == "postgresql": + if postgres_container is not None: + host = postgres_container.get_container_host_ip() + port = postgres_container.get_exposed_port(5432) + return { + "backend": "postgresql", + "host": f"{host}:{port}", + "user": "postgres", + "password": "password", + } + else: + # External PostgreSQL container + host = os.environ.get("DJ_PG_HOST", "localhost") + port = os.environ.get("DJ_PG_PORT", "5432") + return { + "backend": "postgresql", + "host": f"{host}:{port}" if port else host, + "user": os.environ.get("DJ_PG_USER", "postgres"), + "password": os.environ.get("DJ_PG_PASS", "password"), + } + + +@pytest.fixture(scope="function") +def connection_by_backend(db_creds_by_backend): + """Create connection for the specified backend. + + This fixture is function-scoped to ensure database.backend config + is restored after each test, preventing config pollution between tests. + """ + # Save original config to restore after tests + original_backend = dj.config.get("database.backend", "mysql") + original_host = dj.config.get("database.host") + original_port = dj.config.get("database.port") + + # Configure backend + dj.config["database.backend"] = db_creds_by_backend["backend"] + + # Parse host:port + host_port = db_creds_by_backend["host"] + if ":" in host_port: + host, port = host_port.rsplit(":", 1) + else: + host = host_port + port = "3306" if db_creds_by_backend["backend"] == "mysql" else "5432" + + dj.config["database.host"] = host + dj.config["database.port"] = int(port) + dj.config["safemode"] = False + + connection = dj.Connection( + host=host_port, + user=db_creds_by_backend["user"], + password=db_creds_by_backend["password"], + ) + + yield connection + + # Restore original config + connection.close() + dj.config["database.backend"] = original_backend + if original_host is not None: + dj.config["database.host"] = original_host + if original_port is not None: + dj.config["database.port"] = original_port + + # ============================================================================= # DataJoint Configuration # ============================================================================= diff --git a/tests/integration/test_cascade_delete.py b/tests/integration/test_cascade_delete.py new file mode 100644 index 000000000..caf5f331b --- /dev/null +++ b/tests/integration/test_cascade_delete.py @@ -0,0 +1,190 @@ +""" +Integration tests for cascade delete on multiple backends. +""" + +import pytest + +import datajoint as dj + + +@pytest.fixture(scope="function") +def schema_by_backend(connection_by_backend, db_creds_by_backend, request): + """Create a schema for cascade delete tests.""" + backend = db_creds_by_backend["backend"] + # Use unique schema name for each test + import time + + test_id = str(int(time.time() * 1000))[-8:] # Last 8 digits of timestamp + schema_name = f"djtest_cascade_{backend}_{test_id}"[:64] # Limit length + + # Drop schema if exists (cleanup from any previous failed runs) + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass # Ignore errors during cleanup + + # Create fresh schema + schema = dj.Schema(schema_name, connection=connection_by_backend) + + yield schema + + # Cleanup after test + if connection_by_backend.is_connected: + try: + connection_by_backend.query( + f"DROP DATABASE IF EXISTS {connection_by_backend.adapter.quote_identifier(schema_name)}" + ) + except Exception: + pass # Ignore errors during cleanup + + +def test_simple_cascade_delete(schema_by_backend): + """Test basic cascade delete with foreign keys.""" + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + parent_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int + --- + data : varchar(255) + """ + + # Insert test data + Parent.insert1((1, "Parent1")) + Parent.insert1((2, "Parent2")) + Child.insert1((1, 1, "Child1-1")) + Child.insert1((1, 2, "Child1-2")) + Child.insert1((2, 1, "Child2-1")) + + assert len(Parent()) == 2 + assert len(Child()) == 3 + + # Delete parent with cascade + (Parent & {"parent_id": 1}).delete() + + # Check cascade worked + assert len(Parent()) == 1 + assert len(Child()) == 1 + + # Verify remaining data (using to_dicts for DJ 2.0) + remaining = Child().to_dicts() + assert len(remaining) == 1 + assert remaining[0]["parent_id"] == 2 + assert remaining[0]["child_id"] == 1 + assert remaining[0]["data"] == "Child2-1" + + +def test_multi_level_cascade_delete(schema_by_backend): + """Test cascade delete through multiple levels of foreign keys.""" + + @schema_by_backend + class GrandParent(dj.Manual): + definition = """ + gp_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Parent(dj.Manual): + definition = """ + -> GrandParent + parent_id : int + --- + name : varchar(255) + """ + + @schema_by_backend + class Child(dj.Manual): + definition = """ + -> Parent + child_id : int + --- + data : varchar(255) + """ + + # Insert test data + GrandParent.insert1((1, "GP1")) + Parent.insert1((1, 1, "P1")) + Parent.insert1((1, 2, "P2")) + Child.insert1((1, 1, 1, "C1")) + Child.insert1((1, 1, 2, "C2")) + Child.insert1((1, 2, 1, "C3")) + + assert len(GrandParent()) == 1 + assert len(Parent()) == 2 + assert len(Child()) == 3 + + # Delete grandparent - should cascade through parent to child + (GrandParent & {"gp_id": 1}).delete() + + # Check everything is deleted + assert len(GrandParent()) == 0 + assert len(Parent()) == 0 + assert len(Child()) == 0 + + # Verify all tables are empty + assert len(GrandParent().to_dicts()) == 0 + assert len(Parent().to_dicts()) == 0 + assert len(Child().to_dicts()) == 0 + + +def test_cascade_delete_with_renamed_attrs(schema_by_backend): + """Test cascade delete when foreign key renames attributes.""" + + @schema_by_backend + class Animal(dj.Manual): + definition = """ + animal_id : int + --- + species : varchar(255) + """ + + @schema_by_backend + class Observation(dj.Manual): + definition = """ + obs_id : int + --- + -> Animal.proj(subject_id='animal_id') + measurement : float + """ + + # Insert test data + Animal.insert1((1, "Mouse")) + Animal.insert1((2, "Rat")) + Observation.insert1((1, 1, 10.5)) + Observation.insert1((2, 1, 11.2)) + Observation.insert1((3, 2, 15.3)) + + assert len(Animal()) == 2 + assert len(Observation()) == 3 + + # Delete animal - should cascade to observations + (Animal & {"animal_id": 1}).delete() + + # Check cascade worked + assert len(Animal()) == 1 + assert len(Observation()) == 1 + + # Verify remaining data + remaining_animals = Animal().to_dicts() + assert len(remaining_animals) == 1 + assert remaining_animals[0]["animal_id"] == 2 + + remaining_obs = Observation().to_dicts() + assert len(remaining_obs) == 1 + assert remaining_obs[0]["obs_id"] == 3 + assert remaining_obs[0]["subject_id"] == 2 + assert remaining_obs[0]["measurement"] == 15.3 diff --git a/tests/integration/test_declare.py b/tests/integration/test_declare.py index d82f9e5cc..2379f1a9e 100644 --- a/tests/integration/test_declare.py +++ b/tests/integration/test_declare.py @@ -44,27 +44,30 @@ def test_describe(schema_any): """real_definition should match original definition""" rel = Experiment() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_describe_indexes(schema_any): """real_definition should match original definition""" rel = IndexRich() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_describe_dependencies(schema_any): """real_definition should match original definition""" rel = ThingC() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_part(schema_any): @@ -368,3 +371,96 @@ class Table_With_Underscores(dj.Manual): schema_any(Table_With_Underscores) # Verify the table was created successfully assert Table_With_Underscores.is_declared + + +class TestSingletonTables: + """Tests for singleton tables (empty primary keys).""" + + def test_singleton_declaration(self, schema_any): + """Singleton table creates correctly with hidden _singleton attribute.""" + + @schema_any + class Config(dj.Lookup): + definition = """ + # Global configuration + --- + setting : varchar(100) + """ + + # Access attributes first to trigger lazy loading from database + visible_attrs = Config.heading.attributes + all_attrs = Config.heading._attributes + + # Table should exist and have _singleton as hidden PK + assert "_singleton" in all_attrs + assert "_singleton" not in visible_attrs + assert Config.heading.primary_key == [] # Visible PK is empty for singleton + + def test_singleton_insert_and_fetch(self, schema_any): + """Insert and fetch work without specifying _singleton.""" + + @schema_any + class Settings(dj.Lookup): + definition = """ + --- + value : int32 + """ + + # Insert without specifying _singleton + Settings.insert1({"value": 42}) + + # Fetch should work + result = Settings.fetch1() + assert result["value"] == 42 + assert "_singleton" not in result # Hidden attribute excluded + + def test_singleton_uniqueness(self, schema_any): + """Second insert raises DuplicateError.""" + + @schema_any + class SingleValue(dj.Lookup): + definition = """ + --- + data : varchar(50) + """ + + SingleValue.insert1({"data": "first"}) + + # Second insert should fail + with pytest.raises(dj.errors.DuplicateError): + SingleValue.insert1({"data": "second"}) + + def test_singleton_with_multiple_attributes(self, schema_any): + """Singleton table with multiple secondary attributes.""" + + @schema_any + class PipelineConfig(dj.Lookup): + definition = """ + # Pipeline configuration singleton + --- + version : varchar(20) + max_workers : int32 + debug_mode : bool + """ + + PipelineConfig.insert1({"version": "1.0.0", "max_workers": 4, "debug_mode": False}) + + result = PipelineConfig.fetch1() + assert result["version"] == "1.0.0" + assert result["max_workers"] == 4 + assert result["debug_mode"] == 0 # bool stored as tinyint + + def test_singleton_describe(self, schema_any): + """Describe should show the singleton nature.""" + + @schema_any + class Metadata(dj.Lookup): + definition = """ + --- + info : varchar(255) + """ + + description = Metadata.describe() + # Description should show just the secondary attribute + assert "info" in description + # _singleton is hidden, implementation detail diff --git a/tests/integration/test_foreign_keys.py b/tests/integration/test_foreign_keys.py index 014340898..588c12cbf 100644 --- a/tests/integration/test_foreign_keys.py +++ b/tests/integration/test_foreign_keys.py @@ -31,8 +31,9 @@ def test_describe(schema_adv): """real_definition should match original definition""" for rel in (LocalSynapse, GlobalSynapse): describe = rel.describe() - s1 = declare(rel.full_table_name, rel.definition, schema_adv.context)[0].split("\n") - s2 = declare(rel.full_table_name, describe, globals())[0].split("\n") + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, schema_adv.context, adapter)[0].split("\n") + s2 = declare(rel.full_table_name, describe, globals(), adapter)[0].split("\n") for c1, c2 in zip(s1, s2): assert c1 == c2 diff --git a/tests/integration/test_json.py b/tests/integration/test_json.py index 40c8074de..97d0c73bf 100644 --- a/tests/integration/test_json.py +++ b/tests/integration/test_json.py @@ -122,9 +122,10 @@ def test_insert_update(schema_json): def test_describe(schema_json): rel = Team() context = inspect.currentframe().f_globals - s1 = declare(rel.full_table_name, rel.definition, context) - s2 = declare(rel.full_table_name, rel.describe(), context) - assert s1 == s2 + adapter = rel.connection.adapter + s1 = declare(rel.full_table_name, rel.definition, context, adapter) + s2 = declare(rel.full_table_name, rel.describe(), context, adapter) + assert s1[0] == s2[0] # Compare SQL only (declare now returns tuple) def test_restrict(schema_json): diff --git a/tests/integration/test_multi_backend.py b/tests/integration/test_multi_backend.py new file mode 100644 index 000000000..bf904e362 --- /dev/null +++ b/tests/integration/test_multi_backend.py @@ -0,0 +1,143 @@ +""" +Integration tests that verify backend-agnostic behavior. + +These tests run against both MySQL and PostgreSQL to ensure: +1. DDL generation is correct +2. SQL queries work identically +3. Data types map correctly + +To run these tests: + pytest tests/integration/test_multi_backend.py # Run against both backends + pytest -m "mysql" tests/integration/test_multi_backend.py # MySQL only + pytest -m "postgresql" tests/integration/test_multi_backend.py # PostgreSQL only +""" + +import pytest +import datajoint as dj + + +@pytest.mark.backend_agnostic +def test_simple_table_declaration(connection_by_backend, backend, prefix): + """Test that simple tables can be declared on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_simple", + connection=connection_by_backend, + ) + + @schema + class User(dj.Manual): + definition = """ + user_id : int + --- + username : varchar(255) + created_at : datetime + """ + + # Verify table exists + assert User.is_declared + + # Insert and fetch data + from datetime import datetime + + User.insert1((1, "alice", datetime(2025, 1, 1))) + data = User.fetch1() + + assert data["user_id"] == 1 + assert data["username"] == "alice" + + # Cleanup + schema.drop() + + +@pytest.mark.backend_agnostic +def test_foreign_keys(connection_by_backend, backend, prefix): + """Test foreign key declarations work on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_fk", + connection=connection_by_backend, + ) + + @schema + class Animal(dj.Manual): + definition = """ + animal_id : int + --- + name : varchar(255) + """ + + @schema + class Observation(dj.Manual): + definition = """ + -> Animal + obs_id : int + --- + notes : varchar(1000) + """ + + # Insert data + Animal.insert1((1, "Mouse")) + Observation.insert1((1, 1, "Active")) + + # Verify data was inserted + assert len(Animal()) == 1 + assert len(Observation()) == 1 + + # Cleanup + schema.drop() + + +@pytest.mark.backend_agnostic +def test_data_types(connection_by_backend, backend, prefix): + """Test that core data types work on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_types", + connection=connection_by_backend, + ) + + @schema + class TypeTest(dj.Manual): + definition = """ + id : int + --- + int_value : int + str_value : varchar(255) + float_value : float + bool_value : bool + """ + + # Insert data + TypeTest.insert1((1, 42, "test", 3.14, True)) + + # Fetch and verify + data = (TypeTest & {"id": 1}).fetch1() + assert data["int_value"] == 42 + assert data["str_value"] == "test" + assert abs(data["float_value"] - 3.14) < 0.001 + assert data["bool_value"] == 1 # MySQL stores as tinyint(1) + + # Cleanup + schema.drop() + + +@pytest.mark.backend_agnostic +def test_table_comments(connection_by_backend, backend, prefix): + """Test that table comments are preserved on both backends.""" + schema = dj.Schema( + f"{prefix}_multi_backend_{backend}_comments", + connection=connection_by_backend, + ) + + @schema + class Commented(dj.Manual): + definition = """ + # This is a test table for backend testing + id : int # primary key + --- + value : varchar(255) # some value + """ + + # Verify table was created + assert Commented.is_declared + + # Cleanup + schema.drop() diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 6fcaffc6d..ef621765d 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -62,7 +62,7 @@ def test_schema_size_on_disk(schema_any): def test_schema_list(schema_any): - schemas = dj.list_schemas() + schemas = dj.list_schemas(connection=schema_any.connection) assert schema_any.database in schemas diff --git a/tests/integration/test_tls.py b/tests/integration/test_tls.py index e46825227..19ed087b7 100644 --- a/tests/integration/test_tls.py +++ b/tests/integration/test_tls.py @@ -1,20 +1,51 @@ +import logging +import os + import pytest from pymysql.err import OperationalError import datajoint as dj +# SSL tests require docker-compose with datajoint/mysql image (has SSL configured) +# Testcontainers with official mysql image doesn't have SSL certificates +requires_ssl = pytest.mark.skipif( + os.environ.get("DJ_USE_EXTERNAL_CONTAINERS", "").lower() not in ("1", "true", "yes"), + reason="SSL tests require external containers (docker-compose) with SSL configured", +) + + +@requires_ssl +def test_explicit_ssl_connection(db_creds_test, connection_test): + """When use_tls=True is specified, SSL must be active.""" + result = dj.conn(use_tls=True, reset=True, **db_creds_test).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] + assert len(result) > 0, "SSL should be active when use_tls=True" + + +@requires_ssl +def test_ssl_auto_detect(db_creds_test, connection_test, caplog): + """When use_tls is not specified, SSL is preferred but fallback is allowed with warning.""" + with caplog.at_level(logging.WARNING): + conn = dj.conn(reset=True, **db_creds_test) + result = conn.query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] -def test_secure_connection(db_creds_test, connection_test): - result = dj.conn(reset=True, **db_creds_test).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] - assert len(result) > 0 + if len(result) > 0: + # SSL connected successfully + assert "SSL connection failed" not in caplog.text + else: + # SSL failed and fell back - warning should be logged + assert "SSL connection failed" in caplog.text + assert "Falling back to non-SSL" in caplog.text def test_insecure_connection(db_creds_test, connection_test): + """When use_tls=False, SSL should not be used.""" result = dj.conn(use_tls=False, reset=True, **db_creds_test).query("SHOW STATUS LIKE 'Ssl_cipher';").fetchone()[1] assert result == "" +@requires_ssl def test_reject_insecure(db_creds_test, connection_test): + """Users with REQUIRE SSL cannot connect without SSL.""" with pytest.raises(OperationalError): dj.conn( db_creds_test["host"], diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 000000000..b278dd1d4 --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import importlib.metadata + +import datajoint as m + + +def test_version(): + assert importlib.metadata.version("datajoint") == m.__version__ diff --git a/tests/unit/test_adapters.py b/tests/unit/test_adapters.py new file mode 100644 index 000000000..edbff9d52 --- /dev/null +++ b/tests/unit/test_adapters.py @@ -0,0 +1,544 @@ +""" +Unit tests for database adapters. + +Tests adapter functionality without requiring actual database connections. +""" + +import pytest + +from datajoint.adapters import DatabaseAdapter, MySQLAdapter, PostgreSQLAdapter, get_adapter + + +class TestAdapterRegistry: + """Test adapter registry and factory function.""" + + def test_get_adapter_mysql(self): + """Test getting MySQL adapter.""" + adapter = get_adapter("mysql") + assert isinstance(adapter, MySQLAdapter) + assert isinstance(adapter, DatabaseAdapter) + + def test_get_adapter_postgresql(self): + """Test getting PostgreSQL adapter.""" + pytest.importorskip("psycopg2") + adapter = get_adapter("postgresql") + assert isinstance(adapter, PostgreSQLAdapter) + assert isinstance(adapter, DatabaseAdapter) + + def test_get_adapter_postgres_alias(self): + """Test 'postgres' alias for PostgreSQL.""" + pytest.importorskip("psycopg2") + adapter = get_adapter("postgres") + assert isinstance(adapter, PostgreSQLAdapter) + + def test_get_adapter_case_insensitive(self): + """Test case-insensitive backend names.""" + assert isinstance(get_adapter("MySQL"), MySQLAdapter) + # Only test PostgreSQL if psycopg2 is available + try: + pytest.importorskip("psycopg2") + assert isinstance(get_adapter("POSTGRESQL"), PostgreSQLAdapter) + assert isinstance(get_adapter("PoStGrEs"), PostgreSQLAdapter) + except pytest.skip.Exception: + pass # Skip PostgreSQL tests if psycopg2 not available + + def test_get_adapter_invalid(self): + """Test error on invalid backend name.""" + with pytest.raises(ValueError, match="Unknown database backend"): + get_adapter("sqlite") + + +class TestMySQLAdapter: + """Test MySQL adapter implementation.""" + + @pytest.fixture + def adapter(self): + """MySQL adapter instance.""" + return MySQLAdapter() + + def test_default_port(self, adapter): + """Test MySQL default port is 3306.""" + assert adapter.default_port == 3306 + + def test_parameter_placeholder(self, adapter): + """Test MySQL parameter placeholder is %s.""" + assert adapter.parameter_placeholder == "%s" + + def test_quote_identifier(self, adapter): + """Test identifier quoting with backticks.""" + assert adapter.quote_identifier("table_name") == "`table_name`" + assert adapter.quote_identifier("my_column") == "`my_column`" + + def test_quote_string(self, adapter): + """Test string literal quoting.""" + assert "test" in adapter.quote_string("test") + # Should handle escaping + result = adapter.quote_string("It's a test") + assert "It" in result + + def test_core_type_to_sql_simple(self, adapter): + """Test core type mapping for simple types.""" + assert adapter.core_type_to_sql("int64") == "bigint" + assert adapter.core_type_to_sql("int32") == "int" + assert adapter.core_type_to_sql("int16") == "smallint" + assert adapter.core_type_to_sql("int8") == "tinyint" + assert adapter.core_type_to_sql("float32") == "float" + assert adapter.core_type_to_sql("float64") == "double" + assert adapter.core_type_to_sql("bool") == "tinyint" + assert adapter.core_type_to_sql("uuid") == "binary(16)" + assert adapter.core_type_to_sql("bytes") == "longblob" + assert adapter.core_type_to_sql("json") == "json" + assert adapter.core_type_to_sql("date") == "date" + + def test_core_type_to_sql_parametrized(self, adapter): + """Test core type mapping for parametrized types.""" + assert adapter.core_type_to_sql("datetime") == "datetime" + assert adapter.core_type_to_sql("datetime(3)") == "datetime(3)" + assert adapter.core_type_to_sql("char(10)") == "char(10)" + assert adapter.core_type_to_sql("varchar(255)") == "varchar(255)" + assert adapter.core_type_to_sql("decimal(10,2)") == "decimal(10,2)" + assert adapter.core_type_to_sql("enum('a','b','c')") == "enum('a','b','c')" + + def test_core_type_to_sql_invalid(self, adapter): + """Test error on invalid core type.""" + with pytest.raises(ValueError, match="Unknown core type"): + adapter.core_type_to_sql("invalid_type") + + def test_sql_type_to_core(self, adapter): + """Test reverse type mapping.""" + assert adapter.sql_type_to_core("bigint") == "int64" + assert adapter.sql_type_to_core("int") == "int32" + assert adapter.sql_type_to_core("float") == "float32" + assert adapter.sql_type_to_core("double") == "float64" + assert adapter.sql_type_to_core("longblob") == "bytes" + assert adapter.sql_type_to_core("datetime(3)") == "datetime(3)" + # Unmappable types return None + assert adapter.sql_type_to_core("mediumint") is None + + def test_create_schema_sql(self, adapter): + """Test CREATE DATABASE statement.""" + sql = adapter.create_schema_sql("test_db") + assert sql == "CREATE DATABASE `test_db`" + + def test_drop_schema_sql(self, adapter): + """Test DROP DATABASE statement.""" + sql = adapter.drop_schema_sql("test_db") + assert "DROP DATABASE" in sql + assert "IF EXISTS" in sql + assert "`test_db`" in sql + + def test_insert_sql_basic(self, adapter): + """Test basic INSERT statement.""" + sql = adapter.insert_sql("users", ["id", "name"]) + assert sql == "INSERT INTO users (`id`, `name`) VALUES (%s, %s)" + + def test_insert_sql_ignore(self, adapter): + """Test INSERT IGNORE statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="ignore") + assert "INSERT IGNORE" in sql + + def test_insert_sql_replace(self, adapter): + """Test REPLACE INTO statement.""" + sql = adapter.insert_sql("users", ["id"], on_duplicate="replace") + assert "REPLACE INTO" in sql + + def test_insert_sql_update(self, adapter): + """Test INSERT ... ON DUPLICATE KEY UPDATE statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="update") + assert "INSERT INTO" in sql + assert "ON DUPLICATE KEY UPDATE" in sql + + def test_update_sql(self, adapter): + """Test UPDATE statement.""" + sql = adapter.update_sql("users", ["name"], ["id"]) + assert "UPDATE users SET" in sql + assert "`name` = %s" in sql + assert "WHERE" in sql + assert "`id` = %s" in sql + + def test_delete_sql(self, adapter): + """Test DELETE statement.""" + sql = adapter.delete_sql("users") + assert sql == "DELETE FROM users" + + def test_current_timestamp_expr(self, adapter): + """Test CURRENT_TIMESTAMP expression.""" + assert adapter.current_timestamp_expr() == "CURRENT_TIMESTAMP" + assert adapter.current_timestamp_expr(3) == "CURRENT_TIMESTAMP(3)" + + def test_interval_expr(self, adapter): + """Test INTERVAL expression.""" + assert adapter.interval_expr(5, "second") == "INTERVAL 5 SECOND" + assert adapter.interval_expr(10, "minute") == "INTERVAL 10 MINUTE" + + def test_json_path_expr(self, adapter): + """Test JSON path extraction.""" + assert adapter.json_path_expr("data", "field") == "json_value(`data`, _utf8mb4'$.field')" + assert adapter.json_path_expr("record", "nested") == "json_value(`record`, _utf8mb4'$.nested')" + + def test_json_path_expr_with_return_type(self, adapter): + """Test JSON path extraction with return type.""" + result = adapter.json_path_expr("data", "value", "decimal(10,2)") + assert result == "json_value(`data`, _utf8mb4'$.value' returning decimal(10,2))" + + def test_transaction_sql(self, adapter): + """Test transaction statements.""" + assert "START TRANSACTION" in adapter.start_transaction_sql() + assert adapter.commit_sql() == "COMMIT" + assert adapter.rollback_sql() == "ROLLBACK" + + def test_validate_native_type(self, adapter): + """Test native type validation.""" + assert adapter.validate_native_type("int") + assert adapter.validate_native_type("bigint") + assert adapter.validate_native_type("varchar(255)") + assert adapter.validate_native_type("text") + assert adapter.validate_native_type("json") + assert not adapter.validate_native_type("invalid_type") + + +class TestPostgreSQLAdapter: + """Test PostgreSQL adapter implementation.""" + + @pytest.fixture + def adapter(self): + """PostgreSQL adapter instance.""" + # Skip if psycopg2 not installed + pytest.importorskip("psycopg2") + return PostgreSQLAdapter() + + def test_default_port(self, adapter): + """Test PostgreSQL default port is 5432.""" + assert adapter.default_port == 5432 + + def test_parameter_placeholder(self, adapter): + """Test PostgreSQL parameter placeholder is %s.""" + assert adapter.parameter_placeholder == "%s" + + def test_quote_identifier(self, adapter): + """Test identifier quoting with double quotes.""" + assert adapter.quote_identifier("table_name") == '"table_name"' + assert adapter.quote_identifier("my_column") == '"my_column"' + + def test_quote_string(self, adapter): + """Test string literal quoting.""" + assert adapter.quote_string("test") == "'test'" + # PostgreSQL doubles single quotes for escaping + assert adapter.quote_string("It's a test") == "'It''s a test'" + + def test_core_type_to_sql_simple(self, adapter): + """Test core type mapping for simple types.""" + assert adapter.core_type_to_sql("int64") == "bigint" + assert adapter.core_type_to_sql("int32") == "integer" + assert adapter.core_type_to_sql("int16") == "smallint" + assert adapter.core_type_to_sql("int8") == "smallint" # No tinyint in PostgreSQL + assert adapter.core_type_to_sql("float32") == "real" + assert adapter.core_type_to_sql("float64") == "double precision" + assert adapter.core_type_to_sql("bool") == "boolean" + assert adapter.core_type_to_sql("uuid") == "uuid" + assert adapter.core_type_to_sql("bytes") == "bytea" + assert adapter.core_type_to_sql("json") == "jsonb" + assert adapter.core_type_to_sql("date") == "date" + + def test_core_type_to_sql_parametrized(self, adapter): + """Test core type mapping for parametrized types.""" + assert adapter.core_type_to_sql("datetime") == "timestamp" + assert adapter.core_type_to_sql("datetime(3)") == "timestamp(3)" + assert adapter.core_type_to_sql("char(10)") == "char(10)" + assert adapter.core_type_to_sql("varchar(255)") == "varchar(255)" + assert adapter.core_type_to_sql("decimal(10,2)") == "numeric(10,2)" + + def test_sql_type_to_core(self, adapter): + """Test reverse type mapping.""" + assert adapter.sql_type_to_core("bigint") == "int64" + assert adapter.sql_type_to_core("integer") == "int32" + assert adapter.sql_type_to_core("real") == "float32" + assert adapter.sql_type_to_core("double precision") == "float64" + assert adapter.sql_type_to_core("boolean") == "bool" + assert adapter.sql_type_to_core("uuid") == "uuid" + assert adapter.sql_type_to_core("bytea") == "bytes" + assert adapter.sql_type_to_core("jsonb") == "json" + assert adapter.sql_type_to_core("timestamp") == "datetime" + assert adapter.sql_type_to_core("timestamp(3)") == "datetime(3)" + assert adapter.sql_type_to_core("numeric(10,2)") == "decimal(10,2)" + + def test_create_schema_sql(self, adapter): + """Test CREATE SCHEMA statement.""" + sql = adapter.create_schema_sql("test_schema") + assert sql == 'CREATE SCHEMA "test_schema"' + + def test_drop_schema_sql(self, adapter): + """Test DROP SCHEMA statement.""" + sql = adapter.drop_schema_sql("test_schema") + assert "DROP SCHEMA" in sql + assert "IF EXISTS" in sql + assert '"test_schema"' in sql + assert "CASCADE" in sql + + def test_insert_sql_basic(self, adapter): + """Test basic INSERT statement.""" + sql = adapter.insert_sql("users", ["id", "name"]) + assert sql == 'INSERT INTO users ("id", "name") VALUES (%s, %s)' + + def test_insert_sql_ignore(self, adapter): + """Test INSERT ... ON CONFLICT DO NOTHING statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="ignore") + assert "INSERT INTO" in sql + assert "ON CONFLICT DO NOTHING" in sql + + def test_insert_sql_update(self, adapter): + """Test INSERT ... ON CONFLICT DO UPDATE statement.""" + sql = adapter.insert_sql("users", ["id", "name"], on_duplicate="update") + assert "INSERT INTO" in sql + assert "ON CONFLICT DO UPDATE" in sql + assert "EXCLUDED" in sql + + def test_update_sql(self, adapter): + """Test UPDATE statement.""" + sql = adapter.update_sql("users", ["name"], ["id"]) + assert "UPDATE users SET" in sql + assert '"name" = %s' in sql + assert "WHERE" in sql + assert '"id" = %s' in sql + + def test_delete_sql(self, adapter): + """Test DELETE statement.""" + sql = adapter.delete_sql("users") + assert sql == "DELETE FROM users" + + def test_current_timestamp_expr(self, adapter): + """Test CURRENT_TIMESTAMP expression.""" + assert adapter.current_timestamp_expr() == "CURRENT_TIMESTAMP" + assert adapter.current_timestamp_expr(3) == "CURRENT_TIMESTAMP(3)" + + def test_interval_expr(self, adapter): + """Test INTERVAL expression with PostgreSQL syntax.""" + assert adapter.interval_expr(5, "second") == "INTERVAL '5 seconds'" + assert adapter.interval_expr(10, "minute") == "INTERVAL '10 minutes'" + + def test_json_path_expr(self, adapter): + """Test JSON path extraction for PostgreSQL.""" + assert adapter.json_path_expr("data", "field") == "jsonb_extract_path_text(\"data\", 'field')" + assert adapter.json_path_expr("record", "name") == "jsonb_extract_path_text(\"record\", 'name')" + + def test_json_path_expr_nested(self, adapter): + """Test JSON path extraction with nested paths.""" + result = adapter.json_path_expr("data", "nested.field") + assert result == "jsonb_extract_path_text(\"data\", 'nested', 'field')" + + def test_transaction_sql(self, adapter): + """Test transaction statements.""" + assert adapter.start_transaction_sql() == "BEGIN" + assert adapter.commit_sql() == "COMMIT" + assert adapter.rollback_sql() == "ROLLBACK" + + def test_validate_native_type(self, adapter): + """Test native type validation.""" + assert adapter.validate_native_type("integer") + assert adapter.validate_native_type("bigint") + assert adapter.validate_native_type("varchar") + assert adapter.validate_native_type("text") + assert adapter.validate_native_type("jsonb") + assert adapter.validate_native_type("uuid") + assert adapter.validate_native_type("boolean") + assert not adapter.validate_native_type("invalid_type") + + def test_enum_type_sql(self, adapter): + """Test PostgreSQL enum type creation.""" + sql = adapter.create_enum_type_sql("myschema", "mytable", "status", ["pending", "complete"]) + assert "CREATE TYPE" in sql + assert "myschema_mytable_status_enum" in sql + assert "AS ENUM" in sql + assert "'pending'" in sql + assert "'complete'" in sql + + def test_drop_enum_type_sql(self, adapter): + """Test PostgreSQL enum type dropping.""" + sql = adapter.drop_enum_type_sql("myschema", "mytable", "status") + assert "DROP TYPE" in sql + assert "IF EXISTS" in sql + assert "myschema_mytable_status_enum" in sql + assert "CASCADE" in sql + + +class TestAdapterInterface: + """Test that adapters implement the full interface.""" + + @pytest.mark.parametrize("backend", ["mysql", "postgresql"]) + def test_adapter_implements_interface(self, backend): + """Test that adapter implements all abstract methods.""" + if backend == "postgresql": + pytest.importorskip("psycopg2") + + adapter = get_adapter(backend) + + # Check that all abstract methods are implemented (not abstract) + abstract_methods = [ + "connect", + "close", + "ping", + "get_connection_id", + "quote_identifier", + "quote_string", + "core_type_to_sql", + "sql_type_to_core", + "create_schema_sql", + "drop_schema_sql", + "create_table_sql", + "drop_table_sql", + "alter_table_sql", + "add_comment_sql", + "insert_sql", + "update_sql", + "delete_sql", + "list_schemas_sql", + "list_tables_sql", + "get_table_info_sql", + "get_columns_sql", + "get_primary_key_sql", + "get_foreign_keys_sql", + "get_indexes_sql", + "parse_column_info", + "start_transaction_sql", + "commit_sql", + "rollback_sql", + "current_timestamp_expr", + "interval_expr", + "json_path_expr", + "format_column_definition", + "table_options_clause", + "table_comment_ddl", + "column_comment_ddl", + "enum_type_ddl", + "job_metadata_columns", + "translate_error", + "validate_native_type", + ] + + for method_name in abstract_methods: + assert hasattr(adapter, method_name), f"Adapter missing method: {method_name}" + method = getattr(adapter, method_name) + assert callable(method), f"Adapter.{method_name} is not callable" + + # Check properties + assert hasattr(adapter, "default_port") + assert isinstance(adapter.default_port, int) + assert hasattr(adapter, "parameter_placeholder") + assert isinstance(adapter.parameter_placeholder, str) + + +class TestDDLMethods: + """Test DDL generation adapter methods.""" + + @pytest.fixture + def adapter(self): + """MySQL adapter instance.""" + return MySQLAdapter() + + def test_format_column_definition_mysql(self, adapter): + """Test MySQL column definition formatting.""" + result = adapter.format_column_definition("user_id", "bigint", nullable=False, comment="user ID") + assert result == '`user_id` bigint NOT NULL COMMENT "user ID"' + + # Test without comment + result = adapter.format_column_definition("name", "varchar(255)", nullable=False) + assert result == "`name` varchar(255) NOT NULL" + + # Test nullable + result = adapter.format_column_definition("description", "text", nullable=True) + assert result == "`description` text" + + # Test with default + result = adapter.format_column_definition("status", "int", default="DEFAULT 1") + assert result == "`status` int DEFAULT 1" + + def test_table_options_clause_mysql(self, adapter): + """Test MySQL table options clause.""" + result = adapter.table_options_clause("test table") + assert result == 'ENGINE=InnoDB, COMMENT "test table"' + + result = adapter.table_options_clause() + assert result == "ENGINE=InnoDB" + + def test_table_comment_ddl_mysql(self, adapter): + """Test MySQL table comment DDL (should be None).""" + result = adapter.table_comment_ddl("`schema`.`table`", "test comment") + assert result is None + + def test_column_comment_ddl_mysql(self, adapter): + """Test MySQL column comment DDL (should be None).""" + result = adapter.column_comment_ddl("`schema`.`table`", "column", "test comment") + assert result is None + + def test_enum_type_ddl_mysql(self, adapter): + """Test MySQL enum type DDL (should be None).""" + result = adapter.enum_type_ddl("status_type", ["active", "inactive"]) + assert result is None + + def test_job_metadata_columns_mysql(self, adapter): + """Test MySQL job metadata columns.""" + result = adapter.job_metadata_columns() + assert len(result) == 3 + assert "_job_start_time" in result[0] + assert "datetime(3)" in result[0] + assert "_job_duration" in result[1] + assert "float" in result[1] + assert "_job_version" in result[2] + assert "varchar(64)" in result[2] + + +class TestPostgreSQLDDLMethods: + """Test PostgreSQL-specific DDL generation methods.""" + + @pytest.fixture + def postgres_adapter(self): + """Get PostgreSQL adapter for testing.""" + pytest.importorskip("psycopg2") + return get_adapter("postgresql") + + def test_format_column_definition_postgres(self, postgres_adapter): + """Test PostgreSQL column definition formatting.""" + result = postgres_adapter.format_column_definition("user_id", "bigint", nullable=False, comment="user ID") + assert result == '"user_id" bigint NOT NULL' + + # Test without comment (comment handled separately in PostgreSQL) + result = postgres_adapter.format_column_definition("name", "varchar(255)", nullable=False) + assert result == '"name" varchar(255) NOT NULL' + + # Test nullable + result = postgres_adapter.format_column_definition("description", "text", nullable=True) + assert result == '"description" text' + + def test_table_options_clause_postgres(self, postgres_adapter): + """Test PostgreSQL table options clause (should be empty).""" + result = postgres_adapter.table_options_clause("test table") + assert result == "" + + result = postgres_adapter.table_options_clause() + assert result == "" + + def test_table_comment_ddl_postgres(self, postgres_adapter): + """Test PostgreSQL table comment DDL.""" + result = postgres_adapter.table_comment_ddl('"schema"."table"', "test comment") + assert result == 'COMMENT ON TABLE "schema"."table" IS \'test comment\'' + + def test_column_comment_ddl_postgres(self, postgres_adapter): + """Test PostgreSQL column comment DDL.""" + result = postgres_adapter.column_comment_ddl('"schema"."table"', "column", "test comment") + assert result == 'COMMENT ON COLUMN "schema"."table"."column" IS \'test comment\'' + + def test_enum_type_ddl_postgres(self, postgres_adapter): + """Test PostgreSQL enum type DDL.""" + result = postgres_adapter.enum_type_ddl("status_type", ["active", "inactive"]) + assert result == "CREATE TYPE \"status_type\" AS ENUM ('active', 'inactive')" + + def test_job_metadata_columns_postgres(self, postgres_adapter): + """Test PostgreSQL job metadata columns.""" + result = postgres_adapter.job_metadata_columns() + assert len(result) == 3 + assert "_job_start_time" in result[0] + assert "timestamp" in result[0] + assert "_job_duration" in result[1] + assert "real" in result[1] + assert "_job_version" in result[2] + assert "varchar(64)" in result[2] diff --git a/tests/unit/test_settings.py b/tests/unit/test_settings.py index 61f4439e0..af5718503 100644 --- a/tests/unit/test_settings.py +++ b/tests/unit/test_settings.py @@ -748,3 +748,123 @@ def test_similar_prefix_names_allowed(self): finally: dj.config.stores.clear() dj.config.stores.update(original_stores) + + +class TestBackendConfiguration: + """Test database backend configuration and port auto-detection.""" + + def test_backend_default(self): + """Test default backend is mysql.""" + from datajoint.settings import DatabaseSettings + + settings = DatabaseSettings() + assert settings.backend == "mysql" + assert settings.port == 3306 + + def test_backend_postgresql(self, monkeypatch): + """Test PostgreSQL backend with auto port.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 5432 + + def test_backend_explicit_port_overrides(self, monkeypatch): + """Test explicit port overrides auto-detection.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + monkeypatch.setenv("DJ_PORT", "9999") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 9999 + + def test_backend_env_var(self, monkeypatch): + """Test DJ_BACKEND environment variable.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 5432 + + def test_port_env_var_overrides_backend_default(self, monkeypatch): + """Test DJ_PORT overrides backend auto-detection.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "postgresql") + monkeypatch.setenv("DJ_PORT", "8888") + settings = DatabaseSettings() + assert settings.backend == "postgresql" + assert settings.port == 8888 + + def test_invalid_backend(self, monkeypatch): + """Test invalid backend raises validation error.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "sqlite") + with pytest.raises(ValidationError, match="Input should be 'mysql' or 'postgresql'"): + DatabaseSettings() + + def test_config_file_backend(self, tmp_path, monkeypatch): + """Test loading backend from config file.""" + import json + + from datajoint.settings import Config + + # Include port in config since auto-detection only happens during initialization + config_file = tmp_path / "test_config.json" + config_file.write_text(json.dumps({"database": {"backend": "postgresql", "host": "db.example.com", "port": 5432}})) + + # Clear env vars so file values take effect + monkeypatch.delenv("DJ_BACKEND", raising=False) + monkeypatch.delenv("DJ_HOST", raising=False) + monkeypatch.delenv("DJ_PORT", raising=False) + + cfg = Config() + cfg.load(config_file) + assert cfg.database.backend == "postgresql" + assert cfg.database.port == 5432 + assert cfg.database.host == "db.example.com" + + def test_global_config_backend(self): + """Test global config has backend configuration.""" + # Global config should have backend field with default mysql + assert hasattr(dj.config.database, "backend") + # Backend should be one of the valid values + assert dj.config.database.backend in ["mysql", "postgresql"] + # Port should be set (either 3306 or 5432 or custom) + assert isinstance(dj.config.database.port, int) + assert 1 <= dj.config.database.port <= 65535 + + def test_port_auto_detection_on_initialization(self): + """Test port auto-detects only during initialization, not on live updates.""" + from datajoint.settings import DatabaseSettings + + # Start with MySQL (default) + settings = DatabaseSettings() + assert settings.port == 3306 + + # Change backend on live config - port won't auto-update + settings.backend = "postgresql" + # Port remains at previous value (this is expected behavior) + # Users should set port explicitly when changing backend on live config + assert settings.port == 3306 # Didn't auto-update + + def test_mysql_backend_with_explicit_port(self, monkeypatch): + """Test MySQL backend with explicit non-default port.""" + from datajoint.settings import DatabaseSettings + + monkeypatch.setenv("DJ_BACKEND", "mysql") + monkeypatch.setenv("DJ_PORT", "3307") + settings = DatabaseSettings() + assert settings.backend == "mysql" + assert settings.port == 3307 + + def test_backend_field_in_env_var_mapping(self): + """Test that backend is mapped to DJ_BACKEND in ENV_VAR_MAPPING.""" + from datajoint.settings import ENV_VAR_MAPPING + + assert "database.backend" in ENV_VAR_MAPPING + assert ENV_VAR_MAPPING["database.backend"] == "DJ_BACKEND"