diff --git a/CHANGELOG.md b/CHANGELOG.md index 31bc6faea..6e4875bdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Other changes * A few changes for `ui.Chat()`, including: + * The `.messages()` method no longer trims messages by default (i.e., the default value of `token_limits` is now `None` instead of the overly generic and conservative value of `(4096, 1000)`). See the new generative AI in production templates (via `shiny create`) for examples of setting `token_limits` based on the model being used. (#1657) * User input that contains markdown now renders the expected HTML. (#1607) * Busy indication is now visible/apparent during the entire lifecycle of response generation. (#1607) diff --git a/shiny/_main_create.py b/shiny/_main_create.py index 0230fac4b..b9b4524ec 100644 --- a/shiny/_main_create.py +++ b/shiny/_main_create.py @@ -231,6 +231,10 @@ def chat_hello_providers(self) -> list[ShinyTemplate]: def chat_enterprise(self) -> list[ShinyTemplate]: return self._templates("templates/chat/enterprise") + @property + def chat_production(self) -> list[ShinyTemplate]: + return self._templates("templates/chat/production") + shiny_internal_templates = ShinyInternalTemplates() @@ -260,6 +264,7 @@ def use_internal_template( chat_templates = [ *shiny_internal_templates.chat_hello_providers, *shiny_internal_templates.chat_enterprise, + *shiny_internal_templates.chat_production, ] menu_choices = [ @@ -351,6 +356,7 @@ def use_internal_chat_ai_template( choices=[ Choice(title="By provider...", value="_chat-ai_hello-providers"), Choice(title="Enterprise providers...", value="_chat-ai_enterprise"), + Choice(title="Production-ready chat AI", value="_chat-ai_production"), back_choice, cancel_choice, ], @@ -369,11 +375,12 @@ def use_internal_chat_ai_template( ) return - template_choices = ( - shiny_internal_templates.chat_enterprise - if input == "_chat-ai_enterprise" - else shiny_internal_templates.chat_hello_providers - ) + if input == "_chat-ai_production": + template_choices = shiny_internal_templates.chat_production + elif input == "_chat-ai_enterprise": + template_choices = shiny_internal_templates.chat_enterprise + else: + template_choices = shiny_internal_templates.chat_hello_providers choice = question_choose_template(template_choices, back_choice) @@ -385,6 +392,7 @@ def use_internal_chat_ai_template( [ *shiny_internal_templates.chat_hello_providers, *shiny_internal_templates.chat_enterprise, + *shiny_internal_templates.chat_production, ], choice, ) diff --git a/shiny/templates/chat/production/anthropic/_template.json b/shiny/templates/chat/production/anthropic/_template.json new file mode 100644 index 000000000..271e2c040 --- /dev/null +++ b/shiny/templates/chat/production/anthropic/_template.json @@ -0,0 +1,5 @@ +{ + "type": "app", + "id": "chat-ai-anthropic-prod", + "title": "Chat in production with Anthropic" +} diff --git a/shiny/templates/chat/production/anthropic/app.py b/shiny/templates/chat/production/anthropic/app.py new file mode 100644 index 000000000..cac452a32 --- /dev/null +++ b/shiny/templates/chat/production/anthropic/app.py @@ -0,0 +1,59 @@ +# ------------------------------------------------------------------------------------ +# When putting a Chat into production, there are at least a couple additional +# considerations to keep in mind: +# - Token Limits: LLMs have (varying) limits on how many tokens can be included in +# a single request and response. To accurately respect these limits, you'll want +# to find the revelant limits and tokenizer for the model you're using, and inform +# Chat about them. +# - Reproducibility: Consider pinning a snapshot of the LLM model to ensure that the +# same model is used each time the app is run. +# +# See the MODEL_INFO dictionary below for an example of how to set these values for +# Anthropic's Claude model. +# https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table +# ------------------------------------------------------------------------------------ +import os + +from anthropic import AsyncAnthropic +from app_utils import load_dotenv + +from shiny.express import ui + +load_dotenv() +llm = AsyncAnthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) + + +MODEL_INFO = { + "name": "claude-3-5-sonnet-20240620", + # DISCLAIMER: Anthropic has not yet released a public tokenizer for Claude models, + # so this uses the generic default provided by Chat() (for now). That is probably + # ok though since the default tokenizer likely overestimates the token count. + "tokenizer": None, + "token_limits": (200000, 8192), +} + + +ui.page_opts( + title="Hello OpenAI Chat", + fillable=True, + fillable_mobile=True, +) + +chat = ui.Chat( + id="chat", + messages=[ + {"content": "Hello! How can I help you today?", "role": "assistant"}, + ], + tokenizer=MODEL_INFO["tokenizer"], +) + +chat.ui() + + +@chat.on_user_submit +async def _(): + messages = chat.messages(format="openai", token_limits=MODEL_INFO["token_limits"]) + response = await llm.chat.completions.create( + model=MODEL_INFO["name"], messages=messages, stream=True + ) + await chat.append_message_stream(response) diff --git a/shiny/templates/chat/production/anthropic/app_utils.py b/shiny/templates/chat/production/anthropic/app_utils.py new file mode 100644 index 000000000..404a13730 --- /dev/null +++ b/shiny/templates/chat/production/anthropic/app_utils.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path +from typing import Any + +app_dir = Path(__file__).parent +env_file = app_dir / ".env" + + +def load_dotenv(dotenv_path: os.PathLike[str] = env_file, **kwargs: Any) -> None: + """ + A convenience wrapper around `dotenv.load_dotenv` that warns if `dotenv` is not installed. + It also returns `None` to make it easier to ignore the return value. + """ + try: + import dotenv + + dotenv.load_dotenv(dotenv_path=dotenv_path, **kwargs) + except ImportError: + import warnings + + warnings.warn( + "Could not import `dotenv`. If you want to use `.env` files to " + "load environment variables, please install it using " + "`pip install python-dotenv`.", + stacklevel=2, + ) diff --git a/shiny/templates/chat/production/anthropic/requirements.txt b/shiny/templates/chat/production/anthropic/requirements.txt new file mode 100644 index 000000000..fb3b67026 --- /dev/null +++ b/shiny/templates/chat/production/anthropic/requirements.txt @@ -0,0 +1,4 @@ +shiny +python-dotenv +tokenizers +anthropic diff --git a/shiny/templates/chat/production/openai/_template.json b/shiny/templates/chat/production/openai/_template.json new file mode 100644 index 000000000..1a64e5211 --- /dev/null +++ b/shiny/templates/chat/production/openai/_template.json @@ -0,0 +1,5 @@ +{ + "type": "app", + "id": "chat-ai-openai-prod", + "title": "Chat in production with OpenAI" +} diff --git a/shiny/templates/chat/production/openai/app.py b/shiny/templates/chat/production/openai/app.py new file mode 100644 index 000000000..7b1274c95 --- /dev/null +++ b/shiny/templates/chat/production/openai/app.py @@ -0,0 +1,56 @@ +# ------------------------------------------------------------------------------------ +# When putting a Chat into production, there are at least a couple additional +# considerations to keep in mind: +# - Token Limits: LLMs have (varying) limits on how many tokens can be included in +# a single request and response. To accurately respect these limits, you'll want +# to find the revelant limits and tokenizer for the model you're using, and inform +# Chat about them. +# - Reproducibility: Consider pinning a snapshot of the LLM model to ensure that the +# same model is used each time the app is run. +# +# See the MODEL_INFO dictionary below for an example of how to set these values for +# OpenAI's GPT-4o model. +# ------------------------------------------------------------------------------------ +import os + +import tiktoken +from app_utils import load_dotenv +from openai import AsyncOpenAI + +from shiny.express import ui + +load_dotenv() +llm = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + + +MODEL_INFO = { + "name": "gpt-4o-2024-08-06", + "tokenizer": tiktoken.encoding_for_model("gpt-4o-2024-08-06"), + "token_limits": (128000, 16000), +} + + +ui.page_opts( + title="Hello OpenAI Chat", + fillable=True, + fillable_mobile=True, +) + +chat = ui.Chat( + id="chat", + messages=[ + {"content": "Hello! How can I help you today?", "role": "assistant"}, + ], + tokenizer=MODEL_INFO["tokenizer"], +) + +chat.ui() + + +@chat.on_user_submit +async def _(): + messages = chat.messages(format="openai", token_limits=MODEL_INFO["token_limits"]) + response = await llm.chat.completions.create( + model=MODEL_INFO["name"], messages=messages, stream=True + ) + await chat.append_message_stream(response) diff --git a/shiny/templates/chat/production/openai/app_utils.py b/shiny/templates/chat/production/openai/app_utils.py new file mode 100644 index 000000000..404a13730 --- /dev/null +++ b/shiny/templates/chat/production/openai/app_utils.py @@ -0,0 +1,26 @@ +import os +from pathlib import Path +from typing import Any + +app_dir = Path(__file__).parent +env_file = app_dir / ".env" + + +def load_dotenv(dotenv_path: os.PathLike[str] = env_file, **kwargs: Any) -> None: + """ + A convenience wrapper around `dotenv.load_dotenv` that warns if `dotenv` is not installed. + It also returns `None` to make it easier to ignore the return value. + """ + try: + import dotenv + + dotenv.load_dotenv(dotenv_path=dotenv_path, **kwargs) + except ImportError: + import warnings + + warnings.warn( + "Could not import `dotenv`. If you want to use `.env` files to " + "load environment variables, please install it using " + "`pip install python-dotenv`.", + stacklevel=2, + ) diff --git a/shiny/templates/chat/production/openai/requirements.txt b/shiny/templates/chat/production/openai/requirements.txt new file mode 100644 index 000000000..4e5ab6b2a --- /dev/null +++ b/shiny/templates/chat/production/openai/requirements.txt @@ -0,0 +1,4 @@ +shiny +python-dotenv +tiktoken +openai diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index 4940c0290..0cc07d8ca 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -37,7 +37,7 @@ as_provider_message, ) from ._chat_tokenizer import TokenEncoding, TokenizersEncoding, get_default_tokenizer -from ._chat_types import ChatMessage, ClientMessage, StoredMessage, TransformedMessage +from ._chat_types import ChatMessage, ClientMessage, TransformedMessage from ._html_deps_py_shiny import chat_deps from .fill import as_fill_item, as_fillable_container @@ -134,10 +134,10 @@ async def _(): * `"unhandled"`: Do not display any error message to the user. tokenizer The tokenizer to use for calculating token counts, which is required to impose - `token_limits` in `.messages()`. By default, a pre-trained tokenizer is - attempted to be loaded the tokenizers library (if available). A custom tokenizer - can be provided by following the `TokenEncoding` (tiktoken or tozenizer) - protocol. If token limits are of no concern, provide `None`. + `token_limits` in `.messages()`. If not provided, a default generic tokenizer + is attempted to be loaded from the tokenizers library. A specific tokenizer + may also be provided by following the `TokenEncoding` (tiktoken or tozenizers) + protocol (e.g., `tiktoken.encoding_for_model("gpt-4o")`). """ def __init__( @@ -146,7 +146,7 @@ def __init__( *, messages: Sequence[Any] = (), on_error: Literal["auto", "actual", "sanitize", "unhandled"] = "auto", - tokenizer: TokenEncoding | MISSING_TYPE | None = MISSING, + tokenizer: TokenEncoding | None = None, ): if not isinstance(id, str): raise TypeError("`id` must be a string.") @@ -155,10 +155,8 @@ def __init__( self.user_input_id = ResolvedId(f"{self.id}_user_input") self._transform_user: TransformUserInputAsync | None = None self._transform_assistant: TransformAssistantResponseChunkAsync | None = None - if isinstance(tokenizer, MISSING_TYPE): - self._tokenizer = get_default_tokenizer() - else: - self._tokenizer = tokenizer + self._tokenizer = tokenizer + # TODO: remove the `None` when this PR lands: # https://github.com/posit-dev/py-shiny/pull/793/files self._session = require_active_session(None) @@ -187,11 +185,11 @@ def __init__( # Initialize chat state and user input effect with session_context(self._session): # Initialize message state - self._messages: reactive.Value[tuple[StoredMessage, ...]] = reactive.Value( - () + self._messages: reactive.Value[tuple[TransformedMessage, ...]] = ( + reactive.Value(()) ) - self._latest_user_input: reactive.Value[StoredMessage | None] = ( + self._latest_user_input: reactive.Value[TransformedMessage | None] = ( reactive.Value(None) ) @@ -351,7 +349,7 @@ def messages( self, *, format: Literal["anthropic"] = "anthropic", - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[AnthropicMessage, ...]: ... @@ -361,7 +359,7 @@ def messages( self, *, format: Literal["google"] = "google", - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[GoogleMessage, ...]: ... @@ -371,7 +369,7 @@ def messages( self, *, format: Literal["langchain"] = "langchain", - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[LangChainMessage, ...]: ... @@ -381,7 +379,7 @@ def messages( self, *, format: Literal["openai"] = "openai", - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[OpenAIMessage, ...]: ... @@ -391,7 +389,7 @@ def messages( self, *, format: Literal["ollama"] = "ollama", - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[OllamaMessage, ...]: ... @@ -401,7 +399,7 @@ def messages( self, *, format: MISSING_TYPE = MISSING, - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[ChatMessage, ...]: ... @@ -410,7 +408,7 @@ def messages( self, *, format: MISSING_TYPE | ProviderMessageFormat = MISSING, - token_limits: tuple[int, int] | None = (4096, 1000), + token_limits: tuple[int, int] | None = None, transform_user: Literal["all", "last", "none"] = "all", transform_assistant: bool = False, ) -> tuple[ChatMessage | ProviderMessage, ...]: @@ -440,10 +438,15 @@ def messages( * `"openai"`: OpenAI message format. * `"ollama"`: Ollama message format. token_limits - A tuple of two integers. The first integer is the maximum number of tokens - that can be sent to the model in a single request. The second integer is the - amount of tokens to reserve for the model's response. - Can also be `None` to disable message trimming based on token counts. + Limit the conversation history based on token limits. If specified, only + the most recent messages that fit within the token limits are returned. This + is useful for avoiding "exceeded token limit" errors when sending messages + to the relevant model, while still providing the most recent context available. + A specified value must be a tuple of two integers. The first integer is the + maximum number of tokens that can be sent to the model in a single request. + The second integer is the amount of tokens to reserve for the model's response. + Note that token counts based on the `tokenizer` provided to the `Chat` + constructor. transform_user Whether to return user input messages with transformation applied. This only matters if a `transform_user_input` was provided to the chat constructor. @@ -536,7 +539,7 @@ async def _append_message( ) if msg is None: return - msg = self._store_message(msg, chunk=chunk) + self._store_message(msg, chunk=chunk) await self._send_append_message(msg, chunk=chunk) async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any]): @@ -612,7 +615,7 @@ def _can_append_message(self, stream_id: str | None) -> bool: # Send a message to the UI async def _send_append_message( self, - message: StoredMessage, + message: TransformedMessage, chunk: ChunkOption = False, ): if message["role"] == "system": @@ -799,24 +802,11 @@ def _store_message( message: TransformedMessage, chunk: ChunkOption = False, index: int | None = None, - ) -> StoredMessage: - - msg: StoredMessage = { - **message, - "token_count": None, - } + ) -> None: # Don't actually store chunks until the end if chunk is True or chunk == "start": - return msg - - if self._tokenizer is not None: - encoded = self._tokenizer.encode(msg["content_server"]) - if isinstance(encoded, TokenizersEncoding): - token_count = len(encoded.ids) - else: - token_count = len(encoded) - msg["token_count"] = token_count + return None with reactive.isolate(): messages = self._messages() @@ -825,20 +815,20 @@ def _store_message( index = len(messages) messages = list(messages) - messages.insert(index, msg) + messages.insert(index, message) self._messages.set(tuple(messages)) - if msg["role"] == "user": - self._latest_user_input.set(msg) + if message["role"] == "user": + self._latest_user_input.set(message) - return msg + return None - @staticmethod def _trim_messages( - messages: tuple[StoredMessage, ...], + self, + messages: tuple[TransformedMessage, ...], token_limits: tuple[int, int], format: MISSING_TYPE | ProviderMessageFormat, - ) -> tuple[StoredMessage, ...]: + ) -> tuple[TransformedMessage, ...]: n_total, n_reserve = token_limits if n_total <= n_reserve: @@ -852,11 +842,10 @@ def _trim_messages( n_system_tokens: int = 0 n_system_messages: int = 0 n_other_messages: int = 0 + token_counts: list[int] = [] for m in messages: - count = m["token_count"] - # Count can be None if the tokenizer is None - if count is None: - return messages + count = self._get_token_count(m["content_server"]) + token_counts.append(count) if m["role"] == "system": n_system_tokens += count n_system_messages += 1 @@ -872,14 +861,16 @@ def _trim_messages( "`token_limit=None` to disable token limits." ) - messages2: list[StoredMessage] = [] + # Now, iterate through the messages in reverse order and appending + # until we run out of tokens + messages2: list[TransformedMessage] = [] n_other_messages2: int = 0 - for m in reversed(messages): + token_counts.reverse() + for i, m in enumerate(reversed(messages)): if m["role"] == "system": messages2.append(m) continue - count = cast(int, m["token_count"]) # Already checked this - remaining_non_system_tokens -= count + remaining_non_system_tokens -= token_counts[i] if remaining_non_system_tokens >= 0: messages2.append(m) n_other_messages2 += 1 @@ -906,6 +897,28 @@ def _trim_messages( return tuple(messages2) + def _get_token_count( + self, + content: str, + ) -> int: + if self._tokenizer is None: + self._tokenizer = get_default_tokenizer() + + if self._tokenizer is None: + raise ValueError( + "A tokenizer is required to impose `token_limits` on messages. " + "To get a generic default tokenizer, install the `tokenizers` " + "package (`pip install tokenizers`). " + "To get a more precise token count, provide a specific tokenizer " + "to the `Chat` constructor." + ) + + encoded = self._tokenizer.encode(content) + if isinstance(encoded, TokenizersEncoding): + return len(encoded.ids) + else: + return len(encoded) + def user_input(self, transform: bool = False) -> str | None: """ Reactively read the user's message. diff --git a/shiny/ui/_chat_tokenizer.py b/shiny/ui/_chat_tokenizer.py index aa6f515cb..eabf83179 100644 --- a/shiny/ui/_chat_tokenizer.py +++ b/shiny/ui/_chat_tokenizer.py @@ -1,6 +1,5 @@ from __future__ import annotations -import warnings from typing import ( AbstractSet, Any, @@ -51,19 +50,7 @@ def get_default_tokenizer() -> TokenizersTokenizer | None: from tokenizers import Tokenizer return Tokenizer.from_pretrained("bert-base-cased") # type: ignore - except ImportError: - warnings.warn( - "`Chat` is unable obtain a default tokenizer without the `tokenizers` " - "package installed. Please `pip install tokenizers` or set " - "`Chat(tokenizer=None)` to disable tokenization.", - stacklevel=2, - ) - return None except Exception: - warnings.warn( - "Unable to obtain a default tokenizer. " - "Consider providing one to `Chat()`'s `tokenizer` parameter " - "(or set it to `None` to disable tokenization).", - stacklevel=2, - ) - return None + pass + + return None diff --git a/shiny/ui/_chat_types.py b/shiny/ui/_chat_types.py index 696a3b7b3..458286720 100644 --- a/shiny/ui/_chat_types.py +++ b/shiny/ui/_chat_types.py @@ -21,12 +21,6 @@ class TransformedMessage(TypedDict): pre_transform_key: Literal["content_client", "content_server"] -# A message that has been stored in the server-side chat history -class StoredMessage(TransformedMessage): - # Number of tokens in the content - token_count: int | None - - # A message that can be sent to the client class ClientMessage(ChatMessage): content_type: Literal["markdown", "html"] diff --git a/tests/pytest/test_chat.py b/tests/pytest/test_chat.py index f2b27fde5..615912a5e 100644 --- a/tests/pytest/test_chat.py +++ b/tests/pytest/test_chat.py @@ -13,7 +13,7 @@ from shiny.ui import Chat from shiny.ui._chat import as_transformed_message from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk -from shiny.ui._chat_types import ChatMessage, StoredMessage +from shiny.ui._chat_types import ChatMessage # ---------------------------------------------------------------------- # Helpers @@ -35,14 +35,6 @@ def _increment_busy_count(self) -> None: test_session = cast(Session, _MockSession()) -def as_stored_message(message: ChatMessage, token_count: int) -> StoredMessage: - msg = as_transformed_message(message) - return StoredMessage( - **msg, - token_count=token_count, - ) - - # Check if a type is part of a Union def is_type_in_union(type: object, union: object) -> bool: if get_origin(union) is Union: @@ -50,18 +42,21 @@ def is_type_in_union(type: object, union: object) -> bool: return False -# ---------------------------------------------------------------------- -# Unit tests for Chat._get_trimmed_messages() -# ---------------------------------------------------------------------- - - def test_chat_message_trimming(): with session_context(test_session): chat = Chat(id="chat") + # Default tokenizer gives a token count + def generate_content(token_count: int) -> str: + n = int(token_count / 2) + return " ".join(["foo" for _ in range(1, n)]) + msgs = ( - as_stored_message( - {"content": "System message", "role": "system"}, token_count=101 + as_transformed_message( + { + "content": generate_content(102), + "role": "system", + } ), ) @@ -70,11 +65,17 @@ def test_chat_message_trimming(): chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING) msgs = ( - as_stored_message( - {"content": "System message", "role": "system"}, token_count=100 + as_transformed_message( + { + "content": generate_content(100), + "role": "system", + } ), - as_stored_message( - {"content": "User message", "role": "user"}, token_count=1 + as_transformed_message( + { + "content": generate_content(2), + "role": "user", + } ), ) @@ -83,64 +84,79 @@ def test_chat_message_trimming(): chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING) # Raising the limit should allow both messages to fit - trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING) + trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING) assert len(trimmed) == 2 - contents = [msg["content_server"] for msg in trimmed] - assert contents == ["System message", "User message"] + + content1 = generate_content(100) + content2 = generate_content(10) + content3 = generate_content(2) msgs = ( - as_stored_message( - {"content": "System message", "role": "system"}, token_count=100 + as_transformed_message( + { + "content": content1, + "role": "system", + } ), - as_stored_message( - {"content": "User message", "role": "user"}, token_count=10 + as_transformed_message( + { + "content": content2, + "role": "user", + } ), - as_stored_message( - {"content": "User message 2", "role": "user"}, token_count=1 + as_transformed_message( + { + "content": content3, + "role": "user", + } ), ) # Should discard the 1st user message - trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING) + trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING) assert len(trimmed) == 2 contents = [msg["content_server"] for msg in trimmed] - assert contents == ["System message", "User message 2"] + assert contents == [content1, content3] + + content1 = generate_content(50) + content2 = generate_content(10) + content3 = generate_content(50) + content4 = generate_content(2) msgs = ( - as_stored_message( - {"content": "System message", "role": "system"}, token_count=50 + as_transformed_message( + {"content": content1, "role": "system"}, ), - as_stored_message( - {"content": "User message", "role": "user"}, token_count=10 + as_transformed_message( + {"content": content2, "role": "user"}, ), - as_stored_message( - {"content": "System message 2", "role": "system"}, token_count=50 + as_transformed_message( + {"content": content3, "role": "system"}, ), - as_stored_message( - {"content": "User message 2", "role": "user"}, token_count=1 + as_transformed_message( + {"content": content4, "role": "user"}, ), ) # Should discard the 1st user message - trimmed = chat._trim_messages(msgs, token_limits=(102, 0), format=MISSING) + trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING) assert len(trimmed) == 3 contents = [msg["content_server"] for msg in trimmed] - assert contents == ["System message", "System message 2", "User message 2"] + assert contents == [content1, content3, content4] + + content1 = generate_content(50) + content2 = generate_content(10) msgs = ( - as_stored_message( - {"content": "Assistant message", "role": "assistant"}, token_count=50 - ), - as_stored_message( - {"content": "User message", "role": "user"}, token_count=10 - ), + as_transformed_message({"content": content1, "role": "assistant"}), + as_transformed_message({"content": content2, "role": "user"}), ) # Anthropic requires 1st message to be a user message trimmed = chat._trim_messages(msgs, token_limits=(30, 0), format="anthropic") assert len(trimmed) == 1 contents = [msg["content_server"] for msg in trimmed] - assert contents == ["User message"] + assert contents == [content2] # ------------------------------------------------------------------------------------