Skip to content

Reuse memory in TENSORSET #540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 63 additions & 23 deletions src/tensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,60 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
return ret;
}

void RAI_RStringDataTensorDeleter(DLManagedTensor *arg) {
if (arg->dl_tensor.shape) {
RedisModule_Free(arg->dl_tensor.shape);
}
if (arg->dl_tensor.strides) {
RedisModule_Free(arg->dl_tensor.strides);
}
if (arg->manager_ctx) {
RedisModuleString *rstr = (RedisModuleString *)arg->manager_ctx;
RedisModule_FreeString(NULL, rstr);
}

RedisModule_Free(arg);
}

RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long *dims, int ndims,
RedisModuleString *rstr) {
const size_t dtypeSize = Tensor_DataTypeSize(dtype);
if (dtypeSize == 0) {
return NULL;
}

RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret));
int64_t *shape = RedisModule_Alloc(ndims * sizeof(*shape));
int64_t *strides = RedisModule_Alloc(ndims * sizeof(*strides));

size_t len = 1;
for (int64_t i = 0; i < ndims; ++i) {
shape[i] = dims[i];
strides[i] = 1;
len *= dims[i];
}
for (int64_t i = ndims - 2; i >= 0; --i) {
strides[i] *= strides[i + 1] * shape[i + 1];
}

DLContext ctx = (DLContext){.device_type = kDLCPU, .device_id = 0};

char *data = (char *)RedisModule_StringPtrLen(rstr, NULL);

ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx,
.data = data,
.ndim = ndims,
.dtype = dtype,
.shape = shape,
.strides = strides,
.byte_offset = 0},
.manager_ctx = rstr,
.deleter = RAI_RStringDataTensorDeleter};

ret->refCount = 1;
return ret;
}

RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, int hasdata) {
DLDataType dtype = RAI_TensorDataTypeFromString(dataType);
return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
Expand Down Expand Up @@ -815,7 +869,14 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
size_t datalen;
const char *data;
DLDataType datatype = RAI_TensorDataTypeFromString(typestr);
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
if (datafmt == REDISAI_DATA_BLOB) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you are checking it here
the switch in line 889 is redundant since it checks only a single case. I think you can move its content to the else block in line 875

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

RedisModuleString *rstr = argv[argpos];
RedisModule_RetainString(NULL, rstr);
*t = RAI_TensorCreateWithDLDataTypeAndRString(datatype, dims, ndims, rstr);
} else {
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
}

if (!t) {
array_free(dims);
if (ctx == NULL) {
Expand All @@ -826,24 +887,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
return -1;
}
long i = 0;
switch (datafmt) {
case REDISAI_DATA_BLOB: {
const char *blob = RedisModule_StringPtrLen(argv[argpos], &datalen);
if (datalen != nbytes) {
RAI_TensorFree(*t);
array_free(dims);
if (ctx == NULL) {
RAI_SetError(error, RAI_ETENSORSET,
"ERR data length does not match tensor shape and type");
} else {
RedisModule_ReplyWithError(ctx,
"ERR data length does not match tensor shape and type");
}
return -1;
}
RAI_TensorSetData(*t, blob, datalen);
} break;
case REDISAI_DATA_VALUES:
if (datafmt == REDISAI_DATA_VALUES) {
for (; (argpos <= argc - 1) && (i < len); argpos++) {
if (datatype.code == kDLFloat) {
double val;
Expand Down Expand Up @@ -900,10 +944,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
}
i++;
}
break;
default:
// default does not require tensor data setting since calloc setted it to 0
break;
}
array_free(dims);
return argpos;
Expand Down