Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions src/backends/onnxruntime.c
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *erro
return NULL;
}

RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_t batch_size, RAI_Error *error) {
RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, long long batch_size, RAI_Error *error) {
OrtStatus* status = NULL;
const OrtApi* ort = OrtGetApiBase()->GetApi(1);

Expand Down Expand Up @@ -214,7 +214,12 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
shape[i] = dims[i];
strides[i] = 1;
}
shape[0] = batch_size;
if (batch_size != -1) {
shape[0] = batch_size;
}
else {
batch_size = total_batch_size;
}
for (int64_t i = ndims - 2; i >= 0; --i)
{
strides[i] *= strides[i + 1] * shape[i + 1];
Expand Down Expand Up @@ -411,9 +416,11 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)

size_t batch_sizes[nbatches];
size_t batch_offsets[nbatches];
size_t total_batch_size = 0;
if (array_len(mctxs[0]->inputs) > 0) {
for (size_t b=0; b<nbatches; ++b) {
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
total_batch_size += batch_sizes[b];
}
batch_offsets[0] = 0;
for (size_t b=1; b<nbatches; ++b) {
Expand Down Expand Up @@ -529,14 +536,48 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
}

for (size_t i = 0; i < n_output_nodes; i++) {
for (size_t b=0; b<nbatches; b++) {
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
if (nbatches > 1) {
OrtTensorTypeAndShapeInfo* info;
status = ort->GetTensorTypeAndShape(outputs[i], &info);
if (status != NULL) goto error;

size_t ndims;
status = ort->GetDimensionsCount(info, &ndims);
if (status != NULL) goto error;

int64_t dims[ndims];
status = ort->GetDimensions(info, dims, ndims);
if (status != NULL) goto error;

if (dims[0] != total_batch_size) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
ort->ReleaseStatus(status);
return 1;
}

for (size_t b=0; b<nbatches; b++) {
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], batch_offsets[b], batch_sizes[b], error);
if (error->code != RAI_OK) {
ort->ReleaseStatus(status);
return 1;
}
if (output_tensor) {
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
RAI_TensorFree(output_tensor);
}
else {
printf("ERR: non-tensor output from ONNX models, ignoring (currently unsupported)");
}
}
}
else {
RAI_Tensor* output_tensor = RAI_TensorCreateFromOrtValue(outputs[i], 0, -1, error);
if (error->code != RAI_OK) {
ort->ReleaseStatus(status);
return 1;
}
if (output_tensor) {
mctxs[b]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
mctxs[0]->outputs[i].tensor = RAI_TensorGetShallowCopy(output_tensor);
RAI_TensorFree(output_tensor);
}
else {
Expand Down
30 changes: 26 additions & 4 deletions src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ DLDataType RAI_GetDLDataTypeFromTF(TF_DataType dtype) {
return (DLDataType){ .bits = 0 };
}

RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, size_t batch_size) {
RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset, long long batch_size) {
RAI_Tensor* ret = RedisModule_Calloc(1, sizeof(*ret));

DLContext ctx = (DLContext){
Expand All @@ -97,7 +97,12 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor, size_t batch_offset,
shape[i] = TF_Dim(tensor, i);
strides[i] = 1;
}
shape[0] = batch_size;
if (batch_size != -1) {
shape[0] = batch_size;
}
else {
batch_size = total_batch_size;
}
for (int64_t i = ndims-2 ; i >= 0 ; --i) {
strides[i] *= strides[i+1] * shape[i+1];
}
Expand Down Expand Up @@ -475,9 +480,11 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {

size_t batch_sizes[nbatches];
size_t batch_offsets[nbatches];
size_t total_batch_size = 0;
if (ninputs > 0) {
for (size_t b=0; b<nbatches; ++b) {
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
total_batch_size += batch_sizes[b];
}
batch_offsets[0] = 0;
for (size_t b=1; b<nbatches; ++b) {
Expand Down Expand Up @@ -531,8 +538,23 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
}

for(size_t i=0; i<noutputs; ++i) {
for (size_t b=0; b<nbatches; b++) {
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
if (nbatches > 1) {
if (TF_NumDims(outputTensorsValues[i]) == 0) {
continue;
}
if (TF_Dim(outputTensorsValues[i], 0) != total_batch_size) {
TF_DeleteTensor(outputTensorsValues[i]);
TF_DeleteStatus(status);
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
return 1;
}

for (size_t b=0; b<nbatches; b++) {
mctxs[b]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], batch_offsets[b], batch_sizes[b]);
}
}
else {
mctxs[0]->outputs[i].tensor = RAI_TensorCreateFromTFTensor(outputTensorsValues[i], 0, -1);
}
TF_DeleteTensor(outputTensorsValues[i]);
}
Expand Down
6 changes: 5 additions & 1 deletion src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {

size_t batch_sizes[nbatches];
size_t batch_offsets[nbatches];
size_t total_batch_size = 0;

if (nbatches > 1) {
size_t total_batch_size = 0;
if (array_len(mctxs[0]->inputs) > 0) {
for (size_t b=0; b<nbatches; ++b) {
batch_sizes[b] = RAI_TensorDim(mctxs[b]->inputs[0].tensor, 0);
Expand Down Expand Up @@ -147,6 +147,10 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
}
RAI_Tensor* output_tensor = RAI_TensorCreateFromDLTensor(outputs_dl[i]);
if (nbatches > 1) {
if (outputs_dl[i]->dl_tensor.shape[0] != total_batch_size) {
RAI_SetError(error, RAI_EMODELRUN, "ERR Model did not generate the expected batch size");
return 1;
}
for (size_t b=0; b<nbatches; b++) {
mctxs[b]->outputs[i].tensor = RAI_TensorCreateBySlicingTensor(output_tensor, batch_offsets[b], batch_sizes[b]);
}
Expand Down
3 changes: 3 additions & 0 deletions test/test_data/pt-minimal-bb.pt
Git LFS file not shown
15 changes: 15 additions & 0 deletions test/test_data/pt_minimal_bb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch


class MyModule(torch.jit.ScriptModule):
def __init__(self):
super(MyModule, self).__init__()

@torch.jit.script_method
def forward(self, a, b):
return a + b, torch.ones(1)


my_script_module = MyModule()
print(my_script_module(torch.rand(2), torch.rand(2)))
my_script_module.save("pt-minimal-bb.pt")
45 changes: 45 additions & 0 deletions test/tests_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,51 @@ def run():
env.assertEqual(values, [b'4', b'6', b'4', b'6'])


def test_pytorch_modelrun_autobatch_badbatch(env):
if not TEST_PT:
return

con = env.getConnection()

test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
model_filename = os.path.join(test_data_path, 'pt-minimal-bb.pt')

with open(model_filename, 'rb') as f:
model_pb = f.read()

ret = con.execute_command('AI.MODELSET', 'm', 'TORCH', 'CPU',
'BATCHSIZE', 4, 'MINBATCHSIZE', 3, 'BLOB', model_pb)
env.assertEqual(ret, b'OK')

con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
con.execute_command('AI.TENSORSET', 'b', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)

con.execute_command('AI.TENSORSET', 'd', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)
con.execute_command('AI.TENSORSET', 'e', 'FLOAT', 2, 2, 'VALUES', 2, 3, 2, 3)

ensureSlaveSynced(con, env)

def run():
con = env.getConnection()
try:
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'd', 'e', 'OUTPUTS', 'f1', 'f2')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("Model did not generate the expected batch size", exception.__str__())

t = threading.Thread(target=run)
t.start()

try:
con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c1', 'c2')
except Exception as e:
exception = e
env.assertEqual(type(exception), redis.exceptions.ResponseError)
env.assertEqual("Model did not generate the expected batch size", exception.__str__())



def test_pytorch_modelinfo(env):
if not TEST_PT:
env.debugPrint("skipping {} since TEST_PT=0".format(sys._getframe().f_code.co_name), force=True)
Expand Down