Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 140 additions & 17 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,9 @@ def fetchone(self) -> Union[None, Row]:
# Update internal position after successful fetch
self._increment_rownumber()

# Create and return a Row object
return Row(row_data, self.description)
# Create and return a Row object, passing column name map if available
column_map = getattr(self, '_column_name_map', None)
return Row(row_data, self.description, column_map)
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -948,7 +949,8 @@ def fetchmany(self, size: int = None) -> List[Row]:
self._rownumber = self._next_row_index - 1

# Convert raw data to Row objects
return [Row(row_data, self.description) for row_data in rows_data]
column_map = getattr(self, '_column_name_map', None)
return [Row(row_data, self.description, column_map) for row_data in rows_data]
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -977,7 +979,8 @@ def fetchall(self) -> List[Row]:
self._rownumber = self._next_row_index - 1

# Convert raw data to Row objects
return [Row(row_data, self.description) for row_data in rows_data]
column_map = getattr(self, '_column_name_map', None)
return [Row(row_data, self.description, column_map) for row_data in rows_data]
except Exception as e:
# On error, don't increment rownumber - rethrow the error
raise e
Expand Down Expand Up @@ -1258,19 +1261,139 @@ def skip(self, count: int) -> None:
# Clear messages
self.messages = []

# Validate arguments
if not isinstance(count, int):
raise ProgrammingError("Count must be an integer", "Invalid argument type")
# Simply delegate to the scroll method with 'relative' mode
self.scroll(count, 'relative')

def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None,
table_type=None, search_escape=None):
"""
Execute SQLTables ODBC function to retrieve table metadata.

if count < 0:
raise NotSupportedError("Negative skip values are not supported", "Backward scrolling not supported")
Args:
stmt_handle: ODBC statement handle
catalog_name: The catalog name pattern
schema_name: The schema name pattern
table_name: The table name pattern
table_type: The table type filter
search_escape: The escape character for pattern matching
"""
# Convert None values to empty strings for ODBC
catalog = "" if catalog_name is None else catalog_name
schema = "" if schema_name is None else schema_name
table = "" if table_name is None else table_name
types = "" if table_type is None else table_type

# Skip zero is a no-op
if count == 0:
return
# Call the ODBC SQLTables function
retcode = ddbc_bindings.DDBCSQLTables(
stmt_handle,
catalog,
schema,
table,
types
)

# Check return code and handle errors
check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode)

# Capture any diagnostic messages
if stmt_handle:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle))

def tables(self, table=None, catalog=None, schema=None, tableType=None):
"""
Returns information about tables in the database that match the given criteria using
the SQLTables ODBC function.

Args:
table (str, optional): The table name pattern. Default is None (all tables).
catalog (str, optional): The catalog name. Default is None.
schema (str, optional): The schema name pattern. Default is None.
tableType (str or list, optional): The table type filter. Default is None.
Example: "TABLE" or ["TABLE", "VIEW"]

Returns:
list: A list of Row objects containing table information with these columns:
- table_cat: Catalog name
- table_schem: Schema name
- table_name: Table name
- table_type: Table type (e.g., "TABLE", "VIEW")
- remarks: Comments about the table

Notes:
This method only processes the standard five columns as defined in the ODBC
specification. Any additional columns that might be returned by specific ODBC
drivers are not included in the result set.

Example:
# Get all tables in the database
tables = cursor.tables()

# Get all tables in schema 'dbo'
tables = cursor.tables(schema='dbo')

# Get table named 'Customers'
tables = cursor.tables(table='Customers')

# Get all views
tables = cursor.tables(tableType='VIEW')
"""
self._check_closed()

# Clear messages
self.messages = []

# Always reset the cursor first to ensure clean state
self._reset_cursor()

# Format table_type parameter - SQLTables expects comma-separated string
table_type_str = None
if tableType is not None:
if isinstance(tableType, (list, tuple)):
table_type_str = ",".join(tableType)
else:
table_type_str = str(tableType)

# Call SQLTables via the helper method
self._execute_tables(
self.hstmt,
catalog_name=catalog,
schema_name=schema,
table_name=table,
table_type=table_type_str
)

