Skip to content

Commit

Permalink
Remove strategy wrapper. Just implement protobuf in new object, leave…
Browse files Browse the repository at this point in the history
… previous object for native serde.

Signed-off-by: Finn Carroll <carrofin@amazon.com>
  • Loading branch information
finnegancarroll committed Aug 27, 2024
1 parent 4dc97d8 commit d01540d
Show file tree
Hide file tree
Showing 8 changed files with 439 additions and 462 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport.protobuf;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.serde.proto.SearchHitsTransportProto.FetchSearchResultProto;

import java.io.IOException;

/**
* FetchSearchResult child which implements serde operations as protobuf.
* @opensearch.internal
*/
public class FetchSearchResultProtobuf extends FetchSearchResult {
public FetchSearchResultProtobuf(StreamInput in) throws IOException {
fromProtobufStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
toProtobufStream(out);
}

public void toProtobufStream(StreamOutput out) throws IOException {
toProto().writeTo(out);
}

public void fromProtobufStream(StreamInput in) throws IOException {
FetchSearchResultProto proto = FetchSearchResultProto.parseFrom(in);
fromProto(proto);
}

FetchSearchResultProto toProto() {
FetchSearchResultProto.Builder builder = FetchSearchResultProto.newBuilder()
.setHits(new SearchHitsProtobuf(hits).toProto())
.setCounter(this.counter);
return builder.build();
}

void fromProto(FetchSearchResultProto proto) {
hits = new SearchHitsProtobuf(proto.getHits());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* compatible open source license.
*/

package org.opensearch.transport.serde;
package org.opensearch.transport.protobuf;

import com.google.protobuf.ByteString;
import org.apache.lucene.search.Explanation;
Expand Down Expand Up @@ -40,12 +40,7 @@
* SerDe interfaces and protobuf SerDe implementations for some "primitive" types.
* @opensearch.internal
*/
public class SerDe {

public enum Strategy {
PROTOBUF,
NATIVE;
}
public class ProtoSerDeHelpers {

/**
* Serialization/Deserialization exception.
Expand All @@ -61,18 +56,6 @@ public SerializationException(String message, Throwable cause) {
}
}

interface nativeSerializer {
void toNativeStream(StreamOutput out) throws IOException;

void fromNativeStream(StreamInput in) throws IOException;
}

interface protobufSerializer {
void toProtobufStream(StreamOutput out) throws IOException;

void fromProtobufStream(StreamInput in) throws IOException;
}

// TODO: Lucene definitions should maybe be serialized as generic bytes arrays.
static ExplanationProto explanationToProto(Explanation explanation) {
ExplanationProto.Builder builder = ExplanationProto.newBuilder()
Expand Down Expand Up @@ -118,7 +101,7 @@ static DocumentFieldProto documentFieldToProto(DocumentField field) {
return builder.build();
}

static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws SerDe.SerializationException {
static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws ProtoSerDeHelpers.SerializationException {
String name = proto.getName();
List<Object> values = new ArrayList<>(0);

Expand All @@ -128,7 +111,7 @@ static DocumentField documentFieldFromProto(DocumentFieldProto proto) throws Ser
Object readValue = in.readGenericValue();
values.add(readValue);
} catch (IOException e) {
throw new SerDe.SerializationException("Failed to deserialize DocumentField values from proto object", e);
throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize DocumentField values from proto object", e);
}
}

Expand Down Expand Up @@ -177,7 +160,7 @@ static SearchSortValuesProto searchSortValuesToProto(SearchSortValues searchSort
return builder.build();
}

static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) throws SerDe.SerializationException {
static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) throws ProtoSerDeHelpers.SerializationException {
Object[] formattedSortValues = null;
Object[] rawSortValues = null;

Expand All @@ -186,7 +169,7 @@ static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) t
try (StreamInput formattedIn = formattedBytes.streamInput()) {
formattedSortValues = formattedIn.readArray(Lucene::readSortValue, Object[]::new);
} catch (IOException e) {
throw new SerDe.SerializationException("Failed to deserialize SearchSortValues from proto object", e);
throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize SearchSortValues from proto object", e);
}
}

Expand All @@ -195,7 +178,7 @@ static SearchSortValues searchSortValuesFromProto(SearchSortValuesProto proto) t
try (StreamInput rawIn = rawBytes.streamInput()) {
rawSortValues = rawIn.readArray(Lucene::readSortValue, Object[]::new);
} catch (IOException e) {
throw new SerDe.SerializationException("Failed to deserialize SearchSortValues from proto object", e);
throw new ProtoSerDeHelpers.SerializationException("Failed to deserialize SearchSortValues from proto object", e);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.transport.protobuf;

import com.google.protobuf.ByteString;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.text.Text;
import org.opensearch.search.SearchHit;
import org.opensearch.serde.proto.SearchHitsTransportProto.NestedIdentityProto;
import org.opensearch.serde.proto.SearchHitsTransportProto.SearchHitProto;

import java.io.IOException;
import java.util.HashMap;

import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.documentFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.explanationToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.highlightFieldToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchShardTargetToProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesFromProto;
import static org.opensearch.transport.protobuf.ProtoSerDeHelpers.searchSortValuesToProto;

/**
* Serialization/Deserialization implementations for SearchHit.
* @opensearch.internal
*/
public class SearchHitProtobuf extends SearchHit {
public SearchHitProtobuf(SearchHit hit) {
super(hit);
}

public SearchHitProtobuf(StreamInput in) throws IOException {
fromProtobufStream(in);
}

public SearchHitProtobuf(SearchHitProto proto) { fromProto(proto); }

@Override
public void writeTo(StreamOutput out) throws IOException {
toProtobufStream(out);
}

public void toProtobufStream(StreamOutput out) throws IOException {
toProto().writeTo(out);
}

public void fromProtobufStream(StreamInput in) throws IOException {
SearchHitProto proto = SearchHitProto.parseFrom(in);
fromProto(proto);
}

SearchHitProto toProto() {
SearchHitProto.Builder builder = SearchHitProto.newBuilder()
.setScore(score)
.setId(id.string())
.setVersion(version)
.setSeqNo(seqNo)
.setPrimaryTerm(primaryTerm);

builder.setNestedIdentity(nestedIdentityToProto(nestedIdentity));
builder.setSource(ByteString.copyFrom(source.toBytesRef().bytes));
builder.setExplanation(explanationToProto(explanation));
builder.setSortValues(searchSortValuesToProto(sortValues));

documentFields.forEach((key, value) -> builder.putDocumentFields(key, documentFieldToProto(value)));

metaFields.forEach((key, value) -> builder.putMetaFields(key, documentFieldToProto(value)));

highlightFields.forEach((key, value) -> builder.putHighlightFields(key, highlightFieldToProto(value)));

matchedQueries.forEach(builder::putMatchedQueries);

// shard is optional
if (shard != null) {
builder.setShard(searchShardTargetToProto(shard));
}

innerHits.forEach((key, value) -> builder.putInnerHits(key, new SearchHitsProtobuf(value).toProto()));

return builder.build();
}

void fromProto(SearchHitProto proto) {
docId = -1;
score = proto.getScore();
seqNo = proto.getSeqNo();
version = proto.getVersion();
primaryTerm = proto.getPrimaryTerm();
id = new Text(proto.getId());
source = BytesReference.fromByteBuffer(proto.getSource().asReadOnlyByteBuffer());
explanation = explanationFromProto(proto.getExplanation());
sortValues = searchSortValuesFromProto(proto.getSortValues());
nestedIdentity = nestedIdentityFromProto(proto.getNestedIdentity());
matchedQueries = proto.getMatchedQueriesMap();

documentFields = new HashMap<>();
proto.getDocumentFieldsMap().forEach((key, value) -> documentFields.put(key, documentFieldFromProto(value)));

metaFields = new HashMap<>();
proto.getMetaFieldsMap().forEach((key, value) -> metaFields.put(key, documentFieldFromProto(value)));

highlightFields = new HashMap<>();
proto.getHighlightFieldsMap().forEach((key, value) -> highlightFields.put(key, highlightFieldFromProto(value)));

innerHits = new HashMap<>();
proto.getInnerHitsMap().forEach((key, value) -> innerHits.put(key, new SearchHitsProtobuf(value)));

shard = searchShardTargetFromProto(proto.getShard());
index = shard.getIndex();
clusterAlias = shard.getClusterAlias();
}

static NestedIdentityProto nestedIdentityToProto(SearchHit.NestedIdentity nestedIdentity) {
NestedIdentityProto.Builder builder = NestedIdentityProto.newBuilder()
.setField(nestedIdentity.getField().string())
.setOffset(nestedIdentity.getOffset());

if (nestedIdentity.getChild() != null) {
builder.setChild(nestedIdentityToProto(nestedIdentity.getChild()));
}

return builder.build();
}

static SearchHit.NestedIdentity nestedIdentityFromProto(NestedIdentityProto proto) {
String field = proto.getField();
int offset = proto.getOffset();

SearchHit.NestedIdentity child = null;
if (proto.hasChild()) {
child = nestedIdentityFromProto(proto.getChild());
}

return new SearchHit.NestedIdentity(field, offset, child);
}
}
Loading

0 comments on commit d01540d

Please sign in to comment.