diff --git a/Dockerfile b/Dockerfile index b29397f..512bf99 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,7 +9,7 @@ WORKDIR /sagemaker-sparkml-model-server RUN mvn clean package -RUN cp ./target/sparkml-serving-2.3.jar /usr/local/lib/sparkml-serving-2.3.jar +RUN cp ./target/sparkml-serving-2.4.jar /usr/local/lib/sparkml-serving-2.4.jar RUN cp ./serve.sh /usr/local/bin/serve.sh RUN chmod a+x /usr/local/bin/serve.sh diff --git a/README.md b/README.md index 9ae6c42..dd39be1 100644 --- a/README.md +++ b/README.md @@ -223,20 +223,20 @@ Calling `CreateModel` is required for creating a `Model` in SageMaker with this SageMaker works with Docker images stored in [Amazon ECR](https://aws.amazon.com/ecr/). SageMaker team has prepared and uploaded the Docker images for SageMaker SparkML Serving Container in all regions where SageMaker operates. Region to ECR container URL mapping can be found below. For a mapping from Region to Region Name, please see [here](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html). -* us-west-1 = 746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* us-west-2 = 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2 -* us-east-1 = 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* us-east-2 = 257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-sparkml-serving:2.2 -* ap-northeast-1 = 354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* ap-northeast-2 = 366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-sparkml-serving:2.2 -* ap-southeast-1 = 121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* ap-southeast-2 = 783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-sparkml-serving:2.2 -* ap-south-1 = 720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* eu-west-1 = 141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* eu-west-2 = 764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2 -* eu-central-1 = 492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* ca-central-1 = 341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-sparkml-serving:2.2 -* us-gov-west-1 = 414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-sparkml-serving:2.2 +* us-west-1 = 746614075791.dkr.ecr.us-west-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* us-west-2 = 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4 +* us-east-1 = 683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* us-east-2 = 257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-sparkml-serving:2.4 +* ap-northeast-1 = 354813040037.dkr.ecr.ap-northeast-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* ap-northeast-2 = 366743142698.dkr.ecr.ap-northeast-2.amazonaws.com/sagemaker-sparkml-serving:2.4 +* ap-southeast-1 = 121021644041.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* ap-southeast-2 = 783357654285.dkr.ecr.ap-southeast-2.amazonaws.com/sagemaker-sparkml-serving:2.4 +* ap-south-1 = 720646828776.dkr.ecr.ap-south-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* eu-west-1 = 141502667606.dkr.ecr.eu-west-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* eu-west-2 = 764974769150.dkr.ecr.eu-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4 +* eu-central-1 = 492215442770.dkr.ecr.eu-central-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* ca-central-1 = 341280168497.dkr.ecr.ca-central-1.amazonaws.com/sagemaker-sparkml-serving:2.4 +* us-gov-west-1 = 414596584902.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-sparkml-serving:2.4 With [SageMaker Python SDK](https://github.com/aws/sagemaker-python-sdk) ------------------------------------------------------------------------ @@ -263,7 +263,7 @@ First you need to ensure that have installed [Docker](https://www.docker.com/) o In order to build the Docker image, you need to run a single Docker command: ``` -docker build -t sagemaker-sparkml-serving:2.2 . +docker build -t sagemaker-sparkml-serving:2.4 . ``` #### Running the image locally @@ -272,7 +272,7 @@ In order to run the Docker image, you need to run the following command. Please The command will start the server on port 8080 and will also pass the schema as an environment variable to the Docker container. Alternatively, you can edit the `Dockerfile` to add `ENV SAGEMAKER_SPARKML_SCHEMA=schema` as well before building the Docker image. ``` -docker run -p 8080:8080 -e SAGEMAKER_SPARKML_SCHEMA=schema -v /tmp/model:/opt/ml/model sagemaker-sparkml-serving:2.2 serve +docker run -p 8080:8080 -e SAGEMAKER_SPARKML_SCHEMA=schema -v /tmp/model:/opt/ml/model sagemaker-sparkml-serving:2.4 serve ``` #### Invoking with a payload @@ -287,7 +287,7 @@ or curl -i -H "content-type:application/json" -d "{\"data\":[feature_1,\"feature_2\",feature_3]}" http://localhost:8080/invocations ``` -The `Dockerfile` can be found at the root directory of the package. SageMaker SparkML Serving Container tags the Docker images using the Spark major version it is compatible with. Right now, it only supports Spark 2.2 and as a result, the Docker image is tagged with 2.2. +The `Dockerfile` can be found at the root directory of the package. SageMaker SparkML Serving Container tags the Docker images using the Spark major version it is compatible with. Right now, it only supports Spark 2.4 and as a result, the Docker image is tagged with 2.4. In order to save the effort of building the Docker image everytime you are making a code change, you can also install [Maven](http://maven.apache.org/) and run `mvn clean package` at your project root to verify if the code is compiling fine and unit tests are running without any issue. @@ -310,7 +310,7 @@ aws ecr get-login --region us-west-2 --registry-ids 246618743249 --no-include-em * Download the Docker image with the following command: ``` -docker pull 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2 +docker pull 246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4 ``` For running the Docker image, please see the Running the image locally section from above. diff --git a/ci/buildspec.yml b/ci/buildspec.yml index 79f030e..422e014 100644 --- a/ci/buildspec.yml +++ b/ci/buildspec.yml @@ -9,10 +9,10 @@ phases: commands: - echo Build started on `date` - echo Building the Docker image... - - docker build -t sagemaker-sparkml-serving:2.3 . - - docker tag sagemaker-sparkml-serving:2.3 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.3 + - docker build -t sagemaker-sparkml-serving:2.4 . + - docker tag sagemaker-sparkml-serving:2.4 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4 post_build: commands: - echo Build completed on `date` - echo Pushing the Docker image... - - docker push 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.3 + - docker push 515193369038.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4 diff --git a/pom.xml b/pom.xml index c8af276..2579cc4 100644 --- a/pom.xml +++ b/pom.xml @@ -24,7 +24,7 @@ 4.0.0 org.amazonaws.sagemaker sparkml-serving - 2.3 + 2.4 @@ -154,7 +154,7 @@ ml.combust.mleap mleap-runtime_2.11 - 0.13.0 + 0.14.0 org.apache.commons @@ -199,4 +199,4 @@ 1.8 - \ No newline at end of file + diff --git a/serve.sh b/serve.sh index ae4b3e6..4e61a83 100644 --- a/serve.sh +++ b/serve.sh @@ -1,3 +1,3 @@ #!/bin/bash # This is needed to make sure Java correctly detects CPU/Memory set by the container limits -java -XX:+UnlockExperimentalVMOptions -XX:+UseCGroupMemoryLimitForHeap -jar /usr/local/lib/sparkml-serving-2.3.jar \ No newline at end of file +java -XX:+UnlockExperimentalVMOptions -XX:+UseCGroupMemoryLimitForHeap -jar /usr/local/lib/sparkml-serving-2.4.jar \ No newline at end of file diff --git a/src/main/java/com/amazonaws/sagemaker/controller/ServingController.java b/src/main/java/com/amazonaws/sagemaker/controller/ServingController.java index b899191..35d6916 100644 --- a/src/main/java/com/amazonaws/sagemaker/controller/ServingController.java +++ b/src/main/java/com/amazonaws/sagemaker/controller/ServingController.java @@ -34,9 +34,14 @@ import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; + import ml.combust.mleap.runtime.frame.ArrayRow; import ml.combust.mleap.runtime.frame.DefaultLeapFrame; +import ml.combust.mleap.runtime.frame.Row; import ml.combust.mleap.runtime.frame.Transformer; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -99,7 +104,7 @@ public ResponseEntity returnBatchExecutionParameter() throws JsonProcessingExcep } /** - * Implements the invocations POST API for application/json input + * Implements the invocations POST API for application/jsonlines input * * @param sro, the request object * @param accept, accept parameter from request @@ -107,7 +112,7 @@ public ResponseEntity returnBatchExecutionParameter() throws JsonProcessingExcep */ @RequestMapping(path = "/invocations", method = POST, consumes = MediaType.APPLICATION_JSON_VALUE) public ResponseEntity transformRequestJson(@RequestBody final SageMakerRequestObject sro, - @RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) { + @RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) { if (sro == null) { LOG.error("Input passed to the request is empty"); return ResponseEntity.noContent().build(); @@ -115,7 +120,38 @@ public ResponseEntity transformRequestJson(@RequestBody final SageMakerR try { final String acceptVal = this.retrieveAndVerifyAccept(accept); final DataSchema schema = this.retrieveAndVerifySchema(sro.getSchema(), mapper); - return this.processInputData(sro.getData(), schema, acceptVal); + return this.processInputData(Collections.singletonList(sro.getData()), schema, acceptVal); + } catch (final Exception ex) { + LOG.error("Error in processing current request", ex); + return ResponseEntity.badRequest().body(ex.getMessage()); + } + } + + /** + * Implements the invocations POST API for application/json input + * + * @param jsonLines, lines of json values + * @param accept, accept parameter from request + * @return ResponseEntity with body as the expected payload JSON & proper statuscode based on the input + */ + @RequestMapping(path = "/invocations", method = POST, consumes = AdditionalMediaType.APPLICATION_JSONLINES_VALUE) + public ResponseEntity transformRequestJsonLines(@RequestBody final byte[] jsonLines, + @RequestHeader(value = HttpHeaders.ACCEPT, required = false) final String accept) { + if (jsonLines == null) { + LOG.error("Input passed to the request is empty"); + return ResponseEntity.noContent().build(); + } + try { + final String acceptVal = this.retrieveAndVerifyAccept(accept); + final DataSchema schema = this.retrieveAndVerifySchema(null, mapper); + final String jsonStringLines[] = new String(jsonLines).split("\\r?\\n"); + final List> inputDatas = new ArrayList(); + for(String jsonStringLine : jsonStringLines) { + final ObjectMapper mapper = new ObjectMapper(); + final SageMakerRequestObject sro = mapper.readValue(jsonStringLine, SageMakerRequestObject.class); + inputDatas.add(sro.getData()); + } + return this.processInputData(inputDatas, schema, acceptVal); } catch (final Exception ex) { LOG.error("Error in processing current request", ex); return ResponseEntity.badRequest().body(ex.getMessage()); @@ -169,14 +205,14 @@ protected DataSchema retrieveAndVerifySchema(final DataSchema schemaFromPayload, : mapper.readValue(SystemUtils.getEnvironmentVariable("SAGEMAKER_SPARKML_SCHEMA"), DataSchema.class); } - private ResponseEntity processInputData(final List inputData, final DataSchema schema, + private ResponseEntity processInputData(final List> inputDatas, final DataSchema schema, final String acceptVal) throws JsonProcessingException { - final DefaultLeapFrame leapFrame = dataConversionHelper.convertInputToLeapFrame(schema, inputData); + final DefaultLeapFrame leapFrame = dataConversionHelper.convertInputToLeapFrame(schema, inputDatas); // Making call to the MLeap executor to get the output final DefaultLeapFrame totalLeapFrame = ScalaUtils.transformLeapFrame(mleapTransformer, leapFrame); final DefaultLeapFrame predictionsLeapFrame = ScalaUtils .selectFromLeapFrame(totalLeapFrame, schema.getOutput().getName()); - final ArrayRow outputArrayRow = ScalaUtils.getOutputArrayRow(predictionsLeapFrame); + final List outputArrayRow = ScalaUtils.getOutputArrayRow(predictionsLeapFrame); return transformToHttpResponse(schema, outputArrayRow, acceptVal); } @@ -186,17 +222,18 @@ private boolean checkEmptyAccept(final String acceptFromRequest) { return (StringUtils.isBlank(acceptFromRequest) || StringUtils.equals(acceptFromRequest, MediaType.ALL_VALUE)); } - private ResponseEntity transformToHttpResponse(final DataSchema schema, final ArrayRow predictionRow, + private ResponseEntity transformToHttpResponse(final DataSchema schema, final List predictionsRow, final String accept) throws JsonProcessingException { if (StringUtils.equals(schema.getOutput().getStruct(), DataStructureType.BASIC)) { final Object output = dataConversionHelper - .convertMLeapBasicTypeToJavaType(predictionRow, schema.getOutput().getType()); + .convertMLeapBasicTypeToJavaType(predictionsRow.get(0), schema.getOutput().getType()); return responseHelper.sendResponseForSingleValue(output.toString(), accept); } else { // If not basic type, it can be vector or array type from Spark return responseHelper.sendResponseForList( - ScalaUtils.getJavaObjectIteratorFromArrayRow(predictionRow, schema.getOutput().getStruct()), accept); + predictionsRow.stream().map(predictionRow -> ScalaUtils.getJavaObjectIteratorFromArrayRow(predictionRow, schema.getOutput().getStruct())).collect(Collectors.toList()) + , accept); } } diff --git a/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java b/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java index b412685..af82be0 100644 --- a/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java +++ b/src/main/java/com/amazonaws/sagemaker/helper/DataConversionHelper.java @@ -70,19 +70,26 @@ public DataConversionHelper(final LeapFrameBuilderSupport support, final LeapFra * @return List of Objects, where each Object correspond to one feature of the input data * @throws IOException, if there is an exception thrown in the try-with-resources block */ - public List convertCsvToObjectList(final String csvInput, final DataSchema schema) throws IOException { + public List> convertCsvToObjectList(final String csvInput, final DataSchema schema) throws IOException { try (final StringReader sr = new StringReader(csvInput)) { - final List valueList = Lists.newArrayList(); final CSVParser parser = CSVFormat.DEFAULT.parse(sr); // We don not supporting multiple CSV lines as input currently - final CSVRecord record = parser.getRecords().get(0); + final List records = parser.getRecords(); final int inputLength = schema.getInput().size(); - for (int idx = 0; idx < inputLength; ++idx) { - ColumnSchema sc = schema.getInput().get(idx); - // For CSV input, each value is treated as an individual feature by default - valueList.add(this.convertInputDataToJavaType(sc.getType(), DataStructureType.BASIC, record.get(idx))); + + final List> returnList = Lists.newArrayList(); + + for(CSVRecord record : records) { + final List valueList = Lists.newArrayList(); + for (int idx = 0; idx < inputLength; ++idx) { + ColumnSchema sc = schema.getInput().get(idx); + // For CSV input, each value is treated as an individual feature by default + valueList.add(this.convertInputDataToJavaType(sc.getType(), DataStructureType.BASIC, record.get(idx))); + } + returnList.add(valueList); } - return valueList; + + return returnList; } } @@ -91,30 +98,44 @@ public List convertCsvToObjectList(final String csvInput, final DataSche * Convert input object to DefaultLeapFrame * * @param schema, the input schema received from request or environment variable - * @param data , the input data received from request as a list of objects + * @param datas , the input datas received from request as a list of objects * @return the DefaultLeapFrame object which MLeap transformer expects */ - public DefaultLeapFrame convertInputToLeapFrame(final DataSchema schema, final List data) { + public DefaultLeapFrame convertInputToLeapFrame(final DataSchema schema, final List> datas) { final int inputLength = schema.getInput().size(); final List structFieldList = Lists.newArrayList(); - final List valueList = Lists.newArrayList(); + for (int idx = 0; idx < inputLength; ++idx) { ColumnSchema sc = schema.getInput().get(idx); structFieldList - .add(new StructField(sc.getName(), this.convertInputToMLeapInputType(sc.getType(), sc.getStruct()))); - valueList.add(this.convertInputDataToJavaType(sc.getType(), sc.getStruct(), data.get(idx))); + .add(new StructField(sc.getName(), this.convertInputToMLeapInputType(sc.getType(), sc.getStruct()))); } - final StructType mleapSchema = leapFrameBuilder.createSchema(structFieldList); - final Row currentRow = support.createRowFromIterable(valueList); final List rows = Lists.newArrayList(); - rows.add(currentRow); + + for(Object data : datas) + { + final Row currentRow = getRow(schema, (List) data, inputLength); + + rows.add(currentRow); + } return leapFrameBuilder.createFrame(mleapSchema, rows); } + private Row getRow(DataSchema schema, List data, int inputLength) { + final List valueList = Lists.newArrayList(); + + for (int idx = 0; idx < inputLength; ++idx) { + ColumnSchema sc = schema.getInput().get(idx); + valueList.add(this.convertInputDataToJavaType(sc.getType(), sc.getStruct(), data.get(idx))); + } + + return support.createRowFromIterable(valueList); + } + /** * Convert basic types in the MLeap helper to Java types for output. * @@ -122,7 +143,7 @@ public DefaultLeapFrame convertInputToLeapFrame(final DataSchema schema, final L * @param type, the basic type to which the helper should be casted, provided by user via input * @return the proper Java type */ - public Object convertMLeapBasicTypeToJavaType(final ArrayRow predictionRow, final String type) { + public Object convertMLeapBasicTypeToJavaType(final Row predictionRow, final String type) { switch (type) { case BasicDataType.INTEGER: return predictionRow.getInt(0); diff --git a/src/main/java/com/amazonaws/sagemaker/helper/ResponseHelper.java b/src/main/java/com/amazonaws/sagemaker/helper/ResponseHelper.java index eb67b35..1ba8daa 100644 --- a/src/main/java/com/amazonaws/sagemaker/helper/ResponseHelper.java +++ b/src/main/java/com/amazonaws/sagemaker/helper/ResponseHelper.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; + +import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.StringJoiner; @@ -66,48 +68,67 @@ public ResponseEntity sendResponseForSingleValue(final String value, Str * test/resources/com/amazonaws/sagemaker/dto for example output format or SageMaker built-in algorithms * documentaiton to know about the output format. * - * @param outputDataIterator, data iterator for raw output values in case output is an Array or Vector + * @param outputDatasIterator, data iterator for raw output values in case output is an Array or Vector * @param acceptVal, the accept customer has passed or default (text/csv) if not passed * @return Spring ResponseEntity which contains the body and the header. */ - public ResponseEntity sendResponseForList(final Iterator outputDataIterator, String acceptVal) + public ResponseEntity sendResponseForList(final List> outputDatasIterator, String acceptVal) throws JsonProcessingException { if (StringUtils.equals(acceptVal, AdditionalMediaType.APPLICATION_JSONLINES_VALUE)) { - return this.buildStandardJsonOutputForList(outputDataIterator); + return this.buildStandardJsonOutputForList(outputDatasIterator); } else if (StringUtils.equals(acceptVal, AdditionalMediaType.APPLICATION_JSONLINES_TEXT_VALUE)) { - return this.buildTextJsonOutputForList(outputDataIterator); + return this.buildTextJsonOutputForList(outputDatasIterator); } else { - return this.buildCsvOutputForList(outputDataIterator); + return this.buildCsvOutputForList(outputDatasIterator); } } - private ResponseEntity buildCsvOutputForList(final Iterator outputDataIterator) { - final StringJoiner sj = new StringJoiner(","); - while (outputDataIterator.hasNext()) { - sj.add(outputDataIterator.next().toString()); + private ResponseEntity buildCsvOutputForList(final List> outputDatasIterator) { + + final StringJoiner sjLineBreaks = new StringJoiner("\n"); + + for(Iterator outputDataIterator : outputDatasIterator) + { + final StringJoiner sj = new StringJoiner(","); + while (outputDataIterator.hasNext()) { + sj.add(outputDataIterator.next().toString()); + } + sjLineBreaks.add(sj.toString()); } - return this.getCsvOkResponse(sj.toString()); + + return this.getCsvOkResponse(sjLineBreaks.toString()); } - private ResponseEntity buildStandardJsonOutputForList(final Iterator outputDataIterator) + private ResponseEntity buildStandardJsonOutputForList(final List> outputDatasIterator) throws JsonProcessingException { - final List columns = Lists.newArrayList(); - while (outputDataIterator.hasNext()) { - columns.add(outputDataIterator.next()); + + List jsonLinesList = new ArrayList<>(); + for(Iterator outputDataIterator : outputDatasIterator) { + final List columns = Lists.newArrayList(); + while (outputDataIterator.hasNext()) { + columns.add(outputDataIterator.next()); + } + final JsonlinesStandardOutput jsonOutput = new JsonlinesStandardOutput(columns); + jsonLinesList.add(jsonOutput); } - final JsonlinesStandardOutput jsonOutput = new JsonlinesStandardOutput(columns); - final String jsonRepresentation = mapper.writeValueAsString(jsonOutput); + final String jsonRepresentation = mapper.writeValueAsString(jsonLinesList); return this.getJsonlinesOkResponse(jsonRepresentation); } - private ResponseEntity buildTextJsonOutputForList(final Iterator outputDataIterator) + private ResponseEntity buildTextJsonOutputForList(final List> outputDatasIterator) throws JsonProcessingException { - final StringJoiner stringJoiner = new StringJoiner(" "); - while (outputDataIterator.hasNext()) { - stringJoiner.add(outputDataIterator.next().toString()); + + List jsonLinesList = new ArrayList<>(); + for(Iterator outputDataIterator : outputDatasIterator) { + final StringJoiner stringJoiner = new StringJoiner(" "); + while (outputDataIterator.hasNext()) { + stringJoiner.add(outputDataIterator.next().toString()); + } + final JsonlinesTextOutput jsonOutput = new JsonlinesTextOutput(stringJoiner.toString()); + jsonLinesList.add(jsonOutput); } - final JsonlinesTextOutput jsonOutput = new JsonlinesTextOutput(stringJoiner.toString()); - final String jsonRepresentation = mapper.writeValueAsString(jsonOutput); + + final String jsonRepresentation = mapper.writeValueAsString(jsonLinesList); return this.getJsonlinesOkResponse(jsonRepresentation); } diff --git a/src/main/java/com/amazonaws/sagemaker/utils/ScalaUtils.java b/src/main/java/com/amazonaws/sagemaker/utils/ScalaUtils.java index 8669d49..37df561 100644 --- a/src/main/java/com/amazonaws/sagemaker/utils/ScalaUtils.java +++ b/src/main/java/com/amazonaws/sagemaker/utils/ScalaUtils.java @@ -19,7 +19,8 @@ import com.amazonaws.sagemaker.type.DataStructureType; import java.util.Collections; import java.util.Iterator; -import ml.combust.mleap.runtime.frame.ArrayRow; +import java.util.List; + import ml.combust.mleap.runtime.frame.DefaultLeapFrame; import ml.combust.mleap.runtime.frame.Row; import ml.combust.mleap.runtime.frame.Transformer; @@ -65,10 +66,8 @@ public static DefaultLeapFrame selectFromLeapFrame(final DefaultLeapFrame leapFr * @param leapFrame, the DefaultLeapFrame from which output to be extracted * @return ArrayRow which can be used to retrieve the original output */ - public static ArrayRow getOutputArrayRow(final DefaultLeapFrame leapFrame) { - final Iterator rowIterator = leapFrameSupport.collect(leapFrame).iterator(); - // SageMaker input structure only allows to call MLeap transformer for single data point - return (ArrayRow) (rowIterator.next()); + public static List getOutputArrayRow(final DefaultLeapFrame leapFrame) { + return leapFrameSupport.collect(leapFrame); } /** @@ -78,7 +77,7 @@ public static ArrayRow getOutputArrayRow(final DefaultLeapFrame leapFrame) { * @param structure, whether it is Spark Vector or Array * @return Iterator to raw values of the Vector or Array */ - public static Iterator getJavaObjectIteratorFromArrayRow(final ArrayRow predictionRow, + public static Iterator getJavaObjectIteratorFromArrayRow(final Row predictionRow, final String structure) { return (StringUtils.equals(structure, DataStructureType.VECTOR)) ? JavaConverters .asJavaIteratorConverter(predictionRow.getTensor(0).toDense().rawValuesIterator()).asJava() diff --git a/src/test/java/com/amazonaws/sagemaker/controller/ServingControllerTest.java b/src/test/java/com/amazonaws/sagemaker/controller/ServingControllerTest.java index 406b8fa..40837f3 100644 --- a/src/test/java/com/amazonaws/sagemaker/controller/ServingControllerTest.java +++ b/src/test/java/com/amazonaws/sagemaker/controller/ServingControllerTest.java @@ -28,10 +28,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Objects; import ml.combust.mleap.runtime.frame.ArrayRow; import ml.combust.mleap.runtime.frame.DefaultLeapFrame; +import ml.combust.mleap.runtime.frame.Row; import ml.combust.mleap.runtime.frame.Transformer; import ml.combust.mleap.runtime.javadsl.LeapFrameBuilder; import ml.combust.mleap.runtime.javadsl.LeapFrameBuilderSupport; @@ -56,7 +58,7 @@ class ServingControllerTest { private Transformer mleapTransformerMock; private SageMakerRequestObject sro; private DefaultLeapFrame responseLeapFrame; - private ArrayRow outputArrayRow; + private Row outputArrayRow; private List inputColumns; private ColumnSchema outputColumn; private List inputData; @@ -81,7 +83,7 @@ private void buildDefaultSageMakerRequestObject() { private void buildResponseLeapFrame() { responseLeapFrame = new DataConversionHelper(new LeapFrameBuilderSupport(), new LeapFrameBuilder()) - .convertInputToLeapFrame(sro.getSchema(), sro.getData()); + .convertInputToLeapFrame(sro.getSchema(), Collections.singletonList(sro.getData())); outputArrayRow = new ArrayRow(Lists.newArrayList(new Integer("1"))); } @@ -99,7 +101,7 @@ public void setup() { .thenReturn(responseLeapFrame); PowerMockito.when(ScalaUtils.selectFromLeapFrame(Mockito.any(DefaultLeapFrame.class), Mockito.anyString())) .thenReturn(responseLeapFrame); - PowerMockito.when(ScalaUtils.getOutputArrayRow(Mockito.any(DefaultLeapFrame.class))).thenReturn(outputArrayRow); + PowerMockito.when(ScalaUtils.getOutputArrayRow(Mockito.any(DefaultLeapFrame.class))).thenReturn(Collections.singletonList(outputArrayRow)); } @Test @@ -164,7 +166,7 @@ public void testListValueJsonLinesAcceptResponse() { .when(ScalaUtils.getJavaObjectIteratorFromArrayRow(Mockito.any(ArrayRow.class), Mockito.anyString())) .thenReturn(outputResponse.iterator()); final ResponseEntity output = controller.transformRequestJson(sro, "application/jsonlines"); - Assert.assertEquals(output.getBody(), "{\"features\":[1,2]}"); + Assert.assertEquals(output.getBody(), "[{\"features\":[1,2]}]"); } @Test @@ -210,6 +212,19 @@ public void testCsvApiWithListInput() { Assert.assertEquals(output.getBody(), "1,2"); } + @Test + public void testJsonLinesApiWithListInput() { + schemaInJson = "{\"input\":[{\"name\":\"test_name_1\",\"type\":\"int\"},{\"name\":\"test_name_2\"," + + "\"type\":\"double\"},{\"name\":\"test_name_3\",\"type\":\"string\"}],\"output\":{\"name\":\"out_name\",\"type\":\"int\",\"struct\":\"vector\"}}"; + List outputResponse = Lists.newArrayList(1, 2, 0.345); + PowerMockito.when(SystemUtils.getEnvironmentVariable("SAGEMAKER_SPARKML_SCHEMA")).thenReturn(schemaInJson); + PowerMockito + .when(ScalaUtils.getJavaObjectIteratorFromArrayRow(Mockito.any(ArrayRow.class), Mockito.anyString())) + .thenReturn(outputResponse.iterator()); + final ResponseEntity output = controller.transformRequestJsonLines("{\"data\":[1,2.0,\"TEST\"]}".getBytes(), "text/csv"); + Assert.assertEquals(output.getBody(), "1,2,0.345"); + } + @Test public void testCsvApiWithNullInput() { PowerMockito.when(SystemUtils.getEnvironmentVariable("SAGEMAKER_SPARKML_SCHEMA")).thenReturn(schemaInJson); diff --git a/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java b/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java index d070ae3..2204ea5 100644 --- a/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java +++ b/src/test/java/com/amazonaws/sagemaker/helper/DataConversionHelperTest.java @@ -23,6 +23,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import ml.combust.mleap.core.types.ListType; import ml.combust.mleap.core.types.ScalarType; @@ -47,7 +49,9 @@ public void testParseCsvToObjectList() throws IOException { String inputJson = IOUtils .toString(this.getClass().getResourceAsStream("../dto/basic_input_schema.json"), "UTF-8"); DataSchema schema = mapper.readValue(inputJson, DataSchema.class); - List expectedOutput = Lists.newArrayList(new Integer("2"), "C", new Double("34.5")); + List expectedElement = Lists.newArrayList(new Integer("2"), "C", new Double("34.5")); + List> expectedOutput = Lists.newArrayList(); + expectedOutput.add(expectedElement); Assert.assertEquals(dataConversionHelper.convertCsvToObjectList(csvInput, schema), expectedOutput); } @@ -57,7 +61,9 @@ public void testParseCsvQuotesToObjectList() throws IOException { String inputJson = IOUtils .toString(this.getClass().getResourceAsStream("../dto/basic_input_schema.json"), "UTF-8"); DataSchema schema = mapper.readValue(inputJson, DataSchema.class); - List expectedOutput = Lists.newArrayList(new Integer("2"), "C", new Double("34.5")); + List expectedElement = Lists.newArrayList(new Integer("2"), "C", new Double("34.5")); + List> expectedOutput = Lists.newArrayList(); + expectedOutput.add(expectedElement); Assert.assertEquals(dataConversionHelper.convertCsvToObjectList(csvInput, schema), expectedOutput); } @@ -66,7 +72,7 @@ public void testCastingInputToLeapFrame() throws Exception { String inputJson = IOUtils .toString(this.getClass().getResourceAsStream("../dto/complete_input.json"), "UTF-8"); SageMakerRequestObject sro = mapper.readValue(inputJson, SageMakerRequestObject.class); - DefaultLeapFrame leapframeTest = dataConversionHelper.convertInputToLeapFrame(sro.getSchema(), sro.getData()); + DefaultLeapFrame leapframeTest = dataConversionHelper.convertInputToLeapFrame(sro.getSchema(), Collections.singletonList(sro.getData())); Assert.assertNotNull(leapframeTest.schema()); Assert.assertNotNull(leapframeTest.dataset()); } diff --git a/src/test/java/com/amazonaws/sagemaker/helper/ResponseHelperTest.java b/src/test/java/com/amazonaws/sagemaker/helper/ResponseHelperTest.java index 1d786cb..512a981 100644 --- a/src/test/java/com/amazonaws/sagemaker/helper/ResponseHelperTest.java +++ b/src/test/java/com/amazonaws/sagemaker/helper/ResponseHelperTest.java @@ -19,6 +19,8 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; + +import java.util.Collections; import java.util.List; import java.util.Objects; import org.junit.Assert; @@ -65,7 +67,7 @@ public void testSingleOutputNoContentType() { @Test public void testListOutputCsv() throws JsonProcessingException { ResponseEntity outputTest = responseHelperTest - .sendResponseForList(dummyResponse.iterator(), "text/csv"); + .sendResponseForList(Collections.singletonList(dummyResponse.iterator()), "text/csv"); Assert.assertEquals(outputTest.getBody(), "1,0.2"); Assert.assertEquals(Objects.requireNonNull(outputTest.getHeaders().get(HttpHeaders.CONTENT_TYPE)).get(0), "text/csv"); @@ -74,8 +76,8 @@ public void testListOutputCsv() throws JsonProcessingException { @Test public void testListOutputJsonlines() throws JsonProcessingException { ResponseEntity outputTest = responseHelperTest - .sendResponseForList(dummyResponse.iterator(), "application/jsonlines"); - Assert.assertEquals(outputTest.getBody(), "{\"features\":[1,0.2]}"); + .sendResponseForList(Collections.singletonList(dummyResponse.iterator()), "application/jsonlines"); + Assert.assertEquals(outputTest.getBody(), "[{\"features\":[1,0.2]}]"); Assert.assertEquals(Objects.requireNonNull(outputTest.getHeaders().get(HttpHeaders.CONTENT_TYPE)).get(0), "application/jsonlines"); } @@ -84,8 +86,8 @@ public void testListOutputJsonlines() throws JsonProcessingException { public void testTextOutputJsonlines() throws JsonProcessingException { dummyResponse = Lists.newArrayList("this", "is", "spark", "ml", "server"); ResponseEntity outputTest = responseHelperTest - .sendResponseForList(dummyResponse.iterator(), "application/jsonlines;data=text"); - Assert.assertEquals(outputTest.getBody(), "{\"source\":\"this is spark ml server\"}"); + .sendResponseForList(Collections.singletonList(dummyResponse.iterator()), "application/jsonlines;data=text"); + Assert.assertEquals(outputTest.getBody(), "[{\"source\":\"this is spark ml server\"}]"); Assert.assertEquals(Objects.requireNonNull(outputTest.getHeaders().get(HttpHeaders.CONTENT_TYPE)).get(0), "application/jsonlines"); } @@ -93,7 +95,7 @@ public void testTextOutputJsonlines() throws JsonProcessingException { @Test public void testListOutputInvalidAccept() throws JsonProcessingException { ResponseEntity outputTest = responseHelperTest - .sendResponseForList(dummyResponse.iterator(), "application/json"); + .sendResponseForList(Collections.singletonList(dummyResponse.iterator()), "application/json"); Assert.assertEquals(outputTest.getBody(), "1,0.2"); Assert.assertEquals(Objects.requireNonNull(outputTest.getHeaders().get(HttpHeaders.CONTENT_TYPE)).get(0), "text/csv"); @@ -103,7 +105,7 @@ public void testListOutputInvalidAccept() throws JsonProcessingException { public void testTextOutputInvalidAccept() throws JsonProcessingException { dummyResponse = Lists.newArrayList("this", "is", "spark", "ml", "server"); ResponseEntity outputTest = responseHelperTest - .sendResponseForList(dummyResponse.iterator(), "application/json"); + .sendResponseForList(Collections.singletonList(dummyResponse.iterator()), "application/json"); Assert.assertEquals(outputTest.getBody(), "this,is,spark,ml,server"); Assert.assertEquals(Objects.requireNonNull(outputTest.getHeaders().get(HttpHeaders.CONTENT_TYPE)).get(0), "text/csv");