diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 6a14a5a35..667479cdd 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -149,7 +149,15 @@ def make_sql(self, fields=None): """ Make the SQL SELECT statement. - :param fields: used to explicitly set the select attributes + Parameters + ---------- + fields : list, optional + Used to explicitly set the select attributes. + + Returns + ------- + str + The SQL SELECT statement. """ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", @@ -172,23 +180,17 @@ def restrict(self, restriction, semantic_check=True): """ Produces a new expression with the new restriction applied. - :param restriction: a sequence or an array (treated as OR list), another QueryExpression, - an SQL condition string, or an AndList. - :param semantic_check: If True (default), use semantic matching - only match on - homologous namesakes and error on non-homologous namesakes. - If False, use natural matching on all namesakes (no lineage checking). - :return: A new QueryExpression with the restriction applied. - - rel.restrict(restriction) is equivalent to rel & restriction. - rel.restrict(Not(restriction)) is equivalent to rel - restriction + ``rel.restrict(restriction)`` is equivalent to ``rel & restriction``. + ``rel.restrict(Not(restriction))`` is equivalent to ``rel - restriction``. The primary key of the result is unaffected. - Successive restrictions are combined as logical AND: r & a & b is equivalent to r & AndList((a, b)) + Successive restrictions are combined as logical AND: ``r & a & b`` is equivalent to + ``r & AndList((a, b))``. Any QueryExpression, collection, or sequence other than an AndList are treated as OrLists - (logical disjunction of conditions) + (logical disjunction of conditions). Inverse restriction is accomplished by either using the subtraction operator or the Not class. - The expressions in each row equivalent: + The expressions in each row are equivalent: rel & True rel rel & False the empty entity set @@ -210,14 +212,31 @@ def restrict(self, restriction, semantic_check=True): rel - None rel rel - any_empty_entity_set rel - When arg is another QueryExpression, the restriction rel & arg restricts rel to elements that match at least - one element in arg (hence arg is treated as an OrList). - Conversely, rel - arg restricts rel to elements that do not match any elements in arg. - Two elements match when their common attributes have equal values or when they have no common attributes. - All shared attributes must be in the primary key of either rel or arg or both or an error will be raised. + When arg is another QueryExpression, the restriction ``rel & arg`` restricts rel to elements + that match at least one element in arg (hence arg is treated as an OrList). + Conversely, ``rel - arg`` restricts rel to elements that do not match any elements in arg. + Two elements match when their common attributes have equal values or when they have no + common attributes. + All shared attributes must be in the primary key of either rel or arg or both or an error + will be raised. + + QueryExpression.restrict is the only access point that modifies restrictions. All other + operators must ultimately call restrict(). - QueryExpression.restrict is the only access point that modifies restrictions. All other operators must - ultimately call restrict() + Parameters + ---------- + restriction : QueryExpression, AndList, str, dict, list, or array-like + A sequence or an array (treated as OR list), another QueryExpression, + an SQL condition string, or an AndList. + semantic_check : bool, optional + If True (default), use semantic matching - only match on homologous namesakes + and error on non-homologous namesakes. + If False, use natural matching on all namesakes (no lineage checking). + + Returns + ------- + QueryExpression + A new QueryExpression with the restriction applied. """ attributes = set() if isinstance(restriction, Top): @@ -264,8 +283,13 @@ def restrict_in_place(self, restriction): def __and__(self, restriction): """ Restriction operator e.g. ``q1 & q2``. - :return: a restricted copy of the input argument + See QueryExpression.restrict for more detail. + + Returns + ------- + QueryExpression + A restricted copy of the input argument. """ return self.restrict(restriction) @@ -279,16 +303,26 @@ def __xor__(self, restriction): def __sub__(self, restriction): """ Inverted restriction e.g. ``q1 - q2``. - :return: a restricted copy of the input argument + See QueryExpression.restrict for more detail. + + Returns + ------- + QueryExpression + A restricted copy of the input argument. """ return self.restrict(Not(restriction)) def __neg__(self): """ Convert between restriction and inverted restriction e.g. ``-q1``. - :return: target restriction + See QueryExpression.restrict for more detail. + + Returns + ------- + QueryExpression or Not + The target restriction. """ if isinstance(self, Not): return self.restriction @@ -311,18 +345,28 @@ def join(self, other, semantic_check=True, left=False, allow_nullable_pk=False): """ Create the joined QueryExpression. - :param other: QueryExpression to join with - :param semantic_check: If True (default), use semantic matching - only match on - homologous namesakes (same lineage) and error on non-homologous namesakes. + ``a * b`` is short for ``a.join(b)``. + + Parameters + ---------- + other : QueryExpression + QueryExpression to join with. + semantic_check : bool, optional + If True (default), use semantic matching - only match on homologous namesakes + (same lineage) and error on non-homologous namesakes. If False, use natural join on all namesakes (no lineage checking). - :param left: If True, perform a left join (retain all rows from self) - :param allow_nullable_pk: If True, bypass the left join constraint that requires - self to determine other. When bypassed, the result PK is the union of both - operands' PKs, and PK attributes from the right operand could be NULL. - Used internally by aggregation when exclude_nonmatching=False. - :return: The joined QueryExpression - - a * b is short for a.join(b) + left : bool, optional + If True, perform a left join (retain all rows from self). Default False. + allow_nullable_pk : bool, optional + If True, bypass the left join constraint that requires self to determine other. + When bypassed, the result PK is the union of both operands' PKs, and PK + attributes from the right operand could be NULL. + Used internally by aggregation when exclude_nonmatching=False. Default False. + + Returns + ------- + QueryExpression + The joined QueryExpression. """ # Joining with U is no longer supported if isinstance(other, U): @@ -414,21 +458,36 @@ def extend(self, other, semantic_check=True): extend is closer to projection—it adds new attributes to existing entities without changing which entities are in the result. - Example: - # Session determines Trial (session_id is in Trial's PK) - # But Trial does NOT determine Session (trial_num not in Session) + Examples + -------- + Session determines Trial (session_id is in Trial's PK), but Trial does NOT + determine Session (trial_num not in Session). + + Valid: extend trials with session info:: - # Valid: extend trials with session info Trial.extend(Session) # Adds 'date' from Session to each Trial - # Invalid: Session cannot extend to Trial + Invalid: Session cannot extend to Trial:: + Session.extend(Trial) # Error: trial_num not in Session - :param other: QueryExpression whose attributes will extend self - :param semantic_check: If True (default), require homologous namesakes. + Parameters + ---------- + other : QueryExpression + QueryExpression whose attributes will extend self. + semantic_check : bool, optional + If True (default), require homologous namesakes. If False, match on all namesakes without lineage checking. - :return: Extended QueryExpression with self's PK and combined attributes - :raises DataJointError: If self does not determine other + + Returns + ------- + QueryExpression + Extended QueryExpression with self's PK and combined attributes. + + Raises + ------ + DataJointError + If self does not determine other. """ return self.join(other, semantic_check=semantic_check, left=True) @@ -440,22 +499,37 @@ def proj(self, *attributes, **named_attributes): """ Projection operator. - :param attributes: attributes to be included in the result. (The primary key is already included). - :param named_attributes: new attributes computed or renamed from existing attributes. - :return: the projected expression. Primary key attributes cannot be excluded but may be renamed. - If the attribute list contains an Ellipsis ..., then all secondary attributes are included too - Prefixing an attribute name with a dash '-attr' removes the attribute from the list if present. - Keyword arguments can be used to rename attributes as in name='attr', duplicate them as in name='(attr)', or - self.proj(...) or self.proj(Ellipsis) -- include all attributes (return self) - self.proj() -- include only primary key - self.proj('attr1', 'attr2') -- include primary key and attributes attr1 and attr2 - self.proj(..., '-attr1', '-attr2') -- include all attributes except attr1 and attr2 - self.proj(name1='attr1') -- include primary key and 'attr1' renamed as name1 - self.proj('attr1', dup='(attr1)') -- include primary key and attribute attr1 twice, with the duplicate 'dup' - self.proj(k='abs(attr1)') adds the new attribute k with the value computed as an expression (SQL syntax) - from other attributes available before the projection. + If the attribute list contains an Ellipsis ``...``, then all secondary attributes + are included too. + Prefixing an attribute name with a dash ``-attr`` removes the attribute from the list + if present. + Keyword arguments can be used to rename attributes as in ``name='attr'``, duplicate + them as in ``name='(attr)'``, or compute new attributes. + + - ``self.proj(...)`` or ``self.proj(Ellipsis)`` -- include all attributes (return self) + - ``self.proj()`` -- include only primary key + - ``self.proj('attr1', 'attr2')`` -- include primary key and attributes attr1 and attr2 + - ``self.proj(..., '-attr1', '-attr2')`` -- include all attributes except attr1 and attr2 + - ``self.proj(name1='attr1')`` -- include primary key and 'attr1' renamed as name1 + - ``self.proj('attr1', dup='(attr1)')`` -- include primary key and attr1 twice, with + the duplicate 'dup' + - ``self.proj(k='abs(attr1)')`` adds the new attribute k with the value computed as an + expression (SQL syntax) from other attributes available before the projection. + Each attribute name can only be used once. + + Parameters + ---------- + *attributes : str + Attributes to be included in the result. The primary key is already included. + **named_attributes : str + New attributes computed or renamed from existing attributes. + + Returns + ------- + QueryExpression + The projected expression. """ adapter = self.connection.adapter if hasattr(self, "connection") and self.connection else None named_attributes = {k: translate_attribute(v, adapter)[1] for k, v in named_attributes.items()} @@ -556,21 +630,34 @@ def aggr(self, group, *attributes, exclude_nonmatching=False, **named_attributes """ Aggregation/grouping operation, similar to proj but with computations over a grouped relation. - By default, keeps all rows from self (like proj). Use exclude_nonmatching=True to + By default, keeps all rows from self (like proj). Use ``exclude_nonmatching=True`` to keep only rows that have matches in group. - :param group: The query expression to be aggregated. - :param exclude_nonmatching: If True, exclude rows from self that have no matching - entries in group (INNER JOIN). Default False keeps all rows (LEFT JOIN). - :param named_attributes: computations of the form new_attribute="sql expression on attributes of group" - :return: The derived query expression + Parameters + ---------- + group : QueryExpression + The query expression to be aggregated. + *attributes : str + Attributes to include in the result. + exclude_nonmatching : bool, optional + If True, exclude rows from self that have no matching entries in group + (INNER JOIN). Default False keeps all rows (LEFT JOIN). + **named_attributes : str + Computations of the form ``new_attribute="sql expression on attributes of group"``. + + Returns + ------- + QueryExpression + The derived query expression. - Example:: + Examples + -------- + Count sessions per subject (keeps all subjects, even those with 0 sessions):: - # Count sessions per subject (keeps all subjects, even those with 0 sessions) Subject.aggr(Session, n="count(*)") - # Count sessions per subject (only subjects with at least one session) + Count sessions per subject (only subjects with at least one session):: + Subject.aggr(Session, n="count(*)", exclude_nonmatching=True) """ if Ellipsis in attributes: @@ -665,12 +752,26 @@ def fetch1(self, *attrs, squeeze=False): If no attributes are specified, returns the result as a dict. If attributes are specified, returns the corresponding values as a tuple. - :param attrs: attribute names to fetch (if empty, fetch all as dict) - :param squeeze: if True, remove extra dimensions from arrays - :return: dict (no attrs) or tuple/value (with attrs) - :raises DataJointError: if not exactly one row in result + Parameters + ---------- + *attrs : str + Attribute names to fetch. If empty, fetch all as dict. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + dict or tuple or value + Dict (no attrs) or tuple/value (with attrs). + + Raises + ------ + DataJointError + If not exactly one row in result. - Examples:: + Examples + -------- + :: d = table.fetch1() # returns dict with all attributes a, b = table.fetch1('a', 'b') # returns tuple of attribute values @@ -734,17 +835,27 @@ def to_dicts(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a list of dictionaries. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: list of dictionaries, one per row - For object storage types (attachments, filepaths), files are downloaded - to config["download_path"]. Use config.override() to change:: + to ``config["download_path"]``. Use ``config.override()`` to change:: with dj.config.override(download_path="/data"): data = table.to_dicts() + + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + list[dict] + List of dictionaries, one per row. """ expr = self._apply_top(order_by, limit, offset) cursor = expr.cursor(as_dict=True) @@ -755,11 +866,21 @@ def to_pandas(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a pandas DataFrame with primary key as index. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: pandas DataFrame with primary key columns as index + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + pandas.DataFrame + DataFrame with primary key columns as index. """ dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze) df = pandas.DataFrame(dicts) @@ -771,13 +892,23 @@ def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a polars DataFrame. - Requires polars: pip install datajoint[polars] + Requires polars: ``pip install datajoint[polars]`` - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: polars DataFrame + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. + + Returns + ------- + polars.DataFrame + Polars DataFrame. """ try: import polars @@ -790,13 +921,23 @@ def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False): """ Fetch all rows as a PyArrow Table. - Requires pyarrow: pip install datajoint[arrow] + Requires pyarrow: ``pip install datajoint[arrow]`` + + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: pyarrow Table + Returns + ------- + pyarrow.Table + PyArrow Table. """ try: import pyarrow @@ -814,24 +955,39 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset If no attrs specified, returns a numpy structured array (recarray) of all columns. If attrs specified, returns a tuple of numpy arrays (one per attribute). - :param attrs: attribute names to fetch (if empty, fetch all) - :param include_key: if True and attrs specified, prepend primary keys as list of dicts - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :param squeeze: if True, remove extra dimensions from arrays - :return: numpy recarray (no attrs) or tuple of arrays (with attrs). - With include_key=True: (keys, *arrays) where keys is list[dict] + Parameters + ---------- + *attrs : str + Attribute names to fetch. If empty, fetch all. + include_key : bool, optional + If True and attrs specified, prepend primary keys as list of dicts. Default False. + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + squeeze : bool, optional + If True, remove extra dimensions from arrays. Default False. - Examples:: + Returns + ------- + np.recarray or tuple of np.ndarray + Numpy recarray (no attrs) or tuple of arrays (with attrs). + With ``include_key=True``: ``(keys, *arrays)`` where keys is ``list[dict]``. + + Examples + -------- + Fetch as structured array:: - # Fetch as structured array data = table.to_arrays() - # Fetch specific columns as separate arrays + Fetch specific columns as separate arrays:: + a, b = table.to_arrays('a', 'b') - # Fetch with primary keys for later restrictions + Fetch with primary keys for later restrictions:: + keys, a, b = table.to_arrays('a', 'b', include_key=True) # keys = [{'id': 1}, {'id': 2}, ...] # same format as table.keys() """ @@ -901,10 +1057,19 @@ def keys(self, order_by=None, limit=None, offset=None): """ Fetch primary key values as a list of dictionaries. - :param order_by: attribute(s) to order by, or "KEY"/"KEY DESC" - :param limit: maximum number of rows to return - :param offset: number of rows to skip - :return: list of dictionaries containing only primary key columns + Parameters + ---------- + order_by : str or list, optional + Attribute(s) to order by, or "KEY"/"KEY DESC". + limit : int, optional + Maximum number of rows to return. + offset : int, optional + Number of rows to skip. + + Returns + ------- + list[dict] + List of dictionaries containing only primary key columns. """ return self.proj().to_dicts(order_by=order_by, limit=limit, offset=offset) @@ -912,8 +1077,15 @@ def head(self, limit=25): """ Preview the first few entries from query expression. - :param limit: number of entries (default 25) - :return: list of dictionaries + Parameters + ---------- + limit : int, optional + Number of entries. Default 25. + + Returns + ------- + list[dict] + List of dictionaries. """ return self.to_dicts(order_by="KEY", limit=limit) @@ -921,13 +1093,27 @@ def tail(self, limit=25): """ Preview the last few entries from query expression. - :param limit: number of entries (default 25) - :return: list of dictionaries + Parameters + ---------- + limit : int, optional + Number of entries. Default 25. + + Returns + ------- + list[dict] + List of dictionaries. """ return list(reversed(self.to_dicts(order_by="KEY DESC", limit=limit))) def __len__(self): - """:return: number of elements in the result set e.g. ``len(q1)``.""" + """ + Return number of elements in the result set e.g. ``len(q1)``. + + Returns + ------- + int + Number of elements in the result set. + """ result = self.make_subquery() if self._top else copy.copy(self) has_left_join = any(is_left for is_left, _ in result._joins) @@ -950,8 +1136,14 @@ def __len__(self): def __bool__(self): """ - :return: True if the result is not empty. Equivalent to len(self) > 0 but often - faster e.g. ``bool(q1)``. + Check if the result is not empty. + + Equivalent to ``len(self) > 0`` but often faster e.g. ``bool(q1)``. + + Returns + ------- + bool + True if the result is not empty. """ return bool( self.connection.query( @@ -961,12 +1153,20 @@ def __bool__(self): def __contains__(self, item): """ - returns True if the restriction in item matches any entries in self - e.g. ``restriction in q1``. + Check if the restriction in item matches any entries in self. - :param item: any restriction - (item in query_expression) is equivalent to bool(query_expression & item) but may be - executed more efficiently. + ``(item in query_expression)`` is equivalent to ``bool(query_expression & item)`` + but may be executed more efficiently. + + Parameters + ---------- + item : any + Any restriction. + + Returns + ------- + bool + True if the restriction matches any entries e.g. ``restriction in q1``. """ return bool(self & item) # May be optimized e.g. using an EXISTS query @@ -988,8 +1188,15 @@ def cursor(self, as_dict=False): """ Execute the query and return a database cursor. - :param as_dict: if True, rows are returned as dictionaries - :return: database query cursor + Parameters + ---------- + as_dict : bool, optional + If True, rows are returned as dictionaries. Default False. + + Returns + ------- + cursor + Database query cursor. """ sql = self.make_sql() logger.debug(sql) @@ -997,20 +1204,42 @@ def cursor(self, as_dict=False): def __repr__(self): """ - returns the string representation of a QueryExpression object e.g. ``str(q1)``. + Return the string representation of a QueryExpression object e.g. ``str(q1)``. - :param self: A query expression - :type self: :class:`QueryExpression` - :rtype: str + Returns + ------- + str + String representation of the QueryExpression. """ return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview() def preview(self, limit=None, width=None): - """:return: a string of preview of the contents of the query.""" + """ + Return a string preview of the contents of the query. + + Parameters + ---------- + limit : int, optional + Maximum number of rows to preview. + width : int, optional + Maximum width of the preview output. + + Returns + ------- + str + A string preview of the contents of the query. + """ return preview(self, limit, width) def _repr_html_(self): - """:return: HTML to display table in Jupyter notebook.""" + """ + Return HTML to display table in Jupyter notebook. + + Returns + ------- + str + HTML to display table in Jupyter notebook. + """ return repr_html(self) @@ -1032,9 +1261,19 @@ def create(cls, groupby, group, keep_all_rows=False): """ Create an aggregation expression. - :param groupby: The expression to GROUP BY (determines the result's primary key) - :param group: The expression to aggregate over - :param keep_all_rows: If True, use left join to keep all rows from groupby + Parameters + ---------- + groupby : QueryExpression + The expression to GROUP BY (determines the result's primary key). + group : QueryExpression + The expression to aggregate over. + keep_all_rows : bool, optional + If True, use left join to keep all rows from groupby. Default False. + + Returns + ------- + Aggregation + The aggregation expression. """ if inspect.isclass(group) and issubclass(group, QueryExpression): group = group() # instantiate if a class @@ -1260,14 +1499,24 @@ def __sub__(self, other): def aggr(self, group, **named_attributes): """ - Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression") - has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`. + Aggregation of the type ``U('attr1','attr2').aggr(group, computation="QueryExpression")``. + + Has the primary key ``('attr1','attr2')`` and performs aggregation computations for all + matching elements of ``group``. - Note: exclude_nonmatching is always True for dj.U (cannot keep all rows from infinite set). + Note: ``exclude_nonmatching`` is always True for dj.U (cannot keep all rows from infinite set). - :param group: The query expression to be aggregated. - :param named_attributes: computations of the form new_attribute="sql expression on attributes of group" - :return: The derived query expression + Parameters + ---------- + group : QueryExpression + The query expression to be aggregated. + **named_attributes : str + Computations of the form ``new_attribute="sql expression on attributes of group"``. + + Returns + ------- + QueryExpression + The derived query expression. """ if named_attributes.pop("exclude_nonmatching", True) is False: raise DataJointError("Cannot set exclude_nonmatching=False when aggregating on a universal set.") @@ -1298,9 +1547,19 @@ def aggr(self, group, **named_attributes): def _flatten_attribute_list(primary_key, attrs): """ - :param primary_key: list of attributes in primary key - :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" - :return: generator of attributes where "KEY" is replaced with its component attributes + Flatten an attribute list, replacing "KEY" with primary key attributes. + + Parameters + ---------- + primary_key : list + List of attributes in primary key. + attrs : list + List of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC". + + Yields + ------ + str + Attributes where "KEY" is replaced with its component attributes. """ for a in attrs: if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): diff --git a/src/datajoint/heading.py b/src/datajoint/heading.py index 973dd3e9b..4d7f0c62a 100644 --- a/src/datajoint/heading.py +++ b/src/datajoint/heading.py @@ -582,13 +582,24 @@ def _init_from_database(self) -> None: def select(self, select_list, rename_map=None, compute_map=None): """ - derive a new heading by selecting, renaming, or computing attributes. - In relational algebra these operators are known as project, rename, and extend. + Derive a new heading by selecting, renaming, or computing attributes. - :param select_list: the full list of existing attributes to include - :param rename_map: dictionary of renamed attributes: keys=new names, values=old names - :param compute_map: a direction of computed attributes + In relational algebra these operators are known as project, rename, and extend. This low-level method performs no error checking. + + Parameters + ---------- + select_list : list + The full list of existing attributes to include. + rename_map : dict, optional + Dictionary of renamed attributes: keys=new names, values=old names. + compute_map : dict, optional + A dictionary of computed attributes. + + Returns + ------- + Heading + New heading with selected, renamed, and computed attributes. """ rename_map = rename_map or {} compute_map = compute_map or {} @@ -631,16 +642,27 @@ def join(self, other, nullable_pk=False): Join two headings into a new one. The primary key of the result depends on functional dependencies: - - A → B: PK = PK(A), A's attributes first - - B → A (not A → B): PK = PK(B), B's attributes first - - Both: PK = PK(A), left operand takes precedence - - Neither: PK = PK(A) ∪ PK(B), A's PK first then B's new PK attrs - :param nullable_pk: If True, skip PK optimization and use combined PK from both - operands. Used for left joins that bypass the A → B constraint, where the - right operand's PK attributes could be NULL. + - A -> B: PK = PK(A), A's attributes first + - B -> A (not A -> B): PK = PK(B), B's attributes first + - Both: PK = PK(A), left operand takes precedence + - Neither: PK = PK(A) | PK(B), A's PK first then B's new PK attrs It assumes that self and other are headings that share no common dependent attributes. + + Parameters + ---------- + other : Heading + The other heading to join with. + nullable_pk : bool, optional + If True, skip PK optimization and use combined PK from both + operands. Used for left joins that bypass the A -> B constraint, where the + right operand's PK attributes could be NULL. Default False. + + Returns + ------- + Heading + New heading resulting from the join. """ if nullable_pk: a_determines_b = b_determines_a = False diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 63220c45a..b32efcf57 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -129,7 +129,10 @@ def declare(self, context=None): """ Declare the table in the schema based on self.definition. - :param context: the context for foreign key resolution. If None, foreign keys are + Parameters + ---------- + context : dict, optional + The context for foreign key resolution. If None, foreign keys are not allowed. """ if self.connection.in_transaction: @@ -181,8 +184,12 @@ def _declare_check(self, primary_key, fk_attribute_map): Called before the table is created in the database. Override this method to add validation logic (e.g., AutoPopulate validates FK-only primary keys). - :param primary_key: list of primary key attribute names - :param fk_attribute_map: dict mapping child_attr -> (parent_table, parent_attr) + Parameters + ---------- + primary_key : list + List of primary key attribute names. + fk_attribute_map : dict + Dict mapping child_attr -> (parent_table, parent_attr). """ pass # Default: no validation @@ -194,8 +201,12 @@ def _populate_lineage(self, primary_key, fk_attribute_map): - All FK attributes (traced to their origin) - Native primary key attributes (lineage = self) - :param primary_key: list of primary key attribute names - :param fk_attribute_map: dict mapping child_attr -> (parent_table, parent_attr) + Parameters + ---------- + primary_key : list + List of primary key attribute names. + fk_attribute_map : dict + Dict mapping child_attr -> (parent_table, parent_attr). """ from .lineage import ( ensure_lineage_table, @@ -279,26 +290,51 @@ def alter(self, prompt=True, context=None): def from_clause(self): """ - :return: the FROM clause of SQL SELECT statements. + Return the FROM clause of SQL SELECT statements. + + Returns + ------- + str + The full table name for use in SQL FROM clauses. """ return self.full_table_name def get_select_fields(self, select_fields=None): """ - :return: the selected attributes from the SQL SELECT statement. + Return the selected attributes from the SQL SELECT statement. + + Parameters + ---------- + select_fields : list, optional + List of attribute names to select. If None, selects all attributes. + + Returns + ------- + str + The SQL field selection string. """ return "*" if select_fields is None else self.heading.project(select_fields).as_sql def parents(self, primary=None, as_objects=False, foreign_key_info=False): """ - - :param primary: if None, then all parents are returned. If True, then only foreign keys composed of - primary key attributes are considered. If False, return foreign keys including at least one - secondary attribute. - :param as_objects: if False, return table names. If True, return table objects. - :param foreign_key_info: if True, each element in result also includes foreign key info. - :return: list of parents as table names or table objects - with (optional) foreign key information. + Return the list of parent tables. + + Parameters + ---------- + primary : bool, optional + If None, then all parents are returned. If True, then only foreign keys + composed of primary key attributes are considered. If False, return + foreign keys including at least one secondary attribute. + as_objects : bool, optional + If False, return table names. If True, return table objects. + foreign_key_info : bool, optional + If True, each element in result also includes foreign key info. + + Returns + ------- + list + List of parents as table names or table objects with (optional) foreign + key information. """ get_edge = self.connection.dependencies.parents nodes = [ @@ -313,13 +349,24 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False): def children(self, primary=None, as_objects=False, foreign_key_info=False): """ - :param primary: if None, then all children are returned. If True, then only foreign keys composed of - primary key attributes are considered. If False, return foreign keys including at least one - secondary attribute. - :param as_objects: if False, return table names. If True, return table objects. - :param foreign_key_info: if True, each element in result also includes foreign key info. - :return: list of children as table names or table objects - with (optional) foreign key information. + Return the list of child tables. + + Parameters + ---------- + primary : bool, optional + If None, then all children are returned. If True, then only foreign keys + composed of primary key attributes are considered. If False, return + foreign keys including at least one secondary attribute. + as_objects : bool, optional + If False, return table names. If True, return table objects. + foreign_key_info : bool, optional + If True, each element in result also includes foreign key info. + + Returns + ------- + list + List of children as table names or table objects with (optional) foreign + key information. """ get_edge = self.connection.dependencies.children nodes = [ @@ -334,8 +381,18 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False): def descendants(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. - :return: list of tables descendants in topological order. + Return list of descendant tables in topological order. + + Parameters + ---------- + as_objects : bool, optional + If False (default), return a list of table names. If True, return a + list of table objects. + + Returns + ------- + list + List of descendant tables in topological order. """ return [ FreeTable(self.connection, node) if as_objects else node @@ -345,8 +402,18 @@ def descendants(self, as_objects=False): def ancestors(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. - :return: list of tables ancestors in topological order. + Return list of ancestor tables in topological order. + + Parameters + ---------- + as_objects : bool, optional + If False (default), return a list of table names. If True, return a + list of table objects. + + Returns + ------- + list + List of ancestor tables in topological order. """ return [ FreeTable(self.connection, node) if as_objects else node @@ -356,9 +423,18 @@ def ancestors(self, as_objects=False): def parts(self, as_objects=False): """ - return part tables either as entries in a dict with foreign key information or a list of objects - - :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects. + Return part tables for this master table. + + Parameters + ---------- + as_objects : bool, optional + If False (default), the output is a list of full table names. If True, + return table objects. + + Returns + ------- + list + List of part table names or table objects. """ self.connection.dependencies.load(force=False) nodes = [ @@ -371,7 +447,12 @@ def parts(self, as_objects=False): @property def is_declared(self): """ - :return: True is the table is declared in the schema. + Check if the table is declared in the schema. + + Returns + ------- + bool + True if the table is declared in the schema. """ query = self.connection.adapter.get_table_info_sql(self.database, self.table_name) return self.connection.query(query).rowcount > 0 @@ -379,7 +460,12 @@ def is_declared(self): @property def full_table_name(self): """ - :return: full table name in the schema + Return the full table name in the schema. + + Returns + ------- + str + Full table name in the format `database`.`table_name`. """ if self.database is None or self.table_name is None: raise DataJointError( @@ -395,20 +481,24 @@ def adapter(self): def update1(self, row): """ - ``update1`` updates one existing entry in the table. + Update one existing entry in the table. + Caution: In DataJoint the primary modes for data manipulation is to ``insert`` and ``delete`` entire records since referential integrity works on the level of records, not fields. Therefore, updates are reserved for corrective operations outside of main workflow. Use UPDATE methods sparingly with full awareness of potential violations of assumptions. - :param row: a ``dict`` containing the primary key values and the attributes to update. - Setting an attribute value to None will reset it to the default value (if any). - The primary key attributes must always be provided. - Examples: + Parameters + ---------- + row : dict + A dict containing the primary key values and the attributes to update. + Setting an attribute value to None will reset it to the default value (if any). + Examples + -------- >>> table.update1({'id': 1, 'value': 3}) # update value in record with id=1 >>> table.update1({'id': 1, 'value': None}) # reset value to default """ @@ -440,11 +530,6 @@ def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult: """ Validate rows without inserting them. - :param rows: Same format as insert() - iterable of dicts, tuples, numpy records, - or a pandas DataFrame. - :param ignore_extra_fields: If True, ignore fields not in the table heading. - :return: ValidationResult with is_valid, errors list, and rows_checked count. - Validates: - Field existence (all fields must be in table heading) - Row format (correct number of attributes for positional inserts) @@ -458,13 +543,26 @@ def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult: - Unique constraints (other than PK) - Custom MySQL constraints - Example:: - - result = table.validate(rows) - if result: - table.insert(rows) - else: - print(result.summary()) + Parameters + ---------- + rows : iterable + Same format as insert() - iterable of dicts, tuples, numpy records, + or a pandas DataFrame. + ignore_extra_fields : bool, optional + If True, ignore fields not in the table heading. + + Returns + ------- + ValidationResult + Result with is_valid, errors list, and rows_checked count. + + Examples + -------- + >>> result = table.validate(rows) + >>> if result: + ... table.insert(rows) + ... else: + ... print(result.summary()) """ errors = [] @@ -575,10 +673,21 @@ def validate(self, rows, *, ignore_extra_fields=False) -> ValidationResult: def insert1(self, row, **kwargs): """ - Insert one data record into the table. For ``kwargs``, see ``insert()``. + Insert one data record into the table. + + For ``kwargs``, see ``insert()``. - :param row: a numpy record, a dict-like object, or an ordered sequence to be inserted + Parameters + ---------- + row : numpy.void, dict, or sequence + A numpy record, a dict-like object, or an ordered sequence to be inserted as one row. + **kwargs + Additional arguments passed to ``insert()``. + + See Also + -------- + insert : Insert multiple data records. """ self.insert((row,), **kwargs) @@ -623,27 +732,36 @@ def insert( """ Insert a collection of rows. - :param rows: Either (a) an iterable where an element is a numpy record, a - dict-like object, a pandas.DataFrame, a polars.DataFrame, a pyarrow.Table, - a sequence, or a query expression with the same heading as self, or + Parameters + ---------- + rows : iterable or pathlib.Path + Either (a) an iterable where an element is a numpy record, a dict-like + object, a pandas.DataFrame, a polars.DataFrame, a pyarrow.Table, a + sequence, or a query expression with the same heading as self, or (b) a pathlib.Path object specifying a path relative to the current directory with a CSV file, the contents of which will be inserted. - :param replace: If True, replaces the existing tuple. - :param skip_duplicates: If True, silently skip duplicate inserts. - :param ignore_extra_fields: If False, fields that are not in the heading raise error. - :param allow_direct_insert: Only applies in auto-populated tables. If False (default), - insert may only be called from inside the make callback. - :param chunk_size: If set, insert rows in batches of this size. Useful for very - large inserts to avoid memory issues. Each chunk is a separate transaction. - - Example: - - >>> Table.insert([ - >>> dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"), - >>> dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")]) - - # Large insert with chunking - >>> Table.insert(large_dataset, chunk_size=10000) + replace : bool, optional + If True, replaces the existing tuple. + skip_duplicates : bool, optional + If True, silently skip duplicate inserts. + ignore_extra_fields : bool, optional + If False (default), fields that are not in the heading raise error. + allow_direct_insert : bool, optional + Only applies in auto-populated tables. If False (default), insert may + only be called from inside the make callback. + chunk_size : int, optional + If set, insert rows in batches of this size. Useful for very large + inserts to avoid memory issues. Each chunk is a separate transaction. + + Examples + -------- + >>> Table.insert([ + ... dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"), + ... dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")]) + + Large insert with chunking: + + >>> Table.insert(large_dataset, chunk_size=10000) """ if isinstance(rows, pandas.DataFrame): # drop 'extra' synthetic index for 1-field index case - @@ -715,10 +833,16 @@ def _insert_rows(self, rows, replace, skip_duplicates, ignore_extra_fields): """ Internal helper to insert a batch of rows. - :param rows: Iterable of rows to insert - :param replace: If True, use REPLACE instead of INSERT - :param skip_duplicates: If True, use ON DUPLICATE KEY UPDATE - :param ignore_extra_fields: If True, ignore unknown fields + Parameters + ---------- + rows : iterable + Iterable of rows to insert. + replace : bool + If True, use REPLACE instead of INSERT. + skip_duplicates : bool + If True, use ON DUPLICATE KEY UPDATE. + ignore_extra_fields : bool + If True, ignore unknown fields. """ # collects the field list from first row (passed by reference) field_list = [] @@ -757,26 +881,34 @@ def insert_dataframe(self, df, index_as_pk=None, **insert_kwargs): (which sets primary key as index) can be modified and re-inserted using insert_dataframe() without manual index manipulation. - :param df: pandas DataFrame to insert - :param index_as_pk: How to handle DataFrame index: + Parameters + ---------- + df : pandas.DataFrame + DataFrame to insert. + index_as_pk : bool, optional + How to handle DataFrame index: + - None (default): Auto-detect. Use index as primary key if index names match primary_key columns. Drop if unnamed RangeIndex. - True: Treat index as primary key columns. Raises if index names don't match table primary key. - False: Ignore index entirely (drop it). - :param **insert_kwargs: Passed to insert() - replace, skip_duplicates, - ignore_extra_fields, allow_direct_insert, chunk_size + **insert_kwargs + Passed to insert() - replace, skip_duplicates, ignore_extra_fields, + allow_direct_insert, chunk_size. + + Examples + -------- + Round-trip with to_pandas(): - Example:: + >>> df = table.to_pandas() # PK becomes index + >>> df['value'] = df['value'] * 2 # Modify data + >>> table.insert_dataframe(df) # Auto-detects index as PK - # Round-trip with to_pandas() - df = table.to_pandas() # PK becomes index - df['value'] = df['value'] * 2 # Modify data - table.insert_dataframe(df) # Auto-detects index as PK + Explicit control: - # Explicit control - table.insert_dataframe(df, index_as_pk=True) # Use index - table.insert_dataframe(df, index_as_pk=False) # Ignore index + >>> table.insert_dataframe(df, index_as_pk=True) # Use index + >>> table.insert_dataframe(df, index_as_pk=False) # Ignore index """ if not isinstance(df, pandas.DataFrame): raise DataJointError("insert_dataframe requires a pandas DataFrame") @@ -1114,7 +1246,12 @@ def drop(self, prompt: bool | None = None): @property def size_on_disk(self): """ - :return: size of data and indices in bytes on the storage device + Return the size of data and indices in bytes on the storage device. + + Returns + ------- + int + Size of data and indices in bytes. """ ret = self.connection.query( 'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format(database=self.database, table=self.table_name), @@ -1124,7 +1261,20 @@ def size_on_disk(self): def describe(self, context=None, printout=False): """ - :return: the definition string for the query using DataJoint DDL. + Return the definition string for the query using DataJoint DDL. + + Parameters + ---------- + context : dict, optional + The context for foreign key resolution. If None, uses the caller's + local and global namespace. + printout : bool, optional + If True, also log the definition string. + + Returns + ------- + str + The definition string for the table in DataJoint DDL format. """ if context is None: frame = inspect.currentframe().f_back @@ -1202,9 +1352,11 @@ def describe(self, context=None, printout=False): # --- private helper functions ---- def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): """ - For a given attribute `name` with `value`, return its processed value or value placeholder - as a string to be included in the query and the value, if any, to be submitted for - processing by mysql API. + Return processed value or placeholder for an attribute. + + For a given attribute `name` with `value`, return its processed value or + value placeholder as a string to be included in the query and the value, + if any, to be submitted for processing by mysql API. In the simplified type system: - Codecs handle all custom encoding via type chains @@ -1213,10 +1365,22 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): - Blob values pass through as bytes - Numeric values are stringified - :param name: name of attribute to be inserted - :param value: value of attribute to be inserted - :param ignore_extra_fields: if True, return None for unknown fields - :param row: the full row dict (unused in simplified model) + Parameters + ---------- + name : str + Name of attribute to be inserted. + value : any + Value of attribute to be inserted. + ignore_extra_fields : bool, optional + If True, return None for unknown fields. + row : dict, optional + The full row dict (used for context in codec encoding). + + Returns + ------- + tuple or None + A tuple of (name, placeholder, value) or None if the field should be + ignored. """ if ignore_extra_fields and name not in self.heading: return None @@ -1284,17 +1448,31 @@ def __make_placeholder(self, name, value, ignore_extra_fields=False, row=None): def __make_row_to_insert(self, row, field_list, ignore_extra_fields): """ - Helper function for insert and update - - :param row: A tuple to insert - :return: a dict with fields 'names', 'placeholders', 'values' + Helper function for insert and update. + + Parameters + ---------- + row : tuple, dict, or numpy.void + A row to insert. + field_list : list + List to be populated with field names from the first row. + ignore_extra_fields : bool + If True, ignore fields not in the heading. + + Returns + ------- + dict + A dict with fields 'names', 'placeholders', 'values'. """ def check_fields(fields): """ - Validates that all items in `fields` are valid attributes in the heading + Validate that all items in `fields` are valid attributes in the heading. - :param fields: field names of a tuple + Parameters + ---------- + fields : list + Field names of a tuple. """ if not field_list: if not ignore_extra_fields: @@ -1374,12 +1552,24 @@ def check_fields(fields): def lookup_class_name(name, context, depth=3): """ - given a table name in the form `schema_name`.`table_name`, find its class in the context. - - :param name: `schema_name`.`table_name` - :param context: dictionary representing the namespace - :param depth: search depth into imported modules, helps avoid infinite recursion. - :return: class name found in the context or None if not found + Find a table's class in the context given its full table name. + + Given a table name in the form `schema_name`.`table_name`, find its class in + the context. + + Parameters + ---------- + name : str + Full table name in format `schema_name`.`table_name`. + context : dict + Dictionary representing the namespace. + depth : int, optional + Search depth into imported modules, helps avoid infinite recursion. + + Returns + ------- + str or None + Class name found in the context or None if not found. """ # breadth-first search nodes = [dict(context=context, context_name="", depth=depth)] @@ -1415,11 +1605,16 @@ def lookup_class_name(name, context, depth=3): class FreeTable(Table): """ - A base table without a dedicated class. Each instance is associated with a table - specified by full_table_name. + A base table without a dedicated class. + + Each instance is associated with a table specified by full_table_name. - :param conn: a dj.Connection object - :param full_table_name: in format `database`.`table_name` + Parameters + ---------- + conn : datajoint.Connection + A DataJoint connection object. + full_table_name : str + Full table name in format `database`.`table_name`. """ def __init__(self, conn, full_table_name): diff --git a/src/datajoint/utils.py b/src/datajoint/utils.py index 4309d78b9..9716df3d6 100644 --- a/src/datajoint/utils.py +++ b/src/datajoint/utils.py @@ -9,12 +9,23 @@ def user_choice(prompt, choices=("yes", "no"), default=None): """ - Prompts the user for confirmation. The default value, if any, is capitalized. + Prompt the user for confirmation. - :param prompt: Information to display to the user. - :param choices: an iterable of possible choices. - :param default: default choice - :return: the user's choice + The default value, if any, is capitalized. + + Parameters + ---------- + prompt : str + Information to display to the user. + choices : tuple, optional + An iterable of possible choices. Default ("yes", "no"). + default : str, optional + Default choice. Default None. + + Returns + ------- + str + The user's choice. """ assert default is None or default in choices choice_list = ", ".join((choice.title() if choice == default else choice for choice in choices)) @@ -27,20 +38,32 @@ def user_choice(prompt, choices=("yes", "no"), default=None): def get_master(full_table_name: str, adapter=None) -> str: """ + Get the master table name from a part table name. + If the table name is that of a part table, then return what the master table name would be. This follows DataJoint's table naming convention where a master and a part must be in the same schema and the part table is prefixed with the master table name + ``__``. - Example: - `ephys`.`session` -- master (MySQL) - `ephys`.`session__recording` -- part (MySQL) - "ephys"."session__recording" -- part (PostgreSQL) - - :param full_table_name: Full table name including part. - :type full_table_name: str - :param adapter: Optional database adapter for backend-specific parsing. - :return: Supposed master full table name or empty string if not a part table name. - :rtype: str + Parameters + ---------- + full_table_name : str + Full table name including part. + adapter : DatabaseAdapter, optional + Database adapter for backend-specific parsing. Default None. + + Returns + ------- + str + Supposed master full table name or empty string if not a part table name. + + Examples + -------- + >>> get_master('`ephys`.`session__recording`') # MySQL part table + '`ephys`.`session`' + >>> get_master('"ephys"."session__recording"') # PostgreSQL part table + '"ephys"."session"' + >>> get_master('`ephys`.`session`') # Not a part table + '' """ if adapter is not None: result = adapter.get_master_table_name(full_table_name) @@ -57,23 +80,44 @@ def is_camel_case(s): """ Check if a string is in CamelCase notation. - :param s: string to check - :returns: True if the string is in CamelCase notation, False otherwise - Example: - >>> is_camel_case("TableName") # returns True - >>> is_camel_case("table_name") # returns False + Parameters + ---------- + s : str + String to check. + + Returns + ------- + bool + True if the string is in CamelCase notation, False otherwise. + + Examples + -------- + >>> is_camel_case("TableName") + True + >>> is_camel_case("table_name") + False """ return bool(re.match(r"^[A-Z][A-Za-z0-9]*$", s)) def to_camel_case(s): """ - Convert names with under score (_) separation into camel case names. + Convert names with underscore (_) separation into camel case names. + + Parameters + ---------- + s : str + String in under_score notation. - :param s: string in under_score notation - :returns: string in CamelCase notation - Example: - >>> to_camel_case("table_name") # returns "TableName" + Returns + ------- + str + String in CamelCase notation. + + Examples + -------- + >>> to_camel_case("table_name") + 'TableName' """ def to_upper(match): @@ -84,12 +128,27 @@ def to_upper(match): def from_camel_case(s): """ - Convert names in camel case into underscore (_) separated names + Convert names in camel case into underscore (_) separated names. + + Parameters + ---------- + s : str + String in CamelCase notation. - :param s: string in CamelCase notation - :returns: string in under_score notation - Example: - >>> from_camel_case("TableName") # yields "table_name" + Returns + ------- + str + String in under_score notation. + + Raises + ------ + DataJointError + If the string is not in valid CamelCase notation. + + Examples + -------- + >>> from_camel_case("TableName") + 'table_name' """ def convert(match): @@ -102,10 +161,17 @@ def convert(match): def safe_write(filepath, blob): """ - A two-step write. + Write data to a file using a two-step process. + + Writes to a temporary file first, then renames to the final path. + This ensures atomic writes and prevents partial file corruption. - :param filename: full path - :param blob: binary data + Parameters + ---------- + filepath : str or Path + Full path to the destination file. + blob : bytes + Binary data to write. """ filepath = Path(filepath) if not filepath.is_file(): @@ -117,7 +183,19 @@ def safe_write(filepath, blob): def safe_copy(src, dest, overwrite=False): """ - Copy the contents of src file into dest file as a two-step process. Skip if dest exists already + Copy the contents of src file into dest file as a two-step process. + + Copies to a temporary file first, then renames to the final path. + Skips if dest exists already (unless overwrite is True). + + Parameters + ---------- + src : str or Path + Source file path. + dest : str or Path + Destination file path. + overwrite : bool, optional + If True, overwrite existing destination file. Default False. """ src, dest = Path(src), Path(dest) if not (dest.exists() and src.samefile(dest)) and (overwrite or not dest.is_file()): @@ -125,24 +203,3 @@ def safe_copy(src, dest, overwrite=False): temp_file = dest.with_suffix(dest.suffix + ".copying") shutil.copyfile(str(src), str(temp_file)) temp_file.rename(dest) - - -def parse_sql(filepath): - """ - yield SQL statements from an SQL file - """ - delimiter = ";" - statement = [] - with Path(filepath).open("rt") as f: - for line in f: - line = line.strip() - if not line.startswith("--") and len(line) > 1: - if line.startswith("delimiter"): - delimiter = line.split()[1] - else: - statement.append(line) - if line.endswith(delimiter): - yield " ".join(statement) - statement = [] - if statement: - yield " ".join(statement)