Skip to content

Commit

Permalink
Adding request source for cohere (#104926)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner authored Jan 30, 2024
1 parent 202a81f commit 422e6f6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public HttpRequest createHttpRequest() {

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(account.apiKey()));
httpPost.setHeader(CohereUtils.createRequestSourceHeader());

return new HttpRequest(httpPost, getInferenceEntityId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@

package org.elasticsearch.xpack.inference.external.request.cohere;

import org.apache.http.Header;
import org.apache.http.message.BasicHeader;

public class CohereUtils {
public static final String HOST = "api.cohere.ai";
public static final String VERSION_1 = "v1";
public static final String EMBEDDINGS_PATH = "embed";
public static final String REQUEST_SOURCE_HEADER = "Request-Source";
public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch";

public static Header createRequestSourceHeader() {
return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE);
}

private CohereUtils() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.results.TextEmbeddingByteResultsTests;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;
Expand Down Expand Up @@ -130,6 +131,10 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
equalTo(XContentType.JSON.mediaType())
);
MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
MatcherAssert.assertThat(
webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER),
equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE)
);

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
MatcherAssert.assertThat(
Expand Down Expand Up @@ -210,6 +215,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I
equalTo(XContentType.JSON.mediaType())
);
MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
MatcherAssert.assertThat(
webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER),
equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE)
);

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
MatcherAssert.assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ public void testCreateRequest_UrlDefined() throws URISyntaxException, IOExceptio
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
is(CohereUtils.ELASTIC_REQUEST_SOURCE)
);

var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"))));
Expand Down Expand Up @@ -71,6 +75,10 @@ public void testCreateRequest_AllOptionsDefined() throws URISyntaxException, IOE
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
is(CohereUtils.ELASTIC_REQUEST_SOURCE)
);

var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
Expand Down Expand Up @@ -114,6 +122,10 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
is(CohereUtils.ELASTIC_REQUEST_SOURCE)
);

var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(
Expand Down Expand Up @@ -157,6 +169,10 @@ public void testCreateRequest_TruncateNone() throws URISyntaxException, IOExcept
MatcherAssert.assertThat(httpPost.getURI().toString(), is("url"));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
MatcherAssert.assertThat(
httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(),
is(CohereUtils.ELASTIC_REQUEST_SOURCE)
);

var requestMap = entityAsMap(httpPost.getEntity().getContent());
MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "truncate", "none")));
Expand Down

0 comments on commit 422e6f6

Please sign in to comment.