Skip to content

Commit

Permalink
[AI] [Inference] chat and embeddings regression tests (#30478)
Browse files Browse the repository at this point in the history
### Packages impacted by this PR
@azure-rest/ai-inference

### Issues associated with this PR


### Describe the problem that is addressed by this PR
Add regression tests for chat completions and embeddings routes

### What are the possible designs available to address the problem? If
there are more than one possible design, why was the one in this PR
chosen?


### Are there test cases added in this PR? _(If not, why?)_


### Provide a list of related PRs _(if any)_


### Command used to generate this PR:**_(Applicable only to SDK release
request PRs)_

### Checklists
- [ ] Added impacted package name to the issue description
- [ ] Does this PR needs any fixes in the SDK Generator?** _(If so,
create an Issue in the
[Autorest/typescript](https://github.com/Azure/autorest.typescript)
repository and link it here)_
- [ ] Added a changelog (if necessary)
  • Loading branch information
glharper authored Jul 20, 2024
1 parent da2b372 commit 9663ddb
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 5 deletions.
2 changes: 1 addition & 1 deletion sdk/ai/ai-inference-rest/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "js",
"TagPrefix": "js/ai/ai-inference-rest",
"Tag": "js/ai/ai-inference-rest_0c8119a2a5"
"Tag": "js/ai/ai-inference-rest_c675ded6ae"
}
101 changes: 100 additions & 1 deletion sdk/ai/ai-inference-rest/test/public/chatCompletions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import { createRecorder, createModelClient } from "./utils/recordedClient.js";
import { Recorder } from "@azure-tools/test-recorder";
import { assert, beforeEach, afterEach, it, describe } from "vitest";
import { ChatCompletionsOutput, ModelClient, ChatCompletionsFunctionToolCallOutput, isUnexpected } from "../../src/index.js";
import {
ChatCompletionsOutput,
ModelClient,
ChatCompletionsFunctionToolCallOutput,
ChatMessageContentItem,
ChatMessageImageContentItem,
isUnexpected
} from "../../src/index.js";

