diff --git a/docs/commands.md b/docs/commands.md index 1c22c3c61..f09440596 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -79,7 +79,7 @@ Depending on the specified reply format: 1. The tensor's shape as an Array consisting of an item per dimension * **BLOB**: the tensor's binary data as a String. If used together with the **META** option, the binary data string will put after the metadata in the array reply. * **VALUES**: Array containing the numerical representation of the tensor's data. If used together with the **META** option, the binary data string will put after the metadata in the array reply. - +* Default: **META** and **BLOB** are returned by default, in case that non of the arguments above is specified. **Examples** diff --git a/src/tensor.c b/src/tensor.c index 167c5b81a..2f6c2e293 100644 --- a/src/tensor.c +++ b/src/tensor.c @@ -851,6 +851,9 @@ uint ParseTensorGetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) RedisModule_WrongArity(ctx); return fmt; } + if (argc == 2) { + return TENSOR_BLOB | TENSOR_META; + } for (int i = 2; i < argc; i++) { const char *fmtstr = RedisModule_StringPtrLen(argv[i], NULL); if (!strcasecmp(fmtstr, "BLOB")) { diff --git a/tests/flow/tests_common.py b/tests/flow/tests_common.py index 6fcb080c7..c58f2d873 100644 --- a/tests/flow/tests_common.py +++ b/tests/flow/tests_common.py @@ -205,6 +205,16 @@ def test_common_tensorget(env): env.assertEqual(datatype.encode('utf-8'), tensor_dtype) env.assertEqual([2], tensor_dim) + # Confirm that default reply format is META BLOB + for datatype in tested_datatypes: + _, tensor_dtype, _, tensor_dim, _, tensor_blob = con.execute_command('AI.TENSORGET', 'tensor_{0}'.format(datatype), + 'META', 'BLOB') + _, tensor_dtype_default, _, tensor_dim_default, _, tensor_blob_default = con.execute_command('AI.TENSORGET', + 'tensor_{0}'.format(datatype)) + env.assertEqual(tensor_dtype, tensor_dtype_default) + env.assertEqual(tensor_dim, tensor_dim_default) + env.assertEqual(tensor_blob, tensor_blob_default) + def test_common_tensorget_error_replies(env): con = env.getConnection()