Skip to content

Commit be89d81

Browse files
author
MeirShpilraien
committed
made tensor and graph objects opaque
1 parent f4c91bf commit be89d81

File tree

5 files changed

+51
-22
lines changed

5 files changed

+51
-22
lines changed

src/graph.c

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33

44
RedisModuleType *RedisDL_GraphType = NULL;
55

6+
typedef struct RDL_Graph{
7+
TF_Graph* graph;
8+
// TODO: use session pool? The ideal would be to use one session per client.
9+
// If a client disconnects, we dispose the session or reuse it for
10+
// another client.
11+
void *session;
12+
size_t refCount;
13+
}RDL_Graph;
14+
615
typedef struct RDL_GraphCtxParam{
716
TF_Output name;
817
RDL_Tensor* tensor;
@@ -182,7 +191,7 @@ int Graph_Run(RDL_GraphRunCtx* gctx){
182191
TF_Output outputs[array_len(gctx->outputs)];
183192

184193
for(size_t i = 0 ; i < array_len(gctx->inputs) ; ++i){
185-
inputTensorsValues[i] = gctx->inputs[i].tensor->tensor;
194+
inputTensorsValues[i] = Tensor_GetTensor(gctx->inputs[i].tensor);
186195
inputs[i] = gctx->inputs[i].name;
187196
}
188197

src/graph.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,7 @@
1212
#include "redismodule.h"
1313
#include "tensor.h"
1414

15-
typedef struct RDL_Graph{
16-
TF_Graph* graph;
17-
// TODO: use session pool? The ideal would be to use one session per client.
18-
// If a client disconnects, we dispose the session or reuse it for
19-
// another client.
20-
void *session;
21-
size_t refCount;
22-
}RDL_Graph;
15+
typedef struct RDL_Graph RDL_Graph;
2316

2417
typedef struct RDL_GraphCtxParam RDL_GraphCtxParam;
2518

src/redisdl.c

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ int RedisDL_TDim_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
291291

292292
RDL_Tensor *t = RedisModule_ModuleTypeGetValue(key);
293293

294-
long long ndims = TF_NumDims(t->tensor);
294+
long long ndims = Tensor_NumDims(t);
295295

296296
RedisModule_CloseKey(key);
297297

@@ -313,11 +313,11 @@ int RedisDL_TShape_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, i
313313

314314
RDL_Tensor *t = RedisModule_ModuleTypeGetValue(key);
315315

316-
long long ndims = TF_NumDims(t->tensor);
316+
long long ndims = Tensor_NumDims(t);
317317

318318
RedisModule_ReplyWithArray(ctx, ndims);
319319
for (long long i=0; i<ndims; i++) {
320-
long long dim = TF_Dim(t->tensor, i);
320+
long long dim = Tensor_Dim(t, i);
321321
RedisModule_ReplyWithLongLong(ctx, dim);
322322
}
323323

@@ -341,7 +341,7 @@ int RedisDL_TByteSize_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
341341

342342
RDL_Tensor *t = RedisModule_ModuleTypeGetValue(key);
343343

344-
long long size = TF_TensorByteSize(t->tensor);
344+
long long size = Tensor_ByteSize(t);
345345

346346
RedisModule_CloseKey(key);
347347

@@ -375,8 +375,8 @@ int RedisDL_TGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
375375
RDL_Tensor *t = RedisModule_ModuleTypeGetValue(key);
376376

377377
if (datafmt == REDISDL_DATA_BLOB) {
378-
long long size = TF_TensorByteSize(t->tensor);
379-
char *data = TF_TensorData(t->tensor);
378+
long long size = Tensor_ByteSize(t);
379+
char *data = Tensor_Data(t);
380380

381381
int ret = RedisModule_ReplyWithStringBuffer(ctx, data, size);
382382

@@ -386,14 +386,14 @@ int RedisDL_TGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int
386386
}
387387
}
388388
else { // datafmt == REDISDL_DATA_VALUES
389-
long long ndims = TF_NumDims(t->tensor);
389+
long long ndims = Tensor_NumDims(t);
390390
long long len = 1;
391391
long long i;
392392
for (i=0; i<ndims; i++) {
393-
len *= TF_Dim(t->tensor, i);
393+
len *= Tensor_Dim(t, i);
394394
}
395395

396-
TF_DataType datatype = TF_TensorType(t->tensor);
396+
TF_DataType datatype = Tensor_DataType(t);
397397

398398
RedisModule_ReplyWithArray(ctx, len);
399399

src/tensor.c

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55

66
RedisModuleType *RedisDL_TensorType = NULL;
77

8+
typedef struct RDL_Tensor {
9+
TF_Tensor* tensor;
10+
size_t refCount;
11+
}RDL_Tensor;
12+
813
static TF_DataType Tensor_GetDataType(const char* typestr){
914
if (strcasecmp(typestr, "FLOAT") == 0){
1015
return TF_FLOAT;
@@ -220,3 +225,23 @@ RDL_Tensor* Tensor_GetShallowCopy(RDL_Tensor* t){
220225
++t->refCount;
221226
return t;
222227
}
228+
229+
int Tensor_NumDims(RDL_Tensor* t){
230+
return TF_NumDims(t->tensor);
231+
}
232+
233+
long long Tensor_Dim(RDL_Tensor* t, int dim){
234+
return TF_Dim(t->tensor, dim);
235+
}
236+
237+
size_t Tensor_ByteSize(RDL_Tensor* t){
238+
return TF_TensorByteSize(t->tensor);
239+
}
240+
241+
char* Tensor_Data(RDL_Tensor* t){
242+
return TF_TensorData(t->tensor);
243+
}
244+
245+
TF_Tensor* Tensor_GetTensor(RDL_Tensor* t){
246+
return t->tensor;
247+
}

src/tensor.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
#include "tensorflow/c/c_api.h"
1212
#include "redismodule.h"
1313

14-
typedef struct RDL_Tensor {
15-
TF_Tensor* tensor;
16-
size_t refCount;
17-
}RDL_Tensor;
14+
typedef struct RDL_Tensor RDL_Tensor;
1815

1916
extern RedisModuleType *RedisDL_TensorType;
2017

@@ -30,6 +27,11 @@ int Tensor_SetValueFromDouble(RDL_Tensor* tensor, long long i, double val);
3027
int Tensor_GetValueAsDouble(RDL_Tensor* t, long long i, double* val);
3128
int Tensor_GetValueAsLongLong(RDL_Tensor* t, long long i, long long* val);
3229
RDL_Tensor* Tensor_GetShallowCopy(RDL_Tensor* t);
30+
int Tensor_NumDims(RDL_Tensor* t);
31+
long long Tensor_Dim(RDL_Tensor* t, int dim);
32+
size_t Tensor_ByteSize(RDL_Tensor* t);
33+
char* Tensor_Data(RDL_Tensor* t);
34+
TF_Tensor* Tensor_GetTensor(RDL_Tensor* t);
3335

3436

3537

0 commit comments

Comments
 (0)