Skip to content

Commit

Permalink
fix(openai): Prevent extra constructor params from being serialized, …
Browse files Browse the repository at this point in the history
…add script (#7669)
  • Loading branch information
jacoblee93 authored Feb 7, 2025
1 parent 1c1e6cd commit c77a8a5
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 2 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"zod-to-json-schema": "^3.22.3"
},
"peerDependencies": {
"@langchain/core": ">=0.3.29 <0.4.0"
"@langchain/core": ">=0.3.39 <0.4.0"
},
"devDependencies": {
"@azure/identity": "^4.2.1",
Expand Down
15 changes: 15 additions & 0 deletions libs/langchain-openai/src/azure/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,21 @@ export class AzureChatOpenAI extends ChatOpenAI {
};
}

get lc_serializable_keys(): string[] {
return [
...super.lc_serializable_keys,
"azureOpenAIApiKey",
"azureOpenAIApiVersion",
"azureOpenAIBasePath",
"azureOpenAIEndpoint",
"azureOpenAIApiInstanceName",
"azureOpenAIApiDeploymentName",
"deploymentName",
"openAIApiKey",
"openAIApiVersion",
];
}

constructor(
fields?: Partial<OpenAIChatInput> &
Partial<AzureOpenAIInput> & {
Expand Down
39 changes: 39 additions & 0 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,45 @@ export class ChatOpenAI<
};
}

get lc_serializable_keys(): string[] {
return [
"configuration",
"logprobs",
"topLogprobs",
"prefixMessages",
"supportsStrictToolCalling",
"modalities",
"audio",
"reasoningEffort",
"temperature",
"maxTokens",
"topP",
"frequencyPenalty",
"presencePenalty",
"n",
"logitBias",
"user",
"streaming",
"streamUsage",
"modelName",
"model",
"modelKwargs",
"stop",
"stopSequences",
"timeout",
"openAIApiKey",
"apiKey",
"cache",
"maxConcurrency",
"maxRetries",
"verbose",
"callbacks",
"tags",
"metadata",
"disableStreaming",
];
}

temperature?: number;

topP?: number;
Expand Down
14 changes: 14 additions & 0 deletions libs/langchain-openai/src/tests/azure/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ test("Test Azure OpenAI serialization from azure endpoint", async () => {
);
});

test("Test Azure OpenAI serialization does not pass along extra params", async () => {
const chat = new AzureChatOpenAI({
azureOpenAIEndpoint: "https://foobar.openai.azure.com/",
azureOpenAIApiDeploymentName: "gpt-4o",
azureOpenAIApiVersion: "2024-08-01-preview",
azureOpenAIApiKey: "foo",
extraParam: "extra",
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
expect(JSON.stringify(chat)).toEqual(
`{"lc":1,"type":"constructor","id":["langchain","chat_models","azure_openai","AzureChatOpenAI"],"kwargs":{"azure_endpoint":"https://foobar.openai.azure.com/","deployment_name":"gpt-4o","openai_api_version":"2024-08-01-preview","azure_open_ai_api_key":{"lc":1,"type":"secret","id":["AZURE_OPENAI_API_KEY"]}}}`
);
});

test("Test Azure OpenAI serialization from base path", async () => {
const chat = new AzureChatOpenAI({
azureOpenAIBasePath:
Expand Down
12 changes: 12 additions & 0 deletions libs/langchain-openai/src/tests/chat_models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,15 @@ describe("strict tool calling", () => {
}
});
});

test("Test OpenAI serialization doesn't pass along extra params", async () => {
const chat = new ChatOpenAI({
apiKey: "test-key",
model: "o3-mini",
somethingUnexpected: true,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);
expect(JSON.stringify(chat)).toEqual(
`{"lc":1,"type":"constructor","id":["langchain","chat_models","openai","ChatOpenAI"],"kwargs":{"openai_api_key":{"lc":1,"type":"secret","id":["OPENAI_API_KEY"]},"model":"o3-mini"}}`
);
});
2 changes: 2 additions & 0 deletions libs/langchain-scripts/bin/extract_serializable_fields.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/usr/bin/env node
import "../dist/extract_serializable_fields.js";
1 change: 1 addition & 0 deletions libs/langchain-scripts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
},
"homepage": "https://github.com/langchain-ai/langchainjs/tree/main/libs/langchain-scripts/",
"bin": {
"extract_serializable_fields": "bin/extract_serializable_fields.js",
"filter_spam_comment": "bin/filter_spam_comment.js",
"lc_build": "bin/build.js",
"notebook_validate": "bin/validate_notebook.js"
Expand Down
81 changes: 81 additions & 0 deletions libs/langchain-scripts/src/extract_serializable_fields.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import ts from "typescript";
import * as path from "path";

function extractConstructorParams(
sourceFile: string,
className: string
): { type: string; fields: string[] } | null {
const absolutePath = path.resolve(sourceFile);
const program = ts.createProgram([absolutePath], {
target: ts.ScriptTarget.ES2015,
module: ts.ModuleKind.CommonJS,
});
const source = program.getSourceFile(absolutePath);
const typeChecker = program.getTypeChecker();

if (!source) {
console.error(`Could not find source file: ${absolutePath}`);
return null;
}

let result: { type: string; fields: string[] } | null = null;

function visit(node: ts.Node) {
if (ts.isClassDeclaration(node) && node.name?.text === className) {
node.members.forEach((member) => {
if (
ts.isConstructorDeclaration(member) &&
member.parameters.length > 0
) {
const firstParam = member.parameters[0];
const type = typeChecker.getTypeAtLocation(firstParam);
const typeString = typeChecker.typeToString(type);

// Get properties of the type
const fields: string[] = [];
type.getProperties().forEach((prop) => {
// Get the type of the property
const propType = typeChecker.getTypeOfSymbolAtLocation(
prop,
firstParam
);
// Only include non-function properties that don't start with __
if (
!prop.getName().startsWith("__") &&
prop.getName() !== "callbackManager" &&
!(propType.getCallSignatures().length > 0)
) {
fields.push(prop.getName());
}
});

result = {
type: typeString,
fields,
};
}
});
}
ts.forEachChild(node, visit);
}

visit(source);
return result;
}
const filepath = process.argv[2];
const className = process.argv[3];

if (!filepath || !className) {
console.error(
"Usage: node extract_serializable_fields.ts <filepath> <className>"
);
process.exit(1);
}

const results = extractConstructorParams(filepath, className);

if (results?.fields?.length) {
console.log(JSON.stringify(results?.fields, null, 2));
} else {
console.error("No constructor parameters found");
}
3 changes: 2 additions & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -13035,7 +13035,7 @@ __metadata:
zod: ^3.22.4
zod-to-json-schema: ^3.22.3
peerDependencies:
"@langchain/core": ">=0.3.29 <0.4.0"
"@langchain/core": ">=0.3.39 <0.4.0"
languageName: unknown
linkType: soft

Expand Down Expand Up @@ -13200,6 +13200,7 @@ __metadata:
tsx: ^4.16.2
typescript: ^5.4.5
bin:
extract_serializable_fields: bin/extract_serializable_fields.js
filter_spam_comment: bin/filter_spam_comment.js
lc_build: bin/build.js
notebook_validate: bin/validate_notebook.js
Expand Down

0 comments on commit c77a8a5

Please sign in to comment.