# Initialize description from column metadata
column_metadata = []
try:
ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata)
self._initialize_description(column_metadata)
except Exception:
# If describe fails, create a manual description for the standard columns
column_types = [str, str, str, str, str]
self.description = [
("table_cat", column_types[0], None, 128, 128, 0, True),
("table_schem", column_types[1], None, 128, 128, 0, True),
("table_name", column_types[2], None, 128, 128, 0, False),
("table_type", column_types[3], None, 128, 128, 0, False),
("remarks", column_types[4], None, 254, 254, 0, True)
]

# Define column names in ODBC standard order
column_names = [
"table_cat", "table_schem", "table_name", "table_type", "remarks"
]

# Fetch all rows
rows_data = []
ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)

# Create a column map for attribute access
column_map = {name: i for i, name in enumerate(column_names)}

# Create Row objects with the column map
result_rows = []
for row_data in rows_data:
row = Row(row_data, self.description, column_map)
result_rows.append(row)

# Skip the rows by fetching and discarding
for _ in range(count):
row = self.fetchone()
if row is None:
raise IndexError("Cannot skip beyond the end of the result set")
return result_rows
95 changes: 94 additions & 1 deletion mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr;

// Diagnostic APIs
SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr;
SQLTablesFunc SQLTables_ptr = nullptr;

namespace {

Expand Down Expand Up @@ -786,6 +787,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLFreeStmt_ptr = GetFunctionPointer<SQLFreeStmtFunc>(handle, "SQLFreeStmt");

SQLGetDiagRec_ptr = GetFunctionPointer<SQLGetDiagRecFunc>(handle, "SQLGetDiagRecW");
SQLTables_ptr = GetFunctionPointer<SQLTablesFunc>(handle, "SQLTablesW");

bool success =
SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr &&
Expand All @@ -796,7 +798,7 @@ DriverHandle LoadDriverOrThrowException() {
SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr &&
SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr &&
SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr &&
SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLTables_ptr;

if (!success) {
ThrowStdException("Failed to load required function pointers from driver.");
Expand Down Expand Up @@ -982,6 +984,91 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q
return ret;
}

// Wrapper for SQLTables
SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle,
const std::wstring& catalog,
const std::wstring& schema,
const std::wstring& table,
const std::wstring& tableType) {

if (!SQLTables_ptr) {
LOG("Function pointer not initialized. Loading the driver.");
DriverLoader::getInstance().loadDriver();
}

SQLWCHAR* catalogPtr = nullptr;
SQLWCHAR* schemaPtr = nullptr;
SQLWCHAR* tablePtr = nullptr;
SQLWCHAR* tableTypePtr = nullptr;
SQLSMALLINT catalogLen = 0;
SQLSMALLINT schemaLen = 0;
SQLSMALLINT tableLen = 0;
SQLSMALLINT tableTypeLen = 0;

std::vector<SQLWCHAR> catalogBuffer;
std::vector<SQLWCHAR> schemaBuffer;
std::vector<SQLWCHAR> tableBuffer;
std::vector<SQLWCHAR> tableTypeBuffer;

#if defined(__APPLE__) || defined(__linux__)
// On Unix platforms, convert wstring to SQLWCHAR array
if (!catalog.empty()) {
catalogBuffer = WStringToSQLWCHAR(catalog);
catalogPtr = catalogBuffer.data();
catalogLen = SQL_NTS;
}
if (!schema.empty()) {
schemaBuffer = WStringToSQLWCHAR(schema);
schemaPtr = schemaBuffer.data();
schemaLen = SQL_NTS;
}
if (!table.empty()) {
tableBuffer = WStringToSQLWCHAR(table);
tablePtr = tableBuffer.data();
tableLen = SQL_NTS;
}
if (!tableType.empty()) {
tableTypeBuffer = WStringToSQLWCHAR(tableType);
tableTypePtr = tableTypeBuffer.data();
tableTypeLen = SQL_NTS;
}
#else
// On Windows, direct assignment works
if (!catalog.empty()) {
catalogPtr = const_cast<SQLWCHAR*>(catalog.c_str());
catalogLen = SQL_NTS;
}
if (!schema.empty()) {
schemaPtr = const_cast<SQLWCHAR*>(schema.c_str());
schemaLen = SQL_NTS;
}
if (!table.empty()) {
tablePtr = const_cast<SQLWCHAR*>(table.c_str());
tableLen = SQL_NTS;
}
if (!tableType.empty()) {
tableTypePtr = const_cast<SQLWCHAR*>(tableType.c_str());
tableTypeLen = SQL_NTS;
}
#endif

SQLRETURN ret = SQLTables_ptr(
StatementHandle->get(),
catalogPtr, catalogLen,
schemaPtr, schemaLen,
tablePtr, tableLen,
tableTypePtr, tableTypeLen
);

if (!SQL_SUCCEEDED(ret)) {
LOG("SQLTables failed with return code: {}", ret);
} else {
LOG("SQLTables succeeded");
}

return ret;
}

// Executes the provided query. If the query is parametrized, it prepares the statement and
// binds the parameters. Otherwise, it executes the query directly.
// 'usePrepare' parameter can be used to disable the prepare step for queries that might already
Expand Down Expand Up @@ -2616,6 +2703,12 @@ PYBIND11_MODULE(ddbc_bindings, m) {
m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords,
"Get all diagnostic records for a handle",
py::arg("handle"));
// Add to PYBIND11_MODULE section
m.def("DDBCSQLTables", &SQLTables_wrap,
"Get table information using ODBC SQLTables",
py::arg("StatementHandle"), py::arg("catalog") = std::wstring(),
py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(),
py::arg("tableType") = std::wstring());

// Add a version attribute
m.attr("__version__") = "1.0.0";
Expand Down
14 changes: 13 additions & 1 deletion mssql_python/pybind/ddbc_bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,18 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR
typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT);
typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER,
SQLSMALLINT, SQLSMALLINT*, SQLPOINTER);

