Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
98737c4
Fix not fetching cursor on insert/update
rad-pat Feb 6, 2025
dcb77c4
Fetch the cursor on insert/update/delete/copy into
rad-pat Feb 7, 2025
db05b94
Enable CTE tests
rad-pat Feb 7, 2025
754c7c6
Enable further tests
rad-pat Feb 10, 2025
909f53e
Support for table and column comments
rad-pat Mar 18, 2025
1bb2887
Include CTE test now bug is fixed
rad-pat Mar 18, 2025
d93a78c
Run against nightly
rad-pat Mar 18, 2025
5e6f8bf
Update pipenv
rad-pat Mar 18, 2025
1320d83
Work in SQLAlchemy 1.4
rad-pat Mar 18, 2025
dbca246
Handle JSON params
rad-pat Mar 31, 2025
df75e37
Update Reserved Words
rad-pat Mar 31, 2025
032081f
Changes to File Formats
rad-pat May 14, 2025
6c92f44
Fix copyInto files clause
rad-pat Jun 4, 2025
c463b77
Fix test - input bytes will differ, table data is random
rad-pat Jun 5, 2025
1a2fe3e
remove code for Geometry, Geography and structured types.
simozzy Jun 19, 2025
d400c57
Update test.yml
simozzy Jun 19, 2025
41817a4
reverted change to test file
simozzy Jun 24, 2025
2c5bce3
Added tests for TINYINT and BITMAP
simozzy Jun 25, 2025
e1160cf
Added tests for DOUBLE
simozzy Jun 25, 2025
9d8bd00
Update test.yml
simozzy Jun 25, 2025
485dd1b
Merge branch 'main' into sc-20112
simozzy Jun 25, 2025
1115a37
Update test_sqlalchemy.py
simozzy Jun 25, 2025
08328cb
Ensure tests work on sqlalchemy versions 1.4.54 and 2.0 +
simozzy Jun 25, 2025
840c68d
Added types for GEOMETRY and GEOGRAPHY.
simozzy Jun 25, 2025
2dd9c68
Added Zip compression type
simozzy Jun 30, 2025
81cd3b5
Added code to enable geo tables in tests.
simozzy Jul 1, 2025
381f85b
move initialization of enable_geo_create_table
simozzy Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions databend_sqlalchemy/databend_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Table("some_table", metadata, ..., databend_transient=True|False)

