@@ -825,6 +825,22 @@ int RAI_parseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
825
825
}
826
826
RedisModule_ReplyWithDouble (ctx , val );
827
827
}
828
+ argpos ++ ;
829
+ break ;
830
+ } else {
831
+ long long dimension = 1 ;
832
+ const int retval = RedisModule_StringToLongLong (argv [argpos ],& dimension );
833
+ if (retval != REDISMODULE_OK || dimension <= 0 ) {
834
+ RedisModule_Free (dims );
835
+ RedisModule_CloseKey (key );
836
+ return RedisModule_ReplyWithError (ctx ,
837
+ "ERR invalid or negative value found in tensor shape" );
838
+ }
839
+
840
+ ndims ++ ;
841
+ dims = RedisModule_Realloc (dims ,ndims * sizeof (long long ));
842
+ dims [ndims - 1 ]= dimension ;
843
+ len *= dimension ;
828
844
}
829
845
else {
830
846
long long val ;
@@ -903,6 +919,13 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
903
919
if (strlen (devicestr ) > 10 ) {
904
920
return RedisModule_ReplyWithError (ctx , "ERR Invalid DEVICE" );
905
921
}
922
+ else if (datafmt == REDISAI_DATA_VALUES ) {
923
+ long long ndims = RAI_TensorNumDims (t );
924
+ long long len = 1 ;
925
+ long long i ;
926
+ for (i = 0 ; i < ndims ; i ++ ) {
927
+ len *= RAI_TensorDim (t , i );
928
+ }
906
929
907
930
const char * tag = "" ;
908
931
if (AC_AdvanceIfMatch (& ac , "TAG" )) {
@@ -928,6 +951,10 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
928
951
return RedisModule_ReplyWithError (ctx , "ERR Invalid argument for MINBATCHSIZE" );
929
952
}
930
953
}
954
+ RedisModule_CloseKey (key );
955
+
956
+ return REDISMODULE_OK ;
957
+ }
931
958
932
959
933
960
if (AC_IsAtEnd (& ac )) {
@@ -1053,8 +1080,11 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
1053
1080
1054
1081
RedisModule_ReplicateVerbatim (ctx );
1055
1082
1056
- return REDISMODULE_OK ;
1057
- }
1083
+ size_t ninputs = inac .argc ;
1084
+ const char * inputs [ninputs ];
1085
+ for (size_t i = 0 ; i < ninputs ; i ++ ) {
1086
+ AC_GetString (& inac , inputs + i , NULL , 0 );
1087
+ }
1058
1088
1059
1089
/**
1060
1090
* AI.MODELGET model_key [META | BLOB]
@@ -1077,6 +1107,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
1077
1107
else if (!strcasecmp (optstr , "BLOB" )) {
1078
1108
blob = 1 ;
1079
1109
}
1110
+ RAI_ClearError (& err );
1111
+ model = RAI_ModelCreate (backend , devicestr , tag , opts , ninputs , inputs , noutputs , outputs , modeldef , modellen , & err );
1080
1112
}
1081
1113
1082
1114
RAI_Error err = {0 };
0 commit comments