From e4f56c7c48328baacc262bd676353cc06579af7a Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 21 Aug 2025 21:04:20 +0530 Subject: [PATCH] FEAT: Adding setinputsizes --- mssql_python/cursor.py | 158 +++++++++++++++++++++++++++++---- tests/test_004_cursor.py | 183 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+), 17 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 91d4b638..9183da98 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -55,6 +55,7 @@ def __init__(self, connection, timeout: int = 0) -> None: """ self.connection = connection self._timeout = timeout + self._inputsizes = None # self.connection.autocommit = False self.hstmt = None self._initialize_cursor() @@ -463,6 +464,71 @@ def _check_closed(self): if self.closed: raise Exception("Operation cannot be performed: the cursor is closed.") + def setinputsizes(self, sizes): + """ + Sets the type information to be used for parameters in execute and executemany. + + This method can be used to explicitly declare the types and sizes of query parameters. + For example: + + sql = "INSERT INTO product (item, price) VALUES (?, ?)" + params = [('bicycle', 499.99), ('ham', 17.95)] + # specify that parameters are for NVARCHAR(50) and DECIMAL(18,4) columns + cursor.setinputsizes([(SQL_WVARCHAR, 50, 0), (SQL_DECIMAL, 18, 4)]) + cursor.executemany(sql, params) + + Args: + sizes: A sequence of tuples, one for each parameter. Each tuple contains + (sql_type, size, decimal_digits) where size and decimal_digits are optional. + """ + self._inputsizes = [] + + if sizes: + for size_info in sizes: + if isinstance(size_info, tuple): + # Handle tuple format (sql_type, size, decimal_digits) + if len(size_info) == 1: + self._inputsizes.append((size_info[0], 0, 0)) + elif len(size_info) == 2: + self._inputsizes.append((size_info[0], size_info[1], 0)) + elif len(size_info) >= 3: + self._inputsizes.append((size_info[0], size_info[1], size_info[2])) + else: + # Handle single value (just sql_type) + self._inputsizes.append((size_info, 0, 0)) + + def _reset_inputsizes(self): + """Reset input sizes after execution""" + self._inputsizes = None + + def _get_c_type_for_sql_type(self, sql_type): + """Map SQL type to appropriate C type for parameter binding""" + sql_to_c_type = { + ddbc_sql_const.SQL_CHAR.value: ddbc_sql_const.SQL_C_CHAR.value, + ddbc_sql_const.SQL_VARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value, + ddbc_sql_const.SQL_LONGVARCHAR.value: ddbc_sql_const.SQL_C_CHAR.value, + ddbc_sql_const.SQL_WCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, + ddbc_sql_const.SQL_WVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, + ddbc_sql_const.SQL_WLONGVARCHAR.value: ddbc_sql_const.SQL_C_WCHAR.value, + ddbc_sql_const.SQL_DECIMAL.value: ddbc_sql_const.SQL_C_NUMERIC.value, + ddbc_sql_const.SQL_NUMERIC.value: ddbc_sql_const.SQL_C_NUMERIC.value, + ddbc_sql_const.SQL_BIT.value: ddbc_sql_const.SQL_C_BIT.value, + ddbc_sql_const.SQL_TINYINT.value: ddbc_sql_const.SQL_C_TINYINT.value, + ddbc_sql_const.SQL_SMALLINT.value: ddbc_sql_const.SQL_C_SHORT.value, + ddbc_sql_const.SQL_INTEGER.value: ddbc_sql_const.SQL_C_LONG.value, + ddbc_sql_const.SQL_BIGINT.value: ddbc_sql_const.SQL_C_SBIGINT.value, + ddbc_sql_const.SQL_REAL.value: ddbc_sql_const.SQL_C_FLOAT.value, + ddbc_sql_const.SQL_FLOAT.value: ddbc_sql_const.SQL_C_DOUBLE.value, + ddbc_sql_const.SQL_DOUBLE.value: ddbc_sql_const.SQL_C_DOUBLE.value, + ddbc_sql_const.SQL_BINARY.value: ddbc_sql_const.SQL_C_BINARY.value, + ddbc_sql_const.SQL_VARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value, + ddbc_sql_const.SQL_LONGVARBINARY.value: ddbc_sql_const.SQL_C_BINARY.value, + ddbc_sql_const.SQL_DATE.value: ddbc_sql_const.SQL_C_TYPE_DATE.value, + ddbc_sql_const.SQL_TIME.value: ddbc_sql_const.SQL_C_TYPE_TIME.value, + ddbc_sql_const.SQL_TIMESTAMP.value: ddbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, + } + return sql_to_c_type.get(sql_type, ddbc_sql_const.SQL_C_DEFAULT.value) + def _create_parameter_types_list(self, parameter, param_info, parameters_list, i): """ Maps parameter types for the given parameter. @@ -474,9 +540,20 @@ def _create_parameter_types_list(self, parameter, param_info, parameters_list, i paraminfo. """ paraminfo = param_info() - sql_type, c_type, column_size, decimal_digits = self._map_sql_type( - parameter, parameters_list, i - ) + + # Check if we have explicit type information from setinputsizes + if hasattr(self, '_inputsizes') and self._inputsizes and i < len(self._inputsizes): + # Use explicit type information + sql_type, column_size, decimal_digits = self._inputsizes[i] + + # Determine the appropriate C type based on SQL type + c_type = self._get_c_type_for_sql_type(sql_type) + else: + # Fall back to automatic type inference + sql_type, c_type, column_size, decimal_digits = self._map_sql_type( + parameter, parameters_list, i + ) + paraminfo.paramCType = c_type paraminfo.paramSQLType = sql_type paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value @@ -643,6 +720,8 @@ def execute( except Exception as e: # If describe fails, it's likely there are no results (e.g., for INSERT) self.description = None + + self._reset_inputsizes() # Reset input sizes after execution @staticmethod def _select_best_sample_value(column): @@ -709,7 +788,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: if not seq_of_parameters: self.rowcount = 0 return - + # Apply timeout if set (non-zero) if self._timeout > 0: try: @@ -720,23 +799,62 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: timeout_value ) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - log('debug', f"Set query timeout to {self._timeout} seconds") + log('debug', f"Set query timeout to {timeout_value} seconds") except Exception as e: log('warning', f"Failed to set query timeout: {e}") param_info = ddbc_bindings.ParamInfo param_count = len(seq_of_parameters[0]) parameters_type = [] + + # Make a copy of the parameters for potential transformation + processed_parameters = [list(params) for params in seq_of_parameters] + + # Check if we have explicit input sizes set + if hasattr(self, '_inputsizes') and self._inputsizes: + # Use the explicitly set input sizes + for col_index in range(param_count): + if col_index < len(self._inputsizes): + sql_type, column_size, decimal_digits = self._inputsizes[col_index] + c_type = self._get_c_type_for_sql_type(sql_type) + + # If using SQL_DECIMAL/NUMERIC, we need to ensure the Python values + # are properly converted for the driver + if sql_type in (ddbc_sql_const.SQL_DECIMAL.value, ddbc_sql_const.SQL_NUMERIC.value): + # Make sure all values in this column are Decimal objects + for row_idx, row in enumerate(processed_parameters): + if not isinstance(row[col_index], decimal.Decimal): + # Convert to Decimal if it's not already + processed_parameters[row_idx][col_index] = decimal.Decimal(str(row[col_index])) + + paraminfo = param_info() + paraminfo.paramCType = c_type + paraminfo.paramSQLType = sql_type + paraminfo.inputOutputType = ddbc_sql_const.SQL_PARAM_INPUT.value + paraminfo.columnSize = column_size + paraminfo.decimalDigits = decimal_digits + parameters_type.append(paraminfo) + else: + # Fall back to auto-detect for any parameters beyond those specified + column = [row[col_index] for row in seq_of_parameters] + sample_value = self._select_best_sample_value(column) + dummy_row = list(seq_of_parameters[0]) + parameters_type.append( + self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) + ) + else: + # No input sizes set, use auto-detection + for col_index in range(param_count): + column = [row[col_index] for row in seq_of_parameters] + sample_value = self._select_best_sample_value(column) + dummy_row = list(seq_of_parameters[0]) + parameters_type.append( + self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) + ) - for col_index in range(param_count): - column = [row[col_index] for row in seq_of_parameters] - sample_value = self._select_best_sample_value(column) - dummy_row = list(seq_of_parameters[0]) - parameters_type.append( - self._create_parameter_types_list(sample_value, param_info, dummy_row, col_index) - ) - columnwise_params = self._transpose_rowwise_to_columnwise(seq_of_parameters) + columnwise_params = self._transpose_rowwise_to_columnwise(processed_parameters) + log('info', "Executing batch query with %d parameter sets:\n%s", len(seq_of_parameters), "\n".join(f" {i+1}: {tuple(p) if isinstance(p, (list, tuple)) else p}" for i, p in enumerate(seq_of_parameters)) ) @@ -749,11 +867,17 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: parameters_type, len(seq_of_parameters) ) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + try: + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) + self.last_executed_stmt = operation + self._initialize_description() + finally: + # Reset input sizes after execution + self._reset_inputsizes() - self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) - self.last_executed_stmt = operation - self._initialize_description() + def fetchone(self) -> Union[None, Row]: """ diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 9a63e27f..b9ac5a45 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1556,6 +1556,189 @@ def test_decimal_separator_calculations(cursor, db_connection): cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") db_connection.commit() +def test_cursor_setinputsizes_basic(db_connection): + """Test the basic functionality of setinputsizes""" + from mssql_python.constants import ConstantsDDBC + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes") + cursor.execute(""" + CREATE TABLE #test_inputsizes ( + string_col NVARCHAR(100), + int_col INT + ) + """) + + # Set input sizes for parameters + cursor.setinputsizes([ + (ConstantsDDBC.SQL_WVARCHAR.value, 100, 0), + (ConstantsDDBC.SQL_INTEGER.value, 0, 0) + ]) + + # Execute with parameters + cursor.execute( + "INSERT INTO #test_inputsizes VALUES (?, ?)", + "Test String", 42 + ) + + # Verify data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes") + row = cursor.fetchone() + + assert row[0] == "Test String" + assert row[1] == 42 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes") + +def test_cursor_setinputsizes_with_executemany_float(db_connection): + """Test setinputsizes with executemany using float instead of Decimal""" + from mssql_python.constants import ConstantsDDBC + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_float") + cursor.execute(""" + CREATE TABLE #test_inputsizes_float ( + id INT, + name NVARCHAR(50), + price REAL /* Use REAL instead of DECIMAL */ + ) + """) + + # Prepare data with float values + data = [ + (1, "Item 1", 10.99), + (2, "Item 2", 20.50), + (3, "Item 3", 30.75) + ] + + # Set input sizes for parameters + cursor.setinputsizes([ + (ConstantsDDBC.SQL_INTEGER.value, 0, 0), + (ConstantsDDBC.SQL_WVARCHAR.value, 50, 0), + (ConstantsDDBC.SQL_REAL.value, 0, 0) + ]) + + # Execute with parameters + cursor.executemany( + "INSERT INTO #test_inputsizes_float VALUES (?, ?, ?)", + data + ) + + # Verify all data was inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_float ORDER BY id") + rows = cursor.fetchall() + + assert len(rows) == 3 + assert rows[0][0] == 1 + assert rows[0][1] == "Item 1" + assert abs(rows[0][2] - 10.99) < 0.001 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_float") + +def test_cursor_setinputsizes_reset(db_connection): + """Test that setinputsizes is reset after execution""" + from mssql_python.constants import ConstantsDDBC + + cursor = db_connection.cursor() + + # Create a test table + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_reset") + cursor.execute(""" + CREATE TABLE #test_inputsizes_reset ( + col1 NVARCHAR(100), + col2 INT + ) + """) + + # Set input sizes for parameters + cursor.setinputsizes([ + (ConstantsDDBC.SQL_WVARCHAR.value, 100, 0), + (ConstantsDDBC.SQL_INTEGER.value, 0, 0) + ]) + + # Execute with parameters + cursor.execute( + "INSERT INTO #test_inputsizes_reset VALUES (?, ?)", + "Test String", 42 + ) + + # Verify inputsizes was reset + assert cursor._inputsizes is None + + # Now execute again without setting input sizes + cursor.execute( + "INSERT INTO #test_inputsizes_reset VALUES (?, ?)", + "Another String", 84 + ) + + # Verify both rows were inserted correctly + cursor.execute("SELECT * FROM #test_inputsizes_reset ORDER BY col2") + rows = cursor.fetchall() + + assert len(rows) == 2 + assert rows[0][0] == "Test String" + assert rows[0][1] == 42 + assert rows[1][0] == "Another String" + assert rows[1][1] == 84 + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_reset") + +def test_cursor_setinputsizes_override_inference(db_connection): + """Test that setinputsizes overrides type inference""" + from mssql_python.constants import ConstantsDDBC + + cursor = db_connection.cursor() + + # Create a test table with specific types + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_override") + cursor.execute(""" + CREATE TABLE #test_inputsizes_override ( + small_int SMALLINT, + big_text NVARCHAR(MAX) + ) + """) + + # Set input sizes that override the default inference + # For SMALLINT, use a valid precision value (5 is typical for SMALLINT) + cursor.setinputsizes([ + (ConstantsDDBC.SQL_SMALLINT.value, 5, 0), # Use valid precision for SMALLINT + (ConstantsDDBC.SQL_WVARCHAR.value, 8000, 0) # Force short string to NVARCHAR(MAX) + ]) + + # Test with values that would normally be inferred differently + big_number = 30000 # Would normally be INTEGER or BIGINT + short_text = "abc" # Would normally be a regular NVARCHAR + + try: + cursor.execute( + "INSERT INTO #test_inputsizes_override VALUES (?, ?)", + big_number, short_text + ) + + # Verify the row was inserted (may have been truncated by SQL Server) + cursor.execute("SELECT * FROM #test_inputsizes_override") + row = cursor.fetchone() + + # SQL Server would either truncate or round the value + assert row[1] == short_text + + except Exception as e: + # If an exception occurs, it should be related to the data type conversion + # Add "invalid precision" to the expected error messages + error_text = str(e).lower() + assert any(text in error_text for text in ["overflow", "out of range", "convert", "invalid precision", "precision value"]), \ + f"Unexpected error: {e}" + + # Clean up + cursor.execute("DROP TABLE IF EXISTS #test_inputsizes_override") + def test_close(db_connection): """Test closing the cursor""" try: