Skip to content

Commit 4eff3b5

Browse files
Add getter functions for TLM defaults (#59)
Co-authored-by: Jonas Mueller <[email protected]>
1 parent 6258826 commit 4eff3b5

File tree

8 files changed

+131
-15
lines changed

8 files changed

+131
-15
lines changed

CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.1.2] - 2025-05-01
11+
12+
- Add getter functions for `_TLM_DEFAULT_MODEL`, `_DEFAULT_TLM_QUALITY_PRESET`, `_TLM_DEFAULT_CONTEXT_LIMIT`, `_TLM_MAX_TOKEN_RANGE`.
13+
- Add unit tests for the getter functions.
14+
1015
## [1.1.1] - 2025-04-23
1116

1217
### Changed
@@ -141,7 +146,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
141146

142147
- Release of the Cleanlab TLM Python client.
143148

144-
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.1...HEAD
149+
[Unreleased]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.2...HEAD
150+
[1.1.1]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.1...v1.1.2
145151
[1.1.1]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.1.0...v1.1.1
146152
[1.1.0]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.0.23...v1.1.0
147153
[1.0.23]: https://github.com/cleanlab/cleanlab-tlm/compare/v1.0.22...v1.0.23

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ extra-dependencies = [
4747
"pytest",
4848
"pytest-asyncio",
4949
"python-dotenv",
50+
"tiktoken",
5051
]
5152
[tool.hatch.envs.types.scripts]
5253
check = "mypy --strict --install-types --non-interactive {args:src/cleanlab_tlm tests}"
@@ -57,6 +58,7 @@ allow-direct-references = true
5758
extra-dependencies = [
5859
"python-dotenv",
5960
"pytest-asyncio",
61+
"tiktoken",
6062
]
6163

6264
[tool.hatch.envs.hatch-test.env-vars]

src/cleanlab_tlm/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.1.1"
2+
__version__ = "1.1.2"