typedef SQLRETURN (*SQLTablesFunc)(
SQLHSTMT StatementHandle,
SQLWCHAR* CatalogName,
SQLSMALLINT NameLength1,
SQLWCHAR* SchemaName,
SQLSMALLINT NameLength2,
SQLWCHAR* TableName,
SQLSMALLINT NameLength3,
SQLWCHAR* TableType,
SQLSMALLINT NameLength4
);

// Transaction APIs
typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT);

Expand Down Expand Up @@ -148,6 +159,7 @@ extern SQLBindColFunc SQLBindCol_ptr;
extern SQLDescribeColFunc SQLDescribeCol_ptr;
extern SQLMoreResultsFunc SQLMoreResults_ptr;
extern SQLColAttributeFunc SQLColAttribute_ptr;
extern SQLTablesFunc SQLTables_ptr;

// Transaction APIs
extern SQLEndTranFunc SQLEndTran_ptr;
Expand Down
30 changes: 15 additions & 15 deletions mssql_python/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@ class Row:
print(row.column_name) # Access by column name
"""

def __init__(self, values, cursor_description):
def __init__(self, values, description, column_map=None):
"""
Initialize a Row object with values and cursor description.
Initialize a Row object with values and description.

Args:
values: List of values for this row
cursor_description: The cursor description containing column metadata
values: List of values for this row.
description: Description of the columns (from cursor.description).
column_map: Optional mapping of column names to indices.
"""
self._values = values
self._description = description

# TODO: ADO task - Optimize memory usage by sharing column map across rows
# Instead of storing the full cursor_description in each Row object:
# 1. Build the column map once at the cursor level after setting description
# 2. Pass only this map to each Row instance
# 3. Remove cursor_description from Row objects entirely

# Create mapping of column names to indices
self._column_map = {}
for i, desc in enumerate(cursor_description):
if desc and desc[0]: # Ensure column name exists
self._column_map[desc[0]] = i
# Build column map if not provided
if column_map is None:
self._column_map = {}
for i, desc in enumerate(description):
col_name = desc[0]
self._column_map[col_name] = i
self._column_map[col_name.lower()] = i # Add lowercase for case-insensitivity
else:
self._column_map = column_map

def __getitem__(self, index):
"""Allow accessing by numeric index: row[0]"""
Expand Down
Loading