Skip to content

Commit

Permalink
feat: Add RefineChain (#607)
Browse files Browse the repository at this point in the history
* init refine_chain

* formatting

* export refineChain

* added QA refine prompts

* added loadQARefineChain

* added imports

* added test

* formatting for test

* added test

* Updated prompts and changed to predict

* updated initial input fn

* Fix serde afetr changes in main

---------

Co-authored-by: Nuno Campos <nuno@boringbits.io>
  • Loading branch information
RohitMidha23 and nfcampos authored Apr 11, 2023
1 parent 2adce82 commit 38ed068
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 4 deletions.
156 changes: 155 additions & 1 deletion langchain/src/chains/combine_docs_chain.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import type {
SerializedStuffDocumentsChain,
SerializedMapReduceDocumentsChain,
SerializedRefineDocumentsChain,
} from "./serde.js";
import { BaseChain } from "./base.js";
import { LLMChain } from "./llm_chain.js";

import { Document } from "../document.js";

import { ChainValues } from "../schema/index.js";
import { BasePromptTemplate } from "../prompts/base.js";
import { PromptTemplate } from "../prompts/prompt.js";

export interface StuffDocumentsChainInput {
/** LLM Wrapper to use after formatting documents */
Expand Down Expand Up @@ -96,7 +99,7 @@ export interface MapReduceDocumentsChainInput extends StuffDocumentsChainInput {
}

/**
* Chain that combines documents by stuffing into context.
* Combine documents by mapping a chain over them, then combining results.
* @augments BaseChain
* @augments StuffDocumentsChainInput
*/
Expand Down Expand Up @@ -215,3 +218,154 @@ export class MapReduceDocumentsChain
};
}
}

export interface RefineDocumentsChainInput extends StuffDocumentsChainInput {
refineLLMChain: LLMChain;
documentPrompt: BasePromptTemplate;
}

