diff --git a/server/src/main/java/org/opensearch/transport/protobuf/FetchSearchResultProtobuf.java b/server/src/main/java/org/opensearch/transport/protobuf/FetchSearchResultProtobuf.java new file mode 100644 index 0000000000000..2f790a86aca32 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobuf/FetchSearchResultProtobuf.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobuf; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.serde.proto.SearchHitsTransportProto.FetchSearchResultProto; + +import java.io.IOException; + +/** + * FetchSearchResult child which implements serde operations as protobuf. + * @opensearch.internal + */ +public class FetchSearchResultProtobuf extends FetchSearchResult { + public FetchSearchResultProtobuf(StreamInput in) throws IOException { + fromProtobufStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + toProtobufStream(out); + } + + public void toProtobufStream(StreamOutput out) throws IOException { + toProto().writeTo(out); + } + + public void fromProtobufStream(StreamInput in) throws IOException { + FetchSearchResultProto proto = FetchSearchResultProto.parseFrom(in); + fromProto(proto); + } + + FetchSearchResultProto toProto() { + FetchSearchResultProto.Builder builder = FetchSearchResultProto.newBuilder() + .setHits(new SearchHitsProtobuf(hits).toProto()) + .setCounter(this.counter); + return builder.build(); + } + + void fromProto(FetchSearchResultProto proto) { + hits = new SearchHitsProtobuf(proto.getHits()); + } +} diff --git a/server/src/main/java/org/opensearch/transport/serde/SerDe.java b/server/src/main/java/org/opensearch/transport/protobuf/ProtoSerDeHelpers.java similarity index 90% rename from server/src/main/java/org/opensearch/transport/serde/SerDe.java rename to server/src/main/java/org/opensearch/transport/protobuf/ProtoSerDeHelpers.java index fac079c6b9a87..436676a661109 100644 --- a/server/src/main/java/org/opensearch/transport/serde/SerDe.java +++ b/server/src/main/java/org/opensearch/transport/protobuf/ProtoSerDeHelpers.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.transport.serde; +package org.opensearch.transport.protobuf; import com.google.protobuf.ByteString; import org.apache.lucene.search.Explanation; @@ -40,12 +40,7 @@ * SerDe interfaces and protobuf SerDe implementations for some "primitive" types. * @opensearch.internal */ -public class SerDe { - - public enum Strategy { - PROTOBUF, - NATIVE; - } +public class ProtoSerDeHelpers { /** * Serialization/Deserialization exception. @@ -61,18 +56,6 @@ public SerializationException(String message, Throwable cause) { } } - interface nativeSerializer { - void toNativeStream(StreamOutput out) throws IOException; - - void fromNativeStream(StreamInput in) throws IOException; - } - - interface protobufSerializer { - void toProtobufStream(StreamOutput out) throws IOException; - - void fromProtobufStream(StreamInput in) throws IOException; - } - // TODO: Lucene definitions should maybe be serialized as generic bytes arrays. static ExplanationProto explanationToProto(Explanation explanation) { ExplanationProto.Builder builder = ExplanationProto.newBuilder() @@ -118,7 +101,7 @@ static DocumentFieldProto documentFieldToProto(DocumentField field) { return builder.build(); } - static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws SerDe.SerializationException { + static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws ProtoSerDeHelpers.SerializationException { String name = proto.getName(); List values = new ArrayList<>(0); @@ -128,7 +111,7 @@ static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws Ser Object readValue = in.readGenericValue(); values.add(readValue); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to deserialize DocumentField values from proto object", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize DocumentField values from proto object", e); } } @@ -177,7 +160,7 @@ static SearchSortValuesProto searchSortValuesToProto(SearchSortValues searchSort return builder.build(); } - static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) throws SerDe.SerializationException { + static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) throws ProtoSerDeHelpers.SerializationException { Object[] formattedSortValues = null; Object[] rawSortValues = null; @@ -186,7 +169,7 @@ static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) t try (StreamInput formattedIn = formattedBytes.streamInput()) { formattedSortValues = formattedIn.readArray(Lucene::readSortValue, Object[]::new); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to deserialize SearchSortValues from proto object", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize SearchSortValues from proto object", e); } } @@ -195,7 +178,7 @@ static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) t try (StreamInput rawIn = rawBytes.streamInput()) { rawSortValues = rawIn.readArray(Lucene::readSortValue, Object[]::new); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to deserialize SearchSortValues from proto object", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize SearchSortValues from proto object", e); } } diff --git a/server/src/main/java/org/opensearch/transport/protobuf/SearchHitProtobuf.java b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitProtobuf.java new file mode 100644 index 0000000000000..d21f84edff686 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitProtobuf.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobuf; + +import com.google.protobuf.ByteString; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.text.Text; +import org.opensearch.search.SearchHit; +import org.opensearch.serde.proto.SearchHitsTransportProto.NestedIdentityProto; +import org.opensearch.serde.proto.SearchHitsTransportProto.SearchHitProto; + +import java.io.IOException; +import java.util.HashMap; + +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesToProto; + +/** + * Serialization/Deserialization implementations for SearchHit. + * @opensearch.internal + */ +public class SearchHitProtobuf extends SearchHit { + public SearchHitProtobuf(SearchHit hit) { + super(hit); + } + + public SearchHitProtobuf(StreamInput in) throws IOException { + fromProtobufStream(in); + } + + public SearchHitProtobuf(SearchHitProto proto) { fromProto(proto); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + toProtobufStream(out); + } + + public void toProtobufStream(StreamOutput out) throws IOException { + toProto().writeTo(out); + } + + public void fromProtobufStream(StreamInput in) throws IOException { + SearchHitProto proto = SearchHitProto.parseFrom(in); + fromProto(proto); + } + + SearchHitProto toProto() { + SearchHitProto.Builder builder = SearchHitProto.newBuilder() + .setScore(score) + .setId(id.string()) + .setVersion(version) + .setSeqNo(seqNo) + .setPrimaryTerm(primaryTerm); + + builder.setNestedIdentity(nestedIdentityToProto(nestedIdentity)); + builder.setSource(ByteString.copyFrom(source.toBytesRef().bytes)); + builder.setExplanation(explanationToProto(explanation)); + builder.setSortValues(searchSortValuesToProto(sortValues)); + + documentFields.forEach((key, value) -> builder.putDocumentFields(key, documentFieldToProto(value))); + + metaFields.forEach((key, value) -> builder.putMetaFields(key, documentFieldToProto(value))); + + highlightFields.forEach((key, value) -> builder.putHighlightFields(key, highlightFieldToProto(value))); + + matchedQueries.forEach(builder::putMatchedQueries); + + // shard is optional + if (shard != null) { + builder.setShard(searchShardTargetToProto(shard)); + } + + innerHits.forEach((key, value) -> builder.putInnerHits(key, new SearchHitsProtobuf(value).toProto())); + + return builder.build(); + } + + void fromProto(SearchHitProto proto) { + docId = -1; + score = proto.getScore(); + seqNo = proto.getSeqNo(); + version = proto.getVersion(); + primaryTerm = proto.getPrimaryTerm(); + id = new Text(proto.getId()); + source = BytesReference.fromByteBuffer(proto.getSource().asReadOnlyByteBuffer()); + explanation = explanationFromProto(proto.getExplanation()); + sortValues = searchSortValuesFromProto(proto.getSortValues()); + nestedIdentity = nestedIdentityFromProto(proto.getNestedIdentity()); + matchedQueries = proto.getMatchedQueriesMap(); + + documentFields = new HashMap<>(); + proto.getDocumentFieldsMap().forEach((key, value) -> documentFields.put(key, documentFieldFromProto(value))); + + metaFields = new HashMap<>(); + proto.getMetaFieldsMap().forEach((key, value) -> metaFields.put(key, documentFieldFromProto(value))); + + highlightFields = new HashMap<>(); + proto.getHighlightFieldsMap().forEach((key, value) -> highlightFields.put(key, highlightFieldFromProto(value))); + + innerHits = new HashMap<>(); + proto.getInnerHitsMap().forEach((key, value) -> innerHits.put(key, new SearchHitsProtobuf(value))); + + shard = searchShardTargetFromProto(proto.getShard()); + index = shard.getIndex(); + clusterAlias = shard.getClusterAlias(); + } + + static NestedIdentityProto nestedIdentityToProto(SearchHit.NestedIdentity nestedIdentity) { + NestedIdentityProto.Builder builder = NestedIdentityProto.newBuilder() + .setField(nestedIdentity.getField().string()) + .setOffset(nestedIdentity.getOffset()); + + if (nestedIdentity.getChild() != null) { + builder.setChild(nestedIdentityToProto(nestedIdentity.getChild())); + } + + return builder.build(); + } + + static SearchHit.NestedIdentity nestedIdentityFromProto(NestedIdentityProto proto) { + String field = proto.getField(); + int offset = proto.getOffset(); + + SearchHit.NestedIdentity child = null; + if (proto.hasChild()) { + child = nestedIdentityFromProto(proto.getChild()); + } + + return new SearchHit.NestedIdentity(field, offset, child); + } +} diff --git a/server/src/main/java/org/opensearch/transport/serde/SearchHitsSerDe.java b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitsProtobuf.java similarity index 51% rename from server/src/main/java/org/opensearch/transport/serde/SearchHitsSerDe.java rename to server/src/main/java/org/opensearch/transport/protobuf/SearchHitsProtobuf.java index 9433d5446e271..c862bbc0fbdec 100644 --- a/server/src/main/java/org/opensearch/transport/serde/SearchHitsSerDe.java +++ b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitsProtobuf.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.transport.serde; +package org.opensearch.transport.protobuf; import com.google.protobuf.ByteString; import org.apache.lucene.search.SortField; @@ -24,97 +24,41 @@ import java.io.IOException; /** - * Serialization/Deserialization implementations for SearchHits. + * SearchHits child which implements serde operations as protobuf. * @opensearch.internal */ -public class SearchHitsSerDe extends SearchHits implements SerDe.nativeSerializer, SerDe.protobufSerializer { - SerDe.Strategy strategy = SerDe.Strategy.NATIVE; - - public SearchHitsSerDe(SearchHits hits, SerDe.Strategy strategy) { +public class SearchHitsProtobuf extends SearchHits { + public SearchHitsProtobuf(SearchHits hits) { super(hits); - this.strategy = strategy; - } - - public SearchHitsSerDe(SerDe.Strategy strategy, StreamInput in) throws IOException { - this.strategy = strategy; - switch (this.strategy) { - case NATIVE: - fromNativeStream(in); - break; - case PROTOBUF: - fromProtobufStream(in); - break; - default: - throw new AssertionError("This code should not be reachable"); - } } - public SearchHitsSerDe(StreamInput in) throws IOException { - fromNativeStream(in); + public SearchHitsProtobuf(StreamInput in) throws IOException { + fromProtobufStream(in); } - public SearchHitsSerDe(SearchHitsProto proto) { + public SearchHitsProtobuf(SearchHitsProto proto) { fromProto(proto); } @Override public void writeTo(StreamOutput out) throws IOException { - switch (this.strategy) { - case NATIVE: - toNativeStream(out); - break; - case PROTOBUF: - toProtobufStream(out); - break; - default: - throw new AssertionError("This code should not be reachable"); - } + toProtobufStream(out); } - @Override public void toProtobufStream(StreamOutput out) throws IOException { toProto().writeTo(out); } - @Override public void fromProtobufStream(StreamInput in) throws IOException { SearchHitsProto proto = SearchHitsProto.parseFrom(in); fromProto(proto); } - @Override - public void toNativeStream(StreamOutput out) throws IOException { - super.writeTo(out); - } - - @Override - public void fromNativeStream(StreamInput in) throws IOException { - if (in.readBoolean()) { - this.totalHits = Lucene.readTotalHits(in); - } else { - // track_total_hits is false - this.totalHits = null; - } - this.maxScore = in.readFloat(); - int size = in.readVInt(); - if (size == 0) { - hits = EMPTY; - } else { - hits = new SearchHit[size]; - for (int i = 0; i < hits.length; i++) { - hits[i] = new SearchHitSerDe(SerDe.Strategy.NATIVE, in); - } - } - this.sortFields = in.readOptionalArray(Lucene::readSortField, SortField[]::new); - this.collapseField = in.readOptionalString(); - this.collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new); - } - SearchHitsProto toProto() { SearchHitsProto.Builder builder = SearchHitsProto.newBuilder().setMaxScore(maxScore).setCollapseField(collapseField); for (SearchHit hit : hits) { - builder.addHits(new SearchHitSerDe(hit, strategy).toProto()); + builder.addHits(new SearchHitProtobuf(hit).toProto()); } TotalHits totHits = totalHits; @@ -125,20 +69,20 @@ SearchHitsProto toProto() { sortOut.writeOptionalArray(Lucene::writeSortField, sortFields); builder.setSortFields(ByteString.copyFrom(sortOut.bytes().toBytesRef().bytes)); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to serialize SearchHits to proto", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to serialize SearchHits to proto", e); } try (BytesStreamOutput collapseOut = new BytesStreamOutput()) { collapseOut.writeOptionalArray(Lucene::writeSortValue, collapseValues); builder.setCollapseValues(ByteString.copyFrom(collapseOut.bytes().toBytesRef().bytes)); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to serialize SearchHits to proto", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to serialize SearchHits to proto", e); } return builder.build(); } - void fromProto(SearchHitsProto proto) throws SerDe.SerializationException { + void fromProto(SearchHitsProto proto) throws ProtoSerDeHelpers.SerializationException { maxScore = proto.getMaxScore(); collapseField = proto.getCollapseField(); @@ -146,29 +90,25 @@ void fromProto(SearchHitsProto proto) throws SerDe.SerializationException { long rel = totHitsProto.getRelation(); long val = totHitsProto.getValue(); if (rel < 0 || rel >= TotalHits.Relation.values().length) { - throw new SerDe.SerializationException("Failed to deserialize TotalHits from proto"); + throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize TotalHits from proto"); } totalHits = new TotalHits(val, TotalHits.Relation.values()[(int) rel]); try (StreamInput sortBytesInput = new BytesArray(proto.getSortFields().toByteArray()).streamInput()) { sortFields = sortBytesInput.readOptionalArray(Lucene::readSortField, SortField[]::new); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to deserialize SearchHits from proto", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize SearchHits from proto", e); } try (StreamInput collapseBytesInput = new BytesArray(proto.getCollapseValues().toByteArray()).streamInput()) { collapseValues = collapseBytesInput.readOptionalArray(Lucene::readSortValue, Object[]::new); } catch (IOException e) { - throw new SerDe.SerializationException("Failed to deserialize SearchHits from proto", e); + throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize SearchHits from proto", e); } hits = new SearchHit[proto.getHitsCount()]; for (int i = 0; i < hits.length; i++) { - try { - hits[i] = new SearchHitSerDe(proto.getHits(i)); - } catch (IOException e) { - throw new SerDe.SerializationException("Failed to deserialize SearchHits from proto", e); - } + hits[i] = new SearchHitProtobuf(proto.getHits(i)); } } } diff --git a/server/src/main/java/org/opensearch/transport/serde/package-info.java b/server/src/main/java/org/opensearch/transport/protobuf/package-info.java similarity index 86% rename from server/src/main/java/org/opensearch/transport/serde/package-info.java rename to server/src/main/java/org/opensearch/transport/protobuf/package-info.java index fa144e7d4dcfd..16d52127e37e6 100644 --- a/server/src/main/java/org/opensearch/transport/serde/package-info.java +++ b/server/src/main/java/org/opensearch/transport/protobuf/package-info.java @@ -7,4 +7,4 @@ */ /** Serialization/Deserialization implementations for the fetch package. */ -package org.opensearch.transport.serde; +package org.opensearch.transport.protobuf; diff --git a/server/src/main/java/org/opensearch/transport/serde/FetchSearchResultSerDe.java b/server/src/main/java/org/opensearch/transport/serde/FetchSearchResultSerDe.java deleted file mode 100644 index f48d334ad3a2c..0000000000000 --- a/server/src/main/java/org/opensearch/transport/serde/FetchSearchResultSerDe.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.transport.serde; - -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.search.fetch.FetchSearchResult; -import org.opensearch.search.internal.ShardSearchContextId; -import org.opensearch.serde.proto.SearchHitsTransportProto.FetchSearchResultProto; - -import java.io.IOException; - -/** - * Serialization/Deserialization implementations for SearchHit. - * @opensearch.internal - */ -public class FetchSearchResultSerDe extends FetchSearchResult implements SerDe.nativeSerializer, SerDe.protobufSerializer { - SerDe.Strategy strategy = SerDe.Strategy.NATIVE; - - public FetchSearchResultSerDe(SerDe.Strategy strategy, StreamInput in) throws IOException { - this.strategy = strategy; - switch (this.strategy) { - case NATIVE: - fromNativeStream(in); - break; - case PROTOBUF: - fromProtobufStream(in); - break; - default: - throw new AssertionError("This code should not be reachable"); - } - } - - public FetchSearchResultSerDe(StreamInput in) throws IOException { - fromNativeStream(in); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - switch (this.strategy) { - case NATIVE: - toNativeStream(out); - break; - case PROTOBUF: - toProtobufStream(out); - break; - default: - throw new AssertionError("This code should not be reachable"); - } - } - - @Override - public void toProtobufStream(StreamOutput out) throws IOException { - toProto().writeTo(out); - } - - @Override - public void fromProtobufStream(StreamInput in) throws IOException { - FetchSearchResultProto proto = FetchSearchResultProto.parseFrom(in); - fromProto(proto); - } - - @Override - public void toNativeStream(StreamOutput out) throws IOException { - super.writeTo(out); - } - - @Override - public void fromNativeStream(StreamInput in) throws IOException { - this.hits = new SearchHitsSerDe(SerDe.Strategy.NATIVE, in); - this.contextId = new ShardSearchContextId(in); - } - - FetchSearchResultProto toProto() { - FetchSearchResultProto.Builder builder = FetchSearchResultProto.newBuilder() - .setHits(new SearchHitsSerDe(hits, strategy).toProto()) - .setCounter(this.counter); - return builder.build(); - } - - void fromProto(FetchSearchResultProto proto) { - hits = new SearchHitsSerDe(proto.getHits()); - } -} diff --git a/server/src/main/java/org/opensearch/transport/serde/SearchHitSerDe.java b/server/src/main/java/org/opensearch/transport/serde/SearchHitSerDe.java deleted file mode 100644 index 256622a4a033b..0000000000000 --- a/server/src/main/java/org/opensearch/transport/serde/SearchHitSerDe.java +++ /dev/null @@ -1,271 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.transport.serde; - -import com.google.protobuf.ByteString; -import org.opensearch.Version; -import org.opensearch.common.document.DocumentField; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.text.Text; -import org.opensearch.search.SearchHit; -import org.opensearch.search.SearchHits; -import org.opensearch.search.SearchShardTarget; -import org.opensearch.search.SearchSortValues; -import org.opensearch.search.fetch.subphase.highlight.HighlightField; -import org.opensearch.serde.proto.SearchHitsTransportProto.NestedIdentityProto; -import org.opensearch.serde.proto.SearchHitsTransportProto.SearchHitProto; - -import java.io.IOException; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.stream.Collectors; - -import static java.util.Collections.emptyMap; -import static java.util.Collections.singletonMap; -import static java.util.Collections.unmodifiableMap; -import static org.opensearch.common.lucene.Lucene.readExplanation; -import static org.opensearch.transport.serde.SerDe.documentFieldFromProto; -import static org.opensearch.transport.serde.SerDe.documentFieldToProto; -import static org.opensearch.transport.serde.SerDe.explanationFromProto; -import static org.opensearch.transport.serde.SerDe.explanationToProto; -import static org.opensearch.transport.serde.SerDe.highlightFieldFromProto; -import static org.opensearch.transport.serde.SerDe.highlightFieldToProto; -import static org.opensearch.transport.serde.SerDe.searchShardTargetFromProto; -import static org.opensearch.transport.serde.SerDe.searchShardTargetToProto; -import static org.opensearch.transport.serde.SerDe.searchSortValuesFromProto; -import static org.opensearch.transport.serde.SerDe.searchSortValuesToProto; - -/** - * Serialization/Deserialization implementations for SearchHit. - * @opensearch.internal - */ -public class SearchHitSerDe extends SearchHit implements SerDe.nativeSerializer, SerDe.protobufSerializer { - SerDe.Strategy strategy = SerDe.Strategy.NATIVE; - - public SearchHitSerDe(SearchHit hit, SerDe.Strategy strategy) { - super(hit); - this.strategy = strategy; - } - - public SearchHitSerDe(SerDe.Strategy strategy, StreamInput in) throws IOException { - this.strategy = strategy; - switch (this.strategy) { - case NATIVE: - fromNativeStream(in); - break; - case PROTOBUF: - fromProtobufStream(in); - break; - default: - throw new AssertionError("This code should not be reachable"); - } - } - - public SearchHitSerDe(StreamInput in) throws IOException { - fromNativeStream(in); - } - - public SearchHitSerDe(SearchHitProto proto) throws IOException { - fromProto(proto); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - switch (this.strategy) { - case NATIVE: - toNativeStream(out); - break; - case PROTOBUF: - toProtobufStream(out); - break; - default: - throw new AssertionError("This code should not be reachable"); - } - } - - @Override - public void toProtobufStream(StreamOutput out) throws IOException { - toProto().writeTo(out); - } - - @Override - public void fromProtobufStream(StreamInput in) throws IOException { - SearchHitProto proto = SearchHitProto.parseFrom(in); - fromProto(proto); - } - - @Override - public void toNativeStream(StreamOutput out) throws IOException { - super.writeTo(out); - } - - @Override - public void fromNativeStream(StreamInput in) throws IOException { - docId = -1; - score = in.readFloat(); - id = in.readOptionalText(); - if (in.getVersion().before(Version.V_2_0_0)) { - in.readOptionalText(); - } - - nestedIdentity = in.readOptionalWriteable(NestedIdentity::new); - version = in.readLong(); - seqNo = in.readZLong(); - primaryTerm = in.readVLong(); - source = in.readBytesReference(); - if (source.length() == 0) { - source = null; - } - if (in.readBoolean()) { - explanation = readExplanation(in); - } - documentFields = in.readMap(StreamInput::readString, DocumentField::new); - metaFields = in.readMap(StreamInput::readString, DocumentField::new); - - int size = in.readVInt(); - if (size == 0) { - highlightFields = emptyMap(); - } else if (size == 1) { - HighlightField field = new HighlightField(in); - highlightFields = singletonMap(field.name(), field); - } else { - Map hlFields = new HashMap<>(); - for (int i = 0; i < size; i++) { - HighlightField field = new HighlightField(in); - hlFields.put(field.name(), field); - } - highlightFields = unmodifiableMap(hlFields); - } - - sortValues = new SearchSortValues(in); - - size = in.readVInt(); - if (in.getVersion().onOrAfter(Version.V_2_13_0)) { - if (size > 0) { - Map tempMap = in.readMap(StreamInput::readString, StreamInput::readFloat); - matchedQueries = tempMap.entrySet() - .stream() - .sorted(Map.Entry.comparingByKey()) - .collect( - Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new) - ); - } - } else { - matchedQueries = new LinkedHashMap<>(size); - for (int i = 0; i < size; i++) { - matchedQueries.put(in.readString(), Float.NaN); - } - } - shard = in.readOptionalWriteable(SearchShardTarget::new); - if (shard != null) { - index = shard.getIndex(); - clusterAlias = shard.getClusterAlias(); - } - - size = in.readVInt(); - if (size > 0) { - innerHits = new HashMap<>(size); - for (int i = 0; i < size; i++) { - String key = in.readString(); - SearchHits value = new SearchHitsSerDe(strategy, in); - innerHits.put(key, value); - } - } else { - innerHits = null; - } - } - - SearchHitProto toProto() { - SearchHitProto.Builder builder = SearchHitProto.newBuilder() - .setScore(score) - .setId(id.string()) - .setVersion(version) - .setSeqNo(seqNo) - .setPrimaryTerm(primaryTerm); - - builder.setNestedIdentity(nestedIdentityToProto(nestedIdentity)); - builder.setSource(ByteString.copyFrom(source.toBytesRef().bytes)); - builder.setExplanation(explanationToProto(explanation)); - builder.setSortValues(searchSortValuesToProto(sortValues)); - - documentFields.forEach((key, value) -> builder.putDocumentFields(key, documentFieldToProto(value))); - - metaFields.forEach((key, value) -> builder.putMetaFields(key, documentFieldToProto(value))); - - highlightFields.forEach((key, value) -> builder.putHighlightFields(key, highlightFieldToProto(value))); - - matchedQueries.forEach(builder::putMatchedQueries); - - // shard is optional - if (shard != null) { - builder.setShard(searchShardTargetToProto(shard)); - } - - innerHits.forEach((key, value) -> builder.putInnerHits(key, new SearchHitsSerDe(value, strategy).toProto())); - - return builder.build(); - } - - void fromProto(SearchHitProto proto) throws SerDe.SerializationException { - docId = -1; - score = proto.getScore(); - seqNo = proto.getSeqNo(); - version = proto.getVersion(); - primaryTerm = proto.getPrimaryTerm(); - id = new Text(proto.getId()); - source = BytesReference.fromByteBuffer(proto.getSource().asReadOnlyByteBuffer()); - explanation = explanationFromProto(proto.getExplanation()); - sortValues = searchSortValuesFromProto(proto.getSortValues()); - nestedIdentity = nestedIdentityFromProto(proto.getNestedIdentity()); - matchedQueries = proto.getMatchedQueriesMap(); - - documentFields = new HashMap<>(); - proto.getDocumentFieldsMap().forEach((key, value) -> documentFields.put(key, documentFieldFromProto(value))); - - metaFields = new HashMap<>(); - proto.getMetaFieldsMap().forEach((key, value) -> metaFields.put(key, documentFieldFromProto(value))); - - highlightFields = new HashMap<>(); - proto.getHighlightFieldsMap().forEach((key, value) -> highlightFields.put(key, highlightFieldFromProto(value))); - - innerHits = new HashMap<>(); - proto.getInnerHitsMap().forEach((key, value) -> innerHits.put(key, new SearchHitsSerDe(value))); - - shard = searchShardTargetFromProto(proto.getShard()); - index = shard.getIndex(); - clusterAlias = shard.getClusterAlias(); - } - - static NestedIdentityProto nestedIdentityToProto(SearchHit.NestedIdentity nestedIdentity) { - NestedIdentityProto.Builder builder = NestedIdentityProto.newBuilder() - .setField(nestedIdentity.getField().string()) - .setOffset(nestedIdentity.getOffset()); - - if (nestedIdentity.getChild() != null) { - builder.setChild(nestedIdentityToProto(nestedIdentity.getChild())); - } - - return builder.build(); - } - - static SearchHit.NestedIdentity nestedIdentityFromProto(NestedIdentityProto proto) { - String field = proto.getField(); - int offset = proto.getOffset(); - - SearchHit.NestedIdentity child = null; - if (proto.hasChild()) { - child = nestedIdentityFromProto(proto.getChild()); - } - - return new SearchHit.NestedIdentity(field, offset, child); - } -} diff --git a/server/src/test/java/org/opensearch/search/SearchHitsProtobufTests.java b/server/src/test/java/org/opensearch/search/SearchHitsProtobufTests.java new file mode 100644 index 0000000000000..0fba91828fac3 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/SearchHitsProtobufTests.java @@ -0,0 +1,217 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search; + +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.tests.util.TestUtil; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.LuceneTests; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.MediaType; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.AbstractSerializingTestCase; +import org.opensearch.test.AbstractWireSerializingTestCase; +import org.opensearch.transport.protobuf.SearchHitsProtobuf; + +import java.io.IOException; +import java.util.Collections; +import java.util.function.Predicate; + +public class SearchHitsProtobufTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return SearchHitsProtobuf::new; + } + + @Override + protected SearchHitsProtobuf createTestInstance() { + // This instance is used to test the transport serialization so it's fine + // to produce shard targets (withShardTarget is true) since they are serialized + // in this layer. + return createTestItem(randomFrom(XContentType.values()), true, true); + } + + public static SearchHitsProtobuf createTestItem(boolean withOptionalInnerHits, boolean withShardTarget) { + return createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget); + } + + private static SearchHit[] createSearchHitArray( + int size, + final MediaType mediaType, + boolean withOptionalInnerHits, + boolean transportSerialization + ) { + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < hits.length; i++) { + hits[i] = SearchHitTests.createTestItem(mediaType, withOptionalInnerHits, transportSerialization); + } + return hits; + } + + private static TotalHits randomTotalHits(TotalHits.Relation relation) { + long totalHits = TestUtil.nextLong(random(), 0, Long.MAX_VALUE); + return new TotalHits(totalHits, relation); + } + + public static SearchHitsProtobuf createTestItem(final MediaType mediaType, boolean withOptionalInnerHits, boolean transportSerialization) { + return createTestItem(mediaType, withOptionalInnerHits, transportSerialization, randomFrom(TotalHits.Relation.values())); + } + + private static SearchHitsProtobuf createTestItem( + final MediaType mediaType, + boolean withOptionalInnerHits, + boolean transportSerialization, + TotalHits.Relation totalHitsRelation + ) { + int searchHits = randomIntBetween(0, 5); + SearchHit[] hits = createSearchHitArray(searchHits, mediaType, withOptionalInnerHits, transportSerialization); + TotalHits totalHits = frequently() ? randomTotalHits(totalHitsRelation) : null; + float maxScore = frequently() ? randomFloat() : Float.NaN; + SortField[] sortFields = null; + String collapseField = null; + Object[] collapseValues = null; + if (transportSerialization) { + sortFields = randomBoolean() ? createSortFields(randomIntBetween(1, 5)) : null; + collapseField = randomAlphaOfLengthBetween(5, 10); + collapseValues = randomBoolean() ? createCollapseValues(randomIntBetween(1, 10)) : null; + } + return new SearchHitsProtobuf(new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues)); + } + + private static SortField[] createSortFields(int size) { + SortField[] sortFields = new SortField[size]; + for (int i = 0; i < sortFields.length; i++) { + // sort fields are simplified before serialization, we write directly the simplified version + // otherwise equality comparisons become complicated + sortFields[i] = LuceneTests.randomSortField().v2(); + } + return sortFields; + } + + private static Object[] createCollapseValues(int size) { + Object[] collapseValues = new Object[size]; + for (int i = 0; i < collapseValues.length; i++) { + collapseValues[i] = LuceneTests.randomSortValue(); + } + return collapseValues; + } + + @Override + protected SearchHitsProtobuf mutateInstance(SearchHitsProtobuf instance) { + return new SearchHitsProtobuf(mutate(instance)); + } + + protected SearchHits mutate(SearchHits instance) { + switch (randomIntBetween(0, 5)) { + case 0: + return new SearchHits( + createSearchHitArray(instance.getHits().length + 1, randomFrom(XContentType.values()), false, randomBoolean()), + instance.getTotalHits(), + instance.getMaxScore() + ); + case 1: + final TotalHits totalHits; + if (instance.getTotalHits() == null) { + totalHits = randomTotalHits(randomFrom(TotalHits.Relation.values())); + } else { + totalHits = null; + } + return new SearchHits(instance.getHits(), totalHits, instance.getMaxScore()); + case 2: + final float maxScore; + if (Float.isNaN(instance.getMaxScore())) { + maxScore = randomFloat(); + } else { + maxScore = Float.NaN; + } + return new SearchHits(instance.getHits(), instance.getTotalHits(), maxScore); + case 3: + SortField[] sortFields; + if (instance.getSortFields() == null) { + sortFields = createSortFields(randomIntBetween(1, 5)); + } else { + sortFields = randomBoolean() ? createSortFields(instance.getSortFields().length + 1) : null; + } + return new SearchHits( + instance.getHits(), + instance.getTotalHits(), + instance.getMaxScore(), + sortFields, + instance.getCollapseField(), + instance.getCollapseValues() + ); + case 4: + String collapseField; + if (instance.getCollapseField() == null) { + collapseField = randomAlphaOfLengthBetween(5, 10); + } else { + collapseField = randomBoolean() ? instance.getCollapseField() + randomAlphaOfLengthBetween(2, 5) : null; + } + return new SearchHits( + instance.getHits(), + instance.getTotalHits(), + instance.getMaxScore(), + instance.getSortFields(), + collapseField, + instance.getCollapseValues() + ); + case 5: + Object[] collapseValues; + if (instance.getCollapseValues() == null) { + collapseValues = createCollapseValues(randomIntBetween(1, 5)); + } else { + collapseValues = randomBoolean() ? createCollapseValues(instance.getCollapseValues().length + 1) : null; + } + return new SearchHits( + instance.getHits(), + instance.getTotalHits(), + instance.getMaxScore(), + instance.getSortFields(), + instance.getCollapseField(), + collapseValues + ); + default: + throw new UnsupportedOperationException(); + } + } +}