Skip to content

Commit 08e0146

Browse files
twishabansaliennaeglasnt
authored
feat(alloydb): Added generate batch embeddings sample (GoogleCloudPlatform#12721)
--------- Co-authored-by: Jennifer Davis <[email protected]> Co-authored-by: Katie McLaughlin <[email protected]> Co-authored-by: Katie McLaughlin <[email protected]>
1 parent 0fdcba8 commit 08e0146

7 files changed

+1592
-0
lines changed

alloydb/conftest.py

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2022 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import os
17+
import re
18+
import subprocess
19+
import sys
20+
import textwrap
21+
import uuid
22+
from collections.abc import Callable, Iterable
23+
from datetime import datetime
24+
from typing import AsyncIterator
25+
26+
import pytest
27+
import pytest_asyncio
28+
29+
30+
def get_env_var(key: str) -> str:
31+
v = os.environ.get(key)
32+
if v is None:
33+
raise ValueError(f"Must set env var {key}")
34+
return v
35+
36+
37+
@pytest.fixture(scope="session")
38+
def table_name() -> str:
39+
return "investments"
40+
41+
42+
@pytest.fixture(scope="session")
43+
def cluster_name() -> str:
44+
return get_env_var("ALLOYDB_CLUSTER")
45+
46+
47+
@pytest.fixture(scope="session")
48+
def instance_name() -> str:
49+
return get_env_var("ALLOYDB_INSTANCE")
50+
51+
52+
@pytest.fixture(scope="session")
53+
def region() -> str:
54+
return get_env_var("ALLOYDB_REGION")
55+
56+
57+
@pytest.fixture(scope="session")
58+
def database_name() -> str:
59+
return get_env_var("ALLOYDB_DATABASE_NAME")
60+
61+
62+
@pytest.fixture(scope="session")
63+
def password() -> str:
64+
return get_env_var("ALLOYDB_PASSWORD")
65+
66+
67+
@pytest_asyncio.fixture(scope="session")
68+
def project_id() -> str:
69+
gcp_project = get_env_var("GOOGLE_CLOUD_PROJECT")
70+
run_cmd("gcloud", "config", "set", "project", gcp_project)
71+
# Since everything requires the project, let's confiugre and show some
72+
# debugging information here.
73+
run_cmd("gcloud", "version")
74+
run_cmd("gcloud", "config", "list")
75+
return gcp_project
76+
77+
78+
def run_cmd(*cmd: str) -> subprocess.CompletedProcess:
79+
try:
80+
print(f">> {cmd}")
81+
start = datetime.now()
82+
p = subprocess.run(
83+
cmd,
84+
check=True,
85+
stdout=subprocess.PIPE,
86+
stderr=subprocess.PIPE,
87+
)
88+
print(p.stderr.decode("utf-8"))
89+
print(p.stdout.decode("utf-8"))
90+
elapsed = (datetime.now() - start).seconds
91+
minutes = int(elapsed / 60)
92+
seconds = elapsed - minutes * 60
93+
print(f"Command `{cmd[0]}` finished in {minutes}m {seconds}s")
94+
return p
95+
except subprocess.CalledProcessError as e:
96+
# Include the error message from the failed command.
97+
print(e.stderr.decode("utf-8"))
98+
print(e.stdout.decode("utf-8"))
99+
raise RuntimeError(f"{e}\n\n{e.stderr.decode('utf-8')}") from e
100+
101+
102+
def run_notebook(
103+
ipynb_file: str,
104+
prelude: str = "",
105+
section: str = "",
106+
variables: dict = {},
107+
replace: dict[str, str] = {},
108+
preprocess: Callable[[str], str] = lambda source: source,
109+
skip_shell_commands: bool = False,
110+
until_end: bool = False,
111+
) -> None:
112+
import nbformat
113+
from nbclient.client import NotebookClient
114+
from nbclient.exceptions import CellExecutionError
115+
116+
def notebook_filter_section(
117+
start: str,
118+
end: str,
119+
cells: list[nbformat.NotebookNode],
120+
until_end: bool = False,
121+
) -> Iterable[nbformat.NotebookNode]:
122+
in_section = False
123+
for cell in cells:
124+
if cell["cell_type"] == "markdown":
125+
if not in_section and cell["source"].startswith(start):
126+
in_section = True
127+
elif in_section and not until_end and cell["source"].startswith(end):
128+
return
129+
if in_section:
130+
yield cell
131+
132+
# Regular expression to match and remove shell commands from the notebook.
133+
# https://regex101.com/r/EHWBpT/1
134+
shell_command_re = re.compile(r"^!((?:[^\n]+\\\n)*(?:[^\n]+))$", re.MULTILINE)
135+
# Compile regular expressions for variable substitutions.
136+
# https://regex101.com/r/e32vfW/1
137+
compiled_substitutions = [
138+
(
139+
re.compile(rf"""\b{name}\s*=\s*(?:f?'[^']*'|f?"[^"]*"|\w+)"""),
140+
f"{name} = {repr(value)}",
141+
)
142+
for name, value in variables.items()
143+
]
144+
# Filter the section if any, otherwise use the entire notebook.
145+
nb = nbformat.read(ipynb_file, as_version=4)
146+
if section:
147+
start = section
148+
end = section.split(" ", 1)[0] + " "
149+
nb.cells = list(notebook_filter_section(start, end, nb.cells, until_end))
150+
if len(nb.cells) == 0:
151+
raise ValueError(
152+
f"Section {repr(section)} not found in notebook {repr(ipynb_file)}"
153+
)
154+
# Preprocess the cells.
155+
for cell in nb.cells:
156+
# Only preprocess code cells.
157+
if cell["cell_type"] != "code":
158+
continue
159+
# Run any custom preprocessing functions before.
160+
cell["source"] = preprocess(cell["source"])
161+
# Preprocess shell commands.
162+
if skip_shell_commands:
163+
cmd = "pass"
164+
cell["source"] = shell_command_re.sub(cmd, cell["source"])
165+
else:
166+
cell["source"] = shell_command_re.sub(r"_run(f'''\1''')", cell["source"])
167+
# Apply variable substitutions.
168+
for regex, new_value in compiled_substitutions:
169+
cell["source"] = regex.sub(new_value, cell["source"])
170+
# Apply replacements.
171+
for old, new in replace.items():
172+
cell["source"] = cell["source"].replace(old, new)
173+
# Clear outputs.
174+
cell["outputs"] = []
175+
# Prepend the prelude cell.
176+
prelude_src = textwrap.dedent(
177+
"""\
178+
def _run(cmd):
179+
import subprocess as _sp
180+
import sys as _sys
181+
_p = _sp.run(cmd, shell=True, stdout=_sp.PIPE, stderr=_sp.PIPE)
182+
_stdout = _p.stdout.decode('utf-8').strip()
183+
_stderr = _p.stderr.decode('utf-8').strip()
184+
if _stdout:
185+
print(f'➜ !{cmd}')
186+
print(_stdout)
187+
if _stderr:
188+
print(f'➜ !{cmd}', file=_sys.stderr)
189+
print(_stderr, file=_sys.stderr)
190+
if _p.returncode:
191+
raise RuntimeError('\\n'.join([
192+
f"Command returned non-zero exit status {_p.returncode}.",
193+
f"-------- command --------",
194+
f"{cmd}",
195+
f"-------- stderr --------",
196+
f"{_stderr}",
197+
f"-------- stdout --------",
198+
f"{_stdout}",
199+
]))
200+
"""
201+
+ prelude
202+
)
203+
nb.cells = [nbformat.v4.new_code_cell(prelude_src)] + nb.cells
204+
# Run the notebook.
205+
error = ""
206+
client = NotebookClient(nb)
207+
try:
208+
client.execute()
209+
except CellExecutionError as e:
210+
# Remove colors and other escape characters to make it easier to read in the logs.
211+
# https://stackoverflow.com/a/33925425
212+
color_chars = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
213+
error = color_chars.sub("", str(e))
214+
for cell in nb.cells:
215+
if cell["cell_type"] != "code":
216+
continue
217+
for output in cell["outputs"]:
218+
if output.get("name") == "stdout":
219+
print(color_chars.sub("", output["text"]))
220+
elif output.get("name") == "stderr":
221+
print(color_chars.sub("", output["text"]), file=sys.stderr)
222+
if error:
223+
raise RuntimeError(
224+
f"Error on {repr(ipynb_file)}, section {repr(section)}: {error}"
225+
)

alloydb/notebooks/e2e_test.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright 2022 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Maintainer Note: this sample presumes data exists in
16+
# ALLOYDB_TABLE_NAME within the ALLOYDB_(cluster/instance/database)
17+
18+
import asyncpg # type: ignore
19+
import conftest as conftest # python-docs-samples/alloydb/conftest.py
20+
from google.cloud.alloydb.connector import AsyncConnector, IPTypes
21+
import pytest
22+
import sqlalchemy
23+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
24+
25+
26+
def preprocess(source: str) -> str:
27+
# Skip the cells which add data to table
28+
if "df" in source:
29+
return ""
30+
# Skip the colab auth cell
31+
if "colab" in source:
32+
return ""
33+
return source
34+
35+
36+
async def _init_connection_pool(
37+
connector: AsyncConnector,
38+
db_name: str,
39+
project_id: str,
40+
cluster_name: str,
41+
instance_name: str,
42+
region: str,
43+
password: str,
44+
) -> AsyncEngine:
45+
connection_string = (
46+
f"projects/{project_id}/locations/"
47+
f"{region}/clusters/{cluster_name}/"
48+
f"instances/{instance_name}"
49+
)
50+
51+
async def getconn() -> asyncpg.Connection:
52+
conn: asyncpg.Connection = await connector.connect(
53+
connection_string,
54+
"asyncpg",
55+
user="postgres",
56+
password=password,
57+
db=db_name,
58+
ip_type=IPTypes.PUBLIC,
59+
)
60+
return conn
61+
62+
pool = create_async_engine(
63+
"postgresql+asyncpg://",
64+
async_creator=getconn,
65+
max_overflow=0,
66+
)
67+
return pool
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_embeddings_batch_processing(
72+
project_id: str,
73+
cluster_name: str,
74+
instance_name: str,
75+
region: str,
76+
database_name: str,
77+
password: str,
78+
table_name: str,
79+
) -> None:
80+
# TODO: Create new table
81+
# Populate the table with embeddings by running the notebook
82+
conftest.run_notebook(
83+
"embeddings_batch_processing.ipynb",
84+
variables={
85+
"project_id": project_id,
86+
"cluster_name": cluster_name,
87+
"database_name": database_name,
88+
"region": region,
89+
"instance_name": instance_name,
90+
"table_name": table_name,
91+
},
92+
preprocess=preprocess,
93+
skip_shell_commands=True,
94+
replace={
95+
(
96+
"password = input(\"Please provide "
97+
"a password to be used for 'postgres' "
98+
"database user: \")"
99+
): f"password = '{password}'",
100+
(
101+
"await create_db("
102+
"database_name=database_name, "
103+
"connector=connector)"
104+
): "",
105+
},
106+
until_end=True,
107+
)
108+
109+
# Connect to the populated table for validation and clean up
110+
async with AsyncConnector() as connector:
111+
pool = await _init_connection_pool(
112+
connector,
113+
database_name,
114+
project_id,
115+
cluster_name,
116+
instance_name,
117+
region,
118+
password,
119+
)
120+
async with pool.connect() as conn:
121+
# Validate that embeddings are non-empty for all rows
122+
result = await conn.execute(
123+
sqlalchemy.text(
124+
f"SELECT COUNT(*) FROM "
125+
f"{table_name} WHERE "
126+
f"analysis_embedding IS NULL"
127+
)
128+
)
129+
row = result.fetchone()
130+
assert row[0] == 0
131+
result = await conn.execute(
132+
sqlalchemy.text(
133+
f"SELECT COUNT(*) FROM "
134+
f"{table_name} WHERE "
135+
f"overview_embedding IS NULL"
136+
)
137+
)
138+
row = result.fetchone()
139+
assert row[0] == 0
140+
141+
# Get the table back to the original state
142+
await conn.execute(
143+
sqlalchemy.text(
144+
f"UPDATE {table_name} set "
145+
f"analysis_embedding = NULL"
146+
)
147+
)
148+
await conn.execute(
149+
sqlalchemy.text(
150+
f"UPDATE {table_name} set "
151+
f"overview_embedding = NULL"
152+
)
153+
)
154+
await conn.commit()
155+
await pool.dispose()

0 commit comments

Comments
 (0)