-
Notifications
You must be signed in to change notification settings - Fork 19
FEAT: Adding setinputsizes #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: jahnvi/connection_timeout
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,6 +55,7 @@ def __init__(self, connection, timeout: int = 0) -> None: | |
""" | ||
self.connection = connection | ||
self._timeout = timeout | ||
self._inputsizes = None | ||
# self.connection.autocommit = False | ||
self.hstmt = None | ||
self._initialize_cursor() | ||
|
@@ -463,6 +464,71 @@ def _check_closed(self): | |
if self.closed: | ||
raise Exception("Operation cannot be performed: the cursor is closed.") | ||
|
||
def setinputsizes(self, sizes): | ||
""" | ||
Sets the type information to be used for parameters in execute and executemany. | ||
|
||
This method can be used to explicitly declare the types and sizes of query parameters. | ||
For example: | ||
|
||
sql = "INSERT INTO product (item, price) VALUES (?, ?)" | ||
params = [('bicycle', 499.99), ('ham', 17.95)] | ||
# specify that parameters are for NVARCHAR(50) and DECIMAL(18,4) columns | ||
cursor.setinputsizes([(SQL_WVARCHAR, 50, 0), (SQL_DECIMAL, 18, 4)]) | ||
cursor.executemany(sql, params) | ||
|
||
Args: | ||
sizes: A sequence of tuples, one for each parameter. Each tuple contains | ||
(sql_type, size, decimal_digits) where size and decimal_digits are optional. | ||
""" | ||
self._inputsizes = [] | ||
|
||
if sizes: | ||
for size_info in sizes: | ||
if isinstance(size_info, tuple): | ||
# Handle tuple format (sql_type, size, decimal_digits) | ||
if len(size_info) == 1: | ||
self._inputsizes.append((size_info[0], 0, 0)) | ||
elif len(size_info) == 2: | ||
self._inputsizes.append((size_info[0], size_info[1], 0)) | ||
elif len(size_info) >= 3: | ||
self._inputsizes.append((size_info[0], size_info[1], size_info[2])) | ||
else: | ||
# Handle single value (just sql_type) | ||
self._inputsizes.append((size_info, 0, 0)) | ||
|
||
def _reset_inputsizes(self): | ||
"""Reset input sizes after execution""" | ||
self._inputsizes = None | ||
|
||
def _get_c_type_for_sql_type(self, sql_type): | ||
"""Map SQL type to appropriate C type for parameter binding""" | ||
sql_to_c_type = { | ||
ddbc_sql_const.SQL_CHAR.value: ddbc_sql_const.SQL_C_CHAR.value, | ||
ddbc_sql_const.SQL_VARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value, | ||
ddbc_sql_const.SQL_LONGVARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value, | ||
ddbc_sql_const.SQL_WCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, | ||
ddbc_sql_const.SQL_WVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, | ||
ddbc_sql_const.SQL_WLONGVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, | ||
ddbc_sql_const.SQL_DECIMAL.value: ddbc_sql_const.SQL_C_NUMERIC.value, | ||
ddbc_sql_const.SQL_NUMERIC.value: ddbc_sql_const.SQL_C_NUMERIC.value, | ||
ddbc_sql_const.SQL_BIT.value: ddbc_sql_const.SQL_C_BIT.value, | ||
ddbc_sql_const.SQL_TINYINT.value: ddbc_sql_const.SQL_C_TINYINT.value, | ||
ddbc_sql_const.SQL_SMALLINT.value: ddbc_sql_const.SQL_C_SHORT.value, | ||
ddbc_sql_const.SQL_INTEGER.value: ddbc_sql_const.SQL_C_LONG.value, | ||
ddbc_sql_const.SQL_BIGINT.value: ddbc_sql_const.SQL_C_SBIGINT.value, | ||
ddbc_sql_const.SQL_REAL.value: ddbc_sql_const.SQL_C_FLOAT.value, | ||
ddbc_sql_const.SQL_FLOAT.value: ddbc_sql_const.SQL_C_DOUBLE.value, | ||
ddbc_sql_const.SQL_DOUBLE.value: ddbc_sql_const.SQL_C_DOUBLE.value, | ||
ddbc_sql_const.SQL_BINARY.value: ddbc_sql_const.SQL_C_BINARY.value, | ||
ddbc_sql_const.SQL_VARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value, | ||
ddbc_sql_const.SQL_LONGVARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value, | ||
ddbc_sql_const.SQL_DATE.value: ddbc_sql_const.SQL_C_TYPE_DATE.value, | ||
ddbc_sql_const.SQL_TIME.value: ddbc_sql_const.SQL_C_TYPE_TIME.value, | ||
ddbc_sql_const.SQL_TIMESTAMP.value: ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, | ||
} | ||
return sql_to_c_type.get(sql_type, ddbc_sql_const.SQL_C_DEFAULT.value) | ||
|
||
def _create_parameter_types_list(self, parameter, param_info, parameters_list, i): | ||
""" | ||
Maps parameter types for the given parameter. | ||
|
@@ -474,9 +540,20 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i | |
paraminfo. | ||
""" | ||
paraminfo = param_info() | ||
sql_type, c_type, column_size, decimal_digits = self._map_sql_type( | ||
parameter, parameters_list, i | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please double-check that all parameterized queries remain fully protected against SQL injection—even when input sizes or types are set by users via |
||
# Check if we have explicit type information from setinputsizes | ||
if hasattr(self, '_inputsizes') and self._inputsizes and i < len(self._inputsizes): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are places where you check if the object ( |
||
# Use explicit type information | ||
sql_type, column_size, decimal_digits = self._inputsizes[i] | ||
|
||
# Determine the appropriate C type based on SQL type | ||
c_type = self._get_c_type_for_sql_type(sql_type) | ||
else: | ||
# Fall back to automatic type inference | ||
sql_type, c_type, column_size, decimal_digits = self._map_sql_type( | ||
parameter, parameters_list, i | ||
) | ||
|
||
paraminfo.paramCType = c_type | ||
paraminfo.paramSQLType = sql_type | ||
paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value | ||
|
@@ -643,6 +720,8 @@ def execute( | |
except Exception as e: | ||
# If describe fails, it's likely there are no results (e.g., for INSERT) | ||
self.description = None | ||
|
||
self._reset_inputsizes() # Reset input sizes after execution | ||
|
||
@staticmethod | ||
def _select_best_sample_value(column): | ||
|
@@ -709,7 +788,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: | |
if not seq_of_parameters: | ||
self.rowcount = 0 | ||
return | ||
|
||
# Apply timeout if set (non-zero) | ||
if self._timeout > 0: | ||
try: | ||
|
@@ -720,23 +799,62 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: | |
timeout_value | ||
) | ||
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) | ||
log('debug', f"Set query timeout to {self._timeout} seconds") | ||
log('debug', f"Set query timeout to {timeout_value} seconds") | ||
except Exception as e: | ||
log('warning', f"Failed to set query timeout: {e}") | ||
|
||
param_info = ddbc_bindings.ParamInfo | ||
param_count = len(seq_of_parameters[0]) | ||
parameters_type = [] | ||
|
||
# Make a copy of the parameters for potential transformation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider raising a warning or error if the number of input sizes set via |
||
processed_parameters = [list(params) for params in seq_of_parameters] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code is using [list(params) for params in seq_of_parameters] to create a new list of lists from seq_of_parameters (which is probably a list or sequence of parameters for batch inserts). When you do this for a very large number of rows (for example, thousands or millions), it creates a copy of every row in memory. This can use a lot of memory and might slow things down or even cause crashes if there isn’t enough memory. If possible, don’t create a big copy of all the data at once. Current Implementation:
This creates a new list in memory that contains a copy of every params as a list. Generator Expression:
This creates a generator—not a list. It doesn’t copy anything right away. List comprehension is eager: makes everything up front, uses more memory. |
||
|
||
# Check if we have explicit input sizes set | ||
if hasattr(self, '_inputsizes') and self._inputsizes: | ||
# Use the explicitly set input sizes | ||
for col_index in range(param_count): | ||
if col_index < len(self._inputsizes): | ||
sql_type, column_size, decimal_digits = self._inputsizes[col_index] | ||
c_type = self._get_c_type_for_sql_type(sql_type) | ||
|
||
# If using SQL_DECIMAL/NUMERIC, we need to ensure the Python values | ||
# are properly converted for the driver | ||
if sql_type in (ddbc_sql_const.SQL_DECIMAL.value, ddbc_sql_const.SQL_NUMERIC.value): | ||
# Make sure all values in this column are Decimal objects | ||
for row_idx, row in enumerate(processed_parameters): | ||
if not isinstance(row[col_index], decimal.Decimal): | ||
# Convert to Decimal if it's not already | ||
processed_parameters[row_idx][col_index] = decimal.Decimal(str(row[col_index])) | ||
|
||
paraminfo = param_info() | ||
paraminfo.paramCType = c_type | ||
paraminfo.paramSQLType = sql_type | ||
paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value | ||
paraminfo.columnSize = column_size | ||
paraminfo.decimalDigits = decimal_digits | ||
parameters_type.append(paraminfo) | ||
else: | ||
# Fall back to auto-detect for any parameters beyond those specified | ||
column = [row[col_index] for row in seq_of_parameters] | ||
sample_value = self._select_best_sample_value(column) | ||
dummy_row = list(seq_of_parameters[0]) | ||
parameters_type.append( | ||
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) | ||
) | ||
else: | ||
# No input sizes set, use auto-detection | ||
for col_index in range(param_count): | ||
column = [row[col_index] for row in seq_of_parameters] | ||
sample_value = self._select_best_sample_value(column) | ||
dummy_row = list(seq_of_parameters[0]) | ||
parameters_type.append( | ||
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) | ||
) | ||
|
||
for col_index in range(param_count): | ||
column = [row[col_index] for row in seq_of_parameters] | ||
sample_value = self._select_best_sample_value(column) | ||
dummy_row = list(seq_of_parameters[0]) | ||
parameters_type.append( | ||
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) | ||
) | ||
|
||
columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters) | ||
columnwise_params = self._transpose_rowwise_to_columnwise(processed_parameters) | ||
|
||
log('info', "Executing batch query with %d parameter sets:\n%s", | ||
len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters)) | ||
) | ||
|
@@ -749,11 +867,17 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: | |
parameters_type, | ||
len(seq_of_parameters) | ||
) | ||
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) | ||
|
||
try: | ||
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) | ||
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) | ||
self.last_executed_stmt = operation | ||
self._initialize_description() | ||
finally: | ||
# Reset input sizes after execution | ||
self._reset_inputsizes() | ||
|
||
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) | ||
self.last_executed_stmt = operation | ||
self._initialize_description() | ||
|
||
|
||
def fetchone(self) -> Union[None, Row]: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please consider adding type annotations to new methods such as setinputsizes, _reset_inputsizes, and _get_c_type_for_sql_type. Type annotations will improve code clarity, enable better static analysis, and make the codebase more maintainable as it grows.