diff --git a/src/tensor.c b/src/tensor.c index d46554ed3..6dcea82a3 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -289,6 +289,60 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in return ret; } +void RAI_RStringDataTensorDeleter(DLManagedTensor *arg) { + if (arg->dl_tensor.shape) { + RedisModule_Free(arg->dl_tensor.shape); + } + if (arg->dl_tensor.strides) { + RedisModule_Free(arg->dl_tensor.strides); + } + if (arg->manager_ctx) { + RedisModuleString *rstr = (RedisModuleString *)arg->manager_ctx; + RedisModule_FreeString(NULL, rstr); + } + + RedisModule_Free(arg); +} + +RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long *dims, int ndims, + RedisModuleString *rstr) { + const size_t dtypeSize = Tensor_DataTypeSize(dtype); + if (dtypeSize == 0) { + return NULL; + } + + RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret)); + int64_t *shape = RedisModule_Alloc(ndims * sizeof(*shape)); + int64_t *strides = RedisModule_Alloc(ndims * sizeof(*strides)); + + size_t len = 1; + for (int64_t i = 0; i < ndims; ++i) { + shape[i] = dims[i]; + strides[i] = 1; + len *= dims[i]; + } + for (int64_t i = ndims - 2; i >= 0; --i) { + strides[i] *= strides[i + 1] * shape[i + 1]; + } + + DLContext ctx = (DLContext){.device_type = kDLCPU, .device_id = 0}; + + char *data = (char *)RedisModule_StringPtrLen(rstr, NULL); + + ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx, + .data = data, + .ndim = ndims, + .dtype = dtype, + .shape = shape, + .strides = strides, + .byte_offset = 0}, + .manager_ctx = rstr, + .deleter = RAI_RStringDataTensorDeleter}; + + ret->refCount = 1; + return ret; +} + RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, int hasdata) { DLDataType dtype = RAI_TensorDataTypeFromString(dataType); return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC); @@ -815,7 +869,14 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar size_t datalen; const char *data; DLDataType datatype = RAI_TensorDataTypeFromString(typestr); - *t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode); + if (datafmt == REDISAI_DATA_BLOB) { + RedisModuleString *rstr = argv[argpos]; + RedisModule_RetainString(NULL, rstr); + *t = RAI_TensorCreateWithDLDataTypeAndRString(datatype, dims, ndims, rstr); + } else { + *t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode); + } + if (!t) { array_free(dims); if (ctx == NULL) { @@ -826,24 +887,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar return -1; } long i = 0; - switch (datafmt) { - case REDISAI_DATA_BLOB: { - const char *blob = RedisModule_StringPtrLen(argv[argpos], &datalen); - if (datalen != nbytes) { - RAI_TensorFree(*t); - array_free(dims); - if (ctx == NULL) { - RAI_SetError(error, RAI_ETENSORSET, - "ERR data length does not match tensor shape and type"); - } else { - RedisModule_ReplyWithError(ctx, - "ERR data length does not match tensor shape and type"); - } - return -1; - } - RAI_TensorSetData(*t, blob, datalen); - } break; - case REDISAI_DATA_VALUES: + if (datafmt == REDISAI_DATA_VALUES) { for (; (argpos <= argc - 1) && (i < len); argpos++) { if (datatype.code == kDLFloat) { double val; @@ -900,10 +944,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar } i++; } - break; - default: - // default does not require tensor data setting since calloc setted it to 0 - break; } array_free(dims); return argpos;