Skip to content

Support SCRIPTRUN with variadic INPUTS #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions src/main/java/com/redislabs/redisai/RedisAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.List;
import java.util.Map;
import redis.clients.jedis.BinaryClient;
import redis.clients.jedis.Client;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
Expand Down Expand Up @@ -85,6 +86,22 @@ private static JedisPoolConfig initPoolConfig(int poolSize) {
return conf;
}

private Jedis getConnection() {
return pool.getResource();
}

private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
BinaryClient client = conn.getClient();
client.sendCommand(command, args);
return client;
}

private Client sendCommand(Jedis conn, Command command, String... args) {
Client client = conn.getClient();
client.sendCommand(command, args);
return client;
}

/**
* Direct mapping to AI.TENSORSET
*
Expand Down Expand Up @@ -336,20 +353,54 @@ public boolean runModel(String key, String[] inputs, String[] outputs) {
}
}

/** AI.SCRIPTRUN script_key fn_name INPUTS input_key1 ... OUTPUTS output_key1 ... */
/** {@code AI.SCRIPTRUN script_key fn_name INPUTS input_key1 ... OUTPUTS output_key1 ...} */
public boolean runScript(String key, String function, String[] inputs, String[] outputs) {
return runScript(key, function, inputs, false, outputs);
}

/**
* {@code AI.SCRIPTRUN <key> <function> INPUTS <input> [input ...] [$ input ...] OUTPUTS <output>
* [output ...]}
*/
public boolean runScript(
String key, String function, String[] inputs, boolean variadicInputs, String[] outputs) {
try (Jedis conn = getConnection()) {
List<byte[]> args = Script.scriptRunFlatArgs(key, function, inputs, outputs, false);
return sendCommand(conn, Command.SCRIPT_RUN, args.toArray(new byte[args.size()][]))
.getStatusCodeReply()
.equals("OK");

String[] args = scriptRunFlatArgs(key, function, inputs, variadicInputs, outputs);
return sendCommand(conn, Command.SCRIPT_RUN, args).getStatusCodeReply().equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex);
}
}

private String[] scriptRunFlatArgs(
String key, String function, String[] inputs, boolean variadicInputs, String[] outputs) {

if (variadicInputs) {
if (inputs.length < 2) {
throw new IllegalArgumentException(
"At least two inputs are required to support variadic format.");
}
}

int length = 2 + 2 + inputs.length + (variadicInputs ? 1 : 0) + outputs.length;
String[] args = new String[length];
int index = 0;
args[index++] = key;
args[index++] = function;
args[index++] = Keyword.INPUTS.name();

args[index++] = inputs[0];
if (variadicInputs) {
args[index++] = "$";
}
System.arraycopy(inputs, 1, args, index, inputs.length - 1);
index += inputs.length - 1;

args[index++] = Keyword.OUTPUTS.name();
System.arraycopy(outputs, 0, args, index, outputs.length);
return args;
}

/**
* Direct mapping to AI.DAGRUN specifies a direct acyclic graph of operations to run within
* RedisAI
Expand Down Expand Up @@ -434,16 +485,6 @@ public boolean resetStat(String key) {
}
}

private Jedis getConnection() {
return pool.getResource();
}

private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
BinaryClient client = conn.getClient();
client.sendCommand(command, args);
return client;
}

/**
* AI.CONFIG <BACKENDSPATH <path>>
*
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/com/redislabs/redisai/RedisAITest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.redislabs.redisai;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
Expand Down Expand Up @@ -298,6 +301,25 @@ public void testRunScriptNegative() {
}
}

@Test
public void runScriptVariadicInputs() {
String script =
"def addn(a, args : List[Tensor]):\n" + " return a + torch.stack(args).sum()\n";
client.setScript("var_in", new Script(Device.CPU, script));

client.setTensor("t1", new float[] {40}, new int[] {1});
client.setTensor("t2", new float[] {1}, new int[] {1});
client.setTensor("t3", new float[] {1}, new int[] {1});

Assert.assertTrue(
client.runScript(
"var_in", "addn", new String[] {"t1", "t2", "t3"}, true, new String[] {"r"}));
Tensor result = client.getTensor("r");
assertEquals(DataType.FLOAT, result.getDataType());
assertArrayEquals(new long[] {1}, result.getShape());
assertArrayEquals(new float[] {42}, (float[]) result.getValues(), 0f);
}

@Test
public void testGetTensor() {
Assert.assertTrue(client.setTensor("t1", new float[][] {{1, 2}, {3, 4}}, new int[] {2, 2}));
Expand Down