diff --git a/src/main/java/com/redislabs/redisai/RedisAI.java b/src/main/java/com/redislabs/redisai/RedisAI.java index 581e630..21ea147 100644 --- a/src/main/java/com/redislabs/redisai/RedisAI.java +++ b/src/main/java/com/redislabs/redisai/RedisAI.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Map; import redis.clients.jedis.BinaryClient; +import redis.clients.jedis.Client; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; import redis.clients.jedis.JedisPoolConfig; @@ -85,6 +86,22 @@ private static JedisPoolConfig initPoolConfig(int poolSize) { return conf; } + private Jedis getConnection() { + return pool.getResource(); + } + + private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) { + BinaryClient client = conn.getClient(); + client.sendCommand(command, args); + return client; + } + + private Client sendCommand(Jedis conn, Command command, String... args) { + Client client = conn.getClient(); + client.sendCommand(command, args); + return client; + } + /** * Direct mapping to AI.TENSORSET * @@ -336,20 +353,54 @@ public boolean runModel(String key, String[] inputs, String[] outputs) { } } - /** AI.SCRIPTRUN script_key fn_name INPUTS input_key1 ... OUTPUTS output_key1 ... */ + /** {@code AI.SCRIPTRUN script_key fn_name INPUTS input_key1 ... OUTPUTS output_key1 ...} */ public boolean runScript(String key, String function, String[] inputs, String[] outputs) { + return runScript(key, function, inputs, false, outputs); + } + /** + * {@code AI.SCRIPTRUN INPUTS [input ...] [$ input ...] OUTPUTS + * [output ...]} + */ + public boolean runScript( + String key, String function, String[] inputs, boolean variadicInputs, String[] outputs) { try (Jedis conn = getConnection()) { - List args = Script.scriptRunFlatArgs(key, function, inputs, outputs, false); - return sendCommand(conn, Command.SCRIPT_RUN, args.toArray(new byte[args.size()][])) - .getStatusCodeReply() - .equals("OK"); - + String[] args = scriptRunFlatArgs(key, function, inputs, variadicInputs, outputs); + return sendCommand(conn, Command.SCRIPT_RUN, args).getStatusCodeReply().equals("OK"); } catch (JedisDataException ex) { throw new RedisAIException(ex); } } + private String[] scriptRunFlatArgs( + String key, String function, String[] inputs, boolean variadicInputs, String[] outputs) { + + if (variadicInputs) { + if (inputs.length < 2) { + throw new IllegalArgumentException( + "At least two inputs are required to support variadic format."); + } + } + + int length = 2 + 2 + inputs.length + (variadicInputs ? 1 : 0) + outputs.length; + String[] args = new String[length]; + int index = 0; + args[index++] = key; + args[index++] = function; + args[index++] = Keyword.INPUTS.name(); + + args[index++] = inputs[0]; + if (variadicInputs) { + args[index++] = "$"; + } + System.arraycopy(inputs, 1, args, index, inputs.length - 1); + index += inputs.length - 1; + + args[index++] = Keyword.OUTPUTS.name(); + System.arraycopy(outputs, 0, args, index, outputs.length); + return args; + } + /** * Direct mapping to AI.DAGRUN specifies a direct acyclic graph of operations to run within * RedisAI @@ -434,16 +485,6 @@ public boolean resetStat(String key) { } } - private Jedis getConnection() { - return pool.getResource(); - } - - private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) { - BinaryClient client = conn.getClient(); - client.sendCommand(command, args); - return client; - } - /** * AI.CONFIG > * diff --git a/src/test/java/com/redislabs/redisai/RedisAITest.java b/src/test/java/com/redislabs/redisai/RedisAITest.java index 848573a..612b4ec 100644 --- a/src/test/java/com/redislabs/redisai/RedisAITest.java +++ b/src/test/java/com/redislabs/redisai/RedisAITest.java @@ -1,5 +1,8 @@ package com.redislabs.redisai; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; @@ -298,6 +301,25 @@ public void testRunScriptNegative() { } } + @Test + public void runScriptVariadicInputs() { + String script = + "def addn(a, args : List[Tensor]):\n" + " return a + torch.stack(args).sum()\n"; + client.setScript("var_in", new Script(Device.CPU, script)); + + client.setTensor("t1", new float[] {40}, new int[] {1}); + client.setTensor("t2", new float[] {1}, new int[] {1}); + client.setTensor("t3", new float[] {1}, new int[] {1}); + + Assert.assertTrue( + client.runScript( + "var_in", "addn", new String[] {"t1", "t2", "t3"}, true, new String[] {"r"})); + Tensor result = client.getTensor("r"); + assertEquals(DataType.FLOAT, result.getDataType()); + assertArrayEquals(new long[] {1}, result.getShape()); + assertArrayEquals(new float[] {42}, (float[]) result.getValues(), 0f); + } + @Test public void testGetTensor() { Assert.assertTrue(client.setTensor("t1", new float[][] {{1, 2}, {3, 4}}, new int[] {2, 2}));