From 9663ddbf7aacaa69ec4177bc6da819855a5d7257 Mon Sep 17 00:00:00 2001 From: Glenn Harper <64209257+glharper@users.noreply.github.com> Date: Sat, 20 Jul 2024 10:15:34 -0700 Subject: [PATCH] [AI] [Inference] chat and embeddings regression tests (#30478) ### Packages impacted by this PR @azure-rest/ai-inference ### Issues associated with this PR ### Describe the problem that is addressed by this PR Add regression tests for chat completions and embeddings routes ### What are the possible designs available to address the problem? If there are more than one possible design, why was the one in this PR chosen? ### Are there test cases added in this PR? _(If not, why?)_ ### Provide a list of related PRs _(if any)_ ### Command used to generate this PR:**_(Applicable only to SDK release request PRs)_ ### Checklists - [ ] Added impacted package name to the issue description - [ ] Does this PR needs any fixes in the SDK Generator?** _(If so, create an Issue in the [Autorest/typescript](https://github.com/Azure/autorest.typescript) repository and link it here)_ - [ ] Added a changelog (if necessary) --- sdk/ai/ai-inference-rest/assets.json | 2 +- .../test/public/chatCompletions.spec.ts | 101 +++++++++++++++++- .../test/public/embeddings.spec.ts | 48 ++++++++- .../test/public/utils/recordedClient.ts | 4 +- 4 files changed, 150 insertions(+), 5 deletions(-) diff --git a/sdk/ai/ai-inference-rest/assets.json b/sdk/ai/ai-inference-rest/assets.json index cb7c6d3094c8..74da96f6f358 100644 --- a/sdk/ai/ai-inference-rest/assets.json +++ b/sdk/ai/ai-inference-rest/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "js", "TagPrefix": "js/ai/ai-inference-rest", - "Tag": "js/ai/ai-inference-rest_0c8119a2a5" + "Tag": "js/ai/ai-inference-rest_c675ded6ae" } diff --git a/sdk/ai/ai-inference-rest/test/public/chatCompletions.spec.ts b/sdk/ai/ai-inference-rest/test/public/chatCompletions.spec.ts index 96e37fe9a367..e6145190a83f 100644 --- a/sdk/ai/ai-inference-rest/test/public/chatCompletions.spec.ts +++ b/sdk/ai/ai-inference-rest/test/public/chatCompletions.spec.ts @@ -4,7 +4,14 @@ import { createRecorder, createModelClient } from "./utils/recordedClient.js"; import { Recorder } from "@azure-tools/test-recorder"; import { assert, beforeEach, afterEach, it, describe } from "vitest"; -import { ChatCompletionsOutput, ModelClient, ChatCompletionsFunctionToolCallOutput, isUnexpected } from "../../src/index.js"; +import { + ChatCompletionsOutput, + ModelClient, + ChatCompletionsFunctionToolCallOutput, + ChatMessageContentItem, + ChatMessageImageContentItem, + isUnexpected +} from "../../src/index.js"; describe("chat test suite", () => { let recorder: Recorder; @@ -25,6 +32,98 @@ describe("chat test suite", () => { assert.isNotNull(client.pipeline); }); + it("chat regression test", async function () { + const url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"; + const headers = { "extra-parameters": "allow" }; + const body = { + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "How many feet are in a mile?" }, + { role: "user", content: [{ type: "image_url", image_url: { url, detail: "auto" } } ]}, + ], + frequency_penalty: 1, + stream: false, + presence_penalty: 1, + temperature: 1, + top_p: 1, + max_tokens: 1, + stop: [""], + seed: 1, + model: "foo", + response_format: "foo", + tool_choice: "auto", + tools: [ { type: "function", function: { name: "foo", description: "bar" } } ] + } + const response = await client.path("/chat/completions").post({ + headers, + body + }); + const responseHeaders = response.request.headers.toJSON(); + assert.isDefined(responseHeaders); + assert.isDefined(responseHeaders["extra-parameters"]); + assert.isTrue(responseHeaders["extra-parameters"] == headers["extra-parameters"]); + + const request = response.request; + assert.isDefined(request); + + const reqBody = request.body as string; + assert.isDefined(reqBody); + const json = JSON.parse(reqBody); + + assert.isDefined(json["messages"]); + if (json["messages"]) { + assert.isNotEmpty(json["messages"]); + assert.isDefined(json["messages"][0]); + assert.isTrue(json["messages"][0]["role"] == body.messages[0].role); + assert.isTrue(json["messages"][0]["content"] == body.messages[0].content); + assert.isTrue(json["messages"][1]["role"] == body.messages[1].role); + assert.isTrue(json["messages"][1]["content"] == body.messages[1].content); + assert.isTrue(json["messages"][2]["role"] == body.messages[2].role); + + const contentArray = json["messages"][2]["content"]; + assert.isDefined(contentArray); + assert.isNotEmpty(contentArray); + if (contentArray) { + const sourceArray = body.messages[2].content as Array; + assert.isTrue(contentArray[0].type == sourceArray[0].type); + const imageUrlItem = sourceArray[0] as ChatMessageImageContentItem; + assert.isTrue(contentArray[0].image_url.url == imageUrlItem.image_url.url); + assert.isTrue(contentArray[0].image_url.detail == imageUrlItem.image_url.detail); + } + } + assert.isTrue(json["frequency_penalty"] == body.frequency_penalty); + assert.isTrue(json["stream"] == body.stream); + assert.isTrue(json["presence_penalty"] == body.presence_penalty); + assert.isTrue(json["temperature"] == body.temperature); + assert.isTrue(json["top_p"] == body.top_p); + assert.isTrue(json["max_tokens"] == body.max_tokens); + assert.isDefined(json["stop"]); + assert.isArray(json["stop"]); + assert.isNotEmpty(json["stop"]); + + if (json["stop"]) { + assert.isDefined(json["stop"][0]); + assert.isTrue(json["stop"][0] == body.stop[0]); + } + assert.isTrue(json["seed"] == body.seed); + assert.isTrue(json["model"] == body.model); + assert.isTrue(json["response_format"] == body.response_format); + assert.isTrue(json["tool_choice"] == body.tool_choice); + assert.isDefined(json["tools"]); + assert.isArray(json["tools"]); + assert.isNotEmpty(json["tools"]); + if (json["tools"]) { + assert.isDefined(json["tools"][0]); + assert.isTrue(json["tools"][0].type == body.tools[0].type); + assert.isTrue(json["tools"][0].function.name == body.tools[0].function.name); + assert.isTrue(json["tools"][0].function.description == body.tools[0].function.description); + } + }, + { + timeout: 50000 + }); + + it("simple chat test", async function () { const response = await client.path("/chat/completions").post({ body: { diff --git a/sdk/ai/ai-inference-rest/test/public/embeddings.spec.ts b/sdk/ai/ai-inference-rest/test/public/embeddings.spec.ts index d4aac8e487f7..2c21f43e6e33 100644 --- a/sdk/ai/ai-inference-rest/test/public/embeddings.spec.ts +++ b/sdk/ai/ai-inference-rest/test/public/embeddings.spec.ts @@ -4,7 +4,7 @@ import { createRecorder, createModelClient } from "./utils/recordedClient.js"; import { Recorder } from "@azure-tools/test-recorder"; import { assert, beforeEach, afterEach, it, describe } from "vitest"; -import { ModelClient, isUnexpected, EmbeddingsResultOutput } from "../../src/index.js"; +import { ModelClient, GetEmbeddingsBodyParam, isUnexpected, EmbeddingsResultOutput } from "../../src/index.js"; describe("embeddings test suite", () => { let recorder: Recorder; @@ -19,6 +19,52 @@ describe("embeddings test suite", () => { await recorder.stop(); }); + it("embeddings regression test", async function () { + const headers = { "extra-parameters": "allow" }; + const embeddingParams = { + body: { + input: ["first phrase"], + dimensions: 1, + encoding_format: "foo", + input_type: "foo", + model: "foo" + } + } as GetEmbeddingsBodyParam; + + assert.isDefined(embeddingParams); + + const response = await client.path("/embeddings").post({ + headers, + body: embeddingParams.body + }); + const responseHeaders = response.request.headers.toJSON(); + assert.isDefined(responseHeaders); + assert.isDefined(responseHeaders["extra-parameters"]); + assert.isTrue(responseHeaders["extra-parameters"] == headers["extra-parameters"]); + + const request = response.request; + assert.isDefined(request); + + const reqBody = request.body as string; + assert.isDefined(reqBody); + const json = JSON.parse(reqBody); + assert.isDefined(json["input"]); + assert.isArray(json["input"]); + assert.isNotEmpty(json["input"]); + + if (json["input"]) { + assert.isDefined(json["input"][0]); + assert.isTrue(json["input"][0] == embeddingParams.body?.input[0]); + } + assert.isTrue(json["dimensions"] == embeddingParams.body?.dimensions); + assert.isTrue(json["model"] == embeddingParams.body?.model); + assert.isTrue(json["encoding_format"] == embeddingParams.body?.encoding_format); + assert.isTrue(json["input_type"] == embeddingParams.body?.input_type); + }, + { + timeout: 50000 + }); + it("simple embeddings test", async function () { const response = await client.path("/embeddings").post({ body: { diff --git a/sdk/ai/ai-inference-rest/test/public/utils/recordedClient.ts b/sdk/ai/ai-inference-rest/test/public/utils/recordedClient.ts index a3f3582b90f1..cd5a21c6305d 100644 --- a/sdk/ai/ai-inference-rest/test/public/utils/recordedClient.ts +++ b/sdk/ai/ai-inference-rest/test/public/utils/recordedClient.ts @@ -13,7 +13,7 @@ import createClient, { ModelClient } from "../../../src/index.js"; import { DeploymentType } from "../types.js"; const envSetupForPlayback: Record = { - AZURE_AAD_ENDPOINT: "https://endpoint.openai.azure.com/openai/deployments/gpt-4o/", + AZURE_AAD_COMPLETIONS_ENDPOINT: "https://endpoint.openai.azure.com/openai/deployments/gpt-4o/", AZURE_EMBEDDINGS_ENDPOINT: "https://endpoint.openai.azure.com/openai/deployments/text-embedding-3-small/", SUBSCRIPTION_ID: "azure_subscription_id" }; @@ -38,7 +38,7 @@ function getEndpointFromResourceType(resourceType: DeploymentType): string { case "embeddings": return assertEnvironmentVariable("AZURE_EMBEDDINGS_ENDPOINT"); case "completions": - return assertEnvironmentVariable("AZURE_AAD_ENDPOINT"); + return assertEnvironmentVariable("AZURE_AAD_COMPLETIONS_ENDPOINT"); } }