/**
* Combine documents by doing a first pass and then refining on more documents.
* @augments BaseChain
* @augments RefineDocumentsChainInput
*/
export class RefineDocumentsChain
extends BaseChain
implements RefineDocumentsChainInput
{
llmChain: LLMChain;

inputKey = "input_documents";

outputKey = "output_text";

documentVariableName = "context";

initialResponseName = "existing_answer";

refineLLMChain: LLMChain;

get defaultDocumentPrompt(): BasePromptTemplate {
return new PromptTemplate({
inputVariables: ["page_content"],
template: "{page_content}",
});
}

documentPrompt = this.defaultDocumentPrompt;

get inputKeys() {
return [this.inputKey, ...this.refineLLMChain.inputKeys];
}

constructor(fields: {
llmChain: LLMChain;
refineLLMChain: LLMChain;
inputKey?: string;
outputKey?: string;
documentVariableName?: string;
documentPrompt?: BasePromptTemplate;
initialResponseName?: string;
}) {
super();
this.llmChain = fields.llmChain;
this.refineLLMChain = fields.refineLLMChain;
this.documentVariableName =
fields.documentVariableName ?? this.documentVariableName;
this.inputKey = fields.inputKey ?? this.inputKey;
this.documentPrompt = fields.documentPrompt ?? this.documentPrompt;
this.initialResponseName =
fields.initialResponseName ?? this.initialResponseName;
}

_constructInitialInputs(doc: Document, rest: Record<string, unknown>) {
const baseInfo: Record<string, unknown> = {
page_content: doc.pageContent,
...doc.metadata,
};
const documentInfo: Record<string, unknown> = {};
this.documentPrompt.inputVariables.forEach((value) => {
documentInfo[value] = baseInfo[value];
});

const baseInputs: Record<string, unknown> = {
[this.documentVariableName]: this.documentPrompt.format({
...documentInfo,
}),
};
const inputs = { ...baseInputs, ...rest };
return inputs;
}

_constructRefineInputs(doc: Document, res: string) {
const baseInfo: Record<string, unknown> = {
page_content: doc.pageContent,
...doc.metadata,
};
const documentInfo: Record<string, unknown> = {};
this.documentPrompt.inputVariables.forEach((value) => {
documentInfo[value] = baseInfo[value];
});
const baseInputs: Record<string, unknown> = {
[this.documentVariableName]: this.documentPrompt.format({
...documentInfo,
}),
};
const inputs = { [this.initialResponseName]: res, ...baseInputs };
return inputs;
}

async _call(values: ChainValues): Promise<ChainValues> {
if (!(this.inputKey in values)) {
throw new Error(`Document key ${this.inputKey} not found.`);
}
const { [this.inputKey]: docs, ...rest } = values;

const currentDocs = docs as Document[];

const initialInputs = this._constructInitialInputs(currentDocs[0], rest);
let res = await this.llmChain.predict({ ...initialInputs });

const refineSteps = [res];

for (let i = 1; i < currentDocs.length; i += 1) {
const refineInputs = this._constructRefineInputs(currentDocs[i], res);
const inputs = { ...refineInputs, ...rest };
res = await this.refineLLMChain.predict({ ...inputs });
refineSteps.push(res);
}

return { [this.outputKey]: res };
}

_chainType() {
return "refine_documents_chain" as const;
}

static async deserialize(data: SerializedRefineDocumentsChain) {
const SerializedLLMChain = data.llm_chain;

if (!SerializedLLMChain) {
throw new Error("Missing llm_chain");
}

const SerializedRefineDocumentChain = data.refine_llm_chain;

if (!SerializedRefineDocumentChain) {
throw new Error("Missing refine_llm_chain");
}

return new RefineDocumentsChain({
llmChain: await LLMChain.deserialize(SerializedLLMChain),
refineLLMChain: await LLMChain.deserialize(SerializedRefineDocumentChain),
});
}

serialize(): SerializedRefineDocumentsChain {
return {
_type: this._chainType(),
llm_chain: this.llmChain.serialize(),
refine_llm_chain: this.refineLLMChain.serialize(),
};
}
}
3 changes: 3 additions & 0 deletions langchain/src/chains/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export { LLMChain, ConversationChain } from "./llm_chain.js";
export {
StuffDocumentsChain,
MapReduceDocumentsChain,
RefineDocumentsChain,
} from "./combine_docs_chain.js";
export { ChatVectorDBQAChain } from "./chat_vector_db_chain.js";
export { AnalyzeDocumentChain } from "./analyze_documents_chain.js";
Expand All @@ -11,6 +12,7 @@ export {
loadQAChain,
loadQAStuffChain,
loadQAMapReduceChain,
loadQARefineChain,
} from "./question_answering/load.js";
export { loadSummarizationChain } from "./summarization/load.js";
export { SqlDatabaseChain } from "./sql_db/sql_db_chain.js";
Expand All @@ -25,4 +27,5 @@ export {
SerializedMapReduceDocumentsChain,
SerializedStuffDocumentsChain,
SerializedVectorDBQAChain,
SerializedRefineDocumentsChain,
} from "./serde.js";
44 changes: 44 additions & 0 deletions langchain/src/chains/question_answering/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { BasePromptTemplate } from "../../prompts/base.js";
import {
StuffDocumentsChain,
MapReduceDocumentsChain,
RefineDocumentsChain,
} from "../combine_docs_chain.js";
import { QA_PROMPT_SELECTOR, DEFAULT_QA_PROMPT } from "./stuff_prompts.js";
import {
Expand All @@ -12,11 +13,17 @@ import {
COMBINE_QA_PROMPT_SELECTOR,
} from "./map_reduce_prompts.js";
import { BaseLanguageModel } from "../../base_language/index.js";
import {
QUESTION_PROMPT_SELECTOR,
REFINE_PROMPT_SELECTOR,
} from "./refine_prompts.js";

interface qaChainParams {
prompt?: BasePromptTemplate;
combineMapPrompt?: BasePromptTemplate;
combinePrompt?: BasePromptTemplate;
questionPrompt?: BasePromptTemplate;
refinePrompt?: BasePromptTemplate;
type?: string;
}
export const loadQAChain = (
Expand Down Expand Up @@ -47,6 +54,20 @@ export const loadQAChain = (
});
return chain;
}
if (type === "refine") {
const {
questionPrompt = QUESTION_PROMPT_SELECTOR.getPrompt(llm),
refinePrompt = REFINE_PROMPT_SELECTOR.getPrompt(llm),
} = params;
const llmChain = new LLMChain({ prompt: questionPrompt, llm });
const refineLLMChain = new LLMChain({ prompt: refinePrompt, llm });

const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
});
return chain;
}
throw new Error(`Invalid _type: ${type}`);
};

Expand Down Expand Up @@ -89,3 +110,26 @@ export const loadQAMapReduceChain = (
});
return chain;
};

interface RefineQAChainParams {
questionPrompt?: BasePromptTemplate;
refinePrompt?: BasePromptTemplate;
}

