Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 80 additions & 44 deletions llama-index-core/llama_index/core/callbacks/token_counting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)

from llama_index.core.callbacks.pythonically_printing_base_handler import (
PythonicallyPrintingBaseHandler,
Expand All @@ -9,6 +19,9 @@
from llama_index.core.utils import get_tokenizer
import logging

if TYPE_CHECKING:
from llama_index.core.llms import ChatResponse, CompletionResponse


@dataclass
class TokenCountingEvent:
Expand All @@ -23,21 +36,65 @@ def __post_init__(self) -> None:
self.total_token_count = self.prompt_token_count + self.completion_token_count


def get_tokens_from_response(
response: Union["CompletionResponse", "ChatResponse"]
) -> Tuple[int, int]:
"""Get the token counts from a raw response."""
usage = response.raw.get("usage", {})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will error out saying the following.
'ChatCompletion' object has no attribute 'get'

image

We will have to access the usage attribute like
response.raw.usage
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just model dump the raw response if its not already a dict

if usage is None:
usage = response.additional_kwargs

if not usage:
return 0, 0

if not isinstance(usage, dict):
usage = usage.model_dump()

possible_input_keys = ("prompt_tokens", "input_tokens")
possible_output_keys = ("completion_tokens", "output_tokens")
Comment on lines +53 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like Anthropic uses input_tokens and output_tokens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup!


prompt_tokens = 0
for input_key in possible_input_keys:
if input_key in usage:
prompt_tokens = usage[input_key]
break

completion_tokens = 0
for output_key in possible_output_keys:
if output_key in usage:
completion_tokens = usage[output_key]
break

return prompt_tokens, completion_tokens


def get_llm_token_counts(
token_counter: TokenCounter, payload: Dict[str, Any], event_id: str = ""
) -> TokenCountingEvent:
from llama_index.core.llms import ChatMessage

if EventPayload.PROMPT in payload:
prompt = str(payload.get(EventPayload.PROMPT))
completion = str(payload.get(EventPayload.COMPLETION))
prompt = payload.get(EventPayload.PROMPT)
completion = payload.get(EventPayload.COMPLETION)

if completion:
# get from raw or additional_kwargs
prompt_tokens, completion_tokens = get_tokens_from_response(completion)
else:
prompt_tokens, completion_tokens = 0, 0

if prompt_tokens == 0:
prompt_tokens = token_counter.get_string_tokens(str(prompt))

if completion_tokens == 0:
completion_tokens = token_counter.get_string_tokens(str(completion))

return TokenCountingEvent(
event_id=event_id,
prompt=prompt,
prompt_token_count=token_counter.get_string_tokens(prompt),
completion=completion,
completion_token_count=token_counter.get_string_tokens(completion),
prompt=str(prompt),
prompt_token_count=prompt_tokens,
completion=str(completion),
completion_token_count=completion_tokens,
)

elif EventPayload.MESSAGES in payload:
Expand All @@ -47,52 +104,31 @@ def get_llm_token_counts(
response = payload.get(EventPayload.RESPONSE)
response_str = str(response)

# try getting attached token counts first
try:
messages_tokens = 0
response_tokens = 0

if response is not None and response.raw is not None:
if isinstance(response.raw, dict):
raw_dict = response.raw
else:
raw_dict = response.raw.model_dump()
if response:
prompt_tokens, completion_tokens = get_tokens_from_response(response)
else:
prompt_tokens, completion_tokens = 0, 0

usage = raw_dict.get("usage", None)
if prompt_tokens == 0:
prompt_tokens = token_counter.estimate_tokens_in_messages(messages)

if usage is not None:
messages_tokens = usage.get("prompt_tokens", 0)
response_tokens = usage.get("completion_tokens", 0)

if messages_tokens == 0 or response_tokens == 0:
raise ValueError("Invalid token counts!")

return TokenCountingEvent(
event_id=event_id,
prompt=messages_str,
prompt_token_count=messages_tokens,
completion=response_str,
completion_token_count=response_tokens,
)

except (ValueError, KeyError):
# Invalid token counts, or no token counts attached
pass

# Should count tokens ourselves
messages_tokens = token_counter.estimate_tokens_in_messages(messages)
response_tokens = token_counter.get_string_tokens(response_str)
if completion_tokens == 0:
completion_tokens = token_counter.get_string_tokens(response_str)

return TokenCountingEvent(
event_id=event_id,
prompt=messages_str,
prompt_token_count=messages_tokens,
prompt_token_count=prompt_tokens,
completion=response_str,
completion_token_count=response_tokens,
completion_token_count=completion_tokens,
)
else:
raise ValueError(
"Invalid payload! Need prompt and completion or messages and response."
return TokenCountingEvent(
event_id=event_id,
prompt="",
prompt_token_count=0,
completion="",
completion_token_count=0,
)


Expand Down
Loading