@@ -93,20 +93,6 @@ mstime_t mstime(void) {
93
93
return ustime ()/1000 ;
94
94
}
95
95
96
- TF_Tensor * RedisDL_clone (RedisModuleCtx * ctx , const TF_Tensor * tensor ) {
97
- int ndims = TF_NumDims (tensor );
98
- long long * dims = RedisModule_PoolAlloc (ctx , ndims * sizeof (long long ));
99
- for (int j = 0 ; j < ndims ; j ++ ) {
100
- dims [j ] = TF_Dim (tensor , j );
101
- }
102
- size_t len = TF_TensorByteSize (tensor );
103
- void * data = TF_TensorData (tensor );
104
- TF_Tensor * out = TF_AllocateTensor (TF_TensorType (tensor ),
105
- dims , ndims , len );
106
- memcpy (TF_TensorData (out ), data , len );
107
- return out ;
108
- }
109
-
110
96
enum RedisDL_DataFmt {
111
97
REDISDL_DATA_BLOB = 0 ,
112
98
REDISDL_DATA_VALUES
@@ -493,28 +479,19 @@ int RedisDL_GSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
493
479
494
480
struct RedisDL_RunInfo {
495
481
RedisModuleBlockedClient * client ;
496
- TF_Session * session ;
497
- TF_Output * inputs ;
498
- TF_Tensor * * input_values ;
499
- long long ninputs ;
500
- TF_Output * outputs ;
501
- TF_Tensor * * output_values ;
502
- long long noutputs ;
503
- RedisModuleKey * graphkey ;
504
482
RedisModuleString * * outkeys ;
505
- TF_Status * status ;
483
+ RDL_GraphRunCtx * gctx ;
484
+ int status ;
506
485
};
507
486
508
- void RedisDL_FreeRunInfo (struct RedisDL_RunInfo * rinfo ) {
509
- RedisModule_Free (rinfo -> inputs );
510
- for (int i = 0 ; i < rinfo -> ninputs ; i ++ ) {
511
- TF_DeleteTensor (rinfo -> input_values [i ]);
487
+ void RedisDL_FreeRunInfo (RedisModuleCtx * ctx , struct RedisDL_RunInfo * rinfo ) {
488
+ for (int i = 0 ; i < Graph_RunCtxNumOutputs (rinfo -> gctx ) ; ++ i ){
489
+ RedisModule_FreeString (ctx , rinfo -> outkeys [i ]);
512
490
}
513
- RedisModule_Free (rinfo -> input_values );
514
- RedisModule_Free (rinfo -> outputs );
515
- RedisModule_Free (rinfo -> output_values );
516
491
RedisModule_Free (rinfo -> outkeys );
517
- TF_DeleteStatus (rinfo -> status );
492
+
493
+ Graph_RunCtxFreeInternals (rinfo -> gctx );
494
+
518
495
RedisModule_Free (rinfo );
519
496
}
520
497
@@ -525,24 +502,19 @@ void *RedisDL_RunSession(void *arg) {
525
502
526
503
mstime_t start = mstime ();
527
504
528
- TF_Status * status = TF_NewStatus ();
529
505
530
- TF_SessionRun (rinfo -> session , NULL /* run_options */ ,
531
- rinfo -> inputs , rinfo -> input_values , rinfo -> ninputs ,
532
- rinfo -> outputs , rinfo -> output_values , rinfo -> noutputs ,
533
- NULL /* target_opers */ , 0 /* ntargets */ ,
534
- NULL /* run_Metadata */ ,
535
- rinfo -> status );
506
+ rinfo -> status = Graph_Run (rinfo -> gctx );
536
507
537
508
mstime_t end = mstime ();
538
509
539
510
RedisModule_ThreadSafeContextLock (ctx );
540
511
RedisModule_Log (ctx , "notice" , "TF_SessionRun took %fs" , (end - start ) / 1000.0 );
541
512
RedisModule_ThreadSafeContextUnlock (ctx );
542
513
543
- RedisModule_FreeThreadSafeContext (ctx );
544
514
RedisModule_UnblockClient (rinfo -> client , rinfo );
545
515
516
+ RedisModule_FreeThreadSafeContext (ctx );
517
+
546
518
return NULL ;
547
519
}
548
520
@@ -557,31 +529,33 @@ int RedisDL_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
557
529
REDISMODULE_NOT_USED (argc );
558
530
struct RedisDL_RunInfo * rinfo = RedisModule_GetBlockedClientPrivateData (ctx );
559
531
560
- if (TF_GetCode ( rinfo -> status ) != TF_OK ) {
561
- int ret = RedisModule_ReplyWithError (ctx , TF_Message ( rinfo -> status ) );
562
- RedisDL_FreeRunInfo (rinfo );
532
+ if (! rinfo -> status ) {
533
+ int ret = RedisModule_ReplyWithError (ctx , "graph run failed" );
534
+ RedisDL_FreeRunInfo (ctx , rinfo );
563
535
return ret ;
564
536
}
565
537
566
- for (int i = 0 ; i < rinfo -> noutputs ; i ++ ) {
538
+ for (size_t i = 0 ; i < Graph_RunCtxNumOutputs ( rinfo -> gctx ); ++ i ) {
567
539
RedisModuleKey * outkey = RedisModule_OpenKey (ctx , rinfo -> outkeys [i ],
568
540
REDISMODULE_READ |REDISMODULE_WRITE );
569
541
int type = RedisModule_KeyType (outkey );
570
542
if (type != REDISMODULE_KEYTYPE_EMPTY &&
571
543
RedisModule_ModuleTypeGetType (outkey ) != RedisDL_TensorType ) {
572
544
RedisModule_CloseKey (outkey );
573
- RedisDL_FreeRunInfo (rinfo );
545
+ RedisDL_FreeRunInfo (ctx , rinfo );
574
546
return RedisModule_ReplyWithError (ctx , REDISMODULE_ERRORMSG_WRONGTYPE );
575
547
}
576
- RDL_Tensor * t = Tensor_CreateFromTensor (rinfo -> output_values [i ]);
577
- RedisModule_ModuleTypeSetValue (outkey , RedisDL_TensorType , t );
548
+ RDL_Tensor * t = Graph_RunCtxOutputTensor (rinfo -> gctx , i );
549
+ if (t ){
550
+ RedisModule_ModuleTypeSetValue (outkey , RedisDL_TensorType , Tensor_GetShallowCopy (t ));
551
+ }
578
552
RedisModule_CloseKey (outkey );
579
553
}
580
554
581
555
// FIXME This crashes Redis, we need to investigate.
582
556
//RedisModule_CloseKey(rinfo->graphkey);
583
557
584
- RedisDL_FreeRunInfo (rinfo );
558
+ RedisDL_FreeRunInfo (ctx , rinfo );
585
559
586
560
return RedisModule_ReplyWithSimpleString (ctx , "OK" );
587
561
}
@@ -610,9 +584,6 @@ int RedisDL_GRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
610
584
611
585
RDL_Graph * gto = RedisModule_ModuleTypeGetValue (key );
612
586
613
- TF_Graph * graph = gto -> graph ;
614
- TF_Session * session = gto -> session ;
615
-
616
587
long long ninputs ;
617
588
if ((RedisModule_StringToLongLong (argv [2 ], & ninputs ) != REDISMODULE_OK )
618
589
|| ninputs < 0 ) {
@@ -634,64 +605,42 @@ int RedisDL_GRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
634
605
return RedisModule_ReplyWithError (ctx , "ERR odd key/name pairs" );
635
606
}
636
607
637
- TF_Output * inputs = RedisModule_Alloc (ninputs * sizeof (TF_Output ));
638
- TF_Output * outputs = RedisModule_Alloc (noutputs * sizeof (TF_Output ));
639
- TF_Tensor * * input_values = RedisModule_Alloc (ninputs * sizeof (TF_Tensor * ));
640
- TF_Tensor * * output_values = RedisModule_Alloc (noutputs * sizeof (TF_Tensor * ));
608
+ struct RedisDL_RunInfo * rinfo = RedisModule_Alloc (sizeof (struct RedisDL_RunInfo ));
609
+ rinfo -> gctx = Graph_RunCtxCreate (gto );
641
610
642
- RedisModuleString * * outkeys = RedisModule_Alloc (noutputs * sizeof (RedisModuleString * ));
611
+ rinfo -> outkeys = RedisModule_Alloc (noutputs * sizeof (RedisModuleString * ));
643
612
644
613
for (int i = pairoffset ; i < argc ; i += 2 ) {
645
614
int isinput = i < pairoffset + 2 * ninputs ;
646
615
647
- size_t namelen ;
648
616
RedisModuleString * argname = argv [i + 1 ];
649
617
650
618
if (isinput ) {
651
619
RedisModuleKey * argkey = RedisModule_OpenKey (ctx , argv [i ], REDISMODULE_READ );
652
620
if (RedisModule_ModuleTypeGetType (argkey ) != RedisDL_TensorType ) {
621
+ // todo free rinfo
653
622
RedisModule_CloseKey (argkey );
654
623
return RedisModule_ReplyWithError (ctx , REDISMODULE_ERRORMSG_WRONGTYPE );
655
624
}
656
625
RDL_Tensor * t = RedisModule_ModuleTypeGetValue (argkey );
657
- input_values [(i - pairoffset )/2 ] = RedisDL_clone (ctx , t -> tensor );
658
626
RedisModule_CloseKey (argkey );
659
- const char * opname = RedisModule_StringPtrLen (argname , & namelen );
660
- // RedisModule_Log(ctx, "warning", "%s", opname);
661
- TF_Output port ;
662
- port .oper = TF_GraphOperationByName (graph , opname );
663
- port .index = 0 ;
664
- if (port .oper == NULL ) {
627
+ const char * opname = RedisModule_StringPtrLen (argname , NULL );
628
+ if (!Graph_RunCtxAddInput (rinfo -> gctx , opname , t )){
629
+ // todo free rinfo
665
630
return RedisModule_ReplyWithError (ctx , "Input key not found." );
666
631
}
667
- inputs [(i - pairoffset )/2 ] = port ;
668
632
} else {
669
- const char * opname = RedisModule_StringPtrLen (argname , & namelen );
670
- TF_Output port ;
671
- port .oper = TF_GraphOperationByName (graph , opname );
672
- port .index = 0 ;
673
- if (port .oper == NULL ) {
633
+ const char * opname = RedisModule_StringPtrLen (argname , NULL );
634
+ if (!Graph_RunCtxAddOutput (rinfo -> gctx , opname )){
635
+ // todo free rinfo
674
636
return RedisModule_ReplyWithError (ctx , "Output key not found." );
675
637
}
676
- outputs [( i - pairoffset )/ 2 - ninputs ] = port ;
677
- outkeys [(i - pairoffset )/2 - ninputs ] = argv [i ];
638
+ RedisModule_RetainString ( ctx , argv [ i ]) ;
639
+ rinfo -> outkeys [(i - pairoffset )/2 - ninputs ] = argv [i ];
678
640
}
679
641
}
680
642
681
- RedisModuleBlockedClient * bc = RedisModule_BlockClient (ctx , RedisDL_Run_Reply , NULL , NULL , 0 );
682
-
683
- struct RedisDL_RunInfo * rinfo = RedisModule_Alloc (sizeof (struct RedisDL_RunInfo ));
684
- rinfo -> client = bc ;
685
- rinfo -> session = session ;
686
- rinfo -> inputs = inputs ;
687
- rinfo -> input_values = input_values ;
688
- rinfo -> ninputs = ninputs ;
689
- rinfo -> outputs = outputs ;
690
- rinfo -> output_values = output_values ;
691
- rinfo -> noutputs = noutputs ;
692
- rinfo -> graphkey = key ;
693
- rinfo -> outkeys = outkeys ;
694
- rinfo -> status = TF_NewStatus ();
643
+ rinfo -> client = RedisModule_BlockClient (ctx , RedisDL_Run_Reply , NULL , NULL , 0 );
695
644
696
645
// RedisModule_AbortBlock(bc);
697
646
// return RedisModule_ReplyWithError(ctx, "-ERR Can't start thread");
0 commit comments