Skip to content

Commit

Permalink
Unified mistral import naming scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
CarterMorris committed Nov 5, 2024
1 parent 4150202 commit 284c2fc
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 50 deletions.
35 changes: 17 additions & 18 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import {
ChatCompletionRequestToolChoice as MistralAIToolChoice,
Messages as MistralAIMessage,
} from "@mistralai/mistralai/models/components/chatcompletionrequest.js";
import { ContentChunk, ContentChunk$ } from "@mistralai/mistralai/models/components/contentchunk.js";
import { ContentChunk as MistralAIContentChunk} from "@mistralai/mistralai/models/components/contentchunk.js";
import { Tool as MistralAITool } from "@mistralai/mistralai/models/components/tool.js";
import { ToolCall as MistralAIToolCall } from "@mistralai/mistralai/models/components/toolcall.js";
import { ChatCompletionStreamRequest as MistralChatCompletionStreamRequest } from "@mistralai/mistralai/models/components/chatcompletionstreamrequest.js";
import { ChatCompletionStreamRequest as MistralAIChatCompletionStreamRequest } from "@mistralai/mistralai/models/components/chatcompletionstreamrequest.js";
import { UsageInfo as MistralAITokenUsage } from "@mistralai/mistralai/models/components/usageinfo.js";
import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js";
import { ChatCompletionResponse as MistralChatCompletionResponse } from "@mistralai/mistralai/models/components/chatcompletionresponse.js";
import { HTTPClient } from "@mistralai/mistralai/lib/http.js";
import { ChatCompletionResponse as MistralAIChatCompletionResponse } from "@mistralai/mistralai/models/components/chatcompletionresponse.js";
import { HTTPClient as MistralAIHTTPClient} from "@mistralai/mistralai/lib/http.js";
import {
MessageType,
type BaseMessage,
Expand Down Expand Up @@ -164,7 +164,7 @@ export interface ChatMistralAIInput
/**
*
*/
httpClient?: HTTPClient | undefined;
httpClient?: MistralAIHTTPClient | undefined;

}

Expand All @@ -188,22 +188,22 @@ function convertMessagesToMistralMessages(
}
};

