Skip to content

Commit 4582012

Browse files
committed
Fix botocore tests & re-structure
1 parent bb2ebe3 commit 4582012

File tree

7 files changed

+5
-4733
lines changed

7 files changed

+5
-4733
lines changed

newrelic/hooks/external_botocore.py

Lines changed: 2 additions & 364 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import json
16-
import logging
17-
import uuid
18-
from io import BytesIO
19-
20-
from botocore.response import StreamingBody
21-
2215
from newrelic.api.datastore_trace import datastore_trace
2316
from newrelic.api.external_trace import ExternalTrace
24-
from newrelic.api.function_trace import FunctionTrace
2517
from newrelic.api.message_trace import message_trace
26-
from newrelic.api.time_trace import get_trace_linking_metadata
27-
from newrelic.api.transaction import current_transaction
28-
from newrelic.common.object_names import callable_name
29-
from newrelic.common.object_wrapper import function_wrapper, wrap_function_wrapper
30-
from newrelic.core.config import global_settings
31-
32-
_logger = logging.getLogger(__name__)
33-
UNSUPPORTED_MODEL_WARNING_SENT = False
18+
from newrelic.common.object_wrapper import wrap_function_wrapper
3419

3520

3621
def extract_sqs(*args, **kwargs):
@@ -55,353 +40,6 @@ def extractor_string(*args, **kwargs):
5540
return extractor_list
5641

5742

