diff --git a/alloydb/notebooks/generate_batch_embeddings.ipynb b/alloydb/notebooks/generate_batch_embeddings.ipynb new file mode 100644 index 00000000000..406f563b355 --- /dev/null +++ b/alloydb/notebooks/generate_batch_embeddings.ipynb @@ -0,0 +1,482 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "WU0ruULySEbZ", + "metadata": { + "id": "WU0ruULySEbZ" + }, + "source": [ + "# Generate and store batch embeddings in AlloyDB\n", + "\n", + "This notebook shows you how to generate batch vector embeddings and store them in an AlloyDB database using the [GenWealth Demo App](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/sample-apps/genwealth) as the sample dataset. \n", + "\n", + "With the steps listed here, you can dynamically build a batch of text chunks to embed based on character length of the source data in order to get more results per inference, leading to much more efficient embeddings generation. The process uses the `psycopg` library to efficiently load the embeddings into AlloyDB after they are generated. These techniques can significantly speed up the process of generating large batches of embeddings and storing them in AlloyDB vs using the native embedding() function (about 6.5x faster based on limited testing).\n", + "\n", + "### Before you Begin\n", + "* Download and set up the [GenWealth Demo App](https://github.com/GoogleCloudPlatform/generative-ai/tree/main/gemini/sample-apps/genwealth).\n", + "* Running the steps in this notebook will incur Google Cloud charges. You may also be billed for Vertex AI API usages." + ] + }, + { + "cell_type": "markdown", + "id": "jJGlsq4QSnvZ", + "metadata": { + "id": "jJGlsq4QSnvZ" + }, + "source": [ + "\n", + "\n", + "### Install and Import Necessary Packages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "r2576UY5dWfj", + "metadata": { + "id": "r2576UY5dWfj" + }, + "outputs": [], + "source": [ + "# Install required libraries\n", + "%pip install psycopg2-binary --quiet\n", + "\n", + "# Import necessary modules\n", + "import psycopg2\n", + "from tabulate import tabulate\n", + "\n", + "import os\n", + "import tempfile\n", + "import time\n", + "import csv\n", + "\n", + "from typing import List, Optional\n", + "from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel\n" + ] + }, + { + "cell_type": "markdown", + "id": "r-AxCjk0SuF8", + "metadata": { + "id": "r-AxCjk0SuF8" + }, + "source": [ + "### Define Environment Variables \n", + "\n", + "Update the variable values in this cell to match your environment. \n", + "\n", + "> NOTE: This step assumes you have a secret stored in Secret Manager called `alloydb-secret`. You can use an alternate method to define your password if desired." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "EzmvgTvrda4y", + "metadata": { + "id": "EzmvgTvrda4y" + }, + "outputs": [], + "source": [ + "# Set GCP and AlloyDB configuration variables\n", + "\n", + "# GCP vars\n", + "region = \"us-central1\"\n", + "project_id = \"YOUR-PROJECT-ID\"\n", + "\n", + "# AlloyDB vars\n", + "alloydb_ip = \"X.X.X.X\"\n", + "database = \"ragdemos\"\n", + "user = \"postgres\"\n", + "password = !gcloud secrets versions access latest --secret=\"alloydb-secret\"\n", + "password = str(password[0])\n", + "\n", + "# Embedding vars\n", + "text_embedding_model_name = 'textembedding-gecko@003'\n", + "model = TextEmbeddingModel.from_pretrained(text_embedding_model_name)\n", + "task = \"SEMANTIC_SIMILARITY\"\n", + "max_tokens = 20000" + ] + }, + { + "cell_type": "markdown", + "id": "YyMiuKhiTHhh", + "metadata": { + "id": "YyMiuKhiTHhh" + }, + "source": [ + "### Establish a Connection to AlloyDB." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "Sr_KMxiUdbi_", + "metadata": { + "id": "Sr_KMxiUdbi_" + }, + "outputs": [], + "source": [ + "# Establish a connection to AlloyDB\n", + "def getconn():\n", + " conn = psycopg2.connect(\n", + " host=alloydb_ip,\n", + " database=database,\n", + " user=user,\n", + " password=password,\n", + " )\n", + " return conn\n", + "\n", + "conn = getconn()" + ] + }, + { + "cell_type": "markdown", + "id": "GTRUhl0dTNlG", + "metadata": { + "id": "GTRUhl0dTNlG" + }, + "source": [ + "### Fetch Source Data to Embed\n", + "\n", + "This step loads the text data you want to embed from AlloyDB (i.e. the `overview` and `analysis` columns in the example below). You can update the SQL query below to match your target environment. \n", + "\n", + "> IMPORTANT: Make sure to retrieve the primary key along with the columns you want to embed (i.e. `id` in the example below). This key will be used to uniquely identify the rows you are embedding during the bulk load and update process later in this notebook. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ZgTxSb_dePq", + "metadata": { + "id": "3ZgTxSb_dePq" + }, + "outputs": [], + "source": [ + "# Store output in array of serializable dictionaries\n", + "source_array = []\n", + "\n", + "# Define database query to get primary key plus text data to embed\n", + "# Ensure you retrieve the id key to uniquely identify the row you are embedding\n", + "sql = f\"\"\"\n", + " SELECT id, overview, analysis FROM investments;\n", + " \"\"\"\n", + "\n", + "# Run the query\n", + "print(f\"Running SQL query: {sql}\")\n", + "with conn.cursor() as cur:\n", + " cur.execute(sql)\n", + " for row in cur.fetchall():\n", + " source_array.append(dict(zip([col.name for col in cur.description], row)))" + ] + }, + { + "cell_type": "markdown", + "id": "ZCmF90VkTWAq", + "metadata": { + "id": "ZCmF90VkTWAq" + }, + "source": [ + "### Define Batching Function\n", + "\n", + "This helper function dynamically builds batches of text chunks to efficiently generate multiple embeddings with each call to the API." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3Z3uz-PTgoRK", + "metadata": { + "id": "3Z3uz-PTgoRK" + }, + "outputs": [], + "source": [ + "# Function to build batches for embedding based on max tokens/characters\n", + "def build_batch_array(source_array, column_to_embed):\n", + " batch_array = []\n", + " current_chars = 0\n", + " max_chars = max_tokens * 3\n", + "\n", + " global index_pointer\n", + " global batch_char_count\n", + " global total_char_count\n", + "\n", + " batch_char_count = 0\n", + "\n", + " while current_chars < max_chars:\n", + " if index_pointer >= len(source_array):\n", + " return batch_array\n", + " current_chars = current_chars + len(source_array[index_pointer][column_to_embed])\n", + " if current_chars > max_chars:\n", + " batch_char_count = current_chars - len(source_array[index_pointer][column_to_embed])\n", + " total_char_count = total_char_count + batch_char_count\n", + " return batch_array\n", + " else:\n", + " batch_array.append(source_array[index_pointer])\n", + " index_pointer = index_pointer + 1" + ] + }, + { + "cell_type": "markdown", + "id": "xJtnLK4DTVRB", + "metadata": { + "id": "xJtnLK4DTVRB" + }, + "source": [ + "### Define Embedding Functions. \n", + "\n", + "This example uses the `textembedding-gecko@003` model to generate embeddings. You can update the embedding model by setting the `model_name` variable if desired. \n", + "\n", + "If you are using `text-embedding-004` or above, uncomment the lines starting with `dimensionality`, `kwargs`, and `embeddings` in the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "xEFzYPxuncqw", + "metadata": { + "id": "xEFzYPxuncqw" + }, + "outputs": [], + "source": [ + "# Functions for text embedding and object embedding\n", + "def embed_text(\n", + " texts: List[str],\n", + " task: str = \"SEMANTIC_SIMILARITY\",\n", + " model_name: str = \"textembedding-gecko@003\",\n", + " #dimensionality: Optional[int] = 768,\n", + ") -> List[List[float]]:\n", + " \"\"\"Embeds texts with a pre-trained, foundational model.\"\"\"\n", + " model = TextEmbeddingModel.from_pretrained(model_name)\n", + " inputs = [TextEmbeddingInput(text, task) for text in texts]\n", + " #kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}\n", + " #embeddings = model.get_embeddings(inputs, **kwargs)\n", + " embeddings = model.get_embeddings(inputs)\n", + " return [embedding.values for embedding in embeddings]\n", + "\n", + "def embed_objects(source_array, column_to_embed):\n", + " source_array_length = len(source_array)\n", + "\n", + " print(f\"Beginning source_array size: {source_array_length}\")\n", + " result_array = []\n", + "\n", + " # Define global variables to track progress and estimate cost\n", + " global index_pointer\n", + " global batch_count\n", + " global batch_char_count\n", + " global total_char_count\n", + "\n", + " while index_pointer < len(source_array):\n", + " # Get a batch of up to batch_size objects\n", + " # Objects in batch are removed from source_array by build_batch_array function\n", + " batch_array = build_batch_array(source_array, column_to_embed)\n", + "\n", + " if batch_array:\n", + " batch_count = batch_count + 1\n", + " print(f\"Processing batch {batch_count} with size: {len(batch_array)}. Progress: {index_pointer} / {source_array_length}. Character count (batch): {batch_char_count}. Character count (cumulative): {total_char_count}\")\n", + "\n", + " texts_to_embed = [obj[column_to_embed] for obj in batch_array]\n", + " embeddings = embed_text(texts_to_embed, model_name = text_embedding_model_name)\n", + "\n", + " for i, obj in enumerate(batch_array):\n", + " obj['embedding'] = embeddings[i]\n", + " result_array.append(obj)\n", + "\n", + " return result_array" + ] + }, + { + "cell_type": "markdown", + "id": "SjLBXucwTmCr", + "metadata": { + "id": "SjLBXucwTmCr" + }, + "source": [ + "### Define Bulk Load and Update Functions\n", + "\n", + "This step defines functions that update embeddings in the target database using the following process:\n", + "\n", + "1. Create a temp table\n", + "2. Bulk load the generated embeddings (along with the primary key of each embedding) into the temp table\n", + "3. Update the target table by joining to the temp table based on the primary key. \n", + "4. Drop the temp table.\n", + "\n", + "This method is faster and more efficient than updating rows one by one with multiple round trips to the database.\n", + "\n", + "You can modify these functions to generate embeddings in your own environment by changing the SQL queries to match your table structure. Be sure to also update the primary key name in function `insert_to_temp_table` to match your primary key (the example uses `id` as the PK)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ditCQSN66xOx", + "metadata": { + "id": "ditCQSN66xOx" + }, + "outputs": [], + "source": [ + "# Functions to manage temporary table and update the target table\n", + "def create_temp_table(column_to_embed):\n", + " temp_table_name = f\"{column_to_embed}_embeddings_temp\"\n", + " # Update the SQL query below to match your environment\n", + " sql = f\"\"\"\n", + " DROP TABLE IF EXISTS {temp_table_name};\n", + " CREATE TABLE {temp_table_name} (\n", + " id INTEGER PRIMARY KEY,\n", + " col_name TEXT,\n", + " embedding REAL[]\n", + " );\n", + " \"\"\"\n", + "\n", + " with conn.cursor() as cur:\n", + " cur.execute(sql)\n", + " conn.commit()\n", + "\n", + " return temp_table_name\n", + "\n", + "\n", + "def insert_to_temp_table(temp_table_name, column_to_embed, object_array):\n", + " with tempfile.NamedTemporaryFile(mode='w', delete=False) as temp_file:\n", + " writer = csv.writer(temp_file, delimiter='|', quotechar=\"'\", escapechar=\"'\")\n", + " for obj in object_array:\n", + " # Ensure the embedding is represented as an array literal with curly braces\n", + " embedding_str = \"{\" + \", \".join(map(str, obj['embedding'])) + \"}\"\n", + " # Update primary key name here (this example uses 'id' as the PK)\n", + " writer.writerow([obj['id'], column_to_embed, embedding_str])\n", + "\n", + " with conn.cursor() as cur:\n", + " with open(temp_file.name, 'r') as f:\n", + " cur.copy_expert(\n", + " # Use the COPY command for efficient loading of the temp table\n", + " # Update the query below to match your environment\n", + " f\"\"\"COPY {temp_table_name} (id, col_name, embedding)\n", + " FROM STDIN\n", + " WITH (FORMAT csv, DELIMITER '|', QUOTE '''', ESCAPE '''')\"\"\",\n", + " f\n", + " )\n", + " conn.commit()\n", + "\n", + " # Cleanup the temporary file\n", + " os.remove(temp_file.name)\n", + "\n", + "\n", + "def update_target_table(temp_table_name, target_table_name, column_to_embed):\n", + " # Update the SQL query below to match your environment\n", + " sql = f\"\"\"\n", + " UPDATE {target_table_name}\n", + " SET {column_to_embed}_embedding = {temp_table_name}.embedding\n", + " FROM {temp_table_name}\n", + " WHERE {target_table_name}.id = {temp_table_name}.id;\n", + " \"\"\"\n", + "\n", + " #print(f\"Running sql statement: {sql}\")\n", + " with conn.cursor() as cur:\n", + " cur.execute(sql)\n", + " conn.commit()\n", + "\n", + "\n", + "def drop_temp_table(temp_table_name):\n", + " sql = f\"\"\"\n", + " DROP TABLE {temp_table_name};\n", + " \"\"\"\n", + "\n", + " with conn.cursor() as cur:\n", + " cur.execute(sql)\n", + " conn.commit()\n" + ] + }, + { + "cell_type": "markdown", + "id": "lQrPX-ugT2jw", + "metadata": { + "id": "lQrPX-ugT2jw" + }, + "source": [ + "### Run the Embedding Process\n", + "\n", + "This step runs the embedding process using the table structure and functions defined in the cells above. \n", + "\n", + "To modify this code to run in your environment, update the target table (`investments` in this example) and the source text chunk columns (`analysis` and `overview` in this example) to match your data structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "QFgXZyjAidiC", + "metadata": { + "id": "QFgXZyjAidiC" + }, + "outputs": [], + "source": [ + "# Define table where embeddings will be written and columns to be embedded\n", + "target_table_name = 'investments'\n", + "columns_to_embed = ['analysis','overview']\n", + "\n", + "# Define global variables to track progress and estimate cost\n", + "global index_pointer\n", + "global batch_count\n", + "global batch_char_count\n", + "global total_char_count\n", + "\n", + "# Define batch variables\n", + "batch_array = None\n", + "batch_size = None\n", + "batch_count = 0\n", + "total_char_count = 0\n", + "\n", + "# Keep track of job timing\n", + "start_time = time.time()\n", + "\n", + "for column_to_embed in columns_to_embed:\n", + " # Initialize the index pointer for batch processing\n", + " index_pointer = 0\n", + "\n", + " print(f\"Creating embeddings for column {column_to_embed}...\")\n", + " results = embed_objects(source_array, column_to_embed)\n", + "\n", + " print(f\"Creating temp table to store intermediate results...\")\n", + " temp_table_name = create_temp_table(column_to_embed)\n", + "\n", + " print(f\"Inserting embeddings into temp table: {temp_table_name}...\")\n", + " insert_to_temp_table(temp_table_name, column_to_embed, results)\n", + "\n", + " print(f\"Merging temp table {temp_table_name} with target table {target_table_name}...\")\n", + " update_target_table(temp_table_name, target_table_name, column_to_embed)\n", + "\n", + " print(f\"Dropping temp table temp_table_name...\")\n", + " drop_temp_table(temp_table_name)\n", + "\n", + "end_time = time.time()\n", + "elapsed_time = end_time - start_time\n", + "print(f\"Job started at: {time.ctime(start_time)}\")\n", + "print(f\"Job ended at: {time.ctime(end_time)}\")\n", + "print(f\"Total run time: {elapsed_time:.2f} seconds\")" + ] + } + ], + "metadata": { + "colab": { + "name": "fast_embeddings_api_psycopg", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}