Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.metrics import get_meter
from opentelemetry.semconv_ai import Meters, SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
from opentelemetry.semconv._incubating.metrics import gen_ai_metrics as GenAIMetrics
from opentelemetry.trace import get_tracer
from opentelemetry.trace.propagation import set_span_in_context
from opentelemetry.trace.propagation.tracecontext import (
Expand Down Expand Up @@ -68,14 +69,43 @@ def _instrument(self, **kwargs):
description="Measures number of input and output tokens used",
)

# Create streaming time to first token histogram
ttft_histogram = meter.create_histogram(
name=GenAIMetrics.GEN_AI_SERVER_TIME_TO_FIRST_TOKEN,
unit="s",
description="Time to first token in streaming responses",
)

# Create streaming time to generate histogram
streaming_time_histogram = meter.create_histogram(
name=Meters.LLM_STREAMING_TIME_TO_GENERATE,
unit="s",
description="Time between first token and completion in streaming responses",
)

# Create generation choices counter
choices_counter = meter.create_counter(
name=Meters.LLM_GENERATION_CHOICES,
unit="choice",
description="Number of choices returned by completions call",
)

# Create exception counter
exception_counter = meter.create_counter(
name="llm.langchain.completions.exceptions",
unit="time",
description="Number of exceptions occurred during LangChain completions",
)

if not Config.use_legacy_attributes:
event_logger_provider = kwargs.get("event_logger_provider")
Config.event_logger = get_event_logger(
__name__, __version__, event_logger_provider=event_logger_provider
)