58-
def bedrock_error_attributes(exception, request_args, client, extractor):
59-
response = getattr(exception, "response", None)
60-
if not response:
61-
return {}
62-
63-
request_body = request_args.get("body", "")
64-
error_attributes = extractor(request_body)[1]
65-
66-
error_attributes.update({
67-
"request_id": response.get("ResponseMetadata", {}).get("RequestId", ""),
68-
"api_key_last_four_digits": client._request_signer._credentials.access_key[-4:],
69-
"request.model": request_args.get("modelId", ""),
70-
"vendor": "Bedrock",
71-
"ingest_source": "Python",
72-
"http.statusCode": response.get("ResponseMetadata", "").get("HTTPStatusCode", ""),
73-
"error.message": response.get("Error", "").get("Message", ""),
74-
"error.code": response.get("Error", "").get("Code", ""),
75-
})
76-
return error_attributes
77-
78-
79-
def create_chat_completion_message_event(
80-
transaction,
81-
app_name,
82-
message_list,
83-
chat_completion_id,
84-
span_id,
85-
trace_id,
86-
request_model,
87-
request_id,
88-
conversation_id,
89-
response_id="",
90-
):
91-
if not transaction:
92-
return
93-
94-
for index, message in enumerate(message_list):
95-
if response_id:
96-
id_ = "%s-%d" % (response_id, index) # Response ID was set, append message index to it.
97-
else:
98-
id_ = str(uuid.uuid4()) # No response IDs, use random UUID
99-
100-
chat_completion_message_dict = {
101-
"id": id_,
102-
"appName": app_name,
103-
"conversation_id": conversation_id,
104-
"request_id": request_id,
105-
"span_id": span_id,
106-
"trace_id": trace_id,
107-
"transaction_id": transaction._transaction_id,
108-
"content": message.get("content", ""),
109-
"role": message.get("role"),
110-
"completion_id": chat_completion_id,
111-
"sequence": index,
112-
"response.model": request_model,
113-
"vendor": "bedrock",
114-
"ingest_source": "Python",
115-
}
116-
transaction.record_ml_event("LlmChatCompletionMessage", chat_completion_message_dict)
117-
118-
119-
def extract_bedrock_titan_text_model(request_body, response_body=None):
120-
request_body = json.loads(request_body)
121-
if response_body:
122-
response_body = json.loads(response_body)
123-
124-
request_config = request_body.get("textGenerationConfig", {})
125-
126-
chat_completion_summary_dict = {
127-
"request.max_tokens": request_config.get("maxTokenCount", ""),
128-
"request.temperature": request_config.get("temperature", ""),
129-
}
130-
131-
if response_body:
132-
input_tokens = response_body["inputTextTokenCount"]
133-
completion_tokens = sum(result["tokenCount"] for result in response_body.get("results", []))
134-
total_tokens = input_tokens + completion_tokens
135-
136-
message_list = [{"role": "user", "content": request_body.get("inputText", "")}]
137-
message_list.extend(
138-
{"role": "assistant", "content": result["outputText"]} for result in response_body.get("results", [])
139-
)
140-
141-
chat_completion_summary_dict.update({
142-
"response.choices.finish_reason": response_body["results"][0]["completionReason"],
143-
"response.usage.completion_tokens": completion_tokens,
144-
"response.usage.prompt_tokens": input_tokens,
145-
"response.usage.total_tokens": total_tokens,
146-
"response.number_of_messages": len(message_list),
147-
})
148-
else:
149-
message_list = []
150-
151-
return message_list, chat_completion_summary_dict
152-
153-
154-
def extract_bedrock_titan_embedding_model(request_body, response_body=None):
155-
if not response_body:
156-
return [], {} # No extracted information necessary for embedding
157-
158-
request_body = json.loads(request_body)
159-
response_body = json.loads(response_body)
160-
161-
input_tokens = response_body.get("inputTextTokenCount", None)
162-
163-
embedding_dict = {
164-
"input": request_body.get("inputText", ""),
165-
"response.usage.prompt_tokens": input_tokens,
166-
"response.usage.total_tokens": input_tokens,
167-
}
168-
return [], embedding_dict
169-
170-
171-
def extract_bedrock_ai21_j2_model(request_body, response_body=None):
172-
request_body = json.loads(request_body)
173-
if response_body:
174-
response_body = json.loads(response_body)
175-
176-
chat_completion_summary_dict = {
177-
"request.max_tokens": request_body.get("maxTokens", ""),
178-
"request.temperature": request_body.get("temperature", ""),
179-
}
180-
181-
if response_body:
182-
message_list = [{"role": "user", "content": request_body.get("prompt", "")}]
183-
message_list.extend(
184-
{"role": "assistant", "content": result["data"]["text"]} for result in response_body.get("completions", [])
185-
)
186-
187-
chat_completion_summary_dict.update({
188-
"response.choices.finish_reason": response_body["completions"][0]["finishReason"]["reason"],
189-
"response.number_of_messages": len(message_list),
190-
"response_id": str(response_body.get("id", "")),
191-
})
192-
else:
193-
message_list = []
194-
195-
return message_list, chat_completion_summary_dict
196-
197-
198-
def extract_bedrock_claude_model(request_body, response_body=None):
199-
request_body = json.loads(request_body)
200-
if response_body:
201-
response_body = json.loads(response_body)
202-
203-
chat_completion_summary_dict = {
204-
"request.max_tokens": request_body.get("max_tokens_to_sample", ""),
205-
"request.temperature": request_body.get("temperature", ""),
206-
}
207-
208-
if response_body:
209-
message_list = [
210-
{"role": "user", "content": request_body.get("prompt", "")},
211-
{"role": "assistant", "content": response_body.get("completion", "")},
212-
]
213-
214-
chat_completion_summary_dict.update({
215-
"response.choices.finish_reason": response_body.get("stop_reason", ""),
216-
"response.number_of_messages": len(message_list),
217-
})
218-
else:
219-
message_list = []
220-
221-
return message_list, chat_completion_summary_dict
222-
223-
224-
def extract_bedrock_cohere_model(request_body, response_body=None):
225-
request_body = json.loads(request_body)
226-
if response_body:
227-
response_body = json.loads(response_body)
228-
229-
chat_completion_summary_dict = {
230-
"request.max_tokens": request_body.get("max_tokens", ""),
231-
"request.temperature": request_body.get("temperature", ""),
232-
}
233-
234-
if response_body:
235-
message_list = [{"role": "user", "content": request_body.get("prompt", "")}]
236-
message_list.extend(
237-
{"role": "assistant", "content": result["text"]} for result in response_body.get("generations", [])
238-
)
239-
240-
chat_completion_summary_dict.update({
241-
"request.max_tokens": request_body.get("max_tokens", ""),
242-
"request.temperature": request_body.get("temperature", ""),
243-
"response.choices.finish_reason": response_body["generations"][0]["finish_reason"],
244-
"response.number_of_messages": len(message_list),
245-
"response_id": str(response_body.get("id", "")),
246-
})
247-
else:
248-
message_list = []
249-
250-
return message_list, chat_completion_summary_dict
251-
252-
253-
MODEL_EXTRACTORS = [ # Order is important here, avoiding dictionaries
254-
("amazon.titan-embed", extract_bedrock_titan_embedding_model),
255-
("amazon.titan", extract_bedrock_titan_text_model),
256-
("ai21.j2", extract_bedrock_ai21_j2_model),
257-
("cohere", extract_bedrock_cohere_model),
258-
("anthropic.claude", extract_bedrock_claude_model),
259-
]
260-
261-
262-
@function_wrapper
263-
def wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
264-
# Wrapped function only takes keyword arguments, no need for binding
265-
266-
transaction = current_transaction()
267-
268-
if not transaction:
269-
return wrapped(*args, **kwargs)
270-
271-
# Read and replace request file stream bodies
272-
request_body = kwargs["body"]
273-
if hasattr(request_body, "read"):
274-
request_body = request_body.read()
275-
kwargs["body"] = request_body
276-
277-
# Determine model to be used with extractor
278-
model = kwargs.get("modelId")
279-
if not model:
280-
return wrapped(*args, **kwargs)
281-
282-
# Determine extractor by model type
283-
for extractor_name, extractor in MODEL_EXTRACTORS:
284-
if model.startswith(extractor_name):
285-
break
286-
else:
287-
# Model was not found in extractor list
288-
global UNSUPPORTED_MODEL_WARNING_SENT
289-
if not UNSUPPORTED_MODEL_WARNING_SENT:
290-
# Only send warning once to avoid spam
291-
_logger.warning(
292-
"Unsupported Amazon Bedrock model in use (%s). Upgrade to a newer version of the agent, and contact New Relic support if the issue persists.",
293-
model,
294-
)
295-
UNSUPPORTED_MODEL_WARNING_SENT = True
296-
297-
extractor = lambda *args: ([], {}) # Empty extractor that returns nothing
298-
299-
ft_name = callable_name(wrapped)
300-
with FunctionTrace(ft_name) as ft:
301-
try:
302-
response = wrapped(*args, **kwargs)
303-
except Exception as exc:
304-
try:
305-
error_attributes = extractor(request_body)
306-
error_attributes = bedrock_error_attributes(exc, kwargs, instance, extractor)
307-
ft.notice_error(
308-
attributes=error_attributes,
309-
)
310-
finally:
311-
raise
312-
313-
if not response:
314-
return response
315-
316-
# Read and replace response streaming bodies
317-
response_body = response["body"].read()
318-
response["body"] = StreamingBody(BytesIO(response_body), len(response_body))
319-
response_headers = response["ResponseMetadata"]["HTTPHeaders"]
320-
321-
if model.startswith("amazon.titan-embed"): # Only available embedding models
322-
handle_embedding_event(instance, transaction, extractor, model, response_body, response_headers, request_body, ft.duration)
323-
else:
324-
handle_chat_completion_event(instance, transaction, extractor, model, response_body, response_headers, request_body, ft.duration)
325-
326-
return response
327-
328-
def handle_embedding_event(client, transaction, extractor, model, response_body, response_headers, request_body, duration):
329-
embedding_id = str(uuid.uuid4())
330-
available_metadata = get_trace_linking_metadata()
331-
span_id = available_metadata.get("span.id", "")
332-
trace_id = available_metadata.get("trace.id", "")
333-
334-
request_id = response_headers.get("x-amzn-requestid", "")
335-
settings = transaction.settings if transaction.settings is not None else global_settings()
336-
337-
_, embedding_dict = extractor(request_body, response_body)
338-
339-
embedding_dict.update({
340-
"vendor": "bedrock",
341-
"ingest_source": "Python",
342-
"id": embedding_id,
343-
"appName": settings.app_name,
344-
"span_id": span_id,
345-
"trace_id": trace_id,
346-
"request_id": request_id,
347-
"transaction_id": transaction._transaction_id,
348-
"api_key_last_four_digits": client._request_signer._credentials.access_key[-4:],
349-
"duration": duration,
350-
"request.model": model,
351-
"response.model": model,
352-
})
353-
354-
transaction.record_ml_event("LlmEmbedding", embedding_dict)
355-
356-
357-
def handle_chat_completion_event(client, transaction, extractor, model, response_body, response_headers, request_body, duration):
358-
custom_attrs_dict = transaction._custom_params
359-
conversation_id = custom_attrs_dict.get("conversation_id", "")
360-
361-
chat_completion_id = str(uuid.uuid4())
362-
available_metadata = get_trace_linking_metadata()
363-
span_id = available_metadata.get("span.id", "")
364-
trace_id = available_metadata.get("trace.id", "")
365-
366-
request_id = response_headers.get("x-amzn-requestid", "")
367-
settings = transaction.settings if transaction.settings is not None else global_settings()
368-
369-
message_list, chat_completion_summary_dict = extractor(request_body, response_body)
370-
response_id = chat_completion_summary_dict.get("response_id", "")
371-
chat_completion_summary_dict.update(
372-
{
373-
"vendor": "bedrock",
374-
"ingest_source": "Python",
375-
"api_key_last_four_digits": client._request_signer._credentials.access_key[-4:],
376-
"id": chat_completion_id,
377-
"appName": settings.app_name,
378-
"conversation_id": conversation_id,
379-
"span_id": span_id,
380-
"trace_id": trace_id,
381-
"transaction_id": transaction._transaction_id,
382-
"request_id": request_id,
383-
"duration": duration,
384-
"request.model": model,
385-
"response.model": model, # Duplicate data required by the UI
386-
}
387-
)
388-
389-
transaction.record_ml_event("LlmChatCompletionSummary", chat_completion_summary_dict)
390-
391-
create_chat_completion_message_event(
392-
transaction=transaction,
393-
app_name=settings.app_name,
394-
message_list=message_list,
395-
chat_completion_id=chat_completion_id,
396-
span_id=span_id,
397-
trace_id=trace_id,
398-
request_model=model,
399-
request_id=request_id,
400-
conversation_id=conversation_id,
401-
response_id=response_id,
402-
)
403-
404-
40543
CUSTOM_TRACE_POINTS = {
40644
("sns", "publish"): message_trace("SNS", "Produce", "Topic", extract(("TopicArn", "TargetArn"), "PhoneNumber")),
40745
("dynamodb", "put_item"): datastore_trace("DynamoDB", extract("TableName"), "put_item"),
@@ -415,7 +53,6 @@ def handle_chat_completion_event(client, transaction, extractor, model, response
41553
("sqs", "send_message"): message_trace("SQS", "Produce", "Queue", extract_sqs),
41654
("sqs", "send_message_batch"): message_trace("SQS", "Produce", "Queue", extract_sqs),
41755
("sqs", "receive_message"): message_trace("SQS", "Consume", "Queue", extract_sqs),
418-
("bedrock-runtime", "invoke_model"): wrap_bedrock_runtime_invoke_model,
41956
}
42057

42158

@@ -447,6 +84,7 @@ def _nr_endpoint_make_request_(wrapped, instance, args, kwargs):
44784
method = request_dict.get("method", None)
44885

44986
with ExternalTrace(library="botocore", url=url, method=method, source=wrapped) as trace:
87+
45088
try:
45189
trace._add_agent_attribute("aws.operation", operation_model.name)
45290
except:

0 commit comments

Comments
 (0)