src/cleanlab_tlm/internal/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
_VALID_TLM_QUALITY_PRESETS: list[str] = ["best", "high", "medium", "low", "base"]
66
_VALID_TLM_QUALITY_PRESETS_RAG: list[str] = ["medium", "low", "base"]
77
_DEFAULT_TLM_QUALITY_PRESET: TLMQualityPreset = "medium"
8+
_DEFAULT_TLM_MAX_TOKENS: int = 512
89
_VALID_TLM_MODELS: list[str] = [
910
"gpt-3.5-turbo-16k",
1011
"gpt-4",
@@ -32,6 +33,7 @@
3233
"nova-pro",
3334
]
3435
_TLM_DEFAULT_MODEL: str = "gpt-4o-mini"
36+
_TLM_DEFAULT_CONTEXT_LIMIT: int = 70000
3537
_VALID_TLM_TASKS: set[str] = {task.value for task in Task}
3638
TLM_TASK_SUPPORTING_CONSTRAIN_OUTPUTS: set[Task] = {
3739
Task.DEFAULT,

src/cleanlab_tlm/utils/config.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from cleanlab_tlm.internal.constants import (
2+
_DEFAULT_TLM_MAX_TOKENS,
3+
_DEFAULT_TLM_QUALITY_PRESET,
4+
_TLM_DEFAULT_CONTEXT_LIMIT,
5+
_TLM_DEFAULT_MODEL,
6+
)
7+
8+
9+
def get_default_model() -> str:
10+
"""
11+
Get the default model name for TLM.
12+
13+
Returns:
14+
str: The default model name for TLM.
15+
"""
16+
return _TLM_DEFAULT_MODEL
17+
18+
19+
def get_default_quality_preset() -> str:
20+
"""
21+
Get the default quality preset for TLM.
22+
23+
Returns:
24+
str: The default quality preset for TLM.
25+
"""
26+
return _DEFAULT_TLM_QUALITY_PRESET
27+
28+
29+
def get_default_context_limit() -> int:
30+
"""
31+
Get the default context limit for TLM.
32+
33+
Returns:
34+
int: The default context limit for TLM.
35+
"""
36+
return _TLM_DEFAULT_CONTEXT_LIMIT
37+
38+
39+
def get_default_max_tokens() -> int:
40+
"""
41+
Get the default maximum output tokens allowed.
42+
43+
Returns:
44+
int: The default maximum output tokens.
45+
"""
46+
return _DEFAULT_TLM_MAX_TOKENS

tests/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
MAX_COMBINED_LENGTH_TOKENS: int = 70_000
1818

1919
CHARACTERS_PER_TOKEN: int = 4
20+
# 4 character (3 character + 1 space) = 1 token
21+
WORD_THAT_EQUALS_ONE_TOKEN = "orb " # noqa: S105
2022

2123
# Property tests for TLM
2224
excluded_tlm_models: list[str] = [

tests/test_config.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
import tiktoken
3+
4+
from cleanlab_tlm.errors import TlmBadRequestError
5+
from cleanlab_tlm.tlm import TLM
6+
from cleanlab_tlm.utils.config import (
7+
get_default_context_limit,
8+
get_default_max_tokens,
9+
get_default_model,
10+
get_default_quality_preset,
11+
)
12+
from tests.constants import WORD_THAT_EQUALS_ONE_TOKEN
13+
14+
tlm_with_default_setting = TLM()
15+
16+
17+
def test_get_default_model(tlm: TLM) -> None:
18+
assert tlm.get_model_name() == get_default_model()
19+
20+
21+
def test_get_default_quality_preset(tlm: TLM) -> None:
22+
assert get_default_quality_preset() == tlm._quality_preset
23+
24+
25+
def test_prompt_too_long_exception_single_prompt(tlm: TLM) -> None:
26+
"""Tests that bad request error is raised when prompt is too long when calling tlm.prompt with a single prompt."""
27+
with pytest.raises(TlmBadRequestError) as exc_info:
28+
tlm.prompt(WORD_THAT_EQUALS_ONE_TOKEN * (get_default_context_limit() + 1))
29+
30+
assert exc_info.value.message.startswith("Prompt length exceeds")
31+
assert exc_info.value.retryable is False
32+
33+
34+
def test_prompt_within_context_limit_returns_response(tlm: TLM) -> None:
35+
"""Tests that no error is raised when prompt length is within limit."""
36+
response = tlm.prompt(WORD_THAT_EQUALS_ONE_TOKEN * (get_default_context_limit() - 1000))
37+
38+
assert isinstance(response, dict)
39+
assert "response" in response
40+
assert isinstance(response["response"], str)
41+
42+
43+
def test_response_within_max_tokens() -> None:
44+
"""Tests that response is within max tokens limit."""
45+
tlm_base = TLM(quality_preset="base")
46+
prompt = "write a 100 page book about computer science. make sure it is extremely long and comprehensive."
47+
48+
result = tlm_base.prompt(prompt)
49+
assert isinstance(result, dict)
50+
response = result["response"]
51+
assert isinstance(response, str)
52+
53+
try:
54+
enc = tiktoken.encoding_for_model(get_default_model())
55+
except KeyError:
56+
enc = tiktoken.encoding_for_model("gpt-4o")
57+
tokens_in_response = len(enc.encode(response))
58+
assert tokens_in_response <= get_default_max_tokens()

tests/test_validation.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAG
1313
from tests.conftest import make_text_unique
1414
from tests.constants import (
15-
CHARACTERS_PER_TOKEN,
1615
MAX_COMBINED_LENGTH_TOKENS,
1716
MAX_PROMPT_LENGTH_TOKENS,
1817
MAX_RESPONSE_LENGTH_TOKENS,
1918
TEST_PROMPT,
2019
TEST_PROMPT_BATCH,
2120
TEST_RESPONSE,
21+
WORD_THAT_EQUALS_ONE_TOKEN,
2222
)
2323
from tests.test_get_trustworthiness_score import is_tlm_score_response_with_error
2424
from tests.test_prompt import is_tlm_response_with_error
@@ -208,7 +208,7 @@ def test_prompt_too_long_exception_single_prompt(tlm: TLM) -> None:
208208
"""Tests that bad request error is raised when prompt is too long when calling tlm.prompt with a single prompt."""
209209
with pytest.raises(TlmBadRequestError) as exc_info:
210210
tlm.prompt(
211-
"a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN,
211+
WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1),
212212
)
213213

214214
assert exc_info.value.message.startswith("Prompt length exceeds")
@@ -221,7 +221,7 @@ def test_prompt_too_long_exception_prompt(tlm: TLM, num_prompts: int) -> None:
221221
# create batch of prompts with one prompt that is too long
222222
prompts = [test_prompt] * num_prompts
223223
prompt_too_long_index = np.random.randint(0, num_prompts)
224-
prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN
224+
prompts[prompt_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1)
225225

226226
tlm_responses = cast(list[TLMResponse], tlm.prompt(prompts))
227227

@@ -232,8 +232,8 @@ def test_response_too_long_exception_single_score(tlm: TLM) -> None:
232232
"""Tests that bad request error is raised when response is too long when calling tlm.get_trustworthiness_score with a single prompt."""
233233
with pytest.raises(TlmBadRequestError) as exc_info:
234234
tlm.get_trustworthiness_score(
235-
"a",
236-
"a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN,
235+
WORD_THAT_EQUALS_ONE_TOKEN,
236+
WORD_THAT_EQUALS_ONE_TOKEN * (MAX_RESPONSE_LENGTH_TOKENS + 1),
237237
)
238238

239239
assert exc_info.value.message.startswith("Response length exceeds")
@@ -247,7 +247,7 @@ def test_response_too_long_exception_score(tlm: TLM, num_prompts: int) -> None:
247247
prompts = [test_prompt] * num_prompts
248248
responses = [TEST_RESPONSE] * num_prompts
249249
response_too_long_index = np.random.randint(0, num_prompts)
250-
responses[response_too_long_index] = "a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN
250+
responses[response_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * (MAX_RESPONSE_LENGTH_TOKENS + 1)
251251

252252
tlm_responses = cast(list[TLMScore], tlm.get_trustworthiness_score(prompts, responses))
253253

@@ -258,8 +258,8 @@ def test_prompt_too_long_exception_single_score(tlm: TLM) -> None:
258258
"""Tests that bad request error is raised when prompt is too long when calling tlm.get_trustworthiness_score with a single prompt."""
259259
with pytest.raises(TlmBadRequestError) as exc_info:
260260
tlm.get_trustworthiness_score(
261-
"a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN,
262-
"a",
261+
WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1),
262+
WORD_THAT_EQUALS_ONE_TOKEN,
263263
)
264264

265265
assert exc_info.value.message.startswith("Prompt length exceeds")
@@ -273,7 +273,7 @@ def test_prompt_too_long_exception_score(tlm: TLM, num_prompts: int) -> None:
273273
prompts = [test_prompt] * num_prompts
274274
responses = [TEST_RESPONSE] * num_prompts
275275
prompt_too_long_index = np.random.randint(0, num_prompts)
276-
prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN
276+
prompts[prompt_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * (MAX_PROMPT_LENGTH_TOKENS + 1)
277277

278278
tlm_responses = cast(list[TLMScore], tlm.get_trustworthiness_score(prompts, responses))
279279

@@ -286,8 +286,8 @@ def test_combined_too_long_exception_single_score(tlm: TLM) -> None:
286286

287287
with pytest.raises(TlmBadRequestError) as exc_info:
288288
tlm.get_trustworthiness_score(
289-
"a" * max_prompt_length * CHARACTERS_PER_TOKEN,
290-
"a" * MAX_RESPONSE_LENGTH_TOKENS * CHARACTERS_PER_TOKEN,
289+
WORD_THAT_EQUALS_ONE_TOKEN * max_prompt_length,
290+
WORD_THAT_EQUALS_ONE_TOKEN * MAX_RESPONSE_LENGTH_TOKENS,
291291
)
292292

293293
assert exc_info.value.message.startswith("Prompt and response combined length exceeds")
@@ -306,8 +306,8 @@ def test_prompt_and_response_combined_too_long_exception_batch_score(tlm: TLM, n
306306
combined_too_long_index = np.random.randint(0, num_prompts)
307307

308308
max_prompt_length = MAX_COMBINED_LENGTH_TOKENS - MAX_RESPONSE_LENGTH_TOKENS + 1
309-
prompts[combined_too_long_index] = "a" * max_prompt_length * CHARACTERS_PER_TOKEN
310-
responses[combined_too_long_index] = "a" * MAX_RESPONSE_LENGTH_TOKENS * CHARACTERS_PER_TOKEN
309+
prompts[combined_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * max_prompt_length
310+
responses[combined_too_long_index] = WORD_THAT_EQUALS_ONE_TOKEN * MAX_RESPONSE_LENGTH_TOKENS
311311

312312
tlm_responses = cast(list[TLMScore], tlm.get_trustworthiness_score(prompts, responses))
313313

0 commit comments

Comments
 (0)