Skip to content

Commit 6da8010

Browse files
authored
fix(tracing): respect the user-provided log options regardless of tracing configuration
1 parent 6ba7832 commit 6da8010

File tree

2 files changed

+132
-17
lines changed

2 files changed

+132
-17
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,7 @@ async def generate_async(
10281028
await streaming_handler.push_chunk(END_OF_STREAM)
10291029

10301030
# IF tracing is enabled we need to set GenerationLog attrs
1031+
original_log_options = None
10311032
if self.config.tracing.enabled:
10321033
if options is None:
10331034
options = GenerationOptions()
@@ -1038,6 +1039,7 @@ async def generate_async(
10381039
else:
10391040
# If options is a dict, convert it to GenerationOptions
10401041
options = GenerationOptions(**options)
1042+
original_log_options = options.log.model_copy(deep=True)
10411043

10421044
# enable log options
10431045
# it is aggressive, but these are required for tracing
@@ -1155,6 +1157,25 @@ async def generate_async(
11551157
)
11561158
await tracer.export_async()
11571159

1160+
# respect original log specification, if tracing added information to the output
1161+
if original_log_options:
1162+
if not any(
1163+
(
1164+
original_log_options.internal_events,
1165+
original_log_options.activated_rails,
1166+
original_log_options.llm_calls,
1167+
original_log_options.colang_history,
1168+
)
1169+
):
1170+
res.log = None
1171+
else:
1172+
if not original_log_options.internal_events:
1173+
res.log.internal_events = []
1174+
if not original_log_options.activated_rails:
1175+
res.log.activated_rails = []
1176+
if not original_log_options.llm_calls:
1177+
res.log.llm_calls = []
1178+
11581179
return res
11591180
else:
11601181
# If a prompt is used, we only return the content of the message.

tests/test_tracing.py

Lines changed: 111 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
import itertools
1718
import unittest
1819
from unittest.mock import AsyncMock, MagicMock, patch
1920

@@ -312,10 +313,9 @@ async def test_tracing_does_not_mutate_user_options():
312313
), "User's original options were modified! This causes instability."
313314

314315
# verify that tracing still works
315-
assert response.log is not None, "Tracing should still work correctly"
316316
assert (
317-
response.log.activated_rails is not None
318-
), "Should have activated rails data"
317+
response.log is None
318+
), "Tracing should still work correctly, without affecting returned log"
319319

320320

321321
@pytest.mark.asyncio
@@ -358,17 +358,16 @@ async def test_tracing_with_none_options():
358358
messages=[{"role": "user", "content": "hello"}], options=None
359359
)
360360

361-
assert response.log is not None
362-
assert response.log.activated_rails is not None
363-
assert response.log.stats is not None
361+
assert response.log is None
364362

365363

366364
@pytest.mark.asyncio
367365
async def test_tracing_aggressive_override_when_all_disabled():
368366
"""Test that tracing aggressively enables all logging when user disables all options.
369367
370368
When user disables all three tracing related options, tracing still enables
371-
ALL of them to ensure comprehensive logging data.
369+
ALL of them to ensure comprehensive logging data. However, this should not contaminate the
370+
returned response object
372371
"""
373372

