Skip to content

Commit b44349a

Browse files
authored
Merge pull request #34 from Azure-Samples/readmedos
Update token in do_connect event
2 parents e30ea96 + eb9ec6d commit b44349a

File tree

5 files changed

+37
-42
lines changed

5 files changed

+37
-42
lines changed

src/fastapi_app/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44

5-
import azure.identity.aio
5+
import azure.identity
66
from dotenv import load_dotenv
77
from environs import Env
88
from fastapi import FastAPI
@@ -27,9 +27,9 @@ async def lifespan(app: FastAPI):
2727
"Using managed identity for client ID %s",
2828
client_id,
2929
)
30-
azure_credential = azure.identity.aio.ManagedIdentityCredential(client_id=client_id)
30+
azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id)
3131
else:
32-
azure_credential = azure.identity.aio.DefaultAzureCredential()
32+
azure_credential = azure.identity.DefaultAzureCredential()
3333
except Exception as e:
3434
logger.warning("Failed to authenticate to Azure: %s", e)
3535

src/fastapi_app/openai_clients.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
import azure.identity.aio
4+
import azure.identity
55
import openai
66

77
logger = logging.getLogger("ragapp")
@@ -12,7 +12,7 @@ async def create_openai_chat_client(azure_credential):
1212
if OPENAI_CHAT_HOST == "azure":
1313
logger.info("Authenticating to OpenAI using Azure Identity...")
1414

