diff --git a/core/src/main/java/org/opensearch/sql/executor/PaginatedPlanCache.java b/core/src/main/java/org/opensearch/sql/executor/PaginatedPlanCache.java index 7a876e695a..89cb6b0255 100644 --- a/core/src/main/java/org/opensearch/sql/executor/PaginatedPlanCache.java +++ b/core/src/main/java/org/opensearch/sql/executor/PaginatedPlanCache.java @@ -17,7 +17,7 @@ import lombok.SneakyThrows; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.expression.NamedExpression; -import org.opensearch.sql.expression.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.expression.serialization.NoEncodeExpressionSerializer; import org.opensearch.sql.opensearch.executor.Cursor; import org.opensearch.sql.planner.physical.PaginateOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; @@ -28,6 +28,7 @@ @RequiredArgsConstructor public class PaginatedPlanCache { public static final String CURSOR_PREFIX = "n:"; + public static final String UNSUPPORTED_CURSOR = "Unsupported cursor"; private final StorageEngine storageEngine; public static final PaginatedPlanCache None = new PaginatedPlanCache(null); @@ -82,7 +83,7 @@ public static String compress(String str) { @SneakyThrows public static String decompress(String input) { if (input == null || input.length() == 0) { - return null; + return ""; } GZIPInputStream gzip = new GZIPInputStream(new ByteArrayInputStream( HashCode.fromString(input).asBytes())); @@ -96,14 +97,14 @@ public static String decompress(String input) { * @return Remaining part of the cursor. */ private String parseNamedExpressions(List listToFill, String cursor) { - var serializer = new DefaultExpressionSerializer(); + var serializer = new NoEncodeExpressionSerializer(); if (cursor.startsWith(")")) { //empty list return cursor.substring(cursor.indexOf(',') + 1); } while (!cursor.startsWith("(")) { listToFill.add((NamedExpression) serializer.deserialize(cursor.substring(0, - Math.min(cursor.indexOf(','), cursor.indexOf(')'))))); + Math.min(cursor.indexOf(','), cursor.indexOf(')'))).getBytes())); cursor = cursor.substring(cursor.indexOf(',') + 1); } return cursor; @@ -120,7 +121,7 @@ public PhysicalPlan convertToPlan(String cursor) { // TODO Parse with ANTLR or serialize as JSON/XML if (!cursor.startsWith("(Paginate,")) { - throw new UnsupportedOperationException("Unsupported cursor"); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR); } // TODO add checks for > 0 cursor = cursor.substring(cursor.indexOf(',') + 1); @@ -131,11 +132,11 @@ public PhysicalPlan convertToPlan(String cursor) { cursor = cursor.substring(cursor.indexOf(',') + 1); if (!cursor.startsWith("(Project,")) { - throw new UnsupportedOperationException("Unsupported cursor"); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR); } cursor = cursor.substring(cursor.indexOf(',') + 1); if (!cursor.startsWith("(namedParseExpressions,")) { - throw new UnsupportedOperationException("Unsupported cursor"); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR); } cursor = cursor.substring(cursor.indexOf(',') + 1); @@ -144,13 +145,13 @@ public PhysicalPlan convertToPlan(String cursor) { List projectList = new ArrayList<>(); if (!cursor.startsWith("(projectList,")) { - throw new UnsupportedOperationException("Unsupported cursor"); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR); } cursor = cursor.substring(cursor.indexOf(',') + 1); cursor = parseNamedExpressions(projectList, cursor); if (!cursor.startsWith("(OpenSearchPagedIndexScan,")) { - throw new UnsupportedOperationException("Unsupported cursor"); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR); } cursor = cursor.substring(cursor.indexOf(',') + 1); var indexName = cursor.substring(0, cursor.indexOf(',')); @@ -161,10 +162,10 @@ public PhysicalPlan convertToPlan(String cursor) { return new PaginateOperator(new ProjectOperator(scan, projectList, namedParseExpressions), pageSize, currentPageIndex); } catch (Exception e) { - throw new UnsupportedOperationException("Unsupported cursor", e); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR, e); } } else { - throw new UnsupportedOperationException("Unsupported cursor"); + throw new UnsupportedOperationException(UNSUPPORTED_CURSOR); } } } diff --git a/core/src/main/java/org/opensearch/sql/expression/serialization/DefaultExpressionSerializer.java b/core/src/main/java/org/opensearch/sql/expression/serialization/DefaultExpressionSerializer.java index 33c22b2ea5..7dbbbaa94e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/serialization/DefaultExpressionSerializer.java +++ b/core/src/main/java/org/opensearch/sql/expression/serialization/DefaultExpressionSerializer.java @@ -6,38 +6,26 @@ package org.opensearch.sql.expression.serialization; -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.util.Base64; import org.opensearch.sql.expression.Expression; + /** * Default serializer that (de-)serialize expressions by JDK serialization. */ public class DefaultExpressionSerializer implements ExpressionSerializer { + NoEncodeExpressionSerializer noEncodeSerializer = new NoEncodeExpressionSerializer(); + @Override public String serialize(Expression expr) { - try { - ByteArrayOutputStream output = new ByteArrayOutputStream(); - ObjectOutputStream objectOutput = new ObjectOutputStream(output); - objectOutput.writeObject(expr); - objectOutput.flush(); - return Base64.getEncoder().encodeToString(output.toByteArray()); - } catch (IOException e) { - throw new IllegalStateException("Failed to serialize expression: " + expr, e); - } + return Base64.getEncoder().encodeToString(noEncodeSerializer.serialize(expr)); } @Override public Expression deserialize(String code) { try { - ByteArrayInputStream input = new ByteArrayInputStream(Base64.getDecoder().decode(code)); - ObjectInputStream objectInput = new ObjectInputStream(input); - return (Expression) objectInput.readObject(); + return noEncodeSerializer.deserialize(Base64.getDecoder().decode(code)); } catch (Exception e) { throw new IllegalStateException("Failed to deserialize expression code: " + code, e); } diff --git a/core/src/main/java/org/opensearch/sql/expression/serialization/NoEncodeExpressionSerializer.java b/core/src/main/java/org/opensearch/sql/expression/serialization/NoEncodeExpressionSerializer.java new file mode 100644 index 0000000000..3e6bf490f1 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/serialization/NoEncodeExpressionSerializer.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.serialization; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import org.opensearch.sql.expression.Expression; + +public class NoEncodeExpressionSerializer { + + /** + * Serialize an expression into a byte array. + */ + public byte[] serialize(Expression expr) { + try { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(expr); + objectOutput.flush(); + return output.toByteArray(); + } catch (IOException e) { + throw new IllegalStateException("Failed to serialize expression: " + expr, e); + } + } + + /** + * Create an expression from a serialized byte array. + */ + public Expression deserialize(byte[] code) { + try { + ByteArrayInputStream input = new ByteArrayInputStream(code); + ObjectInputStream objectInput = new ObjectInputStream(input); + return (Expression) objectInput.readObject(); + } catch (Exception e) { + throw new IllegalStateException("Failed to deserialize expression code: " + code, e); + } + } + +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java index c61b35e0cb..ab9e78d64a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/ProjectOperator.java @@ -22,7 +22,7 @@ import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.parse.ParseExpression; -import org.opensearch.sql.expression.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.expression.serialization.NoEncodeExpressionSerializer; /** * Project the fields specified in {@link ProjectOperator#projectList} from input. @@ -102,11 +102,13 @@ public String toCursor() { if (child == null || child.isEmpty()) { return null; } - var serializer = new DefaultExpressionSerializer(); + var serializer = new NoEncodeExpressionSerializer(); String projects = createSection("projectList", - projectList.stream().map(serializer::serialize).toArray(String[]::new)); + projectList.stream().map(serializer::serialize) + .map(Object::toString).toArray(String[]::new)); String namedExpressions = createSection("namedParseExpressions", - namedParseExpressions.stream().map(serializer::serialize).toArray(String[]::new)); + namedParseExpressions.stream().map(serializer::serialize) + .map(Object::toString).toArray(String[]::new)); return createSection("Project", namedExpressions, projects, child); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/serialization/NoEncodeExpressionSerializerTest.java b/core/src/test/java/org/opensearch/sql/expression/serialization/NoEncodeExpressionSerializerTest.java new file mode 100644 index 0000000000..195a0eaf53 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/serialization/NoEncodeExpressionSerializerTest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.serialization; + + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.expression.DSL.literal; +import static org.opensearch.sql.expression.DSL.ref; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionNodeVisitor; +import org.opensearch.sql.expression.env.Environment; + +class NoEncodeExpressionSerializerTest { + + private final NoEncodeExpressionSerializer serializer = new NoEncodeExpressionSerializer(); + + @Test + void can_serialize_and_deserialize_literals() { + Expression original = literal(10); + Expression actual = serializer.deserialize(serializer.serialize(original)); + assertEquals(original, actual); + } + + @Test + void can_serialize_and_deserialize_references() { + Expression original = ref("name", STRING); + Expression actual = serializer.deserialize(serializer.serialize(original)); + assertEquals(original, actual); + } + + @Test + void can_serialize_and_deserialize_predicates() { + Expression original = DSL.or(literal(true), DSL.less(literal(1), literal(2))); + Expression actual = serializer.deserialize(serializer.serialize(original)); + assertEquals(original, actual); + } + + @Test + void can_serialize_and_deserialize_functions() { + Expression original = DSL.abs(literal(30.0)); + Expression actual = serializer.deserialize(serializer.serialize(original)); + assertEquals(original, actual); + } + + @Test + void cannot_serialize_illegal_expression() { + Expression illegalExpr = new Expression() { + private final Object object = new Object(); // non-serializable + @Override + public ExprValue valueOf(Environment valueEnv) { + return null; + } + + @Override + public ExprType type() { + return null; + } + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return null; + } + }; + assertThrows(IllegalStateException.class, () -> serializer.serialize(illegalExpr)); + } + + @Test + void cannot_deserialize_illegal_expression_code() { + var arr = "hello world".getBytes(); + assertThrows(IllegalStateException.class, () -> serializer.deserialize(arr)); + } +}