From c5e8539d617aa01e12e9d1c9e0380a47fb8da3bc Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 9 May 2024 15:11:48 -0700 Subject: [PATCH] [awscurl] Allows override stream parameter for dataset --- awscurl/src/main/java/ai/djl/awscurl/AwsCurl.java | 13 ++++++++++--- .../src/test/java/ai/djl/awscurl/AwsCurlTest.java | 13 +++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/awscurl/src/main/java/ai/djl/awscurl/AwsCurl.java b/awscurl/src/main/java/ai/djl/awscurl/AwsCurl.java index b203f2499..a531a43e9 100644 --- a/awscurl/src/main/java/ai/djl/awscurl/AwsCurl.java +++ b/awscurl/src/main/java/ai/djl/awscurl/AwsCurl.java @@ -596,13 +596,20 @@ private void addToDataSet(String data) { if (element.isJsonObject()) { JsonObject obj = element.getAsJsonObject(); JsonObject param = obj.getAsJsonObject("parameters"); + JsonObject targetParam = extraParameters; + if (extraParameters.has("parameters")) { + targetParam = extraParameters.getAsJsonObject("parameters"); + } if (param == null) { - obj.add("parameters", extraParameters); + obj.add("parameters", targetParam); } else { - for (Map.Entry entry : extraParameters.entrySet()) { + for (Map.Entry entry : targetParam.entrySet()) { param.add(entry.getKey(), entry.getValue()); } } + if (extraParameters.has("stream")) { + obj.add("stream", extraParameters.get("stream")); + } data = JsonUtils.GSON.toJson(obj); } } @@ -650,7 +657,7 @@ static Options getOptions() { Option.builder() .longOpt("extra-parameters") .hasArg() - .argName("extra-parameters") + .argName("EXTRA-PARAMETERS") .desc("extra parameters for json dataset") .build()); options.addOption( diff --git a/awscurl/src/test/java/ai/djl/awscurl/AwsCurlTest.java b/awscurl/src/test/java/ai/djl/awscurl/AwsCurlTest.java index 5b79081b3..62c80421b 100644 --- a/awscurl/src/test/java/ai/djl/awscurl/AwsCurlTest.java +++ b/awscurl/src/test/java/ai/djl/awscurl/AwsCurlTest.java @@ -173,6 +173,19 @@ public void testDataset() { ret = AwsCurl.run(args); Assert.assertFalse(ret.hasError()); + args = + new String[] { + "http://localhost:18080/invocations", + "-H", + "Content-Type: application/json", + "--dataset", + "src/test/resources/prompts", + "--extra-parameters", + "{\"parameters\": {\"top_k\": 25}, \"stream\": true}" + }; + ret = AwsCurl.run(args); + Assert.assertFalse(ret.hasError()); + args = new String[] { "http://localhost:18080/invocations",