15-
token_provider = azure.identity.aio.get_bearer_token_provider(
15+
token_provider = azure.identity.get_bearer_token_provider(
1616
azure_credential, "https://cognitiveservices.azure.com/.default"
1717
)
1818
openai_chat_client = openai.AsyncAzureOpenAI(
@@ -40,7 +40,7 @@ async def create_openai_chat_client(azure_credential):
4040
async def create_openai_embed_client(azure_credential):
4141
OPENAI_EMBED_HOST = os.getenv("OPENAI_EMBED_HOST")
4242
if OPENAI_EMBED_HOST == "azure":
43-
token_provider = azure.identity.aio.get_bearer_token_provider(
43+
token_provider = azure.identity.get_bearer_token_provider(
4444
azure_credential, "https://cognitiveservices.azure.com/.default"
4545
)
4646
openai_embed_client = openai.AsyncAzureOpenAI(

src/fastapi_app/postgres_engine.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
import logging
22
import os
33

4-
from azure.identity.aio import DefaultAzureCredential
4+
from azure.identity import DefaultAzureCredential
5+
from sqlalchemy import event
56
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
67

78
logger = logging.getLogger("ragapp")
89

910

1011
async def create_postgres_engine(*, host, username, database, password, sslmode, azure_credential) -> AsyncEngine:
12+
def get_password_from_azure_credential():
13+
token = azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
14+
return token.token
15+
16+
token_based_password = False
1117
if host.endswith(".database.azure.com"):
18+
token_based_password = True
1219
logger.info("Authenticating to Azure Database for PostgreSQL using Azure Identity...")
1320
if azure_credential is None:
1421
raise ValueError("Azure credential must be provided for Azure Database for PostgreSQL")
15-
token = await azure_credential.get_token("https://ossrdbms-aad.database.windows.net/.default")
16-
password = token.token
22+
password = get_password_from_azure_credential()
1723
else:
1824
logger.info("Authenticating to PostgreSQL using password...")
1925

@@ -27,16 +33,20 @@ async def create_postgres_engine(*, host, username, database, password, sslmode,
2733
echo=False,
2834
)
2935

36+
@event.listens_for(engine.sync_engine, "do_connect")
37+
def update_password_token(dialect, conn_rec, cargs, cparams):
38+
if token_based_password:
39+
logger.info("Updating password token for Azure Database for PostgreSQL")
40+
cparams["password"] = get_password_from_azure_credential()
41+
3042
return engine
3143

3244

3345
async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
34-
must_close = False
3546
if azure_credential is None and os.environ["POSTGRES_HOST"].endswith(".database.azure.com"):
3647
azure_credential = DefaultAzureCredential()
37-
must_close = True
3848

39-
engine = await create_postgres_engine(
49+
return await create_postgres_engine(
4050
host=os.environ["POSTGRES_HOST"],
4151
username=os.environ["POSTGRES_USERNAME"],
4252
database=os.environ["POSTGRES_DATABASE"],
@@ -45,28 +55,16 @@ async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
4555
azure_credential=azure_credential,
4656
)
4757

48-
if must_close:
49-
await azure_credential.close()
50-
51-
return engine
52-
5358

5459
async def create_postgres_engine_from_args(args, azure_credential=None) -> AsyncEngine:
55-
must_close = False
5660
if azure_credential is None and args.host.endswith(".database.azure.com"):
5761
azure_credential = DefaultAzureCredential()
58-
must_close = True
5962

60-
engine = await create_postgres_engine(
63+
return await create_postgres_engine(
6164
host=args.host,
6265
username=args.username,
6366
database=args.database,
6467
password=args.password,
6568
sslmode=args.sslmode,
6669
azure_credential=azure_credential,
6770
)
68-
69-
if must_close:
70-
await azure_credential.close()
71-
72-
return engine

src/fastapi_app/query_rewriter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def build_search_function() -> list[ChatCompletionToolParam]:
2626
"properties": {
2727
"comparison_operator": {
2828
"type": "string",
29-
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '=='", # noqa
29+
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
3030
},
3131
"value": {
3232
"type": "number",
@@ -40,7 +40,7 @@ def build_search_function() -> list[ChatCompletionToolParam]:
4040
"properties": {
4141
"comparison_operator": {
4242
"type": "string",
43-
"description": "Operator to compare the column value, either '==' or '!='",
43+
"description": "Operator to compare the column value, either '=' or '!='",
4444
},
4545
"value": {
4646
"type": "string",

src/requirements.txt

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ aiohttp==3.9.5
88
# via fastapi_app (pyproject.toml)
99
aiosignal==1.3.1
1010
# via aiohttp
11-
annotated-types==0.6.0
11+
annotated-types==0.7.0
1212
# via pydantic
13-
anyio==4.3.0
13+
anyio==4.4.0
1414
# via
1515
# httpx
1616
# openai
@@ -53,10 +53,8 @@ email-validator==2.1.1
5353
environs==11.0.0
5454
# via fastapi_app (pyproject.toml)
5555
fastapi==0.111.0
56-
# via
57-
# fastapi-cli
58-
# fastapi_app (pyproject.toml)
59-
fastapi-cli==0.0.3
56+
# via fastapi_app (pyproject.toml)
57+
fastapi-cli==0.0.4
6058
# via fastapi
6159
frozenlist==1.4.1
6260
# via
@@ -107,7 +105,7 @@ multidict==6.0.5
107105
# yarl
108106
numpy==1.26.4
109107
# via pgvector
110-
openai==1.30.1
108+
openai==1.30.4
111109
# via
112110
# fastapi_app (pyproject.toml)
113111
# openai-messages-token-helper
@@ -128,11 +126,11 @@ portalocker==2.8.2
128126
# via msal-extensions
129127
pycparser==2.22
130128
# via cffi
131-
pydantic==2.7.1
129+
pydantic==2.7.2
132130
# via
133131
# fastapi
134132
# openai
135-
pydantic-core==2.18.2
133+
pydantic-core==2.18.3
136134
# via pydantic
137135
pygments==2.18.0
138136
# via rich
@@ -149,9 +147,9 @@ python-multipart==0.0.9
149147
# via fastapi
150148
pyyaml==6.0.1
151149
# via uvicorn
152-
regex==2024.5.10
150+
regex==2024.5.15
153151
# via tiktoken
154-
requests==2.31.0
152+
requests==2.32.2
155153
# via
156154
# azure-core
157155
# msal
@@ -179,7 +177,7 @@ tqdm==4.66.4
179177
# via openai
180178
typer==0.12.3
181179
# via fastapi-cli
182-
typing-extensions==4.11.0
180+
typing-extensions==4.12.0
183181
# via
184182
# azure-core
185183
# fastapi
@@ -192,14 +190,13 @@ ujson==5.10.0
192190
# via fastapi
193191
urllib3==2.2.1
194192
# via requests
195-
uvicorn[standard]==0.29.0
193+
uvicorn[standard]==0.30.0
196194
# via
197195
# fastapi
198-
# fastapi-cli
199196
# fastapi_app (pyproject.toml)
200197
uvloop==0.19.0
201198
# via uvicorn
202-
watchfiles==0.21.0
199+
watchfiles==0.22.0
203200
# via uvicorn
204201
websockets==12.0
205202
# via uvicorn

0 commit comments

Comments
 (0)