@@ -446,6 +446,61 @@ def test_pytorch_scriptrun(env):
446
446
values2 = con2 .execute_command ('AI.TENSORGET' , 'c' , 'VALUES' )
447
447
env .assertEqual (values2 , values )
448
448
449
+
450
+ def test_pytorch_scriptrun_variadic (env ):
451
+ if not TEST_PT :
452
+ env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
453
+ return
454
+
455
+ con = env .getConnection ()
456
+
457
+ test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
458
+ script_filename = os .path .join (test_data_path , 'script.txt' )
459
+
460
+ with open (script_filename , 'rb' ) as f :
461
+ script = f .read ()
462
+
463
+ ret = con .execute_command ('AI.SCRIPTSET' , 'myscript' , DEVICE , 'TAG' , 'version1' , 'SOURCE' , script )
464
+ env .assertEqual (ret , b'OK' )
465
+
466
+ ret = con .execute_command ('AI.TENSORSET' , 'a' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
467
+ env .assertEqual (ret , b'OK' )
468
+ ret = con .execute_command ('AI.TENSORSET' , 'b1' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
469
+ env .assertEqual (ret , b'OK' )
470
+ ret = con .execute_command ('AI.TENSORSET' , 'b2' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
471
+ env .assertEqual (ret , b'OK' )
472
+
473
+ ensureSlaveSynced (con , env )
474
+
475
+ for _ in range ( 0 ,100 ):
476
+ ret = con .execute_command ('AI.SCRIPTRUN' , 'myscript' , 'bar_variadic' , 'INPUTS' , 'a' , '$' , 'b1' , 'b2' , 'OUTPUTS' , 'c' )
477
+ env .assertEqual (ret , b'OK' )
478
+
479
+ ensureSlaveSynced (con , env )
480
+
481
+ info = con .execute_command ('AI.INFO' , 'myscript' )
482
+ info_dict_0 = info_to_dict (info )
483
+
484
+ env .assertEqual (info_dict_0 ['key' ], 'myscript' )
485
+ env .assertEqual (info_dict_0 ['type' ], 'SCRIPT' )
486
+ env .assertEqual (info_dict_0 ['backend' ], 'TORCH' )
487
+ env .assertEqual (info_dict_0 ['tag' ], 'version1' )
488
+ env .assertTrue (info_dict_0 ['duration' ] > 0 )
489
+ env .assertEqual (info_dict_0 ['samples' ], - 1 )
490
+ env .assertEqual (info_dict_0 ['calls' ], 100 )
491
+ env .assertEqual (info_dict_0 ['errors' ], 0 )
492
+
493
+ values = con .execute_command ('AI.TENSORGET' , 'c' , 'VALUES' )
494
+ env .assertEqual (values , [b'4' , b'6' , b'4' , b'6' ])
495
+
496
+ ensureSlaveSynced (con , env )
497
+
498
+ if env .useSlaves :
499
+ con2 = env .getSlaveConnection ()
500
+ values2 = con2 .execute_command ('AI.TENSORGET' , 'c' , 'VALUES' )
501
+ env .assertEqual (values2 , values )
502
+
503
+
449
504
def test_pytorch_scriptrun_errors (env ):
450
505
if not TEST_PT :
451
506
env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
@@ -548,6 +603,66 @@ def test_pytorch_scriptrun_errors(env):
548
603
env .assertEqual (type (exception ), redis .exceptions .ResponseError )
549
604
550
605
606
+ def test_pytorch_scriptrun_errors (env ):
607
+ if not TEST_PT :
608
+ env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
609
+ return
610
+
611
+ con = env .getConnection ()
612
+
613
+ test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
614
+ script_filename = os .path .join (test_data_path , 'script.txt' )
615
+
616
+ with open (script_filename , 'rb' ) as f :
617
+ script = f .read ()
618
+
619
+ ret = con .execute_command ('AI.SCRIPTSET' , 'ket' , DEVICE , 'TAG' , 'asdf' , 'SOURCE' , script )
620
+ env .assertEqual (ret , b'OK' )
621
+
622
+ ret = con .execute_command ('AI.TENSORSET' , 'a' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
623
+ env .assertEqual (ret , b'OK' )
624
+ ret = con .execute_command ('AI.TENSORSET' , 'b' , 'FLOAT' , 2 , 2 , 'VALUES' , 2 , 3 , 2 , 3 )
625
+ env .assertEqual (ret , b'OK' )
626
+
627
+ ensureSlaveSynced (con , env )
628
+
629
+ # ERR Variadic input key is empty
630
+ try :
631
+ con .execute_command ('DEL' , 'EMPTY' )
632
+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'a' , '$' , 'EMPTY' , 'b' , 'OUTPUTS' , 'c' )
633
+ except Exception as e :
634
+ exception = e
635
+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
636
+ env .assertEqual ("tensor key is empty" , exception .__str__ ())
637
+
638
+ # ERR Variadic input key not tensor
639
+ try :
640
+ con .execute_command ('SET' , 'NOT_TENSOR' , 'BAR' )
641
+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'a' , '$' , 'NOT_TENSOR' , 'b' , 'OUTPUTS' , 'c' )
642
+ except Exception as e :
643
+ exception = e
644
+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
645
+ env .assertEqual ("WRONGTYPE Operation against a key holding the wrong kind of value" , exception .__str__ ())
646
+
647
+ try :
648
+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'b' , '$' , 'OUTPUTS' , 'c' )
649
+ except Exception as e :
650
+ exception = e
651
+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
652
+
653
+ try :
654
+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , 'b' , '$' , 'OUTPUTS' )
655
+ except Exception as e :
656
+ exception = e
657
+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
658
+
659
+ try :
660
+ con .execute_command ('AI.SCRIPTRUN' , 'ket' , 'bar_variadic' , 'INPUTS' , '$' , 'OUTPUTS' )
661
+ except Exception as e :
662
+ exception = e
663
+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
664
+
665
+
551
666
def test_pytorch_scriptinfo (env ):
552
667
if not TEST_PT :
553
668
env .debugPrint ("skipping {} since TEST_PT=0" .format (sys ._getframe ().f_code .co_name ), force = True )
0 commit comments