14
14
# limitations under the License.
15
15
16
16
import asyncio
17
+ import itertools
17
18
import unittest
18
19
from unittest .mock import AsyncMock , MagicMock , patch
19
20
@@ -312,10 +313,9 @@ async def test_tracing_does_not_mutate_user_options():
312
313
), "User's original options were modified! This causes instability."
313
314
314
315
# verify that tracing still works
315
- assert response .log is not None , "Tracing should still work correctly"
316
316
assert (
317
- response .log . activated_rails is not None
318
- ), "Should have activated rails data "
317
+ response .log is None
318
+ ), "Tracing should still work correctly, without affecting returned log "
319
319
320
320
321
321
@pytest .mark .asyncio
@@ -358,17 +358,16 @@ async def test_tracing_with_none_options():
358
358
messages = [{"role" : "user" , "content" : "hello" }], options = None
359
359
)
360
360
361
- assert response .log is not None
362
- assert response .log .activated_rails is not None
363
- assert response .log .stats is not None
361
+ assert response .log is None
364
362
365
363
366
364
@pytest .mark .asyncio
367
365
async def test_tracing_aggressive_override_when_all_disabled ():
368
366
"""Test that tracing aggressively enables all logging when user disables all options.
369
367
370
368
When user disables all three tracing related options, tracing still enables
371
- ALL of them to ensure comprehensive logging data.
369
+ ALL of them to ensure comprehensive logging data. However, this should not contaminate the
370
+ returned response object
372
371
"""
373
372
374
373
config = RailsConfig .from_content (
@@ -424,12 +423,9 @@ async def test_tracing_aggressive_override_when_all_disabled():
424
423
assert user_options .log .colang_history == original_colang_history
425
424
426
425
assert response .log is not None
427
- assert (
428
- response .log .activated_rails is not None
429
- and len (response .log .activated_rails ) > 0
430
- )
431
- assert response .log .llm_calls is not None
432
- assert response .log .internal_events is not None
426
+ assert response .log .activated_rails == []
427
+ assert response .log .llm_calls == []
428
+ assert response .log .internal_events == []
433
429
434
430
assert user_options .log .activated_rails == original_activated_rails
435
431
assert user_options .log .llm_calls == original_llm_calls
@@ -439,6 +435,104 @@ async def test_tracing_aggressive_override_when_all_disabled():
439
435
assert user_options .log .internal_events == False
440
436
441
437
438
+ @pytest .mark .asyncio
439
+ @pytest .mark .parametrize (
440
+ "activated_rails,llm_calls,internal_events,colang_history" ,
441
+ list (itertools .product ([False , True ], repeat = 4 )),
442
+ )
443
+ async def test_tracing_preserves_specific_log_fields (
444
+ activated_rails , llm_calls , internal_events , colang_history
445
+ ):
446
+ """Test that adding tracing respects the original user logging options in the response object"""
447
+
448
+ config = RailsConfig .from_content (
449
+ colang_content = """
450
+ define user express greeting
451
+ "hello"
452
+
453
+ define flow
454
+ user express greeting
455
+ bot express greeting
456
+
457
+ define bot express greeting
458
+ "Hello! How can I assist you today?"
459
+ """ ,
460
+ config = {
461
+ "models" : [],
462
+ "tracing" : {"enabled" : True , "adapters" : [{"name" : "FileSystem" }]},
463
+ },
464
+ )
465
+
466
+ chat = TestChat (
467
+ config ,
468
+ llm_completions = [
469
+ "user express greeting" ,
470
+ "bot express greeting" ,
471
+ "Hello! How can I assist you today?" ,
472
+ ],
473
+ )
474
+
475
+ # user enables some subset of log options
476
+ user_options = GenerationOptions (
477
+ log = GenerationLogOptions (
478
+ activated_rails = activated_rails ,
479
+ llm_calls = llm_calls ,
480
+ internal_events = internal_events ,
481
+ colang_history = colang_history ,
482
+ )
483
+ )
484
+
485
+ original_activated_rails = user_options .log .activated_rails
486
+ original_llm_calls = user_options .log .llm_calls
487
+ original_internal_events = user_options .log .internal_events
488
+ original_colang_history = user_options .log .colang_history
489
+
490
+ with patch .object (Tracer , "export_async" , return_value = None ):
491
+ response = await chat .app .generate_async (
492
+ messages = [{"role" : "user" , "content" : "hello" }], options = user_options
493
+ )
494
+
495
+ assert user_options .log .activated_rails == original_activated_rails
496
+ assert user_options .log .llm_calls == original_llm_calls
497
+ assert user_options .log .internal_events == original_internal_events
498
+ assert user_options .log .colang_history == original_colang_history
499
+
500
+ # verify that only the requested log options are returned in the response
501
+ if not any (
502
+ (
503
+ user_options .log .activated_rails ,
504
+ user_options .log .llm_calls ,
505
+ user_options .log .internal_events ,
506
+ user_options .log .colang_history ,
507
+ )
508
+ ):
509
+ assert response .log is None
510
+ else :
511
+ assert response .log is not None
512
+
513
+ if user_options .log .activated_rails :
514
+ assert len (response .log .activated_rails ) > 0
515
+ else :
516
+ assert len (response .log .activated_rails ) == 0
517
+
518
+ if user_options .log .llm_calls :
519
+ assert len (response .log .llm_calls ) > 0
520
+ else :
521
+ assert len (response .log .llm_calls ) == 0
522
+
523
+ if user_options .log .internal_events :
524
+ assert len (response .log .internal_events ) > 0
525
+ else :
526
+ assert len (response .log .internal_events ) == 0
527
+
528
+ assert user_options .log .activated_rails == original_activated_rails
529
+ assert user_options .log .llm_calls == original_llm_calls
530
+ assert user_options .log .internal_events == original_internal_events
531
+ assert user_options .log .activated_rails == activated_rails
532
+ assert user_options .log .llm_calls == llm_calls
533
+ assert user_options .log .internal_events == internal_events
534
+
535
+
442
536
@pytest .mark .asyncio
443
537
async def test_tracing_aggressive_override_with_dict_options ():
444
538
"""Test that tracing works correctly when options are passed as a dict.
@@ -502,11 +596,11 @@ async def test_tracing_aggressive_override_with_dict_options():
502
596
503
597
assert response .log is not None
504
598
assert (
505
- response .log .activated_rails is not None
506
- and len (response .log .activated_rails ) > 0
599
+ response .log .activated_rails == []
600
+ and len (response .log .activated_rails ) == 0
507
601
)
508
- assert response .log .llm_calls is not None
509
- assert response .log .internal_events is not None
602
+ assert response .log .llm_calls == []
603
+ assert response .log .internal_events == []
510
604
511
605
512
606
if __name__ == "__main__" :
0 commit comments