diff --git a/CHANGELOG.md b/CHANGELOG.md index 225cfcee..07b96197 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## UNRELEASED +### Bug Fixes +- `Cursor.executemany` now correctly resets `rowcount` and reports the number of inserted rows after a bulk insert. Previously, `rowcount` retained the value from the previous operation. The insert summary is also appended to `cursor.summary`, consistent with the non-bulk path. In addition, passing a generator as `seq_of_parameters` no longer raises `TypeError`; the bulk-insert optimisation is now skipped for non-indexable iterables and the operation falls through to the row-by-row path as PEP 249 requires. + ### Improvements - The Cython extension modules now declare free-threading compatibility, so importing clickhouse-connect on a free-threaded Python build such as 3.14t no longer silently re-enables the GIL. As part of this change, `ResponseBuffer.read_uint64` no longer uses a module level scratch buffer for its big-endian byte swap, which was the one piece of shared mutable state in the C modules. Building from source now requires Cython 3.1 or later. The CI test matrix now runs the full suite on free-threaded Python 3.14t as a non-blocking job. Free-threading support remains experimental. diff --git a/clickhouse_connect/dbapi/__init__.py b/clickhouse_connect/dbapi/__init__.py index 84a9a8c3..f1c33a75 100644 --- a/clickhouse_connect/dbapi/__init__.py +++ b/clickhouse_connect/dbapi/__init__.py @@ -1,3 +1,5 @@ +from typing import Any + from clickhouse_connect.dbapi.connection import Connection apilevel = "2.0" # PEP 249 DB API level @@ -12,18 +14,18 @@ class Error(Exception): def connect( host: str | None = None, database: str | None = None, - username: str | None = "", - password: str | None = "", + username: str = "", + password: str = "", port: int | None = None, - **kwargs, -): + **kwargs: Any, +) -> Connection: secure = kwargs.pop("secure", False) return Connection( host=host, database=database, username=username, password=password, - port=port, + port=port if port is not None else 0, secure=secure, **kwargs, ) diff --git a/clickhouse_connect/dbapi/connection.py b/clickhouse_connect/dbapi/connection.py index 3dba3c54..80c2c538 100644 --- a/clickhouse_connect/dbapi/connection.py +++ b/clickhouse_connect/dbapi/connection.py @@ -1,3 +1,5 @@ +from typing import Any + from clickhouse_connect.dbapi.cursor import Cursor from clickhouse_connect.driver import create_client from clickhouse_connect.driver.query import QueryResult @@ -10,21 +12,21 @@ class Connection: def __init__( self, - dsn: str = None, + dsn: str | None = None, username: str = "", password: str = "", - host: str = None, - database: str = None, - interface: str = None, + host: str | None = None, + database: str | None = None, + interface: str | None = None, port: int = 0, secure: bool | str = False, - **kwargs, + **kwargs: Any, ): self.client = create_client( host=host, username=username, password=password, - database=database, + database=database if database is not None else "__default__", interface=interface, port=port, secure=secure, @@ -35,20 +37,20 @@ def __init__( self.client._add_integration_tag("sqlalchemy") self.timezone = self.client.server_tz - def close(self): + def close(self) -> None: self.client.close() - def commit(self): + def commit(self) -> None: pass - def rollback(self): + def rollback(self) -> None: pass - def command(self, cmd: str): + def command(self, cmd: str) -> Any: return self.client.command(cmd) def raw_query(self, query: str) -> QueryResult: return self.client.query(query) - def cursor(self): + def cursor(self) -> Cursor: return Cursor(self.client) diff --git a/clickhouse_connect/dbapi/cursor.py b/clickhouse_connect/dbapi/cursor.py index af7f36f6..9b200aba 100644 --- a/clickhouse_connect/dbapi/cursor.py +++ b/clickhouse_connect/dbapi/cursor.py @@ -1,6 +1,7 @@ import logging import re from collections.abc import Mapping, Sequence +from typing import Any, cast from clickhouse_connect.datatypes.registry import get_from_name from clickhouse_connect.driver import Client @@ -23,34 +24,34 @@ class Cursor: def __init__(self, client: Client): self.client = client - self.arraysize = 1 + self.arraysize: int = 1 self.data: Sequence | None = None - self.names = [] - self.types = [] - self._rowcount = 0 - self._summary: list[dict[str, str]] = [] - self._ix = 0 + self.names: Sequence[str] = [] + self.types: Sequence[Any] = [] + self._rowcount: int = 0 + self._summary: list[dict[str, Any]] = [] + self._ix: int = 0 - def check_valid(self): + def check_valid(self) -> None: if self.data is None: raise ProgrammingError("Cursor is not valid") @property - def description(self): + def description(self) -> list[tuple[str, Any, None, None, None, None, bool]]: return [(n, t, None, None, None, None, True) for n, t in zip(self.names, self.types)] @property - def rowcount(self): + def rowcount(self) -> int: return self._rowcount @property - def summary(self) -> list[dict[str, str]]: + def summary(self) -> list[dict[str, Any]]: return self._summary - def close(self): + def close(self) -> None: self.data = None - def execute(self, operation: str, parameters=None): + def execute(self, operation: str, parameters: Any = None) -> None: if not parameters and isinstance(operation, str): # Per PEP 249 pyformat paramstyle, callers (e.g. SQLAlchemy) escape # literal percent signs as %% in operation strings. When there are @@ -80,7 +81,7 @@ def execute(self, operation: str, parameters=None): self.names = meta_result.column_names self.types = [x.name for x in meta_result.column_types] - def _try_bulk_insert(self, operation: str, data): + def _try_bulk_insert(self, operation: str, data: Any) -> bool: match = insert_re.match(remove_sql_comments(operation)) if not match: return False @@ -94,24 +95,31 @@ def _try_bulk_insert(self, operation: str, data): op_columns = None if "VALUES" not in temp.upper(): return False + if not isinstance(data, Sequence) or len(data) == 0: + return False first_row = data[0] + col_names: list[str] | str + data_values: Sequence[Sequence[Any]] if isinstance(first_row, Mapping): - col_names = list(first_row.keys()) - if op_columns and {unescape_identifier(x) for x in op_columns} != set(col_names): + col_names = [str(k) for k in first_row.keys()] + if op_columns and {unescape_identifier(str(x)) for x in op_columns} != set(col_names): return False # Data sent in doesn't match the columns in the insert statement data_values = [list(row.values()) for row in data] elif isinstance(first_row, Sequence) and not isinstance(first_row, (str, bytes)): # PEP 249 also allows rows as sequences; take column names from the # insert statement if present, otherwise insert into all columns - col_names = [unescape_identifier(x) for x in op_columns] if op_columns else "*" + col_names = [unescape_identifier(str(x)) for x in op_columns] if op_columns else "*" data_values = data else: return False - self.client.insert(table, data_values, col_names) + insert_summary = self.client.insert(table, data_values, col_names) self.data = [] + self._rowcount = insert_summary.written_rows + self._ix = 0 + self._summary.append(insert_summary.summary) return True - def executemany(self, operation, parameters): + def executemany(self, operation: str, parameters: Any) -> None: if not parameters or self._try_bulk_insert(operation, parameters): return self.data = [] @@ -138,22 +146,25 @@ def executemany(self, operation, parameters): # Need to reset cursor _ix after performing an execute self._ix = 0 - def fetchall(self): + def fetchall(self) -> Sequence: self.check_valid() - ret = self.data[self._ix :] + data = cast(Sequence, self.data) + ret = data[self._ix :] self._ix = self._rowcount return ret - def fetchone(self): + def fetchone(self) -> Any: self.check_valid() if self._ix >= self._rowcount: return None - val = self.data[self._ix] + data = cast(Sequence, self.data) + val = data[self._ix] self._ix += 1 return val - def fetchmany(self, size: int = -1): + def fetchmany(self, size: int = -1) -> Sequence: self.check_valid() + data = cast(Sequence, self.data) if size < 0: # Fetch all remaining rows @@ -163,12 +174,12 @@ def fetchmany(self, size: int = -1): return [] end = min(self._ix + size, self._rowcount) - ret = self.data[self._ix : end] + ret = data[self._ix : end] self._ix = end return ret - def nextset(self): + def nextset(self) -> None: raise NotImplementedError - def callproc(self, *args, **kwargs): + def callproc(self, *args, **kwargs) -> None: raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index 83a45132..63c4f807 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,9 +73,6 @@ module = [ "clickhouse_connect.datatypes.string", "clickhouse_connect.datatypes.temporal", "clickhouse_connect.datatypes.vector", - "clickhouse_connect.dbapi", - "clickhouse_connect.dbapi.connection", - "clickhouse_connect.dbapi.cursor", "clickhouse_connect.driver", "clickhouse_connect.driver.asyncclient", "clickhouse_connect.driver.asyncqueue", diff --git a/tests/unit_tests/test_driver/test_cursor.py b/tests/unit_tests/test_driver/test_cursor.py index d427d7ed..9999b7e3 100644 --- a/tests/unit_tests/test_driver/test_cursor.py +++ b/tests/unit_tests/test_driver/test_cursor.py @@ -360,3 +360,53 @@ def test_execute_empty_with_query_fetches_metadata(): "SELECT * FROM (WITH value_1 AS 13 SELECT value_1 WHERE value_1 = 79) LIMIT 0", None, ) + + +def _mock_insert_client(written_rows: int, summary_extra: dict | None = None): + """Return a mock client whose insert() returns a QuerySummary-like object.""" + summary_dict = {"written_rows": str(written_rows), **(summary_extra or {})} + insert_summary = Mock() + insert_summary.written_rows = written_rows + insert_summary.summary = summary_dict + client = Mock() + client.insert.return_value = insert_summary + return client, summary_dict + + +def test_executemany_bulk_insert_rowcount_equals_written_rows(): + """rowcount after a bulk executemany reflects the actual number of inserted rows.""" + client, _ = _mock_insert_client(written_rows=2) + cursor = Cursor(client) + + rows = [(13, "user_1"), (79, "user_2")] + cursor.executemany("INSERT INTO test_table (id, name) VALUES (%s, %s)", rows) + + assert cursor.rowcount == 2 + + +def test_executemany_bulk_insert_appends_summary(): + """summary is populated from the insert response after a bulk executemany.""" + client, summary_dict = _mock_insert_client(written_rows=2, summary_extra={"written_bytes": "64"}) + cursor = Cursor(client) + + rows = [(13, "user_1"), (79, "user_2")] + cursor.executemany("INSERT INTO test_table (id, name) VALUES (%s, %s)", rows) + + assert cursor.summary == [summary_dict] + + +def test_executemany_generator_falls_through_to_row_by_row(): + """A generator passed to executemany falls through to the row-by-row path without raising TypeError.""" + client = Mock() + client.query.return_value = create_mock_query_result([]) + + cursor = Cursor(client) + + def row_generator(): + yield {"id": 1, "name": "user_1"} + yield {"id": 2, "name": "user_2"} + + cursor.executemany("INSERT INTO test_table (id, name) VALUES (%(id)s, %(name)s)", row_generator()) + + client.insert.assert_not_called() + assert client.query.call_count == 2