Skip to content

Commit f4c91bf

Browse files
author
MeirShpilraien
committed
Small refactoring stage 2
Moved graph run into graph object, make use of refcount for tensors and for graph. Added utils directory with arr implementation that provide dynamic array object.
1 parent 7454f76 commit f4c91bf

File tree

7 files changed

+350
-85
lines changed

7 files changed

+350
-85
lines changed

src/graph.c

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
#include "graph.h"
2+
#include "utils/arr_rm_alloc.h"
23

34
RedisModuleType *RedisDL_GraphType = NULL;
45

6+
typedef struct RDL_GraphCtxParam{
7+
TF_Output name;
8+
RDL_Tensor* tensor;
9+
}RDL_GraphCtxParam;
10+
11+
typedef struct RDL_GraphRunCtx{
12+
RDL_Graph* graph;
13+
RDL_GraphCtxParam* inputs;
14+
RDL_GraphCtxParam* outputs;
15+
}RDL_GraphRunCtx;
16+
517
static void* Graph_RdbLoad(struct RedisModuleIO *io, int encver){
618
//todo
719
return NULL;
@@ -75,6 +87,9 @@ RDL_Graph* Graph_Create(const char* prefix, const char* graphdef, size_t graphle
7587
}
7688

7789
void Graph_Free(RDL_Graph* graph){
90+
if(--graph->refCount > 0){
91+
return;
92+
}
7893
TF_Status *status = TF_NewStatus();
7994
TF_CloseSession(graph->session, status);
8095

@@ -100,3 +115,102 @@ void Graph_Free(RDL_Graph* graph){
100115

101116
RedisModule_Free(graph);
102117
}
118+
119+
RDL_GraphRunCtx* Graph_RunCtxCreate(RDL_Graph* graph){
120+
#define PARAM_INITIAL_SIZE 10
121+
RDL_GraphRunCtx* gctx = RedisModule_Alloc(sizeof(*gctx));
122+
gctx->graph = Graph_GetShallowCopy(graph);
123+
gctx->inputs = array_new(RDL_GraphCtxParam, PARAM_INITIAL_SIZE);
124+
gctx->outputs = array_new(RDL_GraphCtxParam, PARAM_INITIAL_SIZE);
125+
return gctx;
126+
}
127+
128+
static int Graph_RunCtxAddParam(RDL_GraphRunCtx* gctx, RDL_GraphCtxParam* paramArr, const char* name, RDL_Tensor* tensor){
129+
TF_Output port;
130+
port.oper = TF_GraphOperationByName(gctx->graph->graph, name);
131+
port.index = 0;
132+
if(port.oper == NULL){
133+
return 0;
134+
}
135+
RDL_GraphCtxParam param = {
136+
.name = port,
137+
.tensor = tensor ? Tensor_GetShallowCopy(tensor): NULL,
138+
};
139+
paramArr = array_append(paramArr, param);
140+
return 1;
141+
}
142+
143+
int Graph_RunCtxAddInput(RDL_GraphRunCtx* gctx, const char* inputName, RDL_Tensor* inputTensor){
144+
return Graph_RunCtxAddParam(gctx, gctx->inputs, inputName, inputTensor);
145+
}
146+
147+
int Graph_RunCtxAddOutput(RDL_GraphRunCtx* gctx, const char* outputName){
148+
return Graph_RunCtxAddParam(gctx, gctx->outputs, outputName, NULL);
149+
}
150+
151+
size_t Graph_RunCtxNumOutputs(RDL_GraphRunCtx* gctx){
152+
return array_len(gctx->outputs);
153+
}
154+
155+
RDL_Tensor* Graph_RunCtxOutputTensor(RDL_GraphRunCtx* gctx, size_t index){
156+
assert(Graph_RunCtxNumOutputs(gctx) > index && index >= 0);
157+
return gctx->outputs[index].tensor;
158+
}
159+
160+
void Graph_RunCtxFreeInternals(RDL_GraphRunCtx* gctx){
161+
for(size_t i = 0 ; i < array_len(gctx->inputs) ; ++i){
162+
Tensor_Free(gctx->inputs[i].tensor);
163+
}
164+
array_free(gctx->inputs);
165+
166+
for(size_t i = 0 ; i < array_len(gctx->outputs) ; ++i){
167+
if(gctx->outputs[i].tensor){
168+
Tensor_Free(gctx->outputs[i].tensor);
169+
}
170+
}
171+
array_free(gctx->outputs);
172+
173+
Graph_Free(gctx->graph);
174+
}
175+
176+
int Graph_Run(RDL_GraphRunCtx* gctx){
177+
TF_Status *status = TF_NewStatus();
178+
179+
TF_Tensor* inputTensorsValues[array_len(gctx->inputs)];
180+
TF_Output inputs[array_len(gctx->inputs)];
181+
TF_Tensor* outputTensorsValues[array_len(gctx->outputs)];
182+
TF_Output outputs[array_len(gctx->outputs)];
183+
184+
for(size_t i = 0 ; i < array_len(gctx->inputs) ; ++i){
185+
inputTensorsValues[i] = gctx->inputs[i].tensor->tensor;
186+
inputs[i] = gctx->inputs[i].name;
187+
}
188+
189+
for(size_t i = 0 ; i < array_len(gctx->outputs) ; ++i){
190+
outputs[i] = gctx->outputs[i].name;
191+
}
192+
193+
TF_SessionRun(gctx->graph->session, NULL /* run_options */,
194+
inputs, inputTensorsValues, array_len(gctx->inputs),
195+
outputs, outputTensorsValues, array_len(gctx->outputs),
196+
NULL /* target_opers */, 0 /* ntargets */,
197+
NULL /* run_Metadata */,
198+
status);
199+
200+
if (TF_GetCode(status) != TF_OK) {
201+
TF_DeleteStatus(status);
202+
return 0;
203+
}
204+
205+
for(size_t i = 0 ; i < array_len(gctx->outputs) ; ++i){
206+
gctx->outputs[i].tensor = Tensor_CreateFromTensor(outputTensorsValues[i]);
207+
}
208+
209+
TF_DeleteStatus(status);
210+
return 1;
211+
}
212+
213+
RDL_Graph* Graph_GetShallowCopy(RDL_Graph* graph){
214+
++graph->refCount;
215+
return graph;
216+
}

src/graph.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "tensorflow/c/c_api.h"
1212
#include "redismodule.h"
13+
#include "tensor.h"
1314

1415
typedef struct RDL_Graph{
1516
TF_Graph* graph;
@@ -20,11 +21,23 @@ typedef struct RDL_Graph{
2021
size_t refCount;
2122
}RDL_Graph;
2223

24+
typedef struct RDL_GraphCtxParam RDL_GraphCtxParam;
25+
26+
typedef struct RDL_GraphRunCtx RDL_GraphRunCtx;
27+
2328
extern RedisModuleType *RedisDL_GraphType;
2429

2530
int Graph_Init(RedisModuleCtx* ctx);
2631
RDL_Graph* Graph_Create(const char* prefix, const char* graphdef, size_t graphlen);
2732
void Graph_Free(RDL_Graph* graph);
33+
RDL_GraphRunCtx* Graph_RunCtxCreate(RDL_Graph* graph);
34+
int Graph_RunCtxAddInput(RDL_GraphRunCtx* gctx, const char* inputName, RDL_Tensor* inputTensor);
35+
int Graph_RunCtxAddOutput(RDL_GraphRunCtx* gctx, const char* outputName);
36+
size_t Graph_RunCtxNumOutputs(RDL_GraphRunCtx* gctx);
37+
RDL_Tensor* Graph_RunCtxOutputTensor(RDL_GraphRunCtx* gctx, size_t index);
38+
void Graph_RunCtxFreeInternals(RDL_GraphRunCtx* gctx);
39+
int Graph_Run(RDL_GraphRunCtx* gctx);
40+
RDL_Graph* Graph_GetShallowCopy(RDL_Graph* graph);
2841

2942

3043

src/redisdl.c

Lines changed: 34 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,6 @@ mstime_t mstime(void) {
9393
return ustime()/1000;
9494
}
9595

96-
TF_Tensor* RedisDL_clone(RedisModuleCtx *ctx, const TF_Tensor *tensor) {
97-
int ndims = TF_NumDims(tensor);
98-
long long *dims = RedisModule_PoolAlloc(ctx, ndims * sizeof(long long));
99-
for (int j=0; j<ndims; j++) {
100-
dims[j] = TF_Dim(tensor, j);
101-
}
102-
size_t len = TF_TensorByteSize(tensor);
103-
void *data = TF_TensorData(tensor);
104-
TF_Tensor *out = TF_AllocateTensor(TF_TensorType(tensor),
105-
dims, ndims, len);
106-
memcpy(TF_TensorData(out), data, len);
107-
return out;
108-
}
109-
11096
enum RedisDL_DataFmt {
11197
REDISDL_DATA_BLOB = 0,
11298
REDISDL_DATA_VALUES
@@ -493,28 +479,19 @@ int RedisDL_GSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
493479

494480
struct RedisDL_RunInfo {
495481
RedisModuleBlockedClient *client;
496-
TF_Session *session;
497-
TF_Output *inputs;
498-
TF_Tensor **input_values;
499-
long long ninputs;
500-
TF_Output *outputs;
501-
TF_Tensor **output_values;
502-
long long noutputs;
503-
RedisModuleKey *graphkey;
504482
RedisModuleString **outkeys;
505-
TF_Status *status;
483+
RDL_GraphRunCtx* gctx;
484+
int status;
506485
};
507486

508-
void RedisDL_FreeRunInfo(struct RedisDL_RunInfo *rinfo) {
509-
RedisModule_Free(rinfo->inputs);
510-
for (int i=0; i<rinfo->ninputs; i++) {
511-
TF_DeleteTensor(rinfo->input_values[i]);
487+
void RedisDL_FreeRunInfo(RedisModuleCtx *ctx, struct RedisDL_RunInfo *rinfo) {
488+
for(int i = 0 ; i < Graph_RunCtxNumOutputs(rinfo->gctx) ; ++i){
489+
RedisModule_FreeString(ctx, rinfo->outkeys[i]);
512490
}
513-
RedisModule_Free(rinfo->input_values);
514-
RedisModule_Free(rinfo->outputs);
515-
RedisModule_Free(rinfo->output_values);
516491
RedisModule_Free(rinfo->outkeys);
517-
TF_DeleteStatus(rinfo->status);
492+
493+
Graph_RunCtxFreeInternals(rinfo->gctx);
494+
518495
RedisModule_Free(rinfo);
519496
}
520497

@@ -525,24 +502,19 @@ void *RedisDL_RunSession(void *arg) {
525502

526503
mstime_t start = mstime();
527504

528-
TF_Status *status = TF_NewStatus();
529505

530-
TF_SessionRun(rinfo->session, NULL /* run_options */,
531-
rinfo->inputs, rinfo->input_values, rinfo->ninputs,
532-
rinfo->outputs, rinfo->output_values, rinfo->noutputs,
533-
NULL /* target_opers */, 0 /* ntargets */,
534-
NULL /* run_Metadata */,
535-
rinfo->status);
506+
rinfo->status = Graph_Run(rinfo->gctx);
536507

537508
mstime_t end = mstime();
538509

539510
RedisModule_ThreadSafeContextLock(ctx);
540511
RedisModule_Log(ctx, "notice", "TF_SessionRun took %fs", (end - start) / 1000.0);
541512
RedisModule_ThreadSafeContextUnlock(ctx);
542513

543-
RedisModule_FreeThreadSafeContext(ctx);
544514
RedisModule_UnblockClient(rinfo->client, rinfo);
545515

516+
RedisModule_FreeThreadSafeContext(ctx);
517+
546518
return NULL;
547519
}
548520

@@ -557,31 +529,33 @@ int RedisDL_Run_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
557529
REDISMODULE_NOT_USED(argc);
558530
struct RedisDL_RunInfo *rinfo = RedisModule_GetBlockedClientPrivateData(ctx);
559531

560-
if (TF_GetCode(rinfo->status) != TF_OK) {
561-
int ret = RedisModule_ReplyWithError(ctx, TF_Message(rinfo->status));
562-
RedisDL_FreeRunInfo(rinfo);
532+
if (!rinfo->status) {
533+
int ret = RedisModule_ReplyWithError(ctx, "graph run failed");
534+
RedisDL_FreeRunInfo(ctx, rinfo);
563535
return ret;
564536
}
565537

566-
for (int i=0; i<rinfo->noutputs; i++) {
538+
for (size_t i=0; i<Graph_RunCtxNumOutputs(rinfo->gctx); ++i){
567539
RedisModuleKey *outkey = RedisModule_OpenKey(ctx, rinfo->outkeys[i],
568540
REDISMODULE_READ|REDISMODULE_WRITE);
569541
int type = RedisModule_KeyType(outkey);
570542
if (type != REDISMODULE_KEYTYPE_EMPTY &&
571543
RedisModule_ModuleTypeGetType(outkey) != RedisDL_TensorType) {
572544
RedisModule_CloseKey(outkey);
573-
RedisDL_FreeRunInfo(rinfo);
545+
RedisDL_FreeRunInfo(ctx, rinfo);
574546
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
575547
}
576-
RDL_Tensor *t = Tensor_CreateFromTensor(rinfo->output_values[i]);
577-
RedisModule_ModuleTypeSetValue(outkey, RedisDL_TensorType, t);
548+
RDL_Tensor *t = Graph_RunCtxOutputTensor(rinfo->gctx, i);
549+
if(t){
550+
RedisModule_ModuleTypeSetValue(outkey, RedisDL_TensorType, Tensor_GetShallowCopy(t));
551+
}
578552
RedisModule_CloseKey(outkey);
579553
}
580554

581555
// FIXME This crashes Redis, we need to investigate.
582556
//RedisModule_CloseKey(rinfo->graphkey);
583557

584-
RedisDL_FreeRunInfo(rinfo);
558+
RedisDL_FreeRunInfo(ctx, rinfo);
585559

586560
return RedisModule_ReplyWithSimpleString(ctx, "OK");
587561
}
@@ -610,9 +584,6 @@ int RedisDL_GRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
610584

611585
RDL_Graph *gto = RedisModule_ModuleTypeGetValue(key);
612586

613-
TF_Graph* graph = gto->graph;
614-
TF_Session* session = gto->session;
615-
616587
long long ninputs;
617588
if ((RedisModule_StringToLongLong(argv[2], &ninputs) != REDISMODULE_OK)
618589
|| ninputs < 0) {
@@ -634,64 +605,42 @@ int RedisDL_GRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
634605
return RedisModule_ReplyWithError(ctx, "ERR odd key/name pairs");
635606
}
636607

637-
TF_Output *inputs = RedisModule_Alloc(ninputs*sizeof(TF_Output));
638-
TF_Output *outputs = RedisModule_Alloc(noutputs*sizeof(TF_Output));
639-
TF_Tensor **input_values = RedisModule_Alloc(ninputs*sizeof(TF_Tensor*));
640-
TF_Tensor **output_values = RedisModule_Alloc(noutputs*sizeof(TF_Tensor*));
608+
struct RedisDL_RunInfo *rinfo = RedisModule_Alloc(sizeof(struct RedisDL_RunInfo));
609+
rinfo->gctx = Graph_RunCtxCreate(gto);
641610

642-
RedisModuleString **outkeys = RedisModule_Alloc(noutputs*sizeof(RedisModuleString*));
611+
rinfo->outkeys = RedisModule_Alloc(noutputs*sizeof(RedisModuleString*));
643612

644613
for (int i=pairoffset; i<argc; i+=2) {
645614
int isinput = i < pairoffset + 2 * ninputs;
646615

647-
size_t namelen;
648616
RedisModuleString* argname = argv[i+1];
649617

650618
if (isinput) {
651619
RedisModuleKey *argkey = RedisModule_OpenKey(ctx, argv[i], REDISMODULE_READ);
652620
if (RedisModule_ModuleTypeGetType(argkey) != RedisDL_TensorType) {
621+
// todo free rinfo
653622
RedisModule_CloseKey(argkey);
654623
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
655624
}
656625
RDL_Tensor *t = RedisModule_ModuleTypeGetValue(argkey);
657-
input_values[(i-pairoffset)/2] = RedisDL_clone(ctx, t->tensor);
658626
RedisModule_CloseKey(argkey);
659-
const char* opname = RedisModule_StringPtrLen(argname, &namelen);
660-
// RedisModule_Log(ctx, "warning", "%s", opname);
661-
TF_Output port;
662-
port.oper = TF_GraphOperationByName(graph, opname);
663-
port.index = 0;
664-
if (port.oper == NULL) {
627+
const char* opname = RedisModule_StringPtrLen(argname, NULL);
628+
if(!Graph_RunCtxAddInput(rinfo->gctx, opname, t)){
629+
// todo free rinfo
665630
return RedisModule_ReplyWithError(ctx, "Input key not found.");
666631
}
667-
inputs[(i-pairoffset)/2] = port;
668632
} else {
669-
const char* opname = RedisModule_StringPtrLen(argname, &namelen);
670-
TF_Output port;
671-
port.oper = TF_GraphOperationByName(graph, opname);
672-
port.index = 0;
673-
if (port.oper == NULL) {
633+
const char* opname = RedisModule_StringPtrLen(argname, NULL);
634+
if(!Graph_RunCtxAddOutput(rinfo->gctx, opname)){
635+
// todo free rinfo
674636
return RedisModule_ReplyWithError(ctx, "Output key not found.");
675637
}
676-
outputs[(i-pairoffset)/2-ninputs] = port;
677-
outkeys[(i-pairoffset)/2-ninputs] = argv[i];
638+
RedisModule_RetainString(ctx, argv[i]);
639+
rinfo->outkeys[(i-pairoffset)/2-ninputs] = argv[i];
678640
}
679641
}
680642

681-
RedisModuleBlockedClient *bc = RedisModule_BlockClient(ctx, RedisDL_Run_Reply, NULL, NULL, 0);
682-
683-
struct RedisDL_RunInfo *rinfo = RedisModule_Alloc(sizeof(struct RedisDL_RunInfo));
684-
rinfo->client = bc;
685-
rinfo->session = session;
686-
rinfo->inputs = inputs;
687-
rinfo->input_values = input_values;
688-
rinfo->ninputs = ninputs;
689-
rinfo->outputs = outputs;
690-
rinfo->output_values = output_values;
691-
rinfo->noutputs = noutputs;
692-
rinfo->graphkey = key;
693-
rinfo->outkeys = outkeys;
694-
rinfo->status = TF_NewStatus();
643+
rinfo->client = RedisModule_BlockClient(ctx, RedisDL_Run_Reply, NULL, NULL, 0);
695644

696645
// RedisModule_AbortBlock(bc);
697646
// return RedisModule_ReplyWithError(ctx, "-ERR Can't start thread");

src/tensor.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,8 @@ int Tensor_GetValueAsLongLong(RDL_Tensor* t, long long i, long long* val) {
215215
}
216216
return 1;
217217
}
218+
219+
RDL_Tensor* Tensor_GetShallowCopy(RDL_Tensor* t){
220+
++t->refCount;
221+
return t;
222+
}

src/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ int Tensor_SetValueFromLongLong(RDL_Tensor* tensor, long long i, long long val);
2929
int Tensor_SetValueFromDouble(RDL_Tensor* tensor, long long i, double val);
3030
int Tensor_GetValueAsDouble(RDL_Tensor* t, long long i, double* val);
3131
int Tensor_GetValueAsLongLong(RDL_Tensor* t, long long i, long long* val);
32+
RDL_Tensor* Tensor_GetShallowCopy(RDL_Tensor* t);
3233

3334

3435

0 commit comments

Comments
 (0)