Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 141 additions & 17 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, connection, timeout: int = 0) -> None:
"""
self.connection = connection
self._timeout = timeout
self._inputsizes = None
# self.connection.autocommit = False
self.hstmt = None
self._initialize_cursor()
Expand Down Expand Up @@ -463,6 +464,71 @@ def _check_closed(self):
if self.closed:
raise Exception("Operation cannot be performed: the cursor is closed.")

def setinputsizes(self, sizes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please consider adding type annotations to new methods such as setinputsizes, _reset_inputsizes, and _get_c_type_for_sql_type. Type annotations will improve code clarity, enable better static analysis, and make the codebase more maintainable as it grows.

"""
Sets the type information to be used for parameters in execute and executemany.

This method can be used to explicitly declare the types and sizes of query parameters.
For example:

sql = "INSERT INTO product (item, price) VALUES (?, ?)"
params = [('bicycle', 499.99), ('ham', 17.95)]
# specify that parameters are for NVARCHAR(50) and DECIMAL(18,4) columns
cursor.setinputsizes([(SQL_WVARCHAR, 50, 0), (SQL_DECIMAL, 18, 4)])
cursor.executemany(sql, params)

Args:
sizes: A sequence of tuples, one for each parameter. Each tuple contains
(sql_type, size, decimal_digits) where size and decimal_digits are optional.
"""
self._inputsizes = []

if sizes:
for size_info in sizes:
if isinstance(size_info, tuple):
# Handle tuple format (sql_type, size, decimal_digits)
if len(size_info) == 1:
self._inputsizes.append((size_info[0], 0, 0))
elif len(size_info) == 2:
self._inputsizes.append((size_info[0], size_info[1], 0))
elif len(size_info) >= 3:
self._inputsizes.append((size_info[0], size_info[1], size_info[2]))
else:
# Handle single value (just sql_type)
self._inputsizes.append((size_info, 0, 0))

def _reset_inputsizes(self):
"""Reset input sizes after execution"""
self._inputsizes = None

def _get_c_type_for_sql_type(self, sql_type):
"""Map SQL type to appropriate C type for parameter binding"""
sql_to_c_type = {
ddbc_sql_const.SQL_CHAR.value: ddbc_sql_const.SQL_C_CHAR.value,
ddbc_sql_const.SQL_VARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value,
ddbc_sql_const.SQL_LONGVARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value,
ddbc_sql_const.SQL_WCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value,
ddbc_sql_const.SQL_WVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value,
ddbc_sql_const.SQL_WLONGVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value,
ddbc_sql_const.SQL_DECIMAL.value: ddbc_sql_const.SQL_C_NUMERIC.value,
ddbc_sql_const.SQL_NUMERIC.value: ddbc_sql_const.SQL_C_NUMERIC.value,
ddbc_sql_const.SQL_BIT.value: ddbc_sql_const.SQL_C_BIT.value,
ddbc_sql_const.SQL_TINYINT.value: ddbc_sql_const.SQL_C_TINYINT.value,
ddbc_sql_const.SQL_SMALLINT.value: ddbc_sql_const.SQL_C_SHORT.value,
ddbc_sql_const.SQL_INTEGER.value: ddbc_sql_const.SQL_C_LONG.value,
ddbc_sql_const.SQL_BIGINT.value: ddbc_sql_const.SQL_C_SBIGINT.value,
ddbc_sql_const.SQL_REAL.value: ddbc_sql_const.SQL_C_FLOAT.value,
ddbc_sql_const.SQL_FLOAT.value: ddbc_sql_const.SQL_C_DOUBLE.value,
ddbc_sql_const.SQL_DOUBLE.value: ddbc_sql_const.SQL_C_DOUBLE.value,
ddbc_sql_const.SQL_BINARY.value: ddbc_sql_const.SQL_C_BINARY.value,
ddbc_sql_const.SQL_VARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value,
ddbc_sql_const.SQL_LONGVARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value,
ddbc_sql_const.SQL_DATE.value: ddbc_sql_const.SQL_C_TYPE_DATE.value,
ddbc_sql_const.SQL_TIME.value: ddbc_sql_const.SQL_C_TYPE_TIME.value,
ddbc_sql_const.SQL_TIMESTAMP.value: ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value,
}
return sql_to_c_type.get(sql_type, ddbc_sql_const.SQL_C_DEFAULT.value)

def _create_parameter_types_list(self, parameter, param_info, parameters_list, i):
"""
Maps parameter types for the given parameter.
Expand All @@ -474,9 +540,20 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i
paraminfo.
"""
paraminfo = param_info()
sql_type, c_type, column_size, decimal_digits = self._map_sql_type(
parameter, parameters_list, i
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double-check that all parameterized queries remain fully protected against SQL injection—even when input sizes or types are set by users via setinputsizes. It's important to ensure that user-supplied values for input sizes/types cannot be used to inject malicious SQL or bypass query parameterization. If possible, add validation or sanitization where needed, and consider adding a test case for this scenario.

# Check if we have explicit type information from setinputsizes
if hasattr(self, '_inputsizes') and self._inputsizes and i < len(self._inputsizes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are places where you check if the object (self) has an attribute called _inputsizes using hasattr(self, '_inputsizes').
The reviewer noticed that the _inputsizes attribute is always created (initialized) when the object is constructed (in the class’s __init__ method).
If an attribute is always present (because it’s defined in the constructor), you don’t need to check if it exists every time you use it.
(It will always exist, unless something very unusual happens in your code.)
These hasattr checks are, therefore, unnecessary ("redundant").
Removing them will make your code cleaner, easier to read, and easier to maintain.

# Use explicit type information
sql_type, column_size, decimal_digits = self._inputsizes[i]

# Determine the appropriate C type based on SQL type
c_type = self._get_c_type_for_sql_type(sql_type)
else:
# Fall back to automatic type inference
sql_type, c_type, column_size, decimal_digits = self._map_sql_type(
parameter, parameters_list, i
)

paraminfo.paramCType = c_type
paraminfo.paramSQLType = sql_type
paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value
Expand Down Expand Up @@ -643,6 +720,8 @@ def execute(
except Exception as e:
# If describe fails, it's likely there are no results (e.g., for INSERT)
self.description = None

self._reset_inputsizes() # Reset input sizes after execution

@staticmethod
def _select_best_sample_value(column):
Expand Down Expand Up @@ -709,7 +788,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
if not seq_of_parameters:
self.rowcount = 0
return

# Apply timeout if set (non-zero)
if self._timeout > 0:
try:
Expand All @@ -720,23 +799,62 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
timeout_value
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
log('debug', f"Set query timeout to {self._timeout} seconds")
log('debug', f"Set query timeout to {timeout_value} seconds")
except Exception as e:
log('warning', f"Failed to set query timeout: {e}")

param_info = ddbc_bindings.ParamInfo
param_count = len(seq_of_parameters[0])
parameters_type = []

# Make a copy of the parameters for potential transformation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider raising a warning or error if the number of input sizes set via setinputsizes does not match the number of parameters provided to executemany. This will help catch user mistakes early and prevent subtle bugs due to mismatched parameter and input size definitions.

processed_parameters = [list(params) for params in seq_of_parameters]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is using [list(params) for params in seq_of_parameters] to create a new list of lists from seq_of_parameters (which is probably a list or sequence of parameters for batch inserts).

When you do this for a very large number of rows (for example, thousands or millions), it creates a copy of every row in memory. This can use a lot of memory and might slow things down or even cause crashes if there isn’t enough memory.

If possible, don’t create a big copy of all the data at once.
Instead, you could use a generator expression (which makes one item at a time, only when needed) or change the items in place (if it’s safe to do so).

Current Implementation:

processed_parameters = [list(params) for params in seq_of_parameters]

This creates a new list in memory that contains a copy of every params as a list.
If seq_of_parameters has 1,000,000 items, Python immediately builds a list with 1,000,000 copies in memory.
This can use a lot of memory at once.

Generator Expression:

processed_parameters = (list(params) for params in seq_of_parameters)

This creates a generator—not a list. It doesn’t copy anything right away.
Each list(params) is created only when you need it (for example, when you loop over new_seq).
Much less memory is used because only one item is in memory at a time.

List comprehension is eager: makes everything up front, uses more memory.
Generator expression is lazy: makes each result only when needed, uses less memory.


# Check if we have explicit input sizes set
if hasattr(self, '_inputsizes') and self._inputsizes:
# Use the explicitly set input sizes
for col_index in range(param_count):
if col_index < len(self._inputsizes):
sql_type, column_size, decimal_digits = self._inputsizes[col_index]
c_type = self._get_c_type_for_sql_type(sql_type)

# If using SQL_DECIMAL/NUMERIC, we need to ensure the Python values
# are properly converted for the driver
if sql_type in (ddbc_sql_const.SQL_DECIMAL.value, ddbc_sql_const.SQL_NUMERIC.value):
# Make sure all values in this column are Decimal objects
for row_idx, row in enumerate(processed_parameters):
if not isinstance(row[col_index], decimal.Decimal):
# Convert to Decimal if it's not already
processed_parameters[row_idx][col_index] = decimal.Decimal(str(row[col_index]))

paraminfo = param_info()
paraminfo.paramCType = c_type
paraminfo.paramSQLType = sql_type
paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value
paraminfo.columnSize = column_size
paraminfo.decimalDigits = decimal_digits
parameters_type.append(paraminfo)
else:
# Fall back to auto-detect for any parameters beyond those specified
column = [row[col_index] for row in seq_of_parameters]
sample_value = self._select_best_sample_value(column)
dummy_row = list(seq_of_parameters[0])
parameters_type.append(
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
)
else:
# No input sizes set, use auto-detection
for col_index in range(param_count):
column = [row[col_index] for row in seq_of_parameters]
sample_value = self._select_best_sample_value(column)
dummy_row = list(seq_of_parameters[0])
parameters_type.append(
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
)

for col_index in range(param_count):
column = [row[col_index] for row in seq_of_parameters]
sample_value = self._select_best_sample_value(column)
dummy_row = list(seq_of_parameters[0])
parameters_type.append(
self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index)
)

columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters)
columnwise_params = self._transpose_rowwise_to_columnwise(processed_parameters)

log('info', "Executing batch query with %d parameter sets:\n%s",
len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters))
)
Expand All @@ -749,11 +867,17 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
parameters_type,
len(seq_of_parameters)
)
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)

try:
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret)
self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()
finally:
# Reset input sizes after execution
self._reset_inputsizes()

self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt)
self.last_executed_stmt = operation
self._initialize_description()


def fetchone(self) -> Union[None, Row]:
"""
Expand Down
Loading