const getContent = (content: MessageContent, role: MessageType): string | ContentChunk[] => {
const getContent = (content: MessageContent, role: MessageType): string | MistralAIContentChunk[] => {
const mistralRole = getRole(role);

const _generateContentChunk = (complex: any, role: string): ContentChunk => {
const _generateContentChunk = (complex: any, role: string): MistralAIContentChunk => {
if (complex.type === "image_url" && role === "user") {
return {
type: complex.type,
imageUrl: complex?.image_url
} as ContentChunk;
} as MistralAIContentChunk;
}

if (complex.type === "text" && (role === "user" || role === "system")){
return {
type: complex.type,
text: complex?.text
} as ContentChunk;
} as MistralAIContentChunk;
}

throw new Error(
Expand Down Expand Up @@ -286,7 +286,7 @@ function convertMessagesToMistralMessages(
}

function mistralAIResponseToChatMessage(
choice: NonNullable<MistralChatCompletionResponse["choices"]>[0],
choice: NonNullable<MistralAIChatCompletionResponse["choices"]>[0],
usage?: MistralAITokenUsage
): BaseMessage {
const { message } = choice;
Expand Down Expand Up @@ -821,7 +821,7 @@ export class ChatMistralAI<
* Optional custom HTTP client to manage API requests
* Allows users to add custom fetch implementations, hooks, as well as error and response processing.
*/
httpClient?: HTTPClient;
httpClient?: MistralAIHTTPClient;

constructor(fields?: ChatMistralAIInput) {
super(fields ?? {});
Expand All @@ -840,7 +840,6 @@ export class ChatMistralAI<
this.safePrompt = fields?.safePrompt ?? this.safePrompt;
this.randomSeed = fields?.seed ?? fields?.randomSeed ?? this.seed;
this.seed = this.randomSeed;
this.httpClient = fields?.httpClient;
this.modelName = fields?.model ?? fields?.modelName ?? this.model;
this.model = this.modelName;
this.streamUsage = fields?.streamUsage ?? this.streamUsage;
Expand Down Expand Up @@ -902,32 +901,32 @@ export class ChatMistralAI<
* @returns {Promise<MistralAIChatCompletionResult | AsyncGenerator<MistralAIChatCompletionResult>>} The response from the MistralAI API.
*/
async completionWithRetry(
input: MistralChatCompletionStreamRequest,
input: MistralAIChatCompletionStreamRequest,
streaming: true
): Promise<AsyncIterable<MistralAIChatCompletionEvent>>;

async completionWithRetry(
input: MistralAIChatCompletionRequest,
streaming: false
): Promise<MistralChatCompletionResponse>;
): Promise<MistralAIChatCompletionResponse>;

async completionWithRetry(
input: MistralAIChatCompletionRequest | MistralChatCompletionStreamRequest,
input: MistralAIChatCompletionRequest | MistralAIChatCompletionStreamRequest,
streaming: boolean
): Promise<
MistralChatCompletionResponse | AsyncIterable<MistralAIChatCompletionEvent>
MistralAIChatCompletionResponse | AsyncIterable<MistralAIChatCompletionEvent>
> {
const client = new MistralClient({
apiKey: this.apiKey,
serverURL: this.serverURL,
// If httpClient exists, pass it into constructor
...( this.httpClient ? {httpCLient: this.httpClient} : {})
...( this.httpClient ? {httpClient: this.httpClient} : {})
});

return this.caller.call(async () => {
try {
let res:
| MistralChatCompletionResponse
| MistralAIChatCompletionResponse
| AsyncIterable<MistralAIChatCompletionEvent>;
if (streaming) {
res = await client.chat.stream(input);
Expand Down
16 changes: 8 additions & 8 deletions libs/langchain-mistralai/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { EmbeddingRequest as MistralAIEmbeddingsRequest} from "@mistralai/mistralai/src/models/components/embeddingrequest.js";
import { EmbeddingResponse as MistralAIEmbeddingsResult} from "@mistralai/mistralai/src/models/components/embeddingresponse.js";
import { HTTPClient } from "@mistralai/mistralai/lib/http.js";
import { EmbeddingResponse as MistralAIEmbeddingsResponse} from "@mistralai/mistralai/src/models/components/embeddingresponse.js";
import { HTTPClient as MistralAIHTTPClient} from "@mistralai/mistralai/lib/http.js";

/**
* Interface for MistralAIEmbeddings parameters. Extends EmbeddingsParams and
Expand Down Expand Up @@ -50,7 +50,7 @@ export interface MistralAIEmbeddingsParams extends EmbeddingsParams {
* Optional custom HTTP client to manage API requests
* Allows users to add custom fetch implementations, hooks, as well as error and response processing.
*/
httpCLient?: HTTPClient;
httpClient?: MistralAIHTTPClient;

}

Expand All @@ -75,7 +75,7 @@ export class MistralAIEmbeddings

serverURL?: string;

httpClient?: HTTPClient;
httpClient?: MistralAIHTTPClient;

constructor(fields?: Partial<MistralAIEmbeddingsParams>) {
super(fields ?? {});
Expand All @@ -90,7 +90,7 @@ export class MistralAIEmbeddings
this.encodingFormat = fields?.encodingFormat ?? this.encodingFormat;
this.batchSize = fields?.batchSize ?? this.batchSize;
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.httpClient = fields?.httpCLient ?? undefined;
this.httpClient = fields?.httpClient ?? undefined;
}

/**
Expand Down Expand Up @@ -140,17 +140,17 @@ export class MistralAIEmbeddings
* embeddings. Handles the retry logic and returns the response from the
* API.
* @param {string | Array<string>} inputs Text to send to the MistralAI API.
* @returns {Promise<MistralAIEmbeddingsResult>} Promise that resolves to the response from the API.
* @returns {Promise<MistralAIEmbeddingsResponse>} Promise that resolves to the response from the API.
*/
private async embeddingWithRetry(
inputs: string | Array<string>
): Promise<MistralAIEmbeddingsResult> {
): Promise<MistralAIEmbeddingsResponse> {
const { Mistral } = await this.imports();
const client = new Mistral({
apiKey: this.apiKey,
serverURL: this.serverURL,
// If httpClient exists, pass it into constructor
...( this.httpClient ? {httpCLient: this.httpClient} : {})
...( this.httpClient ? {httpClient: this.httpClient} : {})
});
let embeddingsRequest: MistralAIEmbeddingsRequest = {
model: this.model,
Expand Down
48 changes: 24 additions & 24 deletions libs/langchain-mistralai/src/llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { BaseLLMParams, LLM } from "@langchain/core/language_models/llms";
import { type BaseLanguageModelCallOptions } from "@langchain/core/language_models/base";
import { GenerationChunk, LLMResult } from "@langchain/core/outputs";
import { FIMCompletionRequest as MistralFIMCompletionRequest } from "@mistralai/mistralai/models/components/fimcompletionrequest.js";
import { FIMCompletionStreamRequest as MistralFIMCompletionStreamRequest} from "@mistralai/mistralai/models/components/fimcompletionstreamrequest.js";
import { FIMCompletionResponse as MistralFIMCompletionResponse } from "@mistralai/mistralai/models/components/fimcompletionresponse.js";
import { ChatCompletionChoice as MistralChatCompletionChoice} from "@mistralai/mistralai/models/components/chatcompletionchoice.js";
import { CompletionEvent as MistralChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js";
import { HTTPClient } from "@mistralai/mistralai/lib/http.js";
import { FIMCompletionRequest as MistralAIFIMCompletionRequest } from "@mistralai/mistralai/models/components/fimcompletionrequest.js";
import { FIMCompletionStreamRequest as MistralAIFIMCompletionStreamRequest} from "@mistralai/mistralai/models/components/fimcompletionstreamrequest.js";
import { FIMCompletionResponse as MistralAIFIMCompletionResponse } from "@mistralai/mistralai/models/components/fimcompletionresponse.js";
import { ChatCompletionChoice as MistralAIChatCompletionChoice} from "@mistralai/mistralai/models/components/chatcompletionchoice.js";
import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js";
import { HTTPClient as MistralAIHTTPClient} from "@mistralai/mistralai/lib/http.js";
import { getEnvironmentVariable } from "@langchain/core/utils/env";
import { chunkArray } from "@langchain/core/utils/chunk_array";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
Expand Down Expand Up @@ -73,7 +73,7 @@ export interface MistralAIInput extends BaseLLMParams {
* Optional custom HTTP client to manage API requests
* Allows users to add custom fetch implementations, hooks, as well as error and response processing.
*/
httpClient?: HTTPClient;
httpClient?: MistralAIHTTPClient;
}

/**
Expand Down Expand Up @@ -109,7 +109,7 @@ export class MistralAI

maxConcurrency?: number;

httpClient?: HTTPClient;
httpClient?: MistralAIHTTPClient;

constructor(fields?: MistralAIInput) {
super(fields ?? {});
Expand Down Expand Up @@ -148,7 +148,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA

invocationParams(
options: this["ParsedCallOptions"]
): Omit<MistralFIMCompletionRequest | MistralFIMCompletionStreamRequest, "prompt"> {
): Omit<MistralAIFIMCompletionRequest | MistralAIFIMCompletionStreamRequest, "prompt"> {
return {
model: this.model,
suffix: options.suffix,
Expand Down Expand Up @@ -184,22 +184,22 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
runManager?: CallbackManagerForLLMRun
): Promise<LLMResult> {
const subPrompts = chunkArray(prompts, this.batchSize);
const choices: MistralChatCompletionChoice[][] = [];
const choices: MistralAIChatCompletionChoice[][] = [];

const params = this.invocationParams(options);

for (let i = 0; i < subPrompts.length; i += 1) {
const data = await (async () => {
if (this.streaming) {
const responseData: Array<
{ choices: MistralChatCompletionChoice[] } & Partial<
Omit<MistralFIMCompletionResponse, "choices">
{ choices: MistralAIChatCompletionChoice[] } & Partial<
Omit<MistralAIFIMCompletionResponse, "choices">
>
> = [];
for (let x = 0; x < subPrompts[i].length; x += 1) {
const choices: MistralChatCompletionChoice[] = [];
const choices: MistralAIChatCompletionChoice[] = [];
let response:
| Omit<MistralFIMCompletionResponse, "choices" | "usage">
| Omit<MistralAIFIMCompletionResponse, "choices" | "usage">
| undefined;
const stream = await this.completionWithRetry(
{
Expand Down Expand Up @@ -253,7 +253,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
}
return responseData;
} else {
const responseData: Array<MistralFIMCompletionResponse> = [];
const responseData: Array<MistralAIFIMCompletionResponse> = [];
for (let x = 0; x < subPrompts[i].length; x += 1) {
const res = await this.completionWithRetry(
{
Expand Down Expand Up @@ -286,23 +286,23 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
}

async completionWithRetry(
request: MistralFIMCompletionRequest,
request: MistralAIFIMCompletionRequest,
options: this["ParsedCallOptions"],
stream: false
): Promise<MistralFIMCompletionResponse>;
): Promise<MistralAIFIMCompletionResponse>;

async completionWithRetry(
request: MistralFIMCompletionStreamRequest,
request: MistralAIFIMCompletionStreamRequest,
options: this["ParsedCallOptions"],
stream: true
): Promise<AsyncIterable<MistralChatCompletionEvent>>;
): Promise<AsyncIterable<MistralAIChatCompletionEvent>>;

async completionWithRetry(
request: MistralFIMCompletionRequest | MistralFIMCompletionStreamRequest,
request: MistralAIFIMCompletionRequest | MistralAIFIMCompletionStreamRequest,
options: this["ParsedCallOptions"],
stream: boolean
): Promise<
MistralFIMCompletionResponse | AsyncIterable<MistralChatCompletionEvent>
MistralAIFIMCompletionResponse | AsyncIterable<MistralAIChatCompletionEvent>
> {
const { Mistral } = await this.imports();
const caller = new AsyncCaller({
Expand All @@ -314,7 +314,7 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
serverURL: this.serverURL,
timeoutMs: options.timeout,
// If httpClient exists, pass it into constructor
...( this.httpClient ? {httpCLient: this.httpClient} : {})
...( this.httpClient ? {httpClient: this.httpClient} : {})
});
return caller.callWithOptions(
{
Expand All @@ -323,8 +323,8 @@ Either provide one via the "apiKey" field in the constructor, or set the "MISTRA
async () => {
try {
let res:
| MistralFIMCompletionResponse
| AsyncIterable<MistralChatCompletionEvent>;
| MistralAIFIMCompletionResponse
| AsyncIterable<MistralAIChatCompletionEvent>;
if (stream) {
res = await client.fim.stream(request);
} else {
Expand Down

0 comments on commit 284c2fc

Please sign in to comment.