diff --git a/buildSrc/src/main/groovy/org/opensearch/gradle/precommit/LicenseHeadersTask.groovy b/buildSrc/src/main/groovy/org/opensearch/gradle/precommit/LicenseHeadersTask.groovy index b8d0ed2b9c43c..f6c8da1191c5d 100644 --- a/buildSrc/src/main/groovy/org/opensearch/gradle/precommit/LicenseHeadersTask.groovy +++ b/buildSrc/src/main/groovy/org/opensearch/gradle/precommit/LicenseHeadersTask.groovy @@ -149,6 +149,8 @@ class LicenseHeadersTask extends AntTask { licenseFamilyName: "Generated") { // parsers generated by antlr pattern(substring: "ANTLR GENERATED CODE") + // Protobuf + pattern(substring: "Generated by the protocol buffer compiler") } // Vendored Code diff --git a/server/src/main/java/org/opensearch/search/SearchHit.java b/server/src/main/java/org/opensearch/search/SearchHit.java index 6391353cfe5b1..bcab575022add 100644 --- a/server/src/main/java/org/opensearch/search/SearchHit.java +++ b/server/src/main/java/org/opensearch/search/SearchHit.java @@ -98,46 +98,74 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public final class SearchHit implements Writeable, ToXContentObject, Iterable { +public class SearchHit implements Writeable, ToXContentObject, Iterable { - private final transient int docId; + protected transient int docId; private static final float DEFAULT_SCORE = Float.NaN; - private float score = DEFAULT_SCORE; + protected float score = DEFAULT_SCORE; - private final Text id; + @Nullable + protected Text id; - private final NestedIdentity nestedIdentity; + @Nullable + protected NestedIdentity nestedIdentity; - private long version = -1; - private long seqNo = SequenceNumbers.UNASSIGNED_SEQ_NO; - private long primaryTerm = SequenceNumbers.UNASSIGNED_PRIMARY_TERM; + protected long version = -1; + protected long seqNo = SequenceNumbers.UNASSIGNED_SEQ_NO; + protected long primaryTerm = SequenceNumbers.UNASSIGNED_PRIMARY_TERM; - private BytesReference source; + @Nullable + protected BytesReference source; - private Map documentFields; - private final Map metaFields; + protected Map documentFields; + protected Map metaFields; - private Map highlightFields = null; + @Nullable + protected Map highlightFields = null; - private SearchSortValues sortValues = SearchSortValues.EMPTY; + protected SearchSortValues sortValues = SearchSortValues.EMPTY; - private Map matchedQueries = new HashMap<>(); + protected Map matchedQueries = new HashMap<>(); - private Explanation explanation; + @Nullable + protected Explanation explanation; @Nullable - private SearchShardTarget shard; + protected SearchShardTarget shard; // These two fields normally get set when setting the shard target, so they hold the same values as the target thus don't get // serialized over the wire. When parsing hits back from xcontent though, in most of the cases (whenever explanation is disabled) // we can't rebuild the shard target object so we need to set these manually for users retrieval. - private transient String index; - private transient String clusterAlias; + protected transient String index; + protected transient String clusterAlias; private Map sourceAsMap; - private Map innerHits; + @Nullable + protected Map innerHits; + + public SearchHit(SearchHit hit) { + this.docId = hit.docId; + this.id = hit.id; + this.nestedIdentity = hit.nestedIdentity; + this.version = hit.version; + this.seqNo = hit.seqNo; + this.primaryTerm = hit.primaryTerm; + this.source = hit.source; + this.documentFields = hit.documentFields; + this.metaFields = hit.metaFields; + this.highlightFields = hit.highlightFields; + this.sortValues = hit.sortValues; + this.matchedQueries = hit.matchedQueries; + this.explanation = hit.explanation; + this.shard = hit.shard; + this.index = hit.index; + this.clusterAlias = hit.clusterAlias; + this.innerHits = hit.innerHits; + this.score = hit.score; + this.sourceAsMap = hit.sourceAsMap; + } // used only in tests public SearchHit(int docId) { @@ -236,7 +264,9 @@ public SearchHit(StreamInput in) throws IOException { } } - private static final Text SINGLE_MAPPING_TYPE = new Text(MapperService.SINGLE_MAPPING_NAME); + protected SearchHit() {} + + protected static final Text SINGLE_MAPPING_TYPE = new Text(MapperService.SINGLE_MAPPING_NAME); @Override public void writeTo(StreamOutput out) throws IOException { @@ -993,7 +1023,7 @@ private void buildExplanation(XContentBuilder builder, Explanation explanation) @Override public boolean equals(Object obj) { - if (obj == null || getClass() != obj.getClass()) { + if (!(obj instanceof SearchHit)) { return false; } SearchHit other = (SearchHit) obj; @@ -1057,7 +1087,7 @@ public NestedIdentity(String field, int offset, NestedIdentity child) { this.child = child; } - NestedIdentity(StreamInput in) throws IOException { + public NestedIdentity(StreamInput in) throws IOException { field = in.readOptionalText(); offset = in.readInt(); child = in.readOptionalWriteable(NestedIdentity::new); diff --git a/server/src/main/java/org/opensearch/search/SearchHits.java b/server/src/main/java/org/opensearch/search/SearchHits.java index 8232643b353f5..58b12fd9cd628 100644 --- a/server/src/main/java/org/opensearch/search/SearchHits.java +++ b/server/src/main/java/org/opensearch/search/SearchHits.java @@ -61,7 +61,7 @@ * @opensearch.api */ @PublicApi(since = "1.0.0") -public final class SearchHits implements Writeable, ToXContentFragment, Iterable { +public class SearchHits implements Writeable, ToXContentFragment, Iterable { public static SearchHits empty() { return empty(true); } @@ -72,15 +72,25 @@ public static SearchHits empty(boolean withTotalHits) { public static final SearchHit[] EMPTY = new SearchHit[0]; - private final SearchHit[] hits; - private final TotalHits totalHits; - private final float maxScore; + protected SearchHit[] hits; + protected float maxScore; @Nullable - private final SortField[] sortFields; + protected TotalHits totalHits; @Nullable - private final String collapseField; + protected SortField[] sortFields; @Nullable - private final Object[] collapseValues; + protected String collapseField; + @Nullable + protected Object[] collapseValues; + + public SearchHits(SearchHits sHits) { + this.hits = sHits.hits; + this.totalHits = sHits.totalHits; + this.maxScore = sHits.maxScore; + this.sortFields = sHits.sortFields; + this.collapseField = sHits.collapseField; + this.collapseValues = sHits.collapseValues; + } public SearchHits(SearchHit[] hits, @Nullable TotalHits totalHits, float maxScore) { this(hits, totalHits, maxScore, null, null, null); @@ -124,6 +134,8 @@ public SearchHits(StreamInput in) throws IOException { collapseValues = in.readOptionalArray(Lucene::readSortValue, Object[]::new); } + protected SearchHits() {} + @Override public void writeTo(StreamOutput out) throws IOException { final boolean hasTotalHits = totalHits != null; @@ -288,7 +300,7 @@ public static SearchHits fromXContent(XContentParser parser) throws IOException @Override public boolean equals(Object obj) { - if (obj == null || getClass() != obj.getClass()) { + if (!(obj instanceof SearchHits)) { return false; } SearchHits other = (SearchHits) obj; diff --git a/server/src/main/java/org/opensearch/search/SearchShardTarget.java b/server/src/main/java/org/opensearch/search/SearchShardTarget.java index 80b4feda374c6..f3862e52d83e7 100644 --- a/server/src/main/java/org/opensearch/search/SearchShardTarget.java +++ b/server/src/main/java/org/opensearch/search/SearchShardTarget.java @@ -58,6 +58,8 @@ public final class SearchShardTarget implements Writeable, Comparable details = new ArrayList<>(); + + Number val = null; + switch (proto.getValueCase()) { + case INT_VALUE: + val = proto.getIntValue(); + break; + case LONG_VALUE: + val = proto.getLongValue(); + break; + case FLOAT_VALUE: + val = proto.getFloatValue(); + break; + case DOUBLE_VALUE: + val = proto.getDoubleValue(); + break; + default: + // No value, leave null + } + + for (ExplanationProto det : proto.getDetailsList()) { + details.add(explanationFromProto(det)); + } + + if (proto.getMatch()) { + assert val != null; + return Explanation.match(val, description, details); + } + + return Explanation.noMatch(description, details); + } + + public static DocumentFieldProto documentFieldToProto(DocumentField field) { + DocumentFieldProto.Builder builder = DocumentFieldProto.newBuilder().setName(field.getName()); + + for (Object value : field.getValues()) { + builder.addValues(genericObjectToProto(value)); + } + + return builder.build(); + } + + public static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws TransportSerializationException { + String name = proto.getName(); + ArrayList values = new ArrayList<>(); + + for (int i = 0; i < proto.getValuesCount(); i++) { + GenericObjectProto v = proto.getValues(i); + values.add(genericObjectFromProto(v)); + } + + return new DocumentField(name, values); + } + + public static HighlightFieldProto highlightFieldToProto(HighlightField field) { + HighlightFieldProto.Builder builder = HighlightFieldProto.newBuilder().setName(field.getName()).setFragsNull(true); + + if (field.getFragments() != null) { + builder.setFragsNull(false); + for (Text frag : field.getFragments()) { + builder.addFragments(frag.string()); + } + } + + return builder.build(); + } + + public static HighlightField highlightFieldFromProto(HighlightFieldProto proto) { + String name = proto.getName(); + Text[] fragments = null; + + if (!proto.getFragsNull()) { + fragments = new Text[proto.getFragmentsCount()]; + for (int i = 0; i < proto.getFragmentsCount(); i++) { + fragments[i] = new Text(proto.getFragments(i)); + } + } + + return new HighlightField(name, fragments); + } + + public static SortFieldProto sortFieldToProto(SortField sortField) { + SortFieldProto.Builder builder = SortFieldProto.newBuilder() + .setType(sortTypeToProto(sortField.getType())) + .setReverse(sortField.getReverse()); + + if (sortField.getMissingValue() != null) { + builder.setMissingValue(missingValueToProto(sortField.getMissingValue())); + } + + if (sortField.getField() != null) { + builder.setField(sortField.getField()); + } + + return builder.build(); + } + + public static SortField sortFieldFromProto(SortFieldProto proto) { + String field = null; + if (proto.hasField()) { + field = proto.getField(); + } + + SortField sortField = new SortField(field, sortTypeFromProto(proto.getType()), proto.getReverse()); + + if (proto.hasMissingValue()) { + sortField.setMissingValue(missingValueFromProto(proto.getMissingValue())); + } + + return sortField; + } + + public static SortTypeProto sortTypeToProto(SortField.Type sortType) { + return SortTypeProto.forNumber(sortType.ordinal()); + } + + public static SortField.Type sortTypeFromProto(SortTypeProto proto) { + return SortField.Type.values()[proto.getNumber()]; + } + + public static MissingValueProto missingValueToProto(Object missingValue) { + MissingValueProto.Builder builder = MissingValueProto.newBuilder(); + + if (missingValue == SortField.STRING_FIRST) { + builder.setIntVal(1); + } else if (missingValue == SortField.STRING_LAST) { + builder.setIntVal(2); + } else { + builder.setObjVal(genericObjectToProto(missingValue)); + } + + return builder.build(); + } + + public static Object missingValueFromProto(MissingValueProto proto) { + switch (proto.getValueCase()) { + case INT_VAL: + if (proto.getIntVal() == 1) { + return SortField.STRING_FIRST; + } + if (proto.getIntVal() == 2) { + return SortField.STRING_LAST; + } + throw new TransportSerializationException("Unexpected sortField missingValue (INT_VAL): " + proto.getIntVal()); + case OBJ_VAL: + return genericObjectFromProto(proto.getObjVal()); + default: + throw new TransportSerializationException("Unexpected value case: " + proto.getValueCase()); + } + } + + public static SearchShardTargetProto searchShardTargetToProto(SearchShardTarget shardTarget) { + SearchShardTargetProto.Builder builder = SearchShardTargetProto.newBuilder() + .setNodeId(shardTarget.getNodeId()) + .setShardId(shardIdToProto(shardTarget.getShardId())); + + if (shardTarget.getClusterAlias() != null) { + builder.setClusterAlias(shardTarget.getClusterAlias()); + } + + return builder.build(); + } + + public static SearchShardTarget searchShardTargetFromProto(SearchShardTargetProto proto) { + String nodeId = proto.getNodeId(); + ShardId shardId = shardIdFromProto(proto.getShardId()); + + String clusterAlias = null; + if (proto.hasClusterAlias()) { + clusterAlias = proto.getClusterAlias(); + } + + return new SearchShardTarget(nodeId, shardId, clusterAlias); + } + + public static ShardIdProto shardIdToProto(ShardId shardId) { + return ShardIdProto.newBuilder() + .setIndex(indexToProto(shardId.getIndex())) + .setShardId(shardId.id()) + .setHashCode(shardId.hashCode()) + .build(); + } + + public static ShardId shardIdFromProto(ShardIdProto proto) { + Index index = indexFromProto(proto.getIndex()); + int shardId = proto.getShardId(); + return new ShardId(index, shardId); + } + + public static IndexProto indexToProto(Index index) { + return IndexProto.newBuilder().setName(index.getName()).setUuid(index.getUUID()).build(); + } + + public static Index indexFromProto(IndexProto proto) { + String name = proto.getName(); + String uuid = proto.getUuid(); + return new Index(name, uuid); + } +} diff --git a/server/src/main/java/org/opensearch/transport/protobuf/SearchHitProtobuf.java b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitProtobuf.java new file mode 100644 index 0000000000000..5e4253b6db83e --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitProtobuf.java @@ -0,0 +1,268 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobuf; + +import com.google.protobuf.ByteString; +import org.apache.lucene.util.BytesRef; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.text.Text; +import org.opensearch.proto.search.SearchHitsProtoDef; +import org.opensearch.proto.search.SearchHitsProtoDef.NestedIdentityProto; +import org.opensearch.proto.search.SearchHitsProtoDef.SearchHitProto; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchSortValues; +import org.opensearch.transport.TransportSerializationException; + +import java.io.IOException; +import java.math.BigInteger; +import java.util.HashMap; + +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto; + +/** + * SearchHit which leverages protobuf for transport layer serialization. + * @opensearch.internal + */ +public class SearchHitProtobuf extends SearchHit { + public SearchHitProtobuf(SearchHit hit) { + super(hit); + } + + public SearchHitProtobuf(StreamInput in) throws IOException { + fromProtobufStream(in); + } + + public SearchHitProtobuf(SearchHitProto proto) { + fromProto(proto); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + toProtobufStream(out); + } + + public void toProtobufStream(StreamOutput out) throws IOException { + toProto().writeTo(out); + } + + public void fromProtobufStream(StreamInput in) throws IOException { + SearchHitProto proto = SearchHitProto.parseFrom(in); + fromProto(proto); + } + + SearchHitProto toProto() { + SearchHitProto.Builder builder = SearchHitProto.newBuilder() + .setScore(score) + .setVersion(version) + .setSeqNo(seqNo) + .setPrimaryTerm(primaryTerm) + .setSortValues(searchSortValuesToProto(sortValues)); + + documentFields.forEach((key, value) -> builder.putDocumentFields(key, documentFieldToProto(value))); + metaFields.forEach((key, value) -> builder.putMetaFields(key, documentFieldToProto(value))); + matchedQueries.forEach(builder::putMatchedQueries); + + if (highlightFields != null) { + highlightFields.forEach((key, value) -> builder.putHighlightFields(key, highlightFieldToProto(value))); + } + if (innerHits != null) { + innerHits.forEach((key, value) -> builder.putInnerHits(key, new SearchHitsProtobuf(value).toProto())); + } + + if (source != null) { + builder.setSource(ByteString.copyFrom(source.toBytesRef().bytes)); + } + if (id != null) { + builder.setId(id.string()); + } + if (nestedIdentity != null) { + builder.setNestedIdentity(nestedIdentityToProto(nestedIdentity)); + } + if (shard != null) { + builder.setShard(searchShardTargetToProto(shard)); + } + if (explanation != null) { + builder.setExplanation(explanationToProto(explanation)); + } + + return builder.build(); + } + + void fromProto(SearchHitProto proto) { + docId = -1; + score = proto.getScore(); + version = proto.getVersion(); + seqNo = proto.getSeqNo(); + primaryTerm = proto.getPrimaryTerm(); + sortValues = searchSortValuesFromProto(proto.getSortValues()); + + documentFields = new HashMap<>(); + proto.getDocumentFieldsMap().forEach((key, value) -> documentFields.put(key, documentFieldFromProto(value))); + + metaFields = new HashMap<>(); + proto.getMetaFieldsMap().forEach((key, value) -> metaFields.put(key, documentFieldFromProto(value))); + + matchedQueries = proto.getMatchedQueriesMap(); + + highlightFields = new HashMap<>(); + proto.getHighlightFieldsMap().forEach((key, value) -> highlightFields.put(key, highlightFieldFromProto(value))); + + // innerHits is nullable + if (proto.getInnerHitsCount() < 1) { + innerHits = null; + } else { + innerHits = new HashMap<>(); + proto.getInnerHitsMap().forEach((key, value) -> innerHits.put(key, new SearchHitsProtobuf(value))); + } + + source = proto.hasSource() ? BytesReference.fromByteBuffer(proto.getSource().asReadOnlyByteBuffer()) : null; + id = proto.hasId() ? new Text(proto.getId()) : null; + nestedIdentity = proto.hasNestedIdentity() ? nestedIdentityFromProto(proto.getNestedIdentity()) : null; + explanation = proto.hasExplanation() ? explanationFromProto(proto.getExplanation()) : null; + + if (proto.hasShard()) { + shard = searchShardTargetFromProto(proto.getShard()); + index = shard.getIndex(); + clusterAlias = shard.getClusterAlias(); + } else { + shard = null; + index = null; + clusterAlias = null; + } + } + + public static NestedIdentityProto nestedIdentityToProto(SearchHit.NestedIdentity nestedIdentity) { + NestedIdentityProto.Builder builder = NestedIdentityProto.newBuilder() + .setField(nestedIdentity.getField().string()) + .setOffset(nestedIdentity.getOffset()); + + if (nestedIdentity.getChild() != null) { + builder.setChild(nestedIdentityToProto(nestedIdentity.getChild())); + } + + return builder.build(); + } + + public static SearchHit.NestedIdentity nestedIdentityFromProto(NestedIdentityProto proto) { + String field = proto.getField(); + int offset = proto.getOffset(); + + SearchHit.NestedIdentity child = null; + if (proto.hasChild()) { + child = nestedIdentityFromProto(proto.getChild()); + } + + return new SearchHit.NestedIdentity(field, offset, child); + } + + public static SearchHitsProtoDef.SearchSortValuesProto searchSortValuesToProto(SearchSortValues searchSortValues) { + SearchHitsProtoDef.SearchSortValuesProto.Builder builder = SearchHitsProtoDef.SearchSortValuesProto.newBuilder(); + + for (Object value : searchSortValues.getFormattedSortValues()) { + builder.addFormattedSortValues(sortValueToProto(value)); + } + + for (Object value : searchSortValues.getRawSortValues()) { + builder.addRawSortValues(sortValueToProto(value)); + } + + return builder.build(); + } + + public static SearchSortValues searchSortValuesFromProto(SearchHitsProtoDef.SearchSortValuesProto proto) + throws TransportSerializationException { + Object[] formattedSortValues = new Object[proto.getFormattedSortValuesCount()]; + Object[] rawSortValues = new Object[proto.getRawSortValuesCount()]; + + for (int i = 0; i < formattedSortValues.length; i++) { + SearchHitsProtoDef.SortValueProto sortProto = proto.getFormattedSortValues(i); + formattedSortValues[i] = sortValueFromProto(sortProto); + } + + for (int i = 0; i < rawSortValues.length; i++) { + SearchHitsProtoDef.SortValueProto sortProto = proto.getRawSortValues(i); + rawSortValues[i] = sortValueFromProto(sortProto); + } + + return new SearchSortValues(formattedSortValues, rawSortValues); + } + + public static SearchHitsProtoDef.SortValueProto sortValueToProto(Object sortValue) throws TransportSerializationException { + SearchHitsProtoDef.SortValueProto.Builder builder = SearchHitsProtoDef.SortValueProto.newBuilder(); + + if (sortValue == null) { + builder.setIsNull(true); + } else if (sortValue.getClass().equals(String.class)) { + builder.setStringValue((String) sortValue); + } else if (sortValue.getClass().equals(Integer.class)) { + builder.setIntValue((Integer) sortValue); + } else if (sortValue.getClass().equals(Long.class)) { + builder.setLongValue((Long) sortValue); + } else if (sortValue.getClass().equals(Float.class)) { + builder.setFloatValue((Float) sortValue); + } else if (sortValue.getClass().equals(Double.class)) { + builder.setDoubleValue((Double) sortValue); + } else if (sortValue.getClass().equals(Byte.class)) { + builder.setByteValue((Byte) sortValue); + } else if (sortValue.getClass().equals(Short.class)) { + builder.setShortValue((Short) sortValue); + } else if (sortValue.getClass().equals(Boolean.class)) { + builder.setBoolValue((Boolean) sortValue); + } else if (sortValue.getClass().equals(BytesRef.class)) { + builder.setBytesValue( + ByteString.copyFrom(((BytesRef) sortValue).bytes, ((BytesRef) sortValue).offset, ((BytesRef) sortValue).length) + ); + } else if (sortValue.getClass().equals(BigInteger.class)) { + builder.setBigIntegerValue(sortValue.toString()); + } else { + throw new TransportSerializationException("Unexpected sortValue: " + sortValue); + } + + return builder.build(); + } + + public static Object sortValueFromProto(SearchHitsProtoDef.SortValueProto proto) throws TransportSerializationException { + switch (proto.getValueCase()) { + case STRING_VALUE: + return proto.getStringValue(); + case INT_VALUE: + return proto.getIntValue(); + case LONG_VALUE: + return proto.getLongValue(); + case FLOAT_VALUE: + return proto.getFloatValue(); + case DOUBLE_VALUE: + return proto.getDoubleValue(); + case BYTE_VALUE: + return (byte) proto.getByteValue(); + case SHORT_VALUE: + return (short) proto.getShortValue(); + case BOOL_VALUE: + return proto.getBoolValue(); + case BYTES_VALUE: + ByteString byteString = proto.getBytesValue(); + return new BytesRef(byteString.toByteArray()); + case BIG_INTEGER_VALUE: + return new BigInteger(proto.getBigIntegerValue()); + case IS_NULL: + return null; + } + + throw new TransportSerializationException("Unexpected value case: " + proto.getValueCase()); + } +} diff --git a/server/src/main/java/org/opensearch/transport/protobuf/SearchHitsProtobuf.java b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitsProtobuf.java new file mode 100644 index 0000000000000..413801c732922 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobuf/SearchHitsProtobuf.java @@ -0,0 +1,133 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.transport.protobuf; + +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.proto.search.SearchHitsProtoDef.SearchHitsProto; +import org.opensearch.proto.search.SearchHitsProtoDef.SortFieldProto; +import org.opensearch.proto.search.SearchHitsProtoDef.SortValueProto; +import org.opensearch.proto.search.SearchHitsProtoDef.TotalHitsProto; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.transport.TransportSerializationException; + +import java.io.IOException; +import java.util.List; + +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldToProto; +import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueFromProto; +import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueToProto; + +/** + * SearchHits which leverages protobuf for transport layer serialization. + * @opensearch.internal + */ +public class SearchHitsProtobuf extends SearchHits { + public SearchHitsProtobuf(SearchHits hits) { + super(hits); + } + + public SearchHitsProtobuf(StreamInput in) throws IOException { + fromProtobufStream(in); + } + + public SearchHitsProtobuf(SearchHitsProto proto) { + fromProto(proto); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + toProtobufStream(out); + } + + public void toProtobufStream(StreamOutput out) throws IOException { + toProto().writeTo(out); + } + + public void fromProtobufStream(StreamInput in) throws IOException { + SearchHitsProto proto = SearchHitsProto.parseFrom(in); + fromProto(proto); + } + + SearchHitsProto toProto() { + SearchHitsProto.Builder builder = SearchHitsProto.newBuilder().setMaxScore(maxScore); + + for (SearchHit hit : hits) { + builder.addHits(new SearchHitProtobuf(hit).toProto()); + } + + if (collapseField != null) { + builder.setCollapseField(collapseField); + } + + if (totalHits != null) { + TotalHitsProto.Builder totHitsBuilder = TotalHitsProto.newBuilder() + .setRelation(totalHits.relation.ordinal()) + .setValue(totalHits.value); + builder.setTotalHits(totHitsBuilder); + } + + if (sortFields != null) { + for (SortField field : sortFields) { + builder.addSortFields(sortFieldToProto(field)); + } + } + + if (collapseValues != null) { + for (Object col : collapseValues) { + builder.addCollapseValues(sortValueToProto(col)); + } + } + + return builder.build(); + } + + void fromProto(SearchHitsProto proto) throws TransportSerializationException { + maxScore = proto.getMaxScore(); + + hits = new SearchHit[proto.getHitsCount()]; + for (int i = 0; i < hits.length; i++) { + hits[i] = new SearchHitProtobuf(proto.getHits(i)); + } + + collapseField = proto.hasCollapseField() ? proto.getCollapseField() : null; + totalHits = proto.hasTotalHits() ? totalHitsFromProto(proto.getTotalHits()) : null; + sortFields = proto.getSortFieldsCount() > 0 ? sortFieldsFromProto(proto.getSortFieldsList()) : null; + collapseValues = proto.getCollapseValuesCount() > 0 ? collapseValuesFromProto(proto.getCollapseValuesList()) : null; + } + + private TotalHits totalHitsFromProto(TotalHitsProto proto) { + long rel = proto.getRelation(); + long val = proto.getValue(); + if (rel < 0 || rel >= TotalHits.Relation.values().length) { + throw new TransportSerializationException("Failed to deserialize TotalHits from proto"); + } + return new TotalHits(val, TotalHits.Relation.values()[(int) rel]); + } + + private SortField[] sortFieldsFromProto(List protoList) { + SortField[] fields = new SortField[protoList.size()]; + for (int i = 0; i < protoList.size(); i++) { + fields[i] = sortFieldFromProto(protoList.get(i)); + } + return fields; + } + + private Object[] collapseValuesFromProto(List protoList) { + Object[] vals = new Object[protoList.size()]; + for (int i = 0; i < protoList.size(); i++) { + vals[i] = sortValueFromProto(protoList.get(i)); + } + return vals; + } +} diff --git a/server/src/main/java/org/opensearch/transport/protobuf/package-info.java b/server/src/main/java/org/opensearch/transport/protobuf/package-info.java new file mode 100644 index 0000000000000..16d52127e37e6 --- /dev/null +++ b/server/src/main/java/org/opensearch/transport/protobuf/package-info.java @@ -0,0 +1,10 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** Serialization/Deserialization implementations for the fetch package. */ +package org.opensearch.transport.protobuf; diff --git a/server/src/main/proto/search/SearchHits.proto b/server/src/main/proto/search/SearchHits.proto new file mode 100644 index 0000000000000..9bfcb9849b873 --- /dev/null +++ b/server/src/main/proto/search/SearchHits.proto @@ -0,0 +1,137 @@ +syntax = "proto3"; + +package org.opensearch.proto.search; + +option java_outer_classname = "SearchHitsProtoDef"; + +// Generic java Object +// Ideally we are able to remove all usages of this message type +message GenericObjectProto { + bytes value = 1; +} + +message SearchHitsProto { + float max_score = 1; + optional string collapse_field = 2; + optional TotalHitsProto total_hits = 5; + repeated SortValueProto collapse_values = 4; + repeated SortFieldProto sort_fields = 3; + repeated SearchHitProto hits = 6; +} + +message TotalHitsProto { + int64 value = 1; + int64 relation = 2; +} + +message SortValueProto { + oneof value { + string string_value = 1; + int32 int_value = 2; + int64 long_value = 3; + float float_value = 4; + double double_value = 5; + int32 byte_value = 6; + int32 short_value = 7; + bool bool_value = 8; + bytes bytes_value = 9; + string big_integer_value = 10; + bool is_null = 11; + } +} + +message SortFieldProto { + bool reverse = 1; + SortTypeProto type = 2; + optional string field = 3; + optional MissingValueProto missing_value = 4; +} + +message MissingValueProto { + oneof value { + int32 int_val = 1; + GenericObjectProto obj_val = 2; + } +} + +enum SortTypeProto { + SCORE = 0; + DOC = 1; + STRING = 2; + INT = 3; + FLOAT = 4; + LONG = 5; + DOUBLE = 6; + CUSTOM = 7; + STRING_VAL = 8; + REWRITEABLE = 9; +} + +message SearchHitProto { + float score = 1; + int64 version = 2; + int64 seq_no = 3; + int64 primary_term = 4; + SearchSortValuesProto sort_values = 5; + map document_fields = 6; + map meta_fields = 7; + map highlight_fields = 8; + map matched_queries = 9; + map inner_hits = 10; + optional bytes source = 11; // compressible map + optional string id = 12; + optional NestedIdentityProto nested_identity = 13; + optional SearchShardTargetProto shard = 14; + optional ExplanationProto explanation = 15; +} + +message NestedIdentityProto { + string field = 1; + int32 offset = 2; + NestedIdentityProto child = 3; +} + +message DocumentFieldProto { + string name = 1; + repeated GenericObjectProto values = 2; +} + +message HighlightFieldProto { + string name = 1; + bool frags_null = 2; // fragments can be null OR empty + repeated string fragments = 3; +} + +message SearchSortValuesProto { + repeated SortValueProto formatted_sort_values = 1; + repeated SortValueProto raw_sort_values = 2; +} + +message ExplanationProto { + bool match = 1; + oneof value { + int32 int_value = 2; + int64 long_value = 3; + float float_value = 4; + double double_value = 5; + } + string description = 6; + repeated ExplanationProto details = 7; +} + +message SearchShardTargetProto { + string node_id = 1; + ShardIdProto shard_id = 2; + optional string cluster_alias = 3; +} + +message ShardIdProto { + IndexProto index = 1; + int32 shard_id = 2; + int32 hash_code = 3; +} + +message IndexProto { + string name = 1; + string uuid = 2; +} diff --git a/server/src/main/proto/search/fetch/FetchSearchResult.proto b/server/src/main/proto/search/fetch/FetchSearchResult.proto new file mode 100644 index 0000000000000..e04e041167ed0 --- /dev/null +++ b/server/src/main/proto/search/fetch/FetchSearchResult.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package org.opensearch.proto.search.fetch; + +import "search/SearchHits.proto"; + +option java_outer_classname = "FetchSearchResultProtoDef"; + +message FetchSearchResultProto { + org.opensearch.proto.search.SearchHitsProto hits = 1; + int64 counter = 2; +} diff --git a/server/src/test/java/org/opensearch/search/SearchHitProtobufTests.java b/server/src/test/java/org/opensearch/search/SearchHitProtobufTests.java new file mode 100644 index 0000000000000..9d043efcdc4ea --- /dev/null +++ b/server/src/test/java/org/opensearch/search/SearchHitProtobufTests.java @@ -0,0 +1,132 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search; + +import org.apache.lucene.search.Explanation; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.proto.search.SearchHitsProtoDef; +import org.opensearch.search.fetch.subphase.highlight.HighlightField; +import org.opensearch.search.fetch.subphase.highlight.HighlightFieldTests; +import org.opensearch.test.AbstractWireSerializingTestCase; +import org.opensearch.transport.protobuf.SearchHitProtobuf; + +import static org.opensearch.index.get.DocumentFieldTests.randomDocumentField; +import static org.opensearch.search.SearchHitTests.createExplanation; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto; +import static org.opensearch.transport.protobuf.SearchHitProtobuf.searchSortValuesFromProto; +import static org.opensearch.transport.protobuf.SearchHitProtobuf.searchSortValuesToProto; + +public class SearchHitProtobufTests extends AbstractWireSerializingTestCase { + public void testDocumentFieldProtoSerialization() { + DocumentField orig = randomDocumentField(randomFrom(XContentType.values()), randomBoolean(), fieldName -> false).v1(); + SearchHitsProtoDef.DocumentFieldProto proto = documentFieldToProto(orig); + DocumentField cpy = documentFieldFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + + public void testHighlightFieldProtoSerialization() { + HighlightField orig = HighlightFieldTests.createTestItem(); + SearchHitsProtoDef.HighlightFieldProto proto = highlightFieldToProto(orig); + HighlightField cpy = highlightFieldFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + + public void testSearchSortValuesProtoSerialization() { + SearchSortValues orig = SearchSortValuesTests.createTestItem(randomFrom(XContentType.values()), true); + SearchHitsProtoDef.SearchSortValuesProto proto = searchSortValuesToProto(orig); + SearchSortValues cpy = searchSortValuesFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + + public void testNestedIdentityProtoSerialization() { + SearchHit.NestedIdentity orig = NestedIdentityTests.createTestItem(randomIntBetween(0, 2)); + SearchHitsProtoDef.NestedIdentityProto proto = SearchHitProtobuf.nestedIdentityToProto(orig); + SearchHit.NestedIdentity cpy = SearchHitProtobuf.nestedIdentityFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + + public void testSearchShardTargetProtoSerialization() { + String index = randomAlphaOfLengthBetween(5, 10); + String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(5, 10); + SearchShardTarget orig = new SearchShardTarget( + randomAlphaOfLengthBetween(5, 10), + new ShardId(new Index(index, randomAlphaOfLengthBetween(5, 10)), randomInt()), + clusterAlias, + OriginalIndices.NONE + ); + SearchHitsProtoDef.SearchShardTargetProto proto = searchShardTargetToProto(orig); + SearchShardTarget cpy = searchShardTargetFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + + public void testExplanationProtoSerialization() { + Explanation orig = createExplanation(randomIntBetween(0, 5)); + SearchHitsProtoDef.ExplanationProto proto = explanationToProto(orig); + Explanation cpy = explanationFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + + @Override + protected Writeable.Reader instanceReader() { + return SearchHitProtobuf::new; + } + + @Override + protected SearchHitProtobuf createTestInstance() { + return new SearchHitProtobuf(SearchHitTests.createTestItem(randomFrom(XContentType.values()), randomBoolean(), randomBoolean())); + } +} diff --git a/server/src/test/java/org/opensearch/search/SearchHitsProtobufTests.java b/server/src/test/java/org/opensearch/search/SearchHitsProtobufTests.java new file mode 100644 index 0000000000000..b2f52e87aadd2 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/SearchHitsProtobufTests.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.search; + +import org.apache.lucene.search.SortField; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.proto.search.SearchHitsProtoDef; +import org.opensearch.test.AbstractWireSerializingTestCase; +import org.opensearch.transport.protobuf.SearchHitsProtobuf; + +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldFromProto; +import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.sortFieldToProto; +import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueFromProto; +import static org.opensearch.transport.protobuf.SearchHitProtobuf.sortValueToProto; + +public class SearchHitsProtobufTests extends AbstractWireSerializingTestCase { + public void testSortFieldProtoSerialization() { + SortField[] fields = SearchHitsTests.createSortFields(randomIntBetween(1, 5)); + for (SortField orig : fields) { + SearchHitsProtoDef.SortFieldProto proto = sortFieldToProto(orig); + SortField cpy = sortFieldFromProto(proto); + assertEquals(orig, cpy); + assertEquals(orig.hashCode(), cpy.hashCode()); + assertNotSame(orig, cpy); + } + } + + public void testSortValueProtoSerialization() { + Object[] values = SearchHitsTests.createCollapseValues(randomIntBetween(1, 10)); + for (Object orig : values) { + SearchHitsProtoDef.SortValueProto proto = sortValueToProto(orig); + Object cpy = sortValueFromProto(proto); + + assertEquals(orig, cpy); + if (orig != null && cpy != null) { + assertEquals(orig.hashCode(), cpy.hashCode()); + } + } + } + + @Override + protected Writeable.Reader instanceReader() { + return SearchHitsProtobuf::new; + } + + @Override + protected SearchHitsProtobuf createTestInstance() { + return new SearchHitsProtobuf(SearchHitsTests.createTestItem(randomFrom(XContentType.values()), true, true)); + } + + @Override + protected SearchHitsProtobuf mutateInstance(SearchHitsProtobuf instance) { + return new SearchHitsProtobuf(SearchHitsTests.mutate(instance)); + } +} diff --git a/server/src/test/java/org/opensearch/search/SearchHitsTests.java b/server/src/test/java/org/opensearch/search/SearchHitsTests.java index fd3ba35a4d3bb..cf00ba768da7c 100644 --- a/server/src/test/java/org/opensearch/search/SearchHitsTests.java +++ b/server/src/test/java/org/opensearch/search/SearchHitsTests.java @@ -103,7 +103,7 @@ private static SearchHits createTestItem( return new SearchHits(hits, totalHits, maxScore, sortFields, collapseField, collapseValues); } - private static SortField[] createSortFields(int size) { + public static SortField[] createSortFields(int size) { SortField[] sortFields = new SortField[size]; for (int i = 0; i < sortFields.length; i++) { // sort fields are simplified before serialization, we write directly the simplified version @@ -113,7 +113,7 @@ private static SortField[] createSortFields(int size) { return sortFields; } - private static Object[] createCollapseValues(int size) { + public static Object[] createCollapseValues(int size) { Object[] collapseValues = new Object[size]; for (int i = 0; i < collapseValues.length; i++) { collapseValues[i] = LuceneTests.randomSortValue(); @@ -123,6 +123,10 @@ private static Object[] createCollapseValues(int size) { @Override protected SearchHits mutateInstance(SearchHits instance) { + return mutate(instance); + } + + public static SearchHits mutate(SearchHits instance) { switch (randomIntBetween(0, 5)) { case 0: return new SearchHits(