Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Output Parsers for agents #802

Merged
merged 15 commits into from
Apr 17, 2023
2 changes: 1 addition & 1 deletion examples/src/agents/chat_mrkl.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { ChatOpenAI } from "langchain/chat_models/openai";
import { initializeAgentExecutor } from "langchain/agents";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { SerpAPI } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";

Expand Down
6 changes: 2 additions & 4 deletions examples/src/agents/custom_agent.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { AgentExecutor, ZeroShotAgent } from "langchain/agents";
import { LLMChain } from "langchain/chains";
import { OpenAI } from "langchain/llms/openai";
import { ZeroShotAgent, AgentExecutor } from "langchain/agents";
import { SerpAPI } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";
import { LLMChain } from "langchain/chains";

export const run = async () => {
const model = new OpenAI({ temperature: 0 });
Expand All @@ -29,8 +29,6 @@ Question: {input}

const prompt = ZeroShotAgent.createPrompt(tools, createPromptArgs);

console.log(prompt.template);

const llmChain = new LLMChain({ llm: model, prompt });
const agent = new ZeroShotAgent({
llmChain,
Expand Down
10 changes: 5 additions & 5 deletions examples/src/agents/custom_llm_agent_chat.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import {
LLMSingleActionAgent,
AgentActionOutputParser,
AgentExecutor,
LLMSingleActionAgent,
} from "langchain/agents";
import { LLMChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models/openai";
import {
BaseChatPromptTemplate,
BasePromptTemplate,
SerializedBasePromptTemplate,
renderTemplate,
BaseChatPromptTemplate,
} from "langchain/prompts";
import {
InputValues,
PartialValues,
AgentStep,
AgentAction,
AgentFinish,
AgentStep,
BaseChatMessage,
HumanChatMessage,
InputValues,
PartialValues,
} from "langchain/schema";
import { SerpAPI, Tool } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";
Expand Down
2 changes: 1 addition & 1 deletion examples/src/agents/mrkl.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { OpenAI } from "langchain/llms/openai";
import { initializeAgentExecutor } from "langchain/agents";
import { OpenAI } from "langchain/llms/openai";
import { SerpAPI } from "langchain/tools";
import { Calculator } from "langchain/tools/calculator";

Expand Down
47 changes: 21 additions & 26 deletions langchain/src/agents/agent.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import { BaseLanguageModel } from "../base_language/index.js";
import { CallbackManager } from "../callbacks/base.js";
import { LLMChain } from "../chains/llm_chain.js";
import { BasePromptTemplate } from "../prompts/base.js";
import {
AgentAction,
AgentFinish,
AgentStep,
ChainValues,
BaseChatMessage,
ChainValues,
} from "../schema/index.js";
import { Tool } from "../tools/base.js";
import {
AgentActionOutputParser,
AgentInput,
SerializedAgent,
StoppingMethod,
AgentActionOutputParser,
} from "./types.js";
import { Tool } from "../tools/base.js";

class ParseError extends Error {
output: string;
Expand Down Expand Up @@ -160,6 +161,11 @@ export class LLMSingleActionAgent extends BaseSingleActionAgent {
}
}

export interface AgentArgs {
outputParser?: AgentActionOutputParser;
callbackManager?: CallbackManager;
}

/**
* Class responsible for calling a language model and deciding an action.
*
Expand All @@ -170,6 +176,8 @@ export class LLMSingleActionAgent extends BaseSingleActionAgent {
export abstract class Agent extends BaseSingleActionAgent {
llmChain: LLMChain;

outputParser: AgentActionOutputParser;

private _allowedTools?: string[] = undefined;

get allowedTools(): string[] | undefined {
Expand All @@ -184,15 +192,7 @@ export abstract class Agent extends BaseSingleActionAgent {
super();
this.llmChain = input.llmChain;
this._allowedTools = input.allowedTools;
}

/**
* Extract tool and tool input from LLM output.
*/
async extractToolAndInput(
_input: string
): Promise<{ tool: string; input: string } | null> {
throw new Error("Not implemented");
this.outputParser = input.outputParser;
}

/**
Expand All @@ -210,6 +210,13 @@ export abstract class Agent extends BaseSingleActionAgent {
*/
abstract _agentType(): string;

/**
* Get the default output parser for this agent.
*/
static getDefaultOutputParser(): AgentActionOutputParser {
throw new Error("Not implemented");
}

/**
* Create a prompt for this class
*
Expand All @@ -231,7 +238,7 @@ export abstract class Agent extends BaseSingleActionAgent {
_llm: BaseLanguageModel,
_tools: Tool[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
_args?: Record<string, any>
_args?: AgentArgs
): Agent {
throw new Error("Not implemented");
}
Expand Down Expand Up @@ -284,19 +291,7 @@ export abstract class Agent extends BaseSingleActionAgent {
}

const output = await this.llmChain.predict(newInputs);
const parsed = await this.extractToolAndInput(output);
if (!parsed) {
throw new ParseError(`Invalid output: ${output}`, output);
}
const action = {
tool: parsed.tool,
toolInput: parsed.input,
log: output,
};
if (action.tool === this.finishToolName()) {
return { returnValues: { output: action.toolInput }, log: action.log };
}
return action;
return this.outputParser.parse(output);
}

/**
Expand Down
61 changes: 27 additions & 34 deletions langchain/src/agents/chat/index.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { BaseLanguageModel } from "../../base_language/index.js";
import { LLMChain } from "../../chains/llm_chain.js";
import { Agent } from "../agent.js";
import {
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
} from "../../prompts/chat.js";
import { PREFIX, SUFFIX, FORMAT_INSTRUCTIONS } from "./prompt.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import { AgentStep } from "../../schema/index.js";
import { AgentInput } from "../types.js";
import { Tool } from "../../tools/base.js";

const FINAL_ANSWER_ACTION = "Final Answer:";
import { Optional } from "../../types/type-utils.js";
import { Agent, AgentArgs } from "../agent.js";
import { AgentInput } from "../types.js";
import { ChatAgentOutputParser } from "./outputParser.js";
import { FORMAT_INSTRUCTIONS, PREFIX, SUFFIX } from "./prompt.js";

export type CreatePromptArgs = {
/** String to put after the list of tools. */
Expand All @@ -22,15 +22,17 @@ export type CreatePromptArgs = {
inputVariables?: string[];
};

type ZeroShotAgentInput = AgentInput;
type ChatAgentInput = Optional<AgentInput, "outputParser">;

/**
* Agent for the MRKL chain.
* @augments Agent
*/
export class ChatAgent extends Agent {
constructor(input: ZeroShotAgentInput) {
super(input);
constructor(input: ChatAgentInput) {
const outputParser =
input?.outputParser ?? ChatAgent.getDefaultOutputParser();
super({ ...input, outputParser });
}

_agentType() {
Expand Down Expand Up @@ -59,6 +61,10 @@ export class ChatAgent extends Agent {
}
}

static getDefaultOutputParser() {
return new ChatAgentOutputParser();
}

constructScratchPad(steps: AgentStep[]): string {
const agentScratchpad = super.constructScratchPad(steps);
if (agentScratchpad) {
Expand Down Expand Up @@ -93,35 +99,22 @@ export class ChatAgent extends Agent {
static fromLLMAndTools(
llm: BaseLanguageModel,
tools: Tool[],
args?: CreatePromptArgs
args?: CreatePromptArgs & AgentArgs
) {
ChatAgent.validateTools(tools);
const prompt = ChatAgent.createPrompt(tools, args);
const chain = new LLMChain({ prompt, llm });
const chain = new LLMChain({
prompt,
llm,
callbackManager: args?.callbackManager,
});
const outputParser =
args?.outputParser ?? ChatAgent.getDefaultOutputParser();

return new ChatAgent({
llmChain: chain,
outputParser,
allowedTools: tools.map((t) => t.name),
});
}

async extractToolAndInput(
text: string
): Promise<{ tool: string; input: string } | null> {
if (text.includes(FINAL_ANSWER_ACTION)) {
const parts = text.split(FINAL_ANSWER_ACTION);
const input = parts[parts.length - 1].trim();
return { tool: "Final Answer", input };
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [_, action, __] = text.split("```");
try {
const response = JSON.parse(action.trim());
return { tool: response.action, input: response.action_input };
} catch {
throw new Error(
`Unable to parse JSON response from chat agent.\n\n${text}`
);
}
}
}
34 changes: 34 additions & 0 deletions langchain/src/agents/chat/outputParser.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { AgentActionOutputParser } from "../../agents/types.js";
import { AgentFinish } from "../../schema/index.js";
import { FORMAT_INSTRUCTIONS } from "./prompt.js";

export const FINAL_ANSWER_ACTION = "Final Answer:";
export class ChatAgentOutputParser extends AgentActionOutputParser {
async parse(text: string) {
if (text.includes(FINAL_ANSWER_ACTION)) {
const parts = text.split(FINAL_ANSWER_ACTION);
const output = parts[parts.length - 1].trim();
return { returnValues: { output }, log: text } satisfies AgentFinish;
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars

const [_, action, __] = text.split(/```(?:json)?/g);
try {
const response = JSON.parse(action.trim());
return {
tool: response.action,
toolInput: response.action_input,
log: text,
};
} catch {
throw new Error(
`Unable to parse JSON response from chat agent.\n\n${text}`
);
}
}

getFormatInstructions(): string {
return FORMAT_INSTRUCTIONS;
}
}
Loading