Skip to content

Commit

Permalink
feat[community]: Add chat deployment to IBM chat class (#7633)
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipZmijewski authored Feb 5, 2025
1 parent 2ff134e commit 99829ef
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 160 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
"@gradientai/nodejs-sdk": "^1.2.0",
"@huggingface/inference": "^2.6.4",
"@huggingface/transformers": "^3.2.3",
"@ibm-cloud/watsonx-ai": "^1.3.0",
"@ibm-cloud/watsonx-ai": "^1.4.0",
"@jest/globals": "^29.5.0",
"@lancedb/lancedb": "^0.13.0",
"@langchain/core": "workspace:*",
Expand Down
190 changes: 126 additions & 64 deletions libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
} from "@langchain/core/outputs";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import {
DeploymentsTextChatParams,
RequestCallbacks,
TextChatMessagesTextChatMessageAssistant,
TextChatParameterTools,
Expand Down Expand Up @@ -65,7 +66,13 @@ import {
import { isZodSchema } from "@langchain/core/utils/types";
import { zodToJsonSchema } from "zod-to-json-schema";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
import { WatsonxAuth, WatsonxParams } from "../types/ibm.js";
import {
Neverify,
WatsonxAuth,
WatsonxChatBasicOptions,
WatsonxDeployedParams,
WatsonxParams,
} from "../types/ibm.js";
import {
_convertToolCallIdToMistralCompatible,
authenticateAndSetInstance,
Expand All @@ -80,27 +87,43 @@ export interface WatsonxDeltaStream {
}

export interface WatsonxCallParams
extends Partial<Omit<TextChatParams, "modelId" | "toolChoice">> {
maxRetries?: number;
watsonxCallbacks?: RequestCallbacks;
}
extends Partial<
Omit<TextChatParams, "modelId" | "toolChoice" | "messages" | "headers">
> {}

export interface WatsonxCallDeployedParams extends DeploymentsTextChatParams {}

export interface WatsonxCallOptionsChat
extends Omit<BaseChatModelCallOptions, "stop">,
WatsonxCallParams {
WatsonxCallParams,
WatsonxChatBasicOptions {
promptIndex?: number;
tool_choice?: TextChatParameterTools | string | "auto" | "any";
watsonxCallbacks?: RequestCallbacks;
}

export interface WatsonxCallOptionsDeployedChat
extends WatsonxCallDeployedParams,
WatsonxChatBasicOptions {
promptIndex?: number;
}

type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;

export interface ChatWatsonxInput
extends BaseChatModelParams,
WatsonxParams,
WatsonxCallParams {
streaming?: boolean;
}
WatsonxCallParams,
Neverify<DeploymentsTextChatParams> {}

export interface ChatWatsonxDeployedInput
extends BaseChatModelParams,
WatsonxDeployedParams,
Neverify<TextChatParams> {}

export type ChatWatsonxConstructor = BaseChatModelParams &
Partial<WatsonxParams> &
WatsonxDeployedParams &
WatsonxCallParams;
function _convertToValidToolId(model: string, tool_call_id: string) {
if (model.startsWith("mistralai"))
return _convertToolCallIdToMistralCompatible(tool_call_id);
Expand All @@ -127,7 +150,7 @@ function _convertToolToWatsonxTool(

function _convertMessagesToWatsonxMessages(
messages: BaseMessage[],
model: string
model?: string
): TextChatResultMessage[] {
const getRole = (role: MessageType) => {
switch (role) {
Expand All @@ -151,7 +174,7 @@ function _convertMessagesToWatsonxMessages(
return message.tool_calls
.map((toolCall) => ({
...toolCall,
id: _convertToValidToolId(model, toolCall.id ?? ""),
id: _convertToValidToolId(model ?? "", toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as TextChatToolCall[];
}
Expand All @@ -166,7 +189,7 @@ function _convertMessagesToWatsonxMessages(
role: getRole(message._getType()),
content,
name: message.name,
tool_call_id: _convertToValidToolId(model, message.tool_call_id),
tool_call_id: _convertToValidToolId(model ?? "", message.tool_call_id),
};
}

Expand Down Expand Up @@ -229,7 +252,7 @@ function _watsonxResponseToChatMessage(
function _convertDeltaToMessageChunk(
delta: WatsonxDeltaStream,
rawData: TextChatResponse,
model: string,
model?: string,
usage?: TextChatUsage,
defaultRole?: TextChatMessagesTextChatMessageAssistant.Constants.Role
) {
Expand All @@ -245,7 +268,7 @@ function _convertDeltaToMessageChunk(
} => ({
...toolCall,
index,
id: _convertToValidToolId(model, toolCall.id),
id: _convertToValidToolId(model ?? "", toolCall.id),
type: "function",
})
)
Expand Down Expand Up @@ -298,7 +321,7 @@ function _convertDeltaToMessageChunk(
return new ToolMessageChunk({
content,
additional_kwargs,
tool_call_id: _convertToValidToolId(model, rawToolCalls?.[0].id),
tool_call_id: _convertToValidToolId(model ?? "", rawToolCalls?.[0].id),
});
} else if (role === "function") {
return new FunctionMessageChunk({
Expand Down Expand Up @@ -335,10 +358,12 @@ function _convertToolChoiceToWatsonxToolChoice(
}

export class ChatWatsonx<
CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat
CallOptions extends WatsonxCallOptionsChat =
| WatsonxCallOptionsChat
| WatsonxCallOptionsDeployedChat
>
extends BaseChatModel<CallOptions>
implements ChatWatsonxInput
implements ChatWatsonxConstructor
{
static lc_name() {
return "ChatWatsonx";
Expand Down Expand Up @@ -385,7 +410,7 @@ export class ChatWatsonx<
};
}

model: string;
model?: string;

version = "2024-05-31";

Expand All @@ -399,6 +424,8 @@ export class ChatWatsonx<

projectId?: string;

idOrName?: string;

frequencyPenalty?: number;

logprobs?: boolean;
Expand All @@ -425,37 +452,44 @@ export class ChatWatsonx<

watsonxCallbacks?: RequestCallbacks;

constructor(fields: ChatWatsonxInput & WatsonxAuth) {
constructor(
fields: (ChatWatsonxInput | ChatWatsonxDeployedInput) & WatsonxAuth
) {
super(fields);
if (
(fields.projectId && fields.spaceId) ||
(fields.idOrName && fields.projectId) ||
(fields.spaceId && fields.idOrName)
("projectId" in fields && "spaceId" in fields) ||
("projectId" in fields && "idOrName" in fields) ||
("spaceId" in fields && "idOrName" in fields)
)
throw new Error("Maximum 1 id type can be specified per instance");

if (!fields.projectId && !fields.spaceId && !fields.idOrName)
if (!("projectId" in fields || "spaceId" in fields || "idOrName" in fields))
throw new Error(
"No id specified! At least id of 1 type has to be specified"
);
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
this.frequencyPenalty = fields?.frequencyPenalty;
this.topLogprobs = fields?.topLogprobs;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.presencePenalty = fields?.presencePenalty;
this.topP = fields?.topP;
this.timeLimit = fields?.timeLimit;
this.responseFormat = fields?.responseFormat ?? this.responseFormat;

if ("model" in fields) {
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
this.frequencyPenalty = fields?.frequencyPenalty;
this.topLogprobs = fields?.topLogprobs;
this.maxTokens = fields?.maxTokens ?? this.maxTokens;
this.presencePenalty = fields?.presencePenalty;
this.topP = fields?.topP;
this.timeLimit = fields?.timeLimit;
this.responseFormat = fields?.responseFormat ?? this.responseFormat;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
} else this.idOrName = fields?.idOrName;

this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;
this.serviceUrl = fields?.serviceUrl;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
this.model = fields?.model ?? this.model;
this.version = fields?.version ?? this.version;
this.watsonxCallbacks = fields?.watsonxCallbacks ?? this.watsonxCallbacks;

const {
watsonxAIApikey,
watsonxAIAuthType,
Expand Down Expand Up @@ -486,6 +520,10 @@ export class ChatWatsonx<
}

invocationParams(options: this["ParsedCallOptions"]) {
const { signal, promptIndex, ...rest } = options;
if (this.idOrName && Object.keys(rest).length > 0)
throw new Error("Options cannot be provided to a deployed model");

const params = {
maxTokens: options.maxTokens ?? this.maxTokens,
temperature: options?.temperature ?? this.temperature,
Expand Down Expand Up @@ -521,10 +559,16 @@ export class ChatWatsonx<
} as CallOptions);
}

scopeId() {
if (this.projectId)
scopeId():
| { idOrName: string }
| { projectId: string; modelId: string }
| { spaceId: string; modelId: string } {
if (this.projectId && this.model)
return { projectId: this.projectId, modelId: this.model };
else return { spaceId: this.spaceId, modelId: this.model };
else if (this.spaceId && this.model)
return { spaceId: this.spaceId, modelId: this.model };
else if (this.idOrName) return { idOrName: this.idOrName };
else throw new Error("No scope id provided");
}

async completionWithRetry<T>(
Expand Down Expand Up @@ -595,23 +639,30 @@ export class ChatWatsonx<
.map(([_, value]) => value);
return { generations, llmOutput: { tokenUsage } };
} else {
const params = {
...this.invocationParams(options),
...this.scopeId(),
};
const params = this.invocationParams(options);
const scopeId = this.scopeId();
const watsonxCallbacks = this.invocationCallbacks(options);
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const callback = () =>
this.service.textChat(
{
...params,
messages: watsonxMessages,
},
watsonxCallbacks
);
"idOrName" in scopeId
? this.service.deploymentsTextChat(
{
...scopeId,
messages: watsonxMessages,
},
watsonxCallbacks
)
: this.service.textChat(
{
...params,
...scopeId,
messages: watsonxMessages,
},
watsonxCallbacks
);
const { result } = await this.completionWithRetry(callback, options);
const generations: ChatGeneration[] = [];
for (const part of result.choices) {
Expand Down Expand Up @@ -646,21 +697,33 @@ export class ChatWatsonx<
options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const params = { ...this.invocationParams(options), ...this.scopeId() };
const params = this.invocationParams(options);
const scopeId = this.scopeId();
const watsonxMessages = _convertMessagesToWatsonxMessages(
messages,
this.model
);
const watsonxCallbacks = this.invocationCallbacks(options);
const callback = () =>
this.service.textChatStream(
{
...params,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);
"idOrName" in scopeId
? this.service.deploymentsTextChatStream(
{
...scopeId,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
)
: this.service.textChatStream(
{
...params,
...scopeId,
messages: watsonxMessages,
returnObject: true,
},
watsonxCallbacks
);

const stream = await this.completionWithRetry(callback, options);
let defaultRole;
let usage: TextChatUsage | undefined;
Expand Down Expand Up @@ -707,7 +770,6 @@ export class ChatWatsonx<
if (message === null || (!delta.content && !delta.tool_calls)) {
continue;
}

const generationChunk = new ChatGenerationChunk({
message,
text: delta.content ?? "",
Expand Down
Loading

0 comments on commit 99829ef

Please sign in to comment.