374373
config = RailsConfig.from_content(
@@ -424,12 +423,9 @@ async def test_tracing_aggressive_override_when_all_disabled():
424423
assert user_options.log.colang_history == original_colang_history
425424

426425
assert response.log is not None
427-
assert (
428-
response.log.activated_rails is not None
429-
and len(response.log.activated_rails) > 0
430-
)
431-
assert response.log.llm_calls is not None
432-
assert response.log.internal_events is not None
426+
assert response.log.activated_rails == []
427+
assert response.log.llm_calls == []
428+
assert response.log.internal_events == []
433429

434430
assert user_options.log.activated_rails == original_activated_rails
435431
assert user_options.log.llm_calls == original_llm_calls
@@ -439,6 +435,104 @@ async def test_tracing_aggressive_override_when_all_disabled():
439435
assert user_options.log.internal_events == False
440436

441437

438+
@pytest.mark.asyncio
439+
@pytest.mark.parametrize(
440+
"activated_rails,llm_calls,internal_events,colang_history",
441+
list(itertools.product([False, True], repeat=4)),
442+
)
443+
async def test_tracing_preserves_specific_log_fields(
444+
activated_rails, llm_calls, internal_events, colang_history
445+
):
446+
"""Test that adding tracing respects the original user logging options in the response object"""
447+
448+
config = RailsConfig.from_content(
449+
colang_content="""
450+
define user express greeting
451+
"hello"
452+
453+
define flow
454+
user express greeting
455+
bot express greeting
456+
457+
define bot express greeting
458+
"Hello! How can I assist you today?"
459+
""",
460+
config={
461+
"models": [],
462+
"tracing": {"enabled": True, "adapters": [{"name": "FileSystem"}]},
463+
},
464+
)
465+
466+
chat = TestChat(
467+
config,
468+
llm_completions=[
469+
"user express greeting",
470+
"bot express greeting",
471+
"Hello! How can I assist you today?",
472+
],
473+
)
474+
475+
# user enables some subset of log options
476+
user_options = GenerationOptions(
477+
log=GenerationLogOptions(
478+
activated_rails=activated_rails,
479+
llm_calls=llm_calls,
480+
internal_events=internal_events,
481+
colang_history=colang_history,
482+
)
483+
)
484+
485+
original_activated_rails = user_options.log.activated_rails
486+
original_llm_calls = user_options.log.llm_calls
487+
original_internal_events = user_options.log.internal_events
488+
original_colang_history = user_options.log.colang_history
489+
490+
with patch.object(Tracer, "export_async", return_value=None):
491+
response = await chat.app.generate_async(
492+
messages=[{"role": "user", "content": "hello"}], options=user_options
493+
)
494+
495+
assert user_options.log.activated_rails == original_activated_rails
496+
assert user_options.log.llm_calls == original_llm_calls
497+
assert user_options.log.internal_events == original_internal_events
498+
assert user_options.log.colang_history == original_colang_history
499+
500+
# verify that only the requested log options are returned in the response
501+
if not any(
502+
(
503+
user_options.log.activated_rails,
504+
user_options.log.llm_calls,
505+
user_options.log.internal_events,
506+
user_options.log.colang_history,
507+
)
508+
):
509+
assert response.log is None
510+
else:
511+
assert response.log is not None
512+
513+
if user_options.log.activated_rails:
514+
assert len(response.log.activated_rails) > 0
515+
else:
516+
assert len(response.log.activated_rails) == 0
517+
518+
if user_options.log.llm_calls:
519+
assert len(response.log.llm_calls) > 0
520+
else:
521+
assert len(response.log.llm_calls) == 0
522+
523+
if user_options.log.internal_events:
524+
assert len(response.log.internal_events) > 0
525+
else:
526+
assert len(response.log.internal_events) == 0
527+
528+
assert user_options.log.activated_rails == original_activated_rails
529+
assert user_options.log.llm_calls == original_llm_calls
530+
assert user_options.log.internal_events == original_internal_events
531+
assert user_options.log.activated_rails == activated_rails
532+
assert user_options.log.llm_calls == llm_calls
533+
assert user_options.log.internal_events == internal_events
534+
535+
442536
@pytest.mark.asyncio
443537
async def test_tracing_aggressive_override_with_dict_options():
444538
"""Test that tracing works correctly when options are passed as a dict.
@@ -502,11 +596,11 @@ async def test_tracing_aggressive_override_with_dict_options():
502596

503597
assert response.log is not None
504598
assert (
505-
response.log.activated_rails is not None
506-
and len(response.log.activated_rails) > 0
599+
response.log.activated_rails == []
600+
and len(response.log.activated_rails) == 0
507601
)
508-
assert response.log.llm_calls is not None
509-
assert response.log.internal_events is not None
602+
assert response.log.llm_calls == []
603+
assert response.log.internal_events == []
510604

511605

512606
if __name__ == "__main__":

0 commit comments

Comments
 (0)