describe("chat test suite", () => {
let recorder: Recorder;
Expand All @@ -25,6 +32,98 @@ describe("chat test suite", () => {
assert.isNotNull(client.pipeline);
});

it("chat regression test", async function () {
const url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg";
const headers = { "extra-parameters": "allow" };
const body = {
messages: [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "How many feet are in a mile?" },
{ role: "user", content: [{ type: "image_url", image_url: { url, detail: "auto" } } ]},
],
frequency_penalty: 1,
stream: false,
presence_penalty: 1,
temperature: 1,
top_p: 1,
max_tokens: 1,
stop: ["<stop>"],
seed: 1,
model: "foo",
response_format: "foo",
tool_choice: "auto",
tools: [ { type: "function", function: { name: "foo", description: "bar" } } ]
}
const response = await client.path("/chat/completions").post({
headers,
body
});
const responseHeaders = response.request.headers.toJSON();
assert.isDefined(responseHeaders);
assert.isDefined(responseHeaders["extra-parameters"]);
assert.isTrue(responseHeaders["extra-parameters"] == headers["extra-parameters"]);

const request = response.request;
assert.isDefined(request);

const reqBody = request.body as string;
assert.isDefined(reqBody);
const json = JSON.parse(reqBody);

assert.isDefined(json["messages"]);
if (json["messages"]) {
assert.isNotEmpty(json["messages"]);
assert.isDefined(json["messages"][0]);
assert.isTrue(json["messages"][0]["role"] == body.messages[0].role);
assert.isTrue(json["messages"][0]["content"] == body.messages[0].content);
assert.isTrue(json["messages"][1]["role"] == body.messages[1].role);
assert.isTrue(json["messages"][1]["content"] == body.messages[1].content);
assert.isTrue(json["messages"][2]["role"] == body.messages[2].role);

const contentArray = json["messages"][2]["content"];
assert.isDefined(contentArray);
assert.isNotEmpty(contentArray);
if (contentArray) {
const sourceArray = body.messages[2].content as Array<ChatMessageContentItem>;
assert.isTrue(contentArray[0].type == sourceArray[0].type);
const imageUrlItem = sourceArray[0] as ChatMessageImageContentItem;
assert.isTrue(contentArray[0].image_url.url == imageUrlItem.image_url.url);
assert.isTrue(contentArray[0].image_url.detail == imageUrlItem.image_url.detail);
}
}
assert.isTrue(json["frequency_penalty"] == body.frequency_penalty);
assert.isTrue(json["stream"] == body.stream);
assert.isTrue(json["presence_penalty"] == body.presence_penalty);
assert.isTrue(json["temperature"] == body.temperature);
assert.isTrue(json["top_p"] == body.top_p);
assert.isTrue(json["max_tokens"] == body.max_tokens);
assert.isDefined(json["stop"]);
assert.isArray(json["stop"]);
assert.isNotEmpty(json["stop"]);

if (json["stop"]) {
assert.isDefined(json["stop"][0]);
assert.isTrue(json["stop"][0] == body.stop[0]);
}
assert.isTrue(json["seed"] == body.seed);
assert.isTrue(json["model"] == body.model);
assert.isTrue(json["response_format"] == body.response_format);
assert.isTrue(json["tool_choice"] == body.tool_choice);
assert.isDefined(json["tools"]);
assert.isArray(json["tools"]);
assert.isNotEmpty(json["tools"]);
if (json["tools"]) {
assert.isDefined(json["tools"][0]);
assert.isTrue(json["tools"][0].type == body.tools[0].type);
assert.isTrue(json["tools"][0].function.name == body.tools[0].function.name);
assert.isTrue(json["tools"][0].function.description == body.tools[0].function.description);
}
},
{
timeout: 50000
});


it("simple chat test", async function () {
const response = await client.path("/chat/completions").post({
body: {
Expand Down
48 changes: 47 additions & 1 deletion sdk/ai/ai-inference-rest/test/public/embeddings.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import { createRecorder, createModelClient } from "./utils/recordedClient.js";
import { Recorder } from "@azure-tools/test-recorder";
import { assert, beforeEach, afterEach, it, describe } from "vitest";
import { ModelClient, isUnexpected, EmbeddingsResultOutput } from "../../src/index.js";
import { ModelClient, GetEmbeddingsBodyParam, isUnexpected, EmbeddingsResultOutput } from "../../src/index.js";

describe("embeddings test suite", () => {
let recorder: Recorder;
Expand All @@ -19,6 +19,52 @@ describe("embeddings test suite", () => {
await recorder.stop();
});

it("embeddings regression test", async function () {
const headers = { "extra-parameters": "allow" };
const embeddingParams = {
body: {
input: ["first phrase"],
dimensions: 1,
encoding_format: "foo",
input_type: "foo",
model: "foo"
}
} as GetEmbeddingsBodyParam;

assert.isDefined(embeddingParams);

const response = await client.path("/embeddings").post({
headers,
body: embeddingParams.body
});
const responseHeaders = response.request.headers.toJSON();
assert.isDefined(responseHeaders);
assert.isDefined(responseHeaders["extra-parameters"]);
assert.isTrue(responseHeaders["extra-parameters"] == headers["extra-parameters"]);

const request = response.request;
assert.isDefined(request);

const reqBody = request.body as string;
assert.isDefined(reqBody);
const json = JSON.parse(reqBody);
assert.isDefined(json["input"]);
assert.isArray(json["input"]);
assert.isNotEmpty(json["input"]);

if (json["input"]) {
assert.isDefined(json["input"][0]);
assert.isTrue(json["input"][0] == embeddingParams.body?.input[0]);
}
assert.isTrue(json["dimensions"] == embeddingParams.body?.dimensions);
assert.isTrue(json["model"] == embeddingParams.body?.model);
assert.isTrue(json["encoding_format"] == embeddingParams.body?.encoding_format);
assert.isTrue(json["input_type"] == embeddingParams.body?.input_type);
},
{
timeout: 50000
});

it("simple embeddings test", async function () {
const response = await client.path("/embeddings").post({
body: {
Expand Down
4 changes: 2 additions & 2 deletions sdk/ai/ai-inference-rest/test/public/utils/recordedClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import createClient, { ModelClient } from "../../../src/index.js";
import { DeploymentType } from "../types.js";

const envSetupForPlayback: Record<string, string> = {
AZURE_AAD_ENDPOINT: "https://endpoint.openai.azure.com/openai/deployments/gpt-4o/",
AZURE_AAD_COMPLETIONS_ENDPOINT: "https://endpoint.openai.azure.com/openai/deployments/gpt-4o/",
AZURE_EMBEDDINGS_ENDPOINT: "https://endpoint.openai.azure.com/openai/deployments/text-embedding-3-small/",
SUBSCRIPTION_ID: "azure_subscription_id"
};
Expand All @@ -38,7 +38,7 @@ function getEndpointFromResourceType(resourceType: DeploymentType): string {
case "embeddings":
return assertEnvironmentVariable("AZURE_EMBEDDINGS_ENDPOINT");
case "completions":
return assertEnvironmentVariable("AZURE_AAD_ENDPOINT");
return assertEnvironmentVariable("AZURE_AAD_COMPLETIONS_ENDPOINT");
}
}

Expand Down

0 comments on commit 9663ddb

Please sign in to comment.