Skip to content

Commit 0c8d54e

Browse files
authored
Add merge statement (#32)
* Add merge statement * ideas taken from snowflake-sqlalchemy * Fix Merge statement so a table, select or subquery may be used as source * Fix get_table_names, get_view_names for server update * Quote target table * Quote field name in update
1 parent e0197e5 commit 0c8d54e

File tree

4 files changed

+535
-3
lines changed

4 files changed

+535
-3
lines changed

README.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,32 @@ Compatibility
4040

4141
- If databend version >= v0.9.0 or later, you need to use databend-sqlalchemy version >= v0.1.0.
4242
- The databend-sqlalchemy use [databend-py](https://github.com/datafuselabs/databend-py) as internal driver when version < v0.4.0, but when version >= v0.4.0 it use [databend driver python binding](https://github.com/datafuselabs/bendsql/blob/main/bindings/python/README.md) as internal driver. The only difference between the two is that the connection parameters provided in the DSN are different. When using the corresponding version, you should refer to the connection parameters provided by the corresponding Driver.
43+
44+
45+
Merge Command Support
46+
---------------------
47+
48+
Databend SQLAlchemy supports upserts via its `Merge` custom expression.
49+
See [Merge](https://docs.databend.com/sql/sql-commands/dml/dml-merge) for full documentation.
50+
51+
The Merge command can be used as below::
52+
53+
from sqlalchemy.orm import sessionmaker
54+
from sqlalchemy import MetaData, create_engine
55+
from databend_sqlalchemy.databend_dialect import Merge
56+
57+
engine = create_engine(db.url, echo=False)
58+
session = sessionmaker(bind=engine)()
59+
connection = engine.connect()
60+
61+
meta = MetaData()
62+
meta.reflect(bind=session.bind)
63+
t1 = meta.tables['t1']
64+
t2 = meta.tables['t2']
65+
66+
merge = Merge(target=t1, source=t2, on=t1.c.t1key == t2.c.t2key)
67+
merge.when_matched_then_delete().where(t2.c.marked == 1)
68+
merge.when_matched_then_update().where(t2.c.isnewstatus == 1).values(val = t2.c.newval, status=t2.c.newstatus)
69+
merge.when_matched_then_update().values(val=t2.c.newval)
70+
merge.when_not_matched_then_insert().values(val=t2.c.newval, status=t2.c.newstatus)
71+
connection.execute(merge)

databend_sqlalchemy/databend_dialect.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# licensed under the same Apache 2.0 License
55
import decimal
66
import re
7+
import operator
78
import datetime
89
import sqlalchemy.types as sqltypes
910
from typing import Any, Dict, Optional, Union
1011
from sqlalchemy import util as sa_util
1112
from sqlalchemy.engine import reflection
12-
from sqlalchemy.sql import compiler, text, bindparam
13+
from sqlalchemy.sql import compiler, text, bindparam, select, TableClause, Select, Subquery
1314
from sqlalchemy.dialects.postgresql.base import PGCompiler, PGIdentifierPreparer
1415
from sqlalchemy.types import (
1516
BIGINT,
@@ -27,6 +28,7 @@
2728
)
2829
from sqlalchemy.engine import ExecutionContext, default
2930
from sqlalchemy.exc import DBAPIError, NoSuchTableError
31+
from .dml import Merge
3032

3133

3234
# Type decorators
@@ -231,6 +233,85 @@ def visit_not_like_op_binary(self, binary, operator, **kw):
231233
)
232234

233235

236+
def visit_merge(self, merge, **kw):
237+
clauses = "\n ".join(
238+
clause._compiler_dispatch(self, **kw)
239+
for clause in merge.clauses
240+
)
241+
source_kw = {'asfrom': True}
242+
if isinstance(merge.source, TableClause):
243+
source = select(merge.source).subquery().alias(merge.source.name)._compiler_dispatch(self, **source_kw)
244+
elif isinstance(merge.source, Select):
245+
source = merge.source.subquery().alias(merge.source.get_final_froms()[0].name)._compiler_dispatch(self, **source_kw)
246+
elif isinstance(merge.source, Subquery):
247+
source = merge.source._compiler_dispatch(self, **source_kw)
248+
249+
target_table = self.preparer.format_table(merge.target)
250+
return (
251+
f"MERGE INTO {target_table}\n"
252+
f" USING {source}\n"
253+
f" ON {merge.on}\n"
254+
f"{clauses if clauses else ''}"
255+
)
256+
257+
def visit_when_merge_matched_update(self, merge_matched_update, **kw):
258+
case_predicate = (
259+
f" AND {str(merge_matched_update.predicate._compiler_dispatch(self, **kw))}"
260+
if merge_matched_update.predicate is not None
261+
else ""
262+
)
263+
update_str = (
264+
f"WHEN MATCHED{case_predicate} THEN\n"
265+
f"\tUPDATE"
266+
)
267+
if not merge_matched_update.set:
268+
return f"{update_str} *"
269+
270+
set_list = list(merge_matched_update.set.items())
271+
if kw.get("deterministic", False):
272+
set_list.sort(key=operator.itemgetter(0))
273+
set_values = (
274+
", ".join(
275+
[
276+
f"{self.preparer.quote_identifier(set_item[0])} = {set_item[1]._compiler_dispatch(self, **kw)}"
277+
for set_item in set_list
278+
]
279+
)
280+
)
281+
return f"{update_str} SET {str(set_values)}"
282+
283+
def visit_when_merge_matched_delete(self, merge_matched_delete, **kw):
284+
case_predicate = (
285+
f" AND {str(merge_matched_delete.predicate._compiler_dispatch(self, **kw))}"
286+
if merge_matched_delete.predicate is not None
287+
else ""
288+
)
289+
return f"WHEN MATCHED{case_predicate} THEN DELETE"
290+
291+
def visit_when_merge_unmatched(self, merge_unmatched, **kw):
292+
case_predicate = (
293+
f" AND {str(merge_unmatched.predicate._compiler_dispatch(self, **kw))}"
294+
if merge_unmatched.predicate is not None
295+
else ""
296+
)
297+
insert_str = (
298+
f"WHEN NOT MATCHED{case_predicate} THEN\n"
299+
f"\tINSERT"
300+
)
301+
if not merge_unmatched.set:
302+
return f"{insert_str} *"
303+
304+
set_cols, sets_vals = zip(*merge_unmatched.set.items())
305+
set_cols, sets_vals = list(set_cols), list(sets_vals)
306+
if kw.get("deterministic", False):
307+
set_cols, sets_vals = zip(
308+
*sorted(merge_unmatched.set.items(), key=operator.itemgetter(0))
309+
)
310+
return "{} ({}) VALUES ({})".format(
311+
insert_str,
312+
", ".join(set_cols),
313+
", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_vals)),
314+
)
234315
class DatabendExecutionContext(default.DefaultExecutionContext):
235316
@sa_util.memoized_property
236317
def should_autocommit(self):
@@ -489,12 +570,17 @@ def get_indexes(self, connection, table_name, schema=None, **kw):
489570

490571
@reflection.cache
491572
def get_table_names(self, connection, schema=None, **kw):
492-
query = text("""
573+
table_name_query = """
493574
select table_name
494575
from information_schema.tables
495576
where table_schema = :schema_name
577+
"""
578+
if self.server_version_info <= (1, 2, 410):
579+
table_name_query += """
496580
and engine NOT LIKE '%VIEW%'
497581
"""
582+
query = text(
583+
table_name_query
498584
).bindparams(
499585
bindparam("schema_name", type_=sqltypes.Unicode)
500586
)
@@ -506,13 +592,20 @@ def get_table_names(self, connection, schema=None, **kw):
506592

507593
@reflection.cache
508594
def get_view_names(self, connection, schema=None, **kw):
509-
query = text(
595+
view_name_query = """
596+
select table_name
597+
from information_schema.views
598+
where table_schema = :schema_name
510599
"""
600+
if self.server_version_info <= (1, 2, 410):
601+
view_name_query = """
511602
select table_name
512603
from information_schema.tables
513604
where table_schema = :schema_name
514605
and engine LIKE '%VIEW%'
515606
"""
607+
query = text(
608+
view_name_query
516609
).bindparams(
517610
bindparam("schema_name", type_=sqltypes.Unicode)
518611
)

databend_sqlalchemy/dml.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python
2+
#
3+
# Note: parts of the file come from https://github.com/snowflakedb/snowflake-sqlalchemy
4+
# licensed under the same Apache 2.0 License
5+
6+
from sqlalchemy.sql.selectable import Select, Subquery, TableClause
7+
from sqlalchemy.sql.dml import UpdateBase
8+
from sqlalchemy.sql.elements import ClauseElement
9+
from sqlalchemy.sql.expression import select
10+
11+
12+
class _OnMergeBaseClause(ClauseElement):
13+
# __visit_name__ = "on_merge_base_clause"
14+
15+
def __init__(self):
16+
self.set = {}
17+
self.predicate = None
18+
19+
def __repr__(self):
20+
return f" AND {str(self.predicate)}" if self.predicate is not None else ""
21+
22+
def values(self, **kwargs):
23+
self.set = kwargs
24+
return self
25+
26+
def where(self, expr):
27+
self.predicate = expr
28+
return self
29+
30+
31+
class WhenMergeMatchedUpdateClause(_OnMergeBaseClause):
32+
__visit_name__ = "when_merge_matched_update"
33+
34+
def __repr__(self):
35+
case_predicate = super()
36+
update_str = f"WHEN MATCHED{case_predicate} THEN UPDATE"
37+
if not self.set:
38+
return f"{update_str} *"
39+
40+
set_values = ", ".join([f"{set_item[0]} = {set_item[1]}" for set_item in self.set.items()])
41+
return f"{update_str} SET {str(set_values)}"
42+
43+
44+
class WhenMergeMatchedDeleteClause(_OnMergeBaseClause):
45+
__visit_name__ = "when_merge_matched_delete"
46+
47+
def __repr__(self):
48+
case_predicate = super()
49+
return f"WHEN MATCHED{case_predicate} THEN DELETE"
50+
51+
52+
class WhenMergeUnMatchedClause(_OnMergeBaseClause):
53+
__visit_name__ = "when_merge_unmatched"
54+
55+
def __repr__(self):
56+
case_predicate = super()
57+
insert_str = f"WHEN NOT MATCHED{case_predicate} THEN INSERT"
58+
if not self.set:
59+
return f"{insert_str} *"
60+
61+
sets, sets_tos = zip(*self.set.items())
62+
return "{} ({}) VALUES ({})".format(
63+
insert_str,
64+
", ".join(sets),
65+
", ".join(map(str, sets_tos)),
66+
)
67+
68+
69+
class Merge(UpdateBase):
70+
__visit_name__ = "merge"
71+
_bind = None
72+
73+
def __init__(self, target, source, on):
74+
if not isinstance(source, (TableClause, Select, Subquery)):
75+
raise Exception(f'Invalid type for merge source: {source}')
76+
self.target = target
77+
self.source = source
78+
self.on = on
79+
self.clauses = []
80+
81+
def __repr__(self):
82+
clauses = " ".join([repr(clause) for clause in self.clauses])
83+
return f"MERGE INTO {self.target} USING ({select(self.source)}) AS {self.source.name} ON {self.on}" + (
84+
f" {clauses}" if clauses else ""
85+
)
86+
87+
def when_matched_then_update(self):
88+
clause = WhenMergeMatchedUpdateClause()
89+
self.clauses.append(clause)
90+
return clause
91+
92+
def when_matched_then_delete(self):
93+
clause = WhenMergeMatchedDeleteClause()
94+
self.clauses.append(clause)
95+
return clause
96+
97+
def when_not_matched_then_insert(self):
98+
clause = WhenMergeUnMatchedClause()
99+
self.clauses.append(clause)
100+
return clause

0 commit comments

Comments
 (0)