Skip to content

Commit 53f068d

Browse files
committed
log the number milliseconds that other backends are running as well
1 parent b6c223e commit 53f068d

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/backends/tensorflow.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#define REDISMODULE_MAIN
2+
3+
#include <redis_ai_objects/stats.h>
24
#include "backends/util.h"
35
#include "backends/tensorflow.h"
46
#include "util/arr.h"
@@ -14,6 +16,7 @@ int RAI_InitBackendTF(int (*get_api_fn)(const char *, void *)) {
1416
get_api_fn("RedisModule_Free", ((void **)&RedisModule_Free));
1517
get_api_fn("RedisModule_Realloc", ((void **)&RedisModule_Realloc));
1618
get_api_fn("RedisModule_Strdup", ((void **)&RedisModule_Strdup));
19+
get_api_fn("RedisModule_Log", ((void **)&RedisModule_Log));
1720

1821
return REDISMODULE_OK;
1922
}
@@ -520,10 +523,12 @@ int RAI_ModelRunTF(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *error)
520523
}
521524
outputs[i] = port;
522525
}
523-
526+
long long time_before = mstime();
524527
TF_SessionRun(tfSession, NULL /* run_options */, inputs, inputTensorsValues, ninputs, outputs,
525528
outputTensorsValues, noutputs, NULL /* target_opers */, 0 /* ntargets */,
526529
NULL /* run_Metadata */, status);
530+
long long time_after = mstime();
531+
RedisModule_Log(NULL, "notice", "tf run time was: %lld", time_after - time_before);
527532

528533
for (size_t i = 0; i < ninputs; ++i) {
529534
TF_DeleteTensor(inputTensorsValues[i]);

src/backends/torch.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#define REDISMODULE_MAIN
2+
3+
#include <redis_ai_objects/stats.h>
24
#include "backends/util.h"
35
#include "backends/torch.h"
46
#include "util/arr.h"
@@ -29,6 +31,7 @@ int RAI_InitBackendTorch(int (*get_api_fn)(const char *, void *)) {
2931
((void **)&RedisModule_ThreadSafeContextUnlock));
3032
get_api_fn("RedisModule_FreeThreadSafeContext", ((void **)&RedisModule_FreeThreadSafeContext));
3133
get_api_fn("RedisModule_StringPtrLen", ((void **)&RedisModule_StringPtrLen));
34+
get_api_fn("RedisModule_Log", ((void **)&RedisModule_Log));
3235

3336
return REDISMODULE_OK;
3437
}
@@ -228,7 +231,10 @@ int RAI_ModelRunTorch(RAI_Model *model, RAI_ExecutionCtx **ectxs, RAI_Error *err
228231
}
229232

230233
char *error_descr = NULL;
234+
long long time_before = mstime();
231235
torchRunModel(RAI_ModelGetModel(model), ninputs, inputs_dl, noutputs, outputs_dl, &error_descr);
236+
long long time_after = mstime();
237+
RedisModule_Log(NULL, "notice", "torch run time was: %lld", time_after - time_before);
232238

233239
for (size_t i = 0; i < ninputs; ++i) {
234240
RAI_TensorFree(inputs[i]);

0 commit comments

Comments
 (0)