Skip to content

Commit

Permalink
Introduce neural highlighter framework
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Feb 19, 2025
1 parent 628cb64 commit 1bc7505
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.highlight;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.Query;
import org.opensearch.OpenSearchException;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.query.NeuralKNNQuery;
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
import org.opensearch.search.fetch.subphase.highlight.Highlighter;
import org.opensearch.core.common.text.Text;

import java.util.Locale;
import java.util.Map;
import java.util.Objects;

/**
* Neural highlighter that uses ML models to identify relevant text spans for highlighting
*/
@Log4j2
public class NeuralHighlighter implements Highlighter {
public static final String NAME = "neural";
private static final String MODEL_ID_FIELD = "model_id";

private static MLCommonsClientAccessor mlCommonsClient;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralHighlighter.mlCommonsClient = mlClient;
}

@Override
public boolean canHighlight(MappedFieldType fieldType) {
// TODO: Implement actual condition check in subsequent PR
return true;
}

@Override
public HighlightField highlight(FieldHighlightContext fieldContext) {
try {
String fieldText = getFieldText(fieldContext);
if (fieldText.isEmpty()) {
return null;
}

String searchQuery = extractOriginalQuery(fieldContext.query);
if (searchQuery.isEmpty()) {
return null;
}

Map<String, Object> options = fieldContext.field.fieldOptions().options();
String modelId = getModelId(options);
log.info("Using model ID: {}", modelId); // Will be replaced with actual model loading logic
log.info("Using ML client: {}", mlCommonsClient); // Will be replaced with actual model loading logic

// TODO: Implement actual highlighting logic in subsequent PR
// For now, return a basic highlight of the field text
Text[] fragments = new Text[] { new Text(formatHighlight(fieldText)) };
return new HighlightField(fieldContext.fieldName, fragments);
} catch (Exception e) {
throw new OpenSearchException(
String.format(Locale.ROOT, "Failed to perform neural highlighting for field %s", fieldContext.fieldName),
e
);
}
}

private String getModelId(Map<String, Object> options) {
Object modelId = options.get(MODEL_ID_FIELD);
if (Objects.isNull(modelId)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Missing required option: %s", MODEL_ID_FIELD));
}
return modelId.toString();
}

private String getFieldText(FieldHighlightContext fieldContext) {
Object value = fieldContext.hitContext.sourceLookup().extractValue(fieldContext.fieldName, null);
return value != null ? value.toString() : "";
}

private String formatHighlight(String text) {
// TODO: Implement user provided format options in subsequent PR
return "<em>" + text + "</em>";
}

private String extractOriginalQuery(Query query) {
if (query instanceof NeuralKNNQuery neuralQuery) {
String originalText = neuralQuery.getOriginalQueryText();
if (originalText != null) {
return originalText;
}
}

return query.toString().replaceAll("\\w+:", "").replaceAll("\\s+", " ").trim();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.transport.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
Expand All @@ -25,8 +27,8 @@
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.executors.HybridQueryExecutor;
import org.opensearch.neuralsearch.highlight.NeuralHighlighter;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
Expand Down Expand Up @@ -65,6 +67,7 @@
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.fetch.subphase.highlight.Highlighter;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
Expand Down Expand Up @@ -103,6 +106,7 @@ public Collection<Object> createComponents(
NeuralSearchClusterUtil.instance().initialize(clusterService);
NeuralQueryBuilder.initialize(clientAccessor);
NeuralSparseQueryBuilder.initialize(clientAccessor);
NeuralHighlighter.initialize(clientAccessor);
HybridQueryExecutor.initialize(threadPool);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
Expand Down Expand Up @@ -204,4 +208,12 @@ public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
)
);
}

/**
* Register neural type highlighter
*/
@Override
public Map<String, Highlighter> getHighlighters() {
return Collections.singletonMap(NeuralHighlighter.NAME, new NeuralHighlighter());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
@Getter
public class NeuralKNNQuery extends Query {
private final Query knnQuery;
private final String originalQueryText;

public NeuralKNNQuery(Query knnQuery) {
public NeuralKNNQuery(Query knnQuery, String originalQueryText) {
this.knnQuery = knnQuery;
this.originalQueryText = originalQueryText;
}

@Override
Expand All @@ -49,19 +51,19 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (rewritten == knnQuery) {
return this;
}
return new NeuralKNNQuery(rewritten);
return new NeuralKNNQuery(rewritten, originalQueryText);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
NeuralKNNQuery that = (NeuralKNNQuery) other;
return Objects.equals(knnQuery, that.knnQuery);
return Objects.equals(knnQuery, that.knnQuery) && Objects.equals(originalQueryText, that.originalQueryText);
}

@Override
public int hashCode() {
return Objects.hash(knnQuery);
return Objects.hash(knnQuery, originalQueryText);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
@Getter
public class NeuralKNNQueryBuilder extends AbstractQueryBuilder<NeuralKNNQueryBuilder> {
private final KNNQueryBuilder knnQueryBuilder;
private final String originalQueryText;

/**
* Creates a new builder instance.
Expand Down Expand Up @@ -59,6 +60,7 @@ public static class Builder {
private Boolean expandNested;
private Map<String, ?> methodParameters;
private RescoreContext rescoreContext;
private String originalQueryText;

private Builder() {}

Expand Down Expand Up @@ -107,6 +109,11 @@ public Builder rescoreContext(RescoreContext rescoreContext) {
return this;
}

public Builder originalQueryText(String originalQueryText) {
this.originalQueryText = originalQueryText;
return this;
}

public NeuralKNNQueryBuilder build() {
KNNQueryBuilder knnBuilder = KNNQueryBuilder.builder()
.fieldName(fieldName)
Expand All @@ -119,12 +126,13 @@ public NeuralKNNQueryBuilder build() {
.methodParameters(methodParameters)
.rescoreContext(rescoreContext)
.build();
return new NeuralKNNQueryBuilder(knnBuilder);
return new NeuralKNNQueryBuilder(knnBuilder, originalQueryText);
}
}

private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder) {
private NeuralKNNQueryBuilder(KNNQueryBuilder knnQueryBuilder, String originalQueryText) {
this.knnQueryBuilder = knnQueryBuilder;
this.originalQueryText = originalQueryText;
}

@Override
Expand All @@ -143,23 +151,23 @@ protected QueryBuilder doRewrite(QueryRewriteContext context) throws IOException
if (rewritten == knnQueryBuilder) {
return this;
}
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten);
return new NeuralKNNQueryBuilder((KNNQueryBuilder) rewritten, originalQueryText);
}

@Override
protected Query doToQuery(QueryShardContext context) throws IOException {
Query knnQuery = knnQueryBuilder.toQuery(context);
return new NeuralKNNQuery(knnQuery);
return new NeuralKNNQuery(knnQuery, originalQueryText);
}

@Override
protected boolean doEquals(NeuralKNNQueryBuilder other) {
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder);
return Objects.equals(knnQueryBuilder, other.knnQueryBuilder) && Objects.equals(originalQueryText, other.originalQueryText);
}

@Override
protected int doHashCode() {
return Objects.hash(knnQueryBuilder);
return Objects.hash(knnQueryBuilder, originalQueryText);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
.expandNested(expandNested())
.methodParameters(methodParameters())
.rescoreContext(rescoreContext())
.originalQueryText(queryText())
.build();
}

Expand Down
Loading

0 comments on commit 1bc7505

Please sign in to comment.