Skip to content

Commit

Permalink
Add linear function to rank_feature query (#67438)
Browse files Browse the repository at this point in the history
This adds a linear function to the set of functions available
for rank_feature query

Closes #49859
  • Loading branch information
mayya-sharipova authored Jan 18, 2021
1 parent f449b8f commit 7648221
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 28 deletions.
44 changes: 40 additions & 4 deletions docs/reference/query-dsl/rank-feature-query.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ query supports the following mathematical functions:
* <<rank-feature-query-saturation,Saturation>>
* <<rank-feature-query-logarithm,Logarithm>>
* <<rank-feature-query-sigmoid,Sigmoid>>
* <<rank-feature-query-linear,Linear>>

If you don't know where to start, we recommend using the `saturation` function.
If no function is provided, the `rank_feature` query uses the `saturation`
Expand Down Expand Up @@ -126,7 +127,7 @@ The following query searches for `2016` and boosts relevance scores based on

[source,console]
----
GET /test/_search
GET /test/_search
{
"query": {
"bool": {
Expand Down Expand Up @@ -190,7 +191,7 @@ value of the rank feature `field`. If no function is provided, the `rank_feature
query defaults to the `saturation` function. See
<<rank-feature-query-saturation,Saturation>> for more information.

Only one function `saturation`, `log`, or `sigmoid` can be provided.
Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--

`log`::
Expand All @@ -201,7 +202,7 @@ function used to boost <<relevance-scores,relevance scores>> based on the
value of the rank feature `field`. See
<<rank-feature-query-logarithm,Logarithm>> for more information.

Only one function `saturation`, `log`, or `sigmoid` can be provided.
Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--

`sigmoid`::
Expand All @@ -212,7 +213,18 @@ to boost <<relevance-scores,relevance scores>> based on the value of the
rank feature `field`. See <<rank-feature-query-sigmoid,Sigmoid>> for more
information.

Only one function `saturation`, `log`, or `sigmoid` can be provided.
Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--

`linear`::
+
--
(Optional, <<rank-feature-query-linear,function object>>) Linear function used
to boost <<relevance-scores,relevance scores>> based on the value of the
rank feature `field`. See <<rank-feature-query-linear,Linear>> for more
information.

Only one function `saturation`, `log`, `sigmoid` or `linear` can be provided.
--


Expand Down Expand Up @@ -311,3 +323,27 @@ GET /test/_search
}
}
--------------------------------------------------
[[rank-feature-query-linear]]
===== Linear
The `linear` function is the simplest function, and gives a score equal
to the indexed value of `S`, where `S` is the value of the rank feature
field.
If a rank feature field is indexed with `"positive_score_impact": true`,
its indexed value is equal to `S` and rounded to preserve only
9 significant bits for the precision.
If a rank feature field is indexed with `"positive_score_impact": false`,
its indexed value is equal to `1/S` and rounded to preserve only 9 significant
bits for the precision.

[source,console]
--------------------------------------------------
GET /test/_search
{
"query": {
"rank_feature": {
"field": "pagerank",
"linear": {}
}
}
}
--------------------------------------------------
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.RankFeatureFieldMapper.RankFeatureFieldType;
Expand Down Expand Up @@ -104,7 +105,7 @@ void doXContent(XContentBuilder builder) throws IOException {
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
if (positiveScoreImpact == false) {
throw new IllegalArgumentException("Cannot use the [log] function with a field that has a negative score impact as " +
"it would trigger negative scores");
Expand Down Expand Up @@ -175,7 +176,7 @@ void doXContent(XContentBuilder builder) throws IOException {
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
if (pivot == null) {
return FeatureField.newSaturationQuery(field, feature);
} else {
Expand Down Expand Up @@ -240,10 +241,55 @@ void doXContent(XContentBuilder builder) throws IOException {
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) throws IOException {
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
return FeatureField.newSigmoidQuery(field, feature, DEFAULT_BOOST, pivot, exp);
}
}

/**
* A scoring function that scores documents as simply {@code S}
* where S is the indexed value of the static feature.
*/
public static class Linear extends ScoreFunction {

private static final ObjectParser<Linear, Void> PARSER = new ObjectParser<>("linear", Linear::new);

public Linear() {
}

private Linear(StreamInput in) {
this();
}

@Override
public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) {
return false;
}
return true;
}

@Override
public int hashCode() {
return getClass().hashCode();
}

@Override
void writeTo(StreamOutput out) throws IOException {
out.writeByte((byte) 3);
}

@Override
void doXContent(XContentBuilder builder) throws IOException {
builder.startObject("linear");
builder.endObject();
}

@Override
Query toQuery(String field, String feature, boolean positiveScoreImpact) {
return FeatureField.newLinearQuery(field, feature, DEFAULT_BOOST);
}
}
}

private static ScoreFunction readScoreFunction(StreamInput in) throws IOException {
Expand All @@ -255,6 +301,8 @@ private static ScoreFunction readScoreFunction(StreamInput in) throws IOExceptio
return new ScoreFunction.Saturation(in);
case 2:
return new ScoreFunction.Sigmoid(in);
case 3:
return new ScoreFunction.Linear(in);
default:
throw new IOException("Illegal score function id: " + b);
}
Expand All @@ -268,7 +316,7 @@ private static ScoreFunction readScoreFunction(StreamInput in) throws IOExceptio
long numNonNulls = Arrays.stream(args, 3, args.length).filter(Objects::nonNull).count();
final RankFeatureQueryBuilder query;
if (numNonNulls > 1) {
throw new IllegalArgumentException("Can only specify one of [log], [saturation] and [sigmoid]");
throw new IllegalArgumentException("Can only specify one of [log], [saturation], [sigmoid] and [linear]");
} else if (numNonNulls == 0) {
query = new RankFeatureQueryBuilder(field, new ScoreFunction.Saturation());
} else {
Expand All @@ -292,6 +340,8 @@ private static ScoreFunction readScoreFunction(StreamInput in) throws IOExceptio
ScoreFunction.Saturation.PARSER, new ParseField("saturation"));
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ScoreFunction.Sigmoid.PARSER, new ParseField("sigmoid"));
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(),
ScoreFunction.Linear.PARSER, new ParseField("linear"));
}

public static final String NAME = "rank_feature";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,13 @@ public static RankFeatureQueryBuilder sigmoid(String fieldName, float pivot, flo
return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Sigmoid(pivot, exp));
}

/**
* Return a new {@link RankFeatureQueryBuilder} that will score documents as
* {@code S)} where S is the indexed value of the static feature.
* @param fieldName field that stores features
*/
public static RankFeatureQueryBuilder linear(String fieldName) {
return new RankFeatureQueryBuilder(fieldName, new RankFeatureQueryBuilder.ScoreFunction.Linear());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
protected RankFeatureQueryBuilder doCreateTestQueryBuilder() {
ScoreFunction function;
boolean mayUseNegativeField = true;
switch (random().nextInt(3)) {
switch (random().nextInt(4)) {
case 0:
mayUseNegativeField = false;
function = new ScoreFunction.Log(1 + randomFloat());
Expand All @@ -75,6 +75,9 @@ protected RankFeatureQueryBuilder doCreateTestQueryBuilder() {
case 2:
function = new ScoreFunction.Sigmoid(randomFloat(), randomFloat());
break;
case 3:
function = new ScoreFunction.Linear();
break;
default:
throw new AssertionError();
}
Expand Down Expand Up @@ -106,7 +109,7 @@ public void testDefaultScoreFunction() throws IOException {
assertEquals(FeatureField.newSaturationQuery("_feature", "my_feature_field"), parsedQuery);
}

public void testIllegalField() throws IOException {
public void testIllegalField() {
String query = "{\n" +
" \"rank_feature\" : {\n" +
" \"field\": \"" + TEXT_FIELD_NAME + "\"\n" +
Expand All @@ -118,7 +121,7 @@ public void testIllegalField() throws IOException {
e.getMessage());
}

public void testIllegalCombination() throws IOException {
public void testIllegalCombination() {
String query = "{\n" +
" \"rank_feature\" : {\n" +
" \"field\": \"my_negative_feature_field\",\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -46,7 +46,7 @@ setup:
scaling_factor: 3

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -59,7 +59,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -68,7 +68,7 @@ setup:
pivot: 20

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -81,7 +81,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -91,7 +91,27 @@ setup:
exponent: 0.6

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"

- match:
hits.hits.1._id: "1"

---
"Positive linear":
- do:
search:
index: test
body:
query:
rank_feature:
field: pagerank
linear: {}

- match:
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -105,7 +125,7 @@ setup:
- do:
catch: bad_request
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -118,7 +138,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -127,7 +147,7 @@ setup:
pivot: 20

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand All @@ -140,7 +160,7 @@ setup:

- do:
search:
rest_total_hits_as_int: true
index: test
body:
query:
rank_feature:
Expand All @@ -150,7 +170,28 @@ setup:
exponent: 0.6

- match:
hits.total: 2
hits.total.value: 2

- match:
hits.hits.0._id: "2"

- match:
hits.hits.1._id: "1"

---
"Negative linear":

- do:
search:
index: test
body:
query:
rank_feature:
field: url_length
linear: {}

- match:
hits.total.value: 2

- match:
hits.hits.0._id: "2"
Expand Down
Loading

0 comments on commit 7648221

Please sign in to comment.