"""

import decimal
import re
import operator
Expand Down Expand Up @@ -60,6 +59,17 @@
CHAR,
TIMESTAMP,
)

import sqlalchemy
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.base import Executable

# Check SQLAlchemy version
if sqlalchemy.__version__.startswith('2.'):
from sqlalchemy.types import DOUBLE
else:
from .types import DOUBLE

from sqlalchemy.engine import ExecutionContext, default
from sqlalchemy.exc import DBAPIError, NoSuchTableError

Expand All @@ -71,7 +81,7 @@
AzureBlobStorage,
AmazonS3,
)
from .types import INTERVAL
from .types import INTERVAL, TINYINT, BITMAP, GEOMETRY, GEOGRAPHY

RESERVED_WORDS = {
"Error",
Expand Down Expand Up @@ -693,6 +703,7 @@ def __init__(self, key_type, value_type):
super(MAP, self).__init__()



class DatabendDate(sqltypes.DATE):
__visit_name__ = "DATE"

Expand Down Expand Up @@ -793,12 +804,26 @@ class DatabendInterval(INTERVAL):
render_bind_cast = True


class DatabendBitmap(BITMAP):
render_bind_cast = True


class DatabendTinyInt(TINYINT):
render_bind_cast = True


class DatabendGeometry(GEOMETRY):
render_bind_cast = True

class DatabendGeography(GEOGRAPHY):
render_bind_cast = True

# Type converters
ischema_names = {
"bigint": BIGINT,
"int": INTEGER,
"smallint": SMALLINT,
"tinyint": SMALLINT,
"tinyint": DatabendTinyInt,
"int64": BIGINT,
"int32": INTEGER,
"int16": SMALLINT,
Expand All @@ -813,7 +838,7 @@ class DatabendInterval(INTERVAL):
"datetime": DatabendDateTime,
"timestamp": DatabendDateTime,
"float": FLOAT,
"double": FLOAT,
"double": DOUBLE,
"float64": FLOAT,
"float32": FLOAT,
"string": VARCHAR,
Expand All @@ -826,8 +851,13 @@ class DatabendInterval(INTERVAL):
"binary": BINARY,
"time": DatabendTime,
"interval": DatabendInterval,
"bitmap": DatabendBitmap,
"geometry": DatabendGeometry,
"geography": DatabendGeography
}



# Column spec
colspecs = {
sqltypes.Interval: DatabendInterval,
Expand Down Expand Up @@ -1227,6 +1257,29 @@ def visit_TIME(self, type_, **kw):
def visit_INTERVAL(self, type, **kw):
return "INTERVAL"

def visit_DOUBLE(self, type_, **kw):
return "DOUBLE"

def visit_TINYINT(self, type_, **kw):
return "TINYINT"

def visit_FLOAT(self, type_, **kw):
return "FLOAT"

def visit_BITMAP(self, type_, **kw):
return "BITMAP"

def visit_GEOMETRY(self, type_, **kw):
if type_.srid is not None:
return f"GEOMETRY(SRID {type_.srid})"
return "GEOMETRY"

def visit_GEOGRAPHY(self, type_, **kw):
if type_.srid is not None:
return f"GEOGRAPHY(SRID {type_.srid})"
return "GEOGRAPHY"



class DatabendDDLCompiler(compiler.DDLCompiler):
def visit_primary_key_constraint(self, constraint, **kw):
Expand Down
1 change: 1 addition & 0 deletions databend_sqlalchemy/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class Compression(Enum):
RAW_DEFLATE = "RAW_DEFLATE"
XZ = "XZ"
SNAPPY = "SNAPPY"
ZIP = "ZIP"


class CopyFormat(ClauseElement):
Expand Down
78 changes: 78 additions & 0 deletions databend_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime as dt
from typing import Optional, Type, Any

from sqlalchemy import func
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql import type_api
Expand Down Expand Up @@ -73,3 +74,80 @@ def process(value: dt.timedelta) -> str:
return f"to_interval('{value.total_seconds()} seconds')"

return process


class TINYINT(sqltypes.Integer):
__visit_name__ = "TINYINT"
native = True


class DOUBLE(sqltypes.Float):
__visit_name__ = "DOUBLE"
native = True


class FLOAT(sqltypes.Float):
__visit_name__ = "FLOAT"
native = True


# The “CamelCase” types are to the greatest degree possible database agnostic

# For these datatypes, specific SQLAlchemy dialects provide backend-specific “UPPERCASE” datatypes, for a SQL type that has no analogue on other backends


class BITMAP(sqltypes.TypeEngine):
__visit_name__ = "BITMAP"
render_bind_cast = True

def __init__(self, **kwargs):
super(BITMAP, self).__init__()

def process_result_value(self, value, dialect):
if value is None:
return None
# Databend returns bitmaps as strings of comma-separated integers
return set(int(x) for x in value.split(',') if x)

def bind_expression(self, bindvalue):
return func.to_bitmap(bindvalue, type_=self)

def column_expression(self, col):
# Convert bitmap to string using a custom function
return func.to_string(col, type_=sqltypes.String)

def bind_processor(self, dialect):
def process(value):
if value is None:
return None
if isinstance(value, set):
return ','.join(str(x) for x in sorted(value))
return str(value)
return process

def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
return set(int(x) for x in value.split(',') if x)
return process


class GEOMETRY(sqltypes.TypeEngine):
__visit_name__ = "GEOMETRY"

def __init__(self, srid=None):
super(GEOMETRY, self).__init__()
self.srid = srid



class GEOGRAPHY(sqltypes.TypeEngine):
__visit_name__ = "GEOGRAPHY"
native = True

def __init__(self, srid=None):
super(GEOGRAPHY, self).__init__()
self.srid = srid


20 changes: 14 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlalchemy.dialects import registry
from sqlalchemy import event, Engine, text
import pytest

registry.register("databend.databend", "databend_sqlalchemy.databend_dialect", "DatabendDialect")
Expand All @@ -9,9 +8,18 @@

from sqlalchemy.testing.plugin.pytestplugin import *

from packaging import version
import sqlalchemy
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
from sqlalchemy import event, text
from sqlalchemy import Engine


@event.listens_for(Engine, "connect")
def receive_engine_connect(conn, r):
cur = conn.cursor()
cur.execute('SET global format_null_as_str = 0')
cur.execute('SET global enable_geo_create_table = 1')
cur.close()


@event.listens_for(Engine, "connect")
def receive_engine_connect(conn, r):
cur = conn.cursor()
cur.execute('SET global format_null_as_str = 0')
cur.close()
89 changes: 46 additions & 43 deletions tests/test_copy_into.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
FileColumnClause,
StageClause,
)
import sqlalchemy
from packaging import version


class CompileDatabendCopyIntoTableTest(fixtures.TestBase, AssertsCompiledSQL):
Expand Down Expand Up @@ -215,51 +217,52 @@ def define_tables(cls, metadata):
Column("data", String(50)),
)

def test_copy_into_stage_and_table(self, connection):
# create stage
connection.execute(text('CREATE OR REPLACE STAGE mystage'))
# copy into stage from random table limiting 1000
table = self.tables.random_data
query = table.select().limit(1000)
if version.parse(sqlalchemy.__version__) >= version.parse('2.0.0'):
def test_copy_into_stage_and_table(self, connection):
# create stage
connection.execute(text('CREATE OR REPLACE STAGE mystage'))
# copy into stage from random table limiting 1000
table = self.tables.random_data
query = table.select().limit(1000)

copy_into = CopyIntoLocation(
target=StageClause(
name='mystage'
),
from_=query,
file_format=ParquetFormat(),
options=CopyIntoLocationOptions()
)
r = connection.execute(
copy_into
)
eq_(r.rowcount, 1000)
copy_into_results = r.context.copy_into_location_results()
eq_(copy_into_results['rows_unloaded'], 1000)
# eq_(copy_into_results['input_bytes'], 16250) # input bytes will differ, the table is random
# eq_(copy_into_results['output_bytes'], 4701) # output bytes differs
copy_into = CopyIntoLocation(
target=StageClause(
name='mystage'
),
from_=query,
file_format=ParquetFormat(),
options=CopyIntoLocationOptions()
)
r = connection.execute(
copy_into
)
eq_(r.rowcount, 1000)
copy_into_results = r.context.copy_into_location_results()
eq_(copy_into_results['rows_unloaded'], 1000)
# eq_(copy_into_results['input_bytes'], 16250) # input bytes will differ, the table is random
# eq_(copy_into_results['output_bytes'], 4701) # output bytes differs

# now copy into table
# now copy into table

copy_into_table = CopyIntoTable(
target=self.tables.loaded,
from_=StageClause(
name='mystage'
),
file_format=ParquetFormat(),
options=CopyIntoTableOptions()
)
r = connection.execute(
copy_into_table
)
eq_(r.rowcount, 1000)
copy_into_table_results = r.context.copy_into_table_results()
assert len(copy_into_table_results) == 1
result = copy_into_table_results[0]
assert result['file'].endswith('.parquet')
eq_(result['rows_loaded'], 1000)
eq_(result['errors_seen'], 0)
eq_(result['first_error'], None)
eq_(result['first_error_line'], None)
copy_into_table = CopyIntoTable(
target=self.tables.loaded,
from_=StageClause(
name='mystage'
),
file_format=ParquetFormat(),
options=CopyIntoTableOptions()
)
r = connection.execute(
copy_into_table
)
eq_(r.rowcount, 1000)
copy_into_table_results = r.context.copy_into_table_results()
assert len(copy_into_table_results) == 1
result = copy_into_table_results[0]
assert result['file'].endswith('.parquet')
eq_(result['rows_loaded'], 1000)
eq_(result['errors_seen'], 0)
eq_(result['first_error'], None)
eq_(result['first_error_line'], None)


Loading