Skip to content

Commit

Permalink
[PoC] Expose term frequency in Painless script score context
Browse files Browse the repository at this point in the history
Signed-off-by: Louis Chu <clingzhi@amazon.com>
  • Loading branch information
noCharger committed Jul 28, 2023
1 parent 0003bd8 commit 5991108
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 10 deletions.
10 changes: 7 additions & 3 deletions .idea/runConfigurations/Debug_OpenSearch.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public boolean needs_score() {

@Override
public ScoreScript newInstance(final LeafReaderContext leaf) throws IOException {
return new ScoreScript(null, null, null) {
return new ScoreScript(null, null, null, null) {
// Fake the scorer until setScorer is called.
DoubleValues values = source.getValues(leaf, new DoubleValues() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.lucene.expressions.js.JavascriptCompiler;
import org.apache.lucene.expressions.js.VariableContext;
import org.apache.lucene.search.DoubleValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.SpecialPermission;
import org.opensearch.common.Nullable;
import org.opensearch.index.fielddata.IndexFieldData;
Expand Down Expand Up @@ -110,7 +111,7 @@ public FilterScript.LeafFactory newFactory(Map<String, Object> params, SearchLoo

contexts.put(ScoreScript.CONTEXT, (Expression expr) -> new ScoreScript.Factory() {
@Override
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup) {
public ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher) {
return newScoreScript(expr, lookup, params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,11 @@ static Response innerShardOperation(Request request, ScriptService scriptService
} else if (scriptContext == ScoreScript.CONTEXT) {
return prepareRamIndex(request, (context, leafReaderContext) -> {
ScoreScript.Factory factory = scriptService.compile(request.script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory leafFactory = factory.newFactory(request.getScript().getParams(), context.lookup());
ScoreScript.LeafFactory leafFactory = factory.newFactory(
request.getScript().getParams(),
context.lookup(),
context.searcher()
);
ScoreScript scoreScript = leafFactory.newInstance(leafReaderContext);
scoreScript.setDocument(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class org.opensearch.script.ScoreScript @no_import {
}

static_import {
int termFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TermFreq
float tf(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TF
long totalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$TotalTermFreq
long sumTotalTermFreq(org.opensearch.script.ScoreScript, String, String) bound_to org.opensearch.script.ScoreScriptUtils$SumTotalTermFreq
double saturation(double, double) from_class org.opensearch.script.ScoreScriptUtils
double sigmoid(double, double, double) from_class org.opensearch.script.ScoreScriptUtils
double randomScore(org.opensearch.script.ScoreScript, int, String) bound_to org.opensearch.script.ScoreScriptUtils$RandomScoreField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ protected int doHashCode() {
protected ScoreFunction doToFunction(QueryShardContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
return new ScriptScoreFunction(
script,
searchScript,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
);
}
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup());
ScoreScript.LeafFactory scoreScriptFactory = factory.newFactory(script.getParams(), context.lookup(), context.searcher());
final QueryBuilder queryBuilder = this.query;
Query query = queryBuilder.toQuery(context);
return new ScriptScoreQuery(
Expand Down
41 changes: 39 additions & 2 deletions server/src/main/java/org/opensearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorable;
import org.opensearch.Version;
import org.opensearch.common.logging.DeprecationLogger;
Expand Down Expand Up @@ -115,20 +116,24 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
private final IndexSearcher indexSearcher;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
if (lookup == null) {
assert params == null;
assert leafContext == null;
this.params = null;
this.leafLookup = null;
this.docBase = 0;
this.indexSearcher = null;
} else {
this.leafLookup = lookup.getLeafSearchLookup(leafContext);
params = new HashMap<>(params);
params.putAll(leafLookup.asMap());
this.params = new DynamicMap(params, PARAMS_FUNCTIONS);
this.docBase = leafContext.docBase;
this.indexSearcher = indexSearcher;
}
}

Expand All @@ -144,6 +149,38 @@ public Map<String, ScriptDocValues<?>> getDoc() {
return leafLookup.doc();
}

public int termFreq(String field, String term) {
try {
return leafLookup.termFreq(field, term, docId);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public float tf(String field, String term) {
try {
return leafLookup.tf(field, term, docId, indexSearcher);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public long totalTermFreq(String field, String term) throws IOException {
try {
return leafLookup.totalTermFreq(field, term, docId, indexSearcher);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public long sumTotalTermFreq(String field) throws IOException {
try {
return leafLookup.sumTotalTermFreq(field, docId, indexSearcher);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

/** Set the current document to run the script on next. */
public void setDocument(int docid) {
this.docId = docid;
Expand Down Expand Up @@ -268,7 +305,7 @@ public interface LeafFactory {
*/
public interface Factory extends ScriptFactory {

ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup);
ScoreScript.LeafFactory newFactory(Map<String, Object> params, SearchLookup lookup, IndexSearcher indexSearcher);

}

Expand Down
64 changes: 64 additions & 0 deletions server/src/main/java/org/opensearch/script/ScoreScriptUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,70 @@ public static double sigmoid(double value, double k, double a) {
return Math.pow(value, a) / (Math.pow(k, a) + Math.pow(value, a));
}

public static final class TermFreq {
private final ScoreScript scoreScript;

public TermFreq(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public int termFreq(String fieldName, String term) {
try {
return scoreScript.termFreq(fieldName, term);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

public static final class TF {
private final ScoreScript scoreScript;

public TF(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public float tf(String fieldName, String term) {
try {
return scoreScript.tf(fieldName, term);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

public static final class TotalTermFreq {
private final ScoreScript scoreScript;

public TotalTermFreq(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public long totalTermFreq(String fieldName, String term) {
try {
return scoreScript.totalTermFreq(fieldName, term);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

public static final class SumTotalTermFreq {
private final ScoreScript scoreScript;

public SumTotalTermFreq(ScoreScript scoreScript) {
this.scoreScript = scoreScript;
}

public long sumTotalTermFreq(String fieldName) {
try {
return scoreScript.sumTotalTermFreq(fieldName);
} catch (Exception e) {
throw ExceptionsHelper.convertToOpenSearchException(e);
}
}
}

/**
* random score based on the documents' values of the given field
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
package org.opensearch.search.lookup;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.valuesource.SumTotalTermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TFValueSource;
import org.apache.lucene.queries.function.valuesource.TermFreqValueSource;
import org.apache.lucene.queries.function.valuesource.TotalTermFreqValueSource;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.common.lucene.BytesRefs;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -87,4 +94,30 @@ public void setDocument(int docId) {
sourceLookup.setSegmentAndDocument(ctx, docId);
fieldsLookup.setDocument(docId);
}

public int termFreq(String field, String term, int docId) throws IOException {
TermFreqValueSource valueSource = new TermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
return valueSource.getValues(null, ctx).intVal(docId);
}

public float tf(String field, String term, int docId, IndexSearcher indexSearcher) throws IOException {
TFValueSource valueSource = new TFValueSource(field, term, field, BytesRefs.toBytesRef(term));
Map context = new HashMap();
context.put("searcher", indexSearcher);
return valueSource.getValues(context, ctx).floatVal(docId);
}

public long totalTermFreq(String field, String term, int docId, IndexSearcher indexSearcher) throws IOException {
TotalTermFreqValueSource valueSource = new TotalTermFreqValueSource(field, term, field, BytesRefs.toBytesRef(term));
Map context = new HashMap();
valueSource.createWeight(context, indexSearcher);
return valueSource.getValues(context, ctx).longVal(docId);
}

public long sumTotalTermFreq(String field, int docId, IndexSearcher indexSearcher) throws IOException {
SumTotalTermFreqValueSource valueSource = new SumTotalTermFreqValueSource(field);
Map context = new HashMap();
valueSource.createWeight(context, indexSearcher);
return valueSource.getValues(context, ctx).longVal(docId);
}
}

0 comments on commit 5991108

Please sign in to comment.