Skip to content

Commit

Permalink
Support neural type highlighter
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <junqiu@amazon.com>
  • Loading branch information
junqiu-lei committed Mar 7, 2025
1 parent 5f25d6c commit eeafdff
Show file tree
Hide file tree
Showing 13 changed files with 619 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 3.x](https://github.com/opensearch-project/neural-search/compare/main...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
- Support neural sentence highlighter ([#1193](https://github.com/opensearch-project/neural-search/pull/1193))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
24 changes: 12 additions & 12 deletions DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ through the same build issue.

### Class and package names

Class names should use `CamelCase`.
Class names should use `CamelCase`.

Try to put new classes into existing packages if package name abstracts the purpose of the class.
Try to put new classes into existing packages if package name abstracts the purpose of the class.

Example of good class file name and package utilization:

Expand All @@ -371,7 +371,7 @@ methods rather than a long single one and does everything.
### Documentation

Document you code. That includes purpose of new classes, every public method and code sections that have critical or non-trivial
logic (check this example https://github.com/opensearch-project/neural-search/blob/main/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L238).
logic (check this example https://github.com/opensearch-project/neural-search/blob/main/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java#L238).

When you submit a feature PR, please submit a new
[documentation issue](https://github.com/opensearch-project/documentation-website/issues/new/choose). This is a path for the documentation to be published as part of https://opensearch.org/docs/latest/ documentation site.
Expand All @@ -384,17 +384,17 @@ For the most part, we're using common conventions for Java projects. Here are a

1. Use descriptive names for classes, methods, fields, and variables.
2. Avoid abbreviations unless they are widely accepted
3. Use `final` on all method arguments unless it's absolutely necessary
3. Use `final` on all method arguments unless it's absolutely necessary
4. Wildcard imports are not allowed.
5. Static imports are preferred over qualified imports when using static methods
6. Prefer creating non-static public methods whenever possible. Avoid static methods in general, as they can often serve as shortcuts.
Static methods are acceptable if they are private and do not access class state.
7. Use functional programming style inside methods unless it's a performance critical section.
7. Use functional programming style inside methods unless it's a performance critical section.
8. For parameters of lambda expression please use meaningful names instead of shorten cryptic ones.
9. Use Optional for return values if the value may not be present. This should be preferred to returning null.
10. Do not create checked exceptions, and do not throw checked exceptions from public methods whenever possible. In general, if you call a method with a checked exception, you should wrap that exception into an unchecked exception.
11. Throwing checked exceptions from private methods is acceptable.
12. Use String.format when a string includes parameters, and prefer this over direct string concatenation. Always specify a Locale with String.format;
12. Use String.format when a string includes parameters, and prefer this over direct string concatenation. Always specify a Locale with String.format;
as a rule of thumb, use Locale.ROOT.
13. Prefer Lombok annotations to the manually written boilerplate code
14. When throwing an exception, avoid including user-provided content in the exception message. For secure coding practices,
Expand Down Expand Up @@ -440,17 +440,17 @@ Fix any new warnings before submitting your PR to ensure proper code documentati

### Tests

Write unit and integration tests for your new functionality.
Write unit and integration tests for your new functionality.

Unit tests are preferred as they are cheap and fast, try to use them to cover all possible
combinations of parameters. Utilize mocks to mimic dependencies.
combinations of parameters. Utilize mocks to mimic dependencies.

Integration tests should be used sparingly, focusing primarily on the main (happy path) scenario or cases where extensive
mocking is impractical. Include one or two unhappy paths to confirm that correct response codes are returned to the user.
Whenever possible, favor scenarios that do not require model deployment. If model deployment is necessary, use an existing
Integration tests should be used sparingly, focusing primarily on the main (happy path) scenario or cases where extensive
mocking is impractical. Include one or two unhappy paths to confirm that correct response codes are returned to the user.
Whenever possible, favor scenarios that do not require model deployment. If model deployment is necessary, use an existing
model, as tests involving new model deployments are the most resource-intensive.

If your changes could affect backward compatibility, please include relevant backward compatibility tests along with your
If your changes could affect backward compatibility, please include relevant backward compatibility tests along with your
PR. For guidance on adding these tests, refer to the [Backwards Compatibility Testing](#backwards-compatibility-testing) section in this guide.

### Outdated or irrelevant code
Expand Down
14 changes: 7 additions & 7 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -376,15 +376,15 @@ testClusters.integTest {
// Install K-NN/ml-commons plugins on the integTest cluster nodes except security
configurations.zipArchive.asFileTree.each {
plugin(provider(new Callable<RegularFile>(){
@Override
RegularFile call() throws Exception {
return new RegularFile() {
@Override
File getAsFile() {
return it
}
@Override
RegularFile call() throws Exception {
return new RegularFile() {
@Override
File getAsFile() {
return it
}
}
}
}))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.highlight;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.Query;
import org.opensearch.OpenSearchException;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.query.NeuralKNNQuery;
import org.opensearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.opensearch.search.fetch.subphase.highlight.HighlightField;
import org.opensearch.search.fetch.subphase.highlight.Highlighter;
import org.opensearch.core.common.text.Text;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.processor.SentenceHighlightingRequest;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

/**
* Neural highlighter that uses ML models to identify relevant text spans for highlighting
*/
@Log4j2
public class NeuralHighlighter implements Highlighter {
public static final String NAME = "neural";
private static final String MODEL_ID_FIELD = "model_id";
private static final String PRE_TAG = "<em>";
private static final String POST_TAG = "</em>";
// Support text fields type as of now
private static final String supportedFieldType = "text";

private static MLCommonsClientAccessor mlCommonsClient;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralHighlighter.mlCommonsClient = mlClient;
}

@Override
public boolean canHighlight(MappedFieldType fieldType) {
return supportedFieldType.equals(fieldType.typeName());
}

@Override
public HighlightField highlight(FieldHighlightContext fieldContext) {
try {
MappedFieldType fieldType = fieldContext.fieldType;
if (canHighlight(fieldType) == false) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Field %s is not supported for neural highlighting", fieldContext.fieldName)
);
}

String fieldText = getFieldText(fieldContext);

String searchQuery = extractOriginalQuery(fieldContext.query);

if (fieldContext.field.fieldOptions().options() == null) {
throw new IllegalArgumentException("Field options cannot be null");
}

Map<String, Object> options = fieldContext.field.fieldOptions().options();
String modelId = getModelId(options);

// Get highlighted text from ML model
String highlightedText = getHighlightedText(modelId, searchQuery, fieldText);

// Return highlight field
Text[] fragments = new Text[] { new Text(highlightedText) };
return new HighlightField(fieldContext.fieldName, fragments);
} catch (Exception e) {
throw new OpenSearchException(
String.format(Locale.ROOT, "Failed to perform neural highlighting for field %s", fieldContext.fieldName),
e
);
}
}

/**
* Gets highlighted text from the ML model.
*
* @param modelId The ID of the model to use
* @param question The search query
* @param context The document text
* @return Formatted text with highlighting
*/
private String getHighlightedText(String modelId, String question, String context) {
if (mlCommonsClient == null) {
throw new IllegalStateException("ML Commons client is not initialized");
}

// Use CountDownLatch to wait for async response
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<List<Map<String, Object>>> resultRef = new AtomicReference<>();
AtomicReference<Exception> exceptionRef = new AtomicReference<>();

// Create SentenceHighlightingRequest
SentenceHighlightingRequest request = SentenceHighlightingRequest.builder()
.modelId(modelId)
.question(question)
.context(context)
.build();

// Call ML model with the request
mlCommonsClient.inferenceSentenceHighlighting(request, ActionListener.wrap(result -> {
resultRef.set(result);
latch.countDown();
}, exception -> {
exceptionRef.set(exception);
latch.countDown();
}));

// Check for exceptions
if (exceptionRef.get() != null) {
throw new OpenSearchException("Error during sentence highlighting inference", exceptionRef.get());
}

// Process result
List<Map<String, Object>> result = resultRef.get();

// Apply highlighting to the original context
return applyHighlighting(context, result);
}

/**
* Applies highlighting to the original context based on the ML model response.
*
* @param context The original document text
* @param highlightResults The highlighting results from the ML model
* @return Formatted text with highlighting
*/
private String applyHighlighting(String context, List<Map<String, Object>> highlightResults) {
if (highlightResults == null || highlightResults.isEmpty()) {
return context;
}

StringBuilder highlightedText = new StringBuilder(context);

// Process each highlight result
for (int resultIndex = highlightResults.size() - 1; resultIndex >= 0; resultIndex--) {
Map<String, Object> result = highlightResults.get(resultIndex);

// Get the "highlights" list from the result
@SuppressWarnings("unchecked")
List<Map<String, Object>> highlights = (List<Map<String, Object>>) result.get("highlights");

if (highlights == null || highlights.isEmpty()) {
log.warn("No highlights found in result: {}", result);
continue;
}

// Process each highlight in reverse order to avoid position shifts
for (int i = highlights.size() - 1; i >= 0; i--) {
Map<String, Object> highlight = highlights.get(i);

// Extract start and end positions
Number startNum = (Number) highlight.get("start");
Number endNum = (Number) highlight.get("end");

if (startNum == null || endNum == null) {
log.warn("Missing start or end position in highlight: {}", highlight);
continue;
}

int start = startNum.intValue();
int end = endNum.intValue();

// Validate positions
if (start < 0 || end > highlightedText.length() || start >= end) {
log.warn("Invalid highlight position: start={}, end={}, text length={}", start, end, highlightedText.length());
continue;
}

// Insert highlighting tags
highlightedText.insert(end, POST_TAG);
highlightedText.insert(start, PRE_TAG);
}
}

return highlightedText.toString();
}

private String getModelId(Map<String, Object> options) {
Object modelId = options.get(MODEL_ID_FIELD);
if (Objects.isNull(modelId)) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Missing required option: %s", MODEL_ID_FIELD));
}
return modelId.toString();
}

private String getFieldText(FieldHighlightContext fieldContext) {
// Extract each query hit's field value
String hitValue = (String) fieldContext.hitContext.sourceLookup().extractValue(fieldContext.fieldName, null);
if (hitValue.isEmpty()) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Field %s is empty", fieldContext.fieldName));
}
return hitValue;
}

private String extractOriginalQuery(Query query) {
String queryText = (query instanceof NeuralKNNQuery neuralQuery)
? neuralQuery.getOriginalQueryText()
: query.toString().replaceAll("\\w+:", "").replaceAll("\\s+", " ").trim();

if (queryText.isEmpty()) {
throw new IllegalArgumentException("Original neural query text is empty");
}
return queryText;
}
}
Loading

0 comments on commit eeafdff

Please sign in to comment.