traceloopCallbackHandler = TraceloopCallbackHandler(
tracer, duration_histogram, token_histogram
tracer, duration_histogram, token_histogram, ttft_histogram,
streaming_time_histogram, choices_counter, exception_counter
)
wrap_function_wrapper(
module="langchain_core.callbacks",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
should_send_prompts,
)
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.metrics import Histogram
from opentelemetry.metrics import Histogram, Counter
from opentelemetry.semconv._incubating.attributes.gen_ai_attributes import (
GEN_AI_RESPONSE_ID,
)
Expand Down Expand Up @@ -146,16 +146,65 @@ def _extract_tool_call_data(

class TraceloopCallbackHandler(BaseCallbackHandler):
def __init__(
self, tracer: Tracer, duration_histogram: Histogram, token_histogram: Histogram
self,
tracer: Tracer,
duration_histogram: Histogram,
token_histogram: Histogram,
ttft_histogram: Histogram,
streaming_time_histogram: Histogram,
choices_counter: Counter,
exception_counter: Counter
) -> None:
super().__init__()
self.tracer = tracer
self.duration_histogram = duration_histogram
self.token_histogram = token_histogram
self.ttft_histogram = ttft_histogram
self.streaming_time_histogram = streaming_time_histogram
self.choices_counter = choices_counter
self.exception_counter = exception_counter
self.spans: dict[UUID, SpanHolder] = {}
self.run_inline = True
self._callback_manager: CallbackManager | AsyncCallbackManager = None

def _create_shared_attributes(
self, span, model_name: str, operation_type: str = None, is_streaming: bool = False
) -> dict:
"""Create shared attributes for metrics matching OpenAI SDK structure."""
vendor = span.attributes.get(SpanAttributes.LLM_SYSTEM, "Langchain")
attributes = {
SpanAttributes.LLM_SYSTEM: vendor,
SpanAttributes.LLM_RESPONSE_MODEL: model_name,
}
# Add operation name if available
if operation_type:
attributes["gen_ai.operation.name"] = operation_type
elif span.attributes.get(SpanAttributes.LLM_REQUEST_TYPE):
attributes["gen_ai.operation.name"] = span.attributes.get(SpanAttributes.LLM_REQUEST_TYPE)
server_address = None
try:
association_properties = context_api.get_value("association_properties") or {}
server_address = (
association_properties.get("api_base") or
association_properties.get("endpoint") or
association_properties.get("base_url") or
association_properties.get("server_address")
)
except (AttributeError, KeyError, TypeError):
pass

if not server_address:
# Check if we can get it from span attributes
server_address = span.attributes.get("server.address")

if server_address:
attributes["server.address"] = server_address

if is_streaming:
attributes["stream"] = True

return attributes

@staticmethod
def _get_name_from_callback(
serialized: dict[str, Any],
Expand Down Expand Up @@ -494,7 +543,7 @@ def on_chat_model_start(
metadata=metadata,
serialized=serialized,
)
set_request_params(span, kwargs, self.spans[run_id])
set_request_params(span, kwargs, self.spans[run_id], serialized, metadata)
if should_emit_events():
self._emit_chat_input_events(messages)
else:
Expand Down Expand Up @@ -524,13 +573,53 @@ def on_llm_start(
LLMRequestTypeValues.COMPLETION,
serialized=serialized,
)
set_request_params(span, kwargs, self.spans[run_id])
set_request_params(span, kwargs, self.spans[run_id], serialized, metadata)
if should_emit_events():
for prompt in prompts:
emit_event(MessageEvent(content=prompt, role="user"))
else:
set_llm_request(span, serialized, prompts, kwargs, self.spans[run_id])

@dont_throw
def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run on new LLM token. Track TTFT and streaming metrics."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

if run_id not in self.spans:
return

span_holder = self.spans[run_id]
current_time = time.time()

# Track time to first token
if span_holder.first_token_time is None:
span_holder.first_token_time = current_time
ttft = current_time - span_holder.start_time

# Record TTFT metric
span = span_holder.span
from opentelemetry.instrumentation.langchain.span_utils import _get_unified_unknown_model

model_name = (
span.attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) or
span_holder.request_model or
_get_unified_unknown_model(existing_model=span_holder.request_model)
)

self.ttft_histogram.record(
ttft,
attributes=self._create_shared_attributes(span, model_name, is_streaming=True)
)

@dont_throw
def on_llm_end(
self,
Expand All @@ -552,7 +641,7 @@ def on_llm_end(
) or response.llm_output.get("model_id")
if model_name is not None:
_set_span_attribute(
span, SpanAttributes.LLM_RESPONSE_MODEL, model_name or "unknown"
span, SpanAttributes.LLM_RESPONSE_MODEL, model_name
)

if self.spans[run_id].request_model is None:
Expand All @@ -571,6 +660,20 @@ def on_llm_end(
model_name = _extract_model_name_from_association_metadata(
association_properties
)

# Final fallback: use model name from request if all else fails
if model_name is None and run_id in self.spans and self.spans[run_id].request_model:
model_name = self.spans[run_id].request_model

# Ensure model_name is never None for downstream usage
if model_name is None:
from opentelemetry.instrumentation.langchain.span_utils import _get_unified_unknown_model
existing_model = self.spans[run_id].request_model if run_id in self.spans else None
model_name = _get_unified_unknown_model(existing_model=existing_model)

# Update span attribute with final resolved model name
_set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model_name)

token_usage = (response.llm_output or {}).get("token_usage") or (
response.llm_output or {}
).get("usage")
Expand Down Expand Up @@ -600,26 +703,17 @@ def on_llm_end(
)

# Record token usage metrics
vendor = span.attributes.get(SpanAttributes.LLM_SYSTEM, "Langchain")
base_attrs = self._create_shared_attributes(span, model_name)

if prompt_tokens > 0:
self.token_histogram.record(
prompt_tokens,
attributes={
SpanAttributes.LLM_SYSTEM: vendor,
SpanAttributes.LLM_TOKEN_TYPE: "input",
SpanAttributes.LLM_RESPONSE_MODEL: model_name or "unknown",
},
)
input_attrs = {**base_attrs, SpanAttributes.LLM_TOKEN_TYPE: "input"}
self.token_histogram.record(prompt_tokens, attributes=input_attrs)

if completion_tokens > 0:
self.token_histogram.record(
completion_tokens,
attributes={
SpanAttributes.LLM_SYSTEM: vendor,
SpanAttributes.LLM_TOKEN_TYPE: "output",
SpanAttributes.LLM_RESPONSE_MODEL: model_name or "unknown",
},
)
output_attrs = {**base_attrs, SpanAttributes.LLM_TOKEN_TYPE: "output"}
self.token_histogram.record(completion_tokens, attributes=output_attrs)
# Always call set_chat_response_usage for complete usage metadata extraction
# The function handles duplicate recording prevention internally
set_chat_response_usage(
span, response, self.token_histogram, token_usage is None, model_name
)
Expand All @@ -628,16 +722,27 @@ def on_llm_end(
else:
set_chat_response(span, response)

# Record duration before ending span
duration = time.time() - self.spans[run_id].start_time
vendor = span.attributes.get(SpanAttributes.LLM_SYSTEM, "Langchain")
self.duration_histogram.record(
duration,
attributes={
SpanAttributes.LLM_SYSTEM: vendor,
SpanAttributes.LLM_RESPONSE_MODEL: model_name or "unknown",
},
)
# Record generation choices count
total_choices = 0
for generation_list in response.generations:
total_choices += len(generation_list)

span_holder = self.spans[run_id]
current_time = time.time()
is_streaming_request = span_holder.first_token_time is not None

shared_attrs = self._create_shared_attributes(span, model_name, is_streaming=is_streaming_request)

if total_choices > 0:
self.choices_counter.add(total_choices, attributes=shared_attrs)

# Record streaming time to generate if TTFT was tracked
if span_holder.first_token_time is not None:
streaming_time = current_time - span_holder.first_token_time
self.streaming_time_histogram.record(streaming_time, attributes=shared_attrs)

duration = current_time - span_holder.start_time
self.duration_histogram.record(duration, attributes=shared_attrs)

self._end_span(span, run_id)

Expand Down Expand Up @@ -752,6 +857,23 @@ def _handle_error(
span = self._get_span(run_id)
span.set_status(Status(StatusCode.ERROR))
span.record_exception(error)

# Record exception metric for LLM errors
if run_id in self.spans:
span_holder = self.spans[run_id]
from opentelemetry.instrumentation.langchain.span_utils import _get_unified_unknown_model

model_name = (
span.attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) or
span_holder.request_model or
_get_unified_unknown_model(existing_model=span_holder.request_model)
)

exception_attrs = self._create_shared_attributes(span, model_name)
exception_attrs[ERROR_TYPE] = type(error).__name__

self.exception_counter.add(1, attributes=exception_attrs)

self._end_span(span, run_id)

@dont_throw
Expand Down
Loading