export const loadQARefineChain = (
llm: BaseLanguageModel,
params: RefineQAChainParams = {}
) => {
const {
questionPrompt = QUESTION_PROMPT_SELECTOR.getPrompt(llm),
refinePrompt = REFINE_PROMPT_SELECTOR.getPrompt(llm),
} = params;
const llmChain = new LLMChain({ prompt: questionPrompt, llm });
const refineLLMChain = new LLMChain({ prompt: refinePrompt, llm });

const chain = new RefineDocumentsChain({
llmChain,
refineLLMChain,
});
return chain;
};
76 changes: 76 additions & 0 deletions langchain/src/chains/question_answering/refine_prompts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* eslint-disable tree-shaking/no-side-effects-in-initialization */
/* eslint-disable spaced-comment */
import {
PromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
AIMessagePromptTemplate,
} from "../../prompts/index.js";
import { ConditionalPromptSelector, isChatModel } from "../prompt_selector.js";

export const DEFAULT_REFINE_PROMPT_TMPL = `The original question is as follows: {question}
We have provided an existing answer: {existing_answer}
We have the opportunity to refine the existing answer
(only if needed) with some more context below.
------------
{context}
------------
Given the new context, refine the original answer to better answer the question.
If the context isn't useful, return the original answer.`;
export const DEFAULT_REFINE_PROMPT = /*#__PURE__*/ new PromptTemplate({
inputVariables: ["question", "existing_answer", "context"],
template: DEFAULT_REFINE_PROMPT_TMPL,
});

const refineTemplate = `The original question is as follows: {question}
We have provided an existing answer: {existing_answer}
We have the opportunity to refine the existing answer
(only if needed) with some more context below.
------------
{context}
------------
Given the new context, refine the original answer to better answer the question.
If the context isn't useful, return the original answer.`;

const messages = [
/*#__PURE__*/ HumanMessagePromptTemplate.fromTemplate("{question}"),
/*#__PURE__*/ AIMessagePromptTemplate.fromTemplate("{existing_answer}"),
/*#__PURE__*/ HumanMessagePromptTemplate.fromTemplate(refineTemplate),
];

export const CHAT_REFINE_PROMPT =
/*#__PURE__*/ ChatPromptTemplate.fromPromptMessages(messages);

export const REFINE_PROMPT_SELECTOR =
/*#__PURE__*/ new ConditionalPromptSelector(DEFAULT_REFINE_PROMPT, [
[isChatModel, CHAT_REFINE_PROMPT],
]);

export const DEFAULT_TEXT_QA_PROMPT_TMPL = `Context information is below.
---------------------
{context}
---------------------
Given the context information and not prior knowledge, answer the question: {question}`;
export const DEFAULT_TEXT_QA_PROMPT = /*#__PURE__*/ new PromptTemplate({
inputVariables: ["context", "question"],
template: DEFAULT_TEXT_QA_PROMPT_TMPL,
});

const chat_qa_prompt_template = `Context information is below.
---------------------
{context}
---------------------
Given the context information and not prior knowledge, answer any questions`;
const chat_messages = [
/*#__PURE__*/ SystemMessagePromptTemplate.fromTemplate(
chat_qa_prompt_template
),
/*#__PURE__*/ HumanMessagePromptTemplate.fromTemplate("{question}"),
];
export const CHAT_QUESTION_PROMPT =
/*#__PURE__*/ ChatPromptTemplate.fromPromptMessages(chat_messages);
export const QUESTION_PROMPT_SELECTOR =
/*#__PURE__*/ new ConditionalPromptSelector(DEFAULT_TEXT_QA_PROMPT, [
[isChatModel, CHAT_QUESTION_PROMPT],
]);
20 changes: 19 additions & 1 deletion langchain/src/chains/question_answering/tests/load.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import { test } from "@jest/globals";
import { OpenAI } from "../../../llms/openai.js";
import { loadQAMapReduceChain, loadQAStuffChain } from "../load.js";
import {
loadQAMapReduceChain,
loadQARefineChain,
loadQAStuffChain,
} from "../load.js";
import { Document } from "../../../document.js";

test("Test loadQAStuffChain", async () => {
Expand All @@ -26,3 +30,17 @@ test("Test loadQAMapReduceChain", async () => {
const res = await chain.call({ input_documents: docs, question: "Whats up" });
console.log({ res });
});

test("Test loadQARefineChain", async () => {
const model = new OpenAI({ modelName: "text-ada-001" });
const chain = loadQARefineChain(model);
const docs = [
new Document({ pageContent: "Harrison went to Harvard." }),
new Document({ pageContent: "Ankush went to Princeton." }),
];
const res = await chain.call({
input_documents: docs,
question: "Where did Harrison go to college?",
});
console.log({ res });
});
Loading

0 comments on commit 38ed068

Please sign in to comment.