diff --git a/.github/workflows/example_test.yaml b/.github/workflows/example_test.yaml index 29be0696a..f7020a40d 100644 --- a/.github/workflows/example_test.yaml +++ b/.github/workflows/example_test.yaml @@ -86,7 +86,7 @@ jobs: run: tests/example-test.sh - name: Upload logs if test fail if: failure() - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ github.sha }}-${{ matrix.no }}.logs path: ${{ env.LOG_DIR }} diff --git a/api/app-node/chain/v1alpha1/llmchain_types.go b/api/app-node/chain/v1alpha1/llmchain_types.go index 21058eae4..dfb725a47 100644 --- a/api/app-node/chain/v1alpha1/llmchain_types.go +++ b/api/app-node/chain/v1alpha1/llmchain_types.go @@ -42,13 +42,45 @@ type Output struct { } type CommonChainConfig struct { - // 记忆相关参数 + // for memory Memory Memory `json:"memory,omitempty"` + + // Model is the model to use in an llm call.like `gpt-3.5-turbo` or `chatglm_turbo` + // Usually this value is just empty + Model string `json:"model,omitempty"` + // MaxTokens is the maximum number of tokens to generate to use in a llm call. + // +kubebuilder:validation:Minimum=10 + // +kubebuilder:validation:Maximum=4096 + // +kubebuilder:default=512 + MaxTokens int `json:"maxTokens,omitempty"` + // Temperature is the temperature for sampling to use in a llm call, between 0 and 1. + //+kubebuilder:validation:Minimum=0 + //+kubebuilder:validation:Maximum=1 + Temperature float64 `json:"temperature,omitempty"` + // StopWords is a list of words to stop on to use in a llm call. + StopWords []string `json:"stopWords,omitempty"` + // TopK is the number of tokens to consider for top-k sampling in a llm call. + TopK int `json:"topK,omitempty"` + // TopP is the cumulative probability for top-p sampling in a llm call. + TopP float64 `json:"topP,omitempty"` + // Seed is a seed for deterministic sampling in a llm call. + Seed int `json:"seed,omitempty"` + // MinLength is the minimum length of the generated text in a llm call. + MinLength int `json:"minLength,omitempty"` + // MaxLength is the maximum length of the generated text in a llm call. + MaxLength int `json:"maxLength,omitempty"` + // RepetitionPenalty is the repetition penalty for sampling in a llm call. + RepetitionPenalty float64 `json:"repetitionPenalty,omitempty"` } type Memory struct { - // 能记住的最大 token 数 + // MaxTokenLimit is the maximum number of tokens to keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. MaxTokenLimit int `json:"maxTokenLimit,omitempty"` + // ConversionWindowSize is the maximum number of conversation rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=30 + // +kubebuilder:default=5 + ConversionWindowSize int `json:"conversionWindowSize,omitempty"` } // LLMChainStatus defines the observed state of LLMChain diff --git a/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go b/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go index 63266fa59..6c6b0cc8d 100644 --- a/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go @@ -29,6 +29,11 @@ import ( func (in *CommonChainConfig) DeepCopyInto(out *CommonChainConfig) { *out = *in out.Memory = in.Memory + if in.StopWords != nil { + in, out := &in.StopWords, &out.StopWords + *out = make([]string, len(*in)) + copy(*out, *in) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CommonChainConfig. @@ -121,7 +126,7 @@ func (in *LLMChainList) DeepCopyObject() runtime.Object { func (in *LLMChainSpec) DeepCopyInto(out *LLMChainSpec) { *out = *in out.CommonSpec = in.CommonSpec - out.CommonChainConfig = in.CommonChainConfig + in.CommonChainConfig.DeepCopyInto(&out.CommonChainConfig) out.Input = in.Input in.Output.DeepCopyInto(&out.Output) } @@ -263,7 +268,7 @@ func (in *RetrievalQAChainList) DeepCopyObject() runtime.Object { func (in *RetrievalQAChainSpec) DeepCopyInto(out *RetrievalQAChainSpec) { *out = *in out.CommonSpec = in.CommonSpec - out.CommonChainConfig = in.CommonChainConfig + in.CommonChainConfig.DeepCopyInto(&out.CommonChainConfig) out.Input = in.Input in.Output.DeepCopyInto(&out.Output) } diff --git a/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go b/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go index 2ba4fff32..859608c87 100644 --- a/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go +++ b/api/app-node/retriever/v1alpha1/knowledgebaseretriever_types.go @@ -25,9 +25,10 @@ import ( // KnowledgeBaseRetrieverSpec defines the desired state of KnowledgeBaseRetriever type KnowledgeBaseRetrieverSpec struct { - v1alpha1.CommonSpec `json:",inline"` - Input Input `json:"input,omitempty"` - Output Output `json:"output,omitempty"` + v1alpha1.CommonSpec `json:",inline"` + Input Input `json:"input,omitempty"` + Output Output `json:"output,omitempty"` + CommonRetrieverConfig `json:",inline"` } type Input struct { @@ -38,6 +39,22 @@ type Output struct { node.CommonOrInPutOrOutputRef `json:",inline"` } +type CommonRetrieverConfig struct { + // ScoreThreshold is the cosine distance float score threshold. Lower score represents more similarity. + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=1 + // +kubebuilder:default=0.7 + ScoreThreshold float32 `json:"scoreThreshold,omitempty"` + // NumDocuments is the max number of documents to return. + // +kubebuilder:default=5 + // +kubebuilder:validation:Minimum=1 + // +kubebuilder:validation:Maximum=10 + NumDocuments int `json:"numDocuments,omitempty"` + // DocNullReturn is the return statement when the query result is empty from the retriever. + // +kubebuilder:default="未找到您询问的内容,请详细描述您的问题" + DocNullReturn string `json:"docNullReturn,omitempty"` +} + // KnowledgeBaseRetrieverStatus defines the observed state of KnowledgeBaseRetriever type KnowledgeBaseRetrieverStatus struct { // ObservedGeneration is the last observed generation. diff --git a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go index b62187098..2d13b6325 100644 --- a/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/retriever/v1alpha1/zz_generated.deepcopy.go @@ -25,6 +25,21 @@ import ( runtime "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *CommonRetrieverConfig) DeepCopyInto(out *CommonRetrieverConfig) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CommonRetrieverConfig. +func (in *CommonRetrieverConfig) DeepCopy() *CommonRetrieverConfig { + if in == nil { + return nil + } + out := new(CommonRetrieverConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Input) DeepCopyInto(out *Input) { *out = *in @@ -106,6 +121,7 @@ func (in *KnowledgeBaseRetrieverSpec) DeepCopyInto(out *KnowledgeBaseRetrieverSp out.CommonSpec = in.CommonSpec out.Input = in.Input in.Output.DeepCopyInto(&out.Output) + out.CommonRetrieverConfig = in.CommonRetrieverConfig } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new KnowledgeBaseRetrieverSpec. diff --git a/apiserver/pkg/chat/chat.go b/apiserver/pkg/chat/chat.go index 7ccfd975c..966a92f6c 100644 --- a/apiserver/pkg/chat/chat.go +++ b/apiserver/pkg/chat/chat.go @@ -22,6 +22,7 @@ import ( "errors" "time" + "github.com/tmc/langchaingo/memory" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -73,6 +74,7 @@ func AppRun(ctx context.Context, req ChatReqBody) (*ChatRespBody, chan ChatRespB StartedAt: time.Now(), UpdatedAt: time.Now(), Messages: make([]Message, 0), + History: memory.NewChatMessageHistory(), } } conversion.Messages = append(conversion.Messages, Message{ @@ -85,7 +87,7 @@ func AppRun(ctx context.Context, req ChatReqBody) (*ChatRespBody, chan ChatRespB if err != nil { return nil, nil, err } - out, outStream, err := appRun.Run(ctx, c, application.Input{Question: req.Query, NeedStream: req.ResponseMode == Streaming}) + out, outStream, err := appRun.Run(ctx, c, application.Input{Question: req.Query, NeedStream: req.ResponseMode == Streaming, History: conversion.History}) if err != nil { return nil, nil, err } diff --git a/apiserver/pkg/chat/chat_type.go b/apiserver/pkg/chat/chat_type.go index 44e917f18..5d376bae1 100644 --- a/apiserver/pkg/chat/chat_type.go +++ b/apiserver/pkg/chat/chat_type.go @@ -16,7 +16,11 @@ limitations under the License. package chat -import "time" +import ( + "time" + + "github.com/tmc/langchaingo/memory" +) type ResponseMode string @@ -48,6 +52,7 @@ type Conversion struct { StartedAt time.Time `json:"started_at"` UpdatedAt time.Time `json:"updated_at"` Messages []Message `json:"messages"` + History *memory.ChatMessageHistory } type Message struct { diff --git a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml index 6feaaeb55..5d9749e3e 100644 --- a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml +++ b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml @@ -88,13 +88,40 @@ spec: - llm - prompt type: object + maxLength: + description: MaxLength is the maximum length of the generated text + in a llm call. + type: integer + maxTokens: + default: 512 + description: MaxTokens is the maximum number of tokens to generate + to use in a llm call. + maximum: 4096 + minimum: 10 + type: integer memory: - description: 记忆相关参数 + description: for memory properties: + conversionWindowSize: + default: 5 + description: ConversionWindowSize is the maximum number of conversation + rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. + maximum: 30 + minimum: 0 + type: integer maxTokenLimit: - description: 能记住的最大 token 数 + description: MaxTokenLimit is the maximum number of tokens to + keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. type: integer type: object + minLength: + description: MinLength is the minimum length of the generated text + in a llm call. + type: integer + model: + description: Model is the model to use in an llm call.like `gpt-3.5-turbo` + or `chatglm_turbo` Usually this value is just empty + type: string output: properties: apiGroup: @@ -111,6 +138,33 @@ spec: description: Name is the name of resource being referenced type: string type: object + repetitionPenalty: + description: RepetitionPenalty is the repetition penalty for sampling + in a llm call. + type: number + seed: + description: Seed is a seed for deterministic sampling in a llm call. + type: integer + stopWords: + description: StopWords is a list of words to stop on to use in a llm + call. + items: + type: string + type: array + temperature: + description: Temperature is the temperature for sampling to use in + a llm call, between 0 and 1. + maximum: 1 + minimum: 0 + type: number + topK: + description: TopK is the number of tokens to consider for top-k sampling + in a llm call. + type: integer + topP: + description: TopP is the cumulative probability for top-p sampling + in a llm call. + type: number required: - input - output diff --git a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml index cd1a62829..fbb2d6162 100644 --- a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml +++ b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml @@ -109,13 +109,40 @@ spec: - prompt - retriever type: object + maxLength: + description: MaxLength is the maximum length of the generated text + in a llm call. + type: integer + maxTokens: + default: 512 + description: MaxTokens is the maximum number of tokens to generate + to use in a llm call. + maximum: 4096 + minimum: 10 + type: integer memory: - description: 记忆相关参数 + description: for memory properties: + conversionWindowSize: + default: 5 + description: ConversionWindowSize is the maximum number of conversation + rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. + maximum: 30 + minimum: 0 + type: integer maxTokenLimit: - description: 能记住的最大 token 数 + description: MaxTokenLimit is the maximum number of tokens to + keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. type: integer type: object + minLength: + description: MinLength is the minimum length of the generated text + in a llm call. + type: integer + model: + description: Model is the model to use in an llm call.like `gpt-3.5-turbo` + or `chatglm_turbo` Usually this value is just empty + type: string output: properties: apiGroup: @@ -132,6 +159,33 @@ spec: description: Name is the name of resource being referenced type: string type: object + repetitionPenalty: + description: RepetitionPenalty is the repetition penalty for sampling + in a llm call. + type: number + seed: + description: Seed is a seed for deterministic sampling in a llm call. + type: integer + stopWords: + description: StopWords is a list of words to stop on to use in a llm + call. + items: + type: string + type: array + temperature: + description: Temperature is the temperature for sampling to use in + a llm call, between 0 and 1. + maximum: 1 + minimum: 0 + type: number + topK: + description: TopK is the number of tokens to consider for top-k sampling + in a llm call. + type: integer + topP: + description: TopP is the cumulative probability for top-p sampling + in a llm call. + type: number required: - input - output diff --git a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml index 1a6bd12ff..efc55ab79 100644 --- a/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml +++ b/config/crd/bases/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml @@ -45,6 +45,11 @@ spec: displayName: description: DisplayName defines datasource display name type: string + docNullReturn: + default: 未找到您询问的内容,请详细描述您的问题 + description: DocNullReturn is the return statement when the query + result is empty from the retriever. + type: string input: properties: apiGroup: @@ -66,6 +71,12 @@ spec: - apiGroup - kind type: object + numDocuments: + default: 5 + description: NumDocuments is the max number of documents to return. + maximum: 10 + minimum: 1 + type: integer output: properties: apiGroup: @@ -82,6 +93,13 @@ spec: description: Name is the name of resource being referenced type: string type: object + scoreThreshold: + default: 0.7 + description: ScoreThreshold is the cosine distance float score threshold. + Lower score represents more similarity. + maximum: 1 + minimum: 0 + type: number type: object status: description: KnowledgeBaseRetrieverStatus defines the observed state of diff --git a/config/samples/app_llmchain_chat_with_bot.yaml b/config/samples/app_llmchain_chat_with_bot.yaml new file mode 100644 index 000000000..a71247ca6 --- /dev/null +++ b/config/samples/app_llmchain_chat_with_bot.yaml @@ -0,0 +1,118 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-bot + namespace: arcadia +spec: + displayName: "对话机器人" + description: "和AI对话,品赛博人生" + prologue: "Hello, I am KubeAGI Bot🤖, Tell me something?" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-bot + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "zhipu大模型服务" + description: "设定质谱大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: base-chat-with-bot + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "llm chain" + description: "chain是langchain的核心概念,llmChain用于连接prompt和llm" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: LLMChain + name: base-chat-with-bot + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output +--- +apiVersion: prompt.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Prompt +metadata: + name: base-chat-with-bot + namespace: arcadia +spec: + displayName: "设定对话的prompt" + description: "设定对话的prompt" + userMessage: | + The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. + + Current conversation: + {{.history}} + Human: {{.question}} + AI: + input: + kind: "Input" + name: "Input" + output: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: LLMChain + name: base-chat-with-bot +--- +apiVersion: chain.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: LLMChain +metadata: + name: base-chat-with-bot + namespace: arcadia +spec: + displayName: "llm chain" + description: "llm chain" + memory: + conversionWindowSize: 2 + model: chatglm_turbo # notice: default model chatglm_lite gets poor results in most cases, openai's gpt-3.5-turbo is also good enough + input: + llm: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: base-chat-with-bot + prompt: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-bot + output: + apiGroup: "arcadia.kubeagi.k8s.com.cn" + kind: "Output" + name: "output-node" +--- +apiVersion: v1 +kind: Secret +metadata: + name: base-chat-with-bot + namespace: arcadia +type: Opaque +data: + apiKey: "MTZlZDcxYzcwMDE0NGFiMjIyMmI5YmEwZDFhMTBhZTUuUTljWVZtWWxmdjlnZGtDeQ==" # replace this with your API key +--- +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: LLM +metadata: + name: base-chat-with-bot + namespace: arcadia +spec: + type: "zhipuai" + provider: + endpoint: + url: "https://open.bigmodel.cn/api/paas/v3/model-api" # replace this with your LLM URL(Zhipuai use predefined url https://open.bigmodel.cn/api/paas/v3/model-api) + authSecret: + kind: secret + name: base-chat-with-bot diff --git a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml index 6feaaeb55..5d9749e3e 100644 --- a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml +++ b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml @@ -88,13 +88,40 @@ spec: - llm - prompt type: object + maxLength: + description: MaxLength is the maximum length of the generated text + in a llm call. + type: integer + maxTokens: + default: 512 + description: MaxTokens is the maximum number of tokens to generate + to use in a llm call. + maximum: 4096 + minimum: 10 + type: integer memory: - description: 记忆相关参数 + description: for memory properties: + conversionWindowSize: + default: 5 + description: ConversionWindowSize is the maximum number of conversation + rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. + maximum: 30 + minimum: 0 + type: integer maxTokenLimit: - description: 能记住的最大 token 数 + description: MaxTokenLimit is the maximum number of tokens to + keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. type: integer type: object + minLength: + description: MinLength is the minimum length of the generated text + in a llm call. + type: integer + model: + description: Model is the model to use in an llm call.like `gpt-3.5-turbo` + or `chatglm_turbo` Usually this value is just empty + type: string output: properties: apiGroup: @@ -111,6 +138,33 @@ spec: description: Name is the name of resource being referenced type: string type: object + repetitionPenalty: + description: RepetitionPenalty is the repetition penalty for sampling + in a llm call. + type: number + seed: + description: Seed is a seed for deterministic sampling in a llm call. + type: integer + stopWords: + description: StopWords is a list of words to stop on to use in a llm + call. + items: + type: string + type: array + temperature: + description: Temperature is the temperature for sampling to use in + a llm call, between 0 and 1. + maximum: 1 + minimum: 0 + type: number + topK: + description: TopK is the number of tokens to consider for top-k sampling + in a llm call. + type: integer + topP: + description: TopP is the cumulative probability for top-p sampling + in a llm call. + type: number required: - input - output diff --git a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml index cd1a62829..fbb2d6162 100644 --- a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml +++ b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml @@ -109,13 +109,40 @@ spec: - prompt - retriever type: object + maxLength: + description: MaxLength is the maximum length of the generated text + in a llm call. + type: integer + maxTokens: + default: 512 + description: MaxTokens is the maximum number of tokens to generate + to use in a llm call. + maximum: 4096 + minimum: 10 + type: integer memory: - description: 记忆相关参数 + description: for memory properties: + conversionWindowSize: + default: 5 + description: ConversionWindowSize is the maximum number of conversation + rounds in memory.Can only use MaxTokenLimit or ConversionWindowSize. + maximum: 30 + minimum: 0 + type: integer maxTokenLimit: - description: 能记住的最大 token 数 + description: MaxTokenLimit is the maximum number of tokens to + keep in memory. Can only use MaxTokenLimit or ConversionWindowSize. type: integer type: object + minLength: + description: MinLength is the minimum length of the generated text + in a llm call. + type: integer + model: + description: Model is the model to use in an llm call.like `gpt-3.5-turbo` + or `chatglm_turbo` Usually this value is just empty + type: string output: properties: apiGroup: @@ -132,6 +159,33 @@ spec: description: Name is the name of resource being referenced type: string type: object + repetitionPenalty: + description: RepetitionPenalty is the repetition penalty for sampling + in a llm call. + type: number + seed: + description: Seed is a seed for deterministic sampling in a llm call. + type: integer + stopWords: + description: StopWords is a list of words to stop on to use in a llm + call. + items: + type: string + type: array + temperature: + description: Temperature is the temperature for sampling to use in + a llm call, between 0 and 1. + maximum: 1 + minimum: 0 + type: number + topK: + description: TopK is the number of tokens to consider for top-k sampling + in a llm call. + type: integer + topP: + description: TopP is the cumulative probability for top-p sampling + in a llm call. + type: number required: - input - output diff --git a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml index 1a6bd12ff..efc55ab79 100644 --- a/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml +++ b/deploy/charts/arcadia/crds/retriever.arcadia.kubeagi.k8s.com.cn_knowledgebaseretrievers.yaml @@ -45,6 +45,11 @@ spec: displayName: description: DisplayName defines datasource display name type: string + docNullReturn: + default: 未找到您询问的内容,请详细描述您的问题 + description: DocNullReturn is the return statement when the query + result is empty from the retriever. + type: string input: properties: apiGroup: @@ -66,6 +71,12 @@ spec: - apiGroup - kind type: object + numDocuments: + default: 5 + description: NumDocuments is the max number of documents to return. + maximum: 10 + minimum: 1 + type: integer output: properties: apiGroup: @@ -82,6 +93,13 @@ spec: description: Name is the name of resource being referenced type: string type: object + scoreThreshold: + default: 0.7 + description: ScoreThreshold is the cosine distance float score threshold. + Lower score represents more similarity. + maximum: 1 + minimum: 0 + type: number type: object status: description: KnowledgeBaseRetrieverStatus defines the observed state of diff --git a/go.mod b/go.mod index 2f152e30a..8561eddbb 100644 --- a/go.mod +++ b/go.mod @@ -161,3 +161,5 @@ require ( ) replace github.com/amikos-tech/chroma-go => github.com/bjwswang/chroma-go v0.0.0-20231011091545-0041221c9bb3 + +replace github.com/tmc/langchaingo => github.com/Abirdcfly/langchaingo v0.0.0-20231215064216-74b306119ffa // TODO remove this replace after https://github.com/tmc/langchaingo/pull/411 is merged diff --git a/go.sum b/go.sum index 7d9ad89ef..7b8fbf9e2 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/99designs/gqlgen v0.17.40 h1:/l8JcEVQ93wqIfmH9VS1jsAkwm6eAF1NwQn3N+SDqBY= github.com/99designs/gqlgen v0.17.40/go.mod h1:b62q1USk82GYIVjC60h02YguAZLqYZtvWml8KkhJps4= +github.com/Abirdcfly/langchaingo v0.0.0-20231215064216-74b306119ffa h1:IoYt2IecQ35790Pdhp73y7eSK+YuS3PonY0DkLOrlsw= +github.com/Abirdcfly/langchaingo v0.0.0-20231215064216-74b306119ffa/go.mod h1:l3ZkgXSZAighAgArxdo5QoveM4TfSob0u4yEJS8p5mY= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-autorest v14.2.0+incompatible h1:V5VMDjClD3GiElqLWO7mz2MxNAK/vTfRHdAubSIPRgs= github.com/Azure/go-autorest v14.2.0+incompatible/go.mod h1:r+4oMnoxhatjLLJ6zxSWATqVooLgysK6ZNox3g/xq24= @@ -609,8 +611,6 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/tmc/langchaingo v0.0.0-20231209214832-00f364f27fe2 h1:jYxFk98N3864Zq88+6seoUke6IUCUn8s1meYtSJuGdk= -github.com/tmc/langchaingo v0.0.0-20231209214832-00f364f27fe2/go.mod h1:VQf9L5xRny7iSOWD2qn7mAU/N7PJILIXD0RgdD9mV2k= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= diff --git a/pkg/application/app_run.go b/pkg/application/app_run.go index 6cfaf29f2..7194cf9f3 100644 --- a/pkg/application/app_run.go +++ b/pkg/application/app_run.go @@ -21,8 +21,8 @@ import ( "context" "errors" "fmt" - "reflect" + langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/client-go/dynamic" "k8s.io/klog/v2" "k8s.io/utils/strings/slices" @@ -37,9 +37,9 @@ import ( type Input struct { Question string - // History []schema.ChatMessage // overrideConfig NeedStream bool + History langchaingoschema.ChatMessageHistory } type Output struct { Answer string @@ -53,29 +53,35 @@ type Application struct { EndingNode base.Node } -var cache = map[string]*Application{} +// var cache = map[string]*Application{} -func cacheKey(app *arcadiav1alpha1.Application) string { - return app.Namespace + "/" + app.Name -} +// func cacheKey(app *arcadiav1alpha1.Application) string { +// return app.Namespace + "/" + app.Name +//} func NewAppOrGetFromCache(ctx context.Context, app *arcadiav1alpha1.Application, cli dynamic.Interface) (*Application, error) { if app == nil || app.Name == "" || app.Namespace == "" { return nil, errors.New("app has no name or namespace") } - a, ok := cache[cacheKey(app)] - if !ok { - a = &Application{ - Spec: app.Spec, - } - cache[cacheKey(app)] = a - return a, a.Init(ctx, cli) - } - if reflect.DeepEqual(a.Spec, app.Spec) { - return a, nil - } - a.Spec = app.Spec - a.Inited = false + // TODO: disable cache for now. + // https://github.com/kubeagi/arcadia/issues/391 + // a, ok := cache[cacheKey(app)] + // if !ok { + // a = &Application{ + // Spec: app.Spec, + // } + // cache[cacheKey(app)] = a + // return a, a.Init(ctx, cli) + // } + // if reflect.DeepEqual(a.Spec, app.Spec) { + // return a, nil + // } + a := &Application{ + Spec: app.Spec, + Inited: false, + } + // a.Spec = app.Spec + // a.Inited = false return a, a.Init(ctx, cli) } @@ -100,10 +106,10 @@ func (a *Application) Init(ctx context.Context, cli dynamic.Interface) (err erro for _, node := range a.Spec.Nodes { n, err := InitNode(ctx, node.Name, *node.Ref, cli) if err != nil { - return err + return fmt.Errorf("initnode %s failed: %v", node.Name, err) } if err := n.Init(ctx, cli, map[string]any{}); err != nil { // TODO arg - return err + return fmt.Errorf("node %s init failed: %v", node.Name, err) } a.Nodes[node.Name] = n if node.Name == inputNodeName { @@ -141,8 +147,9 @@ func (a *Application) Init(ctx context.Context, cli dynamic.Interface) (err erro func (a *Application) Run(ctx context.Context, cli dynamic.Interface, input Input) (output Output, outputStream chan string, err error) { out := map[string]any{ - "question": input.Question, - "answer_stream": make(chan string, 1000), + "question": input.Question, + "_answer_stream": make(chan string, 1000), + "_history": input.History, } visited := make(map[string]bool) waitRunningNodes := list.New() @@ -152,12 +159,12 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, input Inpu for e := waitRunningNodes.Front(); e != nil; e = e.Next() { e := e.Value.(base.Node) if !visited[e.Name()] { - out["need_stream"] = false + out["_need_stream"] = false if a.EndingNode.Name() == e.Name() && input.NeedStream { - out["need_stream"] = true + out["_need_stream"] = true } if out, err = e.Run(ctx, cli, out); err != nil { - return Output{}, nil, err + return Output{}, nil, fmt.Errorf("run node %s: %w", e.Name(), err) } visited[e.Name()] = true } @@ -165,12 +172,12 @@ func (a *Application) Run(ctx context.Context, cli dynamic.Interface, input Inpu waitRunningNodes.PushBack(n) } } - if a, ok := out["answer"]; ok { + if a, ok := out["_answer"]; ok { if answer, ok := a.(string); ok && len(answer) > 0 { output = Output{Answer: answer} } } - if a, ok := out["answer_stream"]; ok { + if a, ok := out["_answer_stream"]; ok { if answer, ok := a.(chan string); ok && len(answer) > 0 { outputStream = answer } diff --git a/pkg/application/chain/common.go b/pkg/application/chain/common.go new file mode 100644 index 000000000..e8134fc7e --- /dev/null +++ b/pkg/application/chain/common.go @@ -0,0 +1,91 @@ +/* +Copyright 2023 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package chain + +import ( + "context" + "errors" + + "github.com/tmc/langchaingo/chains" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/memory" + langchaingoschema "github.com/tmc/langchaingo/schema" + "k8s.io/klog/v2" + + "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" +) + +func stream(res map[string]any) func(ctx context.Context, chunk []byte) error { + return func(ctx context.Context, chunk []byte) error { + if _, ok := res["_answer_stream"]; !ok { + res["_answer_stream"] = make(chan string) + } + streamChan, ok := res["_answer_stream"].(chan string) + if !ok { + klog.Errorln("answer_stream is not chan string") + return errors.New("answer_stream is not chan string") + } + klog.V(5).Infoln("stream out:", string(chunk)) + streamChan <- string(chunk) + return nil + } +} + +func getChainOptions(config v1alpha1.CommonChainConfig) []chains.ChainCallOption { + options := make([]chains.ChainCallOption, 0) + if config.MaxTokens > 0 { + options = append(options, chains.WithMaxTokens(config.MaxTokens)) + } + if config.Temperature > 0 { + options = append(options, chains.WithTemperature(config.Temperature)) + } + if len(config.StopWords) > 0 { + options = append(options, chains.WithStopWords(config.StopWords)) + } + if config.TopK > 0 { + options = append(options, chains.WithTopK(config.TopK)) + } + if config.TopP > 0 { + options = append(options, chains.WithTopP(config.TopP)) + } + if config.Seed > 0 { + options = append(options, chains.WithSeed(config.Seed)) + } + if config.MinLength > 0 { + options = append(options, chains.WithMinLength(config.MinLength)) + } + if config.MaxLength > 0 { + options = append(options, chains.WithMaxLength(config.MaxLength)) + } + if config.RepetitionPenalty > 0 { + options = append(options, chains.WithRepetitionPenalty(config.RepetitionPenalty)) + } + if config.Model != "" { + options = append(options, chains.WithModel(config.Model)) + } + return options +} + +func getMemory(llm llms.LanguageModel, config v1alpha1.Memory, history langchaingoschema.ChatMessageHistory) langchaingoschema.Memory { + if config.MaxTokenLimit > 0 { + return memory.NewConversationTokenBuffer(llm, config.MaxTokenLimit, memory.WithInputKey("question"), memory.WithOutputKey("text"), memory.WithChatHistory(history)) + } + if config.ConversionWindowSize > 0 { + return memory.NewConversationWindowBuffer(config.ConversionWindowSize, memory.WithInputKey("question"), memory.WithOutputKey("text"), memory.WithChatHistory(history)) + } + return memory.NewSimple() +} diff --git a/pkg/application/chain/llmchain.go b/pkg/application/chain/llmchain.go index a46221df5..748ac92af 100644 --- a/pkg/application/chain/llmchain.go +++ b/pkg/application/chain/llmchain.go @@ -19,12 +19,18 @@ package chain import ( "context" "errors" + "fmt" "github.com/tmc/langchaingo/chains" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/prompts" + langchaingoschema "github.com/tmc/langchaingo/schema" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" + "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" "github.com/kubeagi/arcadia/pkg/application/base" ) @@ -40,7 +46,7 @@ func NewLLMChain(baseNode base.BaseNode) *LLMChain { } } -func (l *LLMChain) Run(ctx context.Context, _ dynamic.Interface, args map[string]any) (map[string]any, error) { +func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[string]any) (map[string]any, error) { v1, ok := args["llm"] if !ok { return args, errors.New("no llm") @@ -57,18 +63,44 @@ func (l *LLMChain) Run(ctx context.Context, _ dynamic.Interface, args map[string if !ok { return args, errors.New("prompt not prompts.FormatPrompter") } + v3, ok := args["_history"] + if !ok { + return args, errors.New("no history") + } + history, ok := v3.(langchaingoschema.ChatMessageHistory) + if !ok { + return args, errors.New("history not memory.ChatMessageHistory") + } + + ns := base.GetAppNamespace(ctx) + instance := &v1alpha1.LLMChain{} + obj, err := cli.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "llmchains"}). + Namespace(l.Ref.GetNamespace(ns)).Get(ctx, l.Ref.Name, metav1.GetOptions{}) + if err != nil { + return args, fmt.Errorf("cant find the chain in cluster: %w", err) + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), instance) + if err != nil { + return args, err + } + options := getChainOptions(instance.Spec.CommonChainConfig) + chain := chains.NewLLMChain(llm, prompt) + chain.Memory = getMemory(llm, instance.Spec.Memory, history) l.LLMChain = *chain var out string - var err error - if needStream, ok := args["need_stream"].(bool); ok && needStream { - option := chains.WithStreamingFunc(stream(args)) - out, err = chains.Predict(ctx, l.LLMChain, args, option) + if needStream, ok := args["_need_stream"].(bool); ok && needStream { + options = append(options, chains.WithStreamingFunc(stream(args))) + out, err = chains.Predict(ctx, l.LLMChain, args, options...) } else { - out, err = chains.Predict(ctx, l.LLMChain, args) + if len(options) > 0 { + out, err = chains.Predict(ctx, l.LLMChain, args, options...) + } else { + out, err = chains.Predict(ctx, l.LLMChain, args) + } } if err == nil { - args["answer"] = out + args["_answer"] = out } return args, err } diff --git a/pkg/application/chain/retrievalqachain.go b/pkg/application/chain/retrievalqachain.go index 971f7e6f9..af0f860eb 100644 --- a/pkg/application/chain/retrievalqachain.go +++ b/pkg/application/chain/retrievalqachain.go @@ -19,31 +19,36 @@ package chain import ( "context" "errors" + "fmt" "github.com/tmc/langchaingo/chains" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/prompts" - "github.com/tmc/langchaingo/schema" + langchainschema "github.com/tmc/langchaingo/schema" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/dynamic" "k8s.io/klog/v2" + "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" "github.com/kubeagi/arcadia/pkg/application/base" appretriever "github.com/kubeagi/arcadia/pkg/application/retriever" ) type RetrievalQAChain struct { - chains.RetrievalQA + chains.ConversationalRetrievalQA base.BaseNode } func NewRetrievalQAChain(baseNode base.BaseNode) *RetrievalQAChain { return &RetrievalQAChain{ - chains.RetrievalQA{}, + chains.ConversationalRetrievalQA{}, baseNode, } } -func (l *RetrievalQAChain) Run(ctx context.Context, _ dynamic.Interface, args map[string]any) (map[string]any, error) { +func (l *RetrievalQAChain) Run(ctx context.Context, cli dynamic.Interface, args map[string]any) (map[string]any, error) { v1, ok := args["llm"] if !ok { return args, errors.New("no llm") @@ -64,32 +69,56 @@ func (l *RetrievalQAChain) Run(ctx context.Context, _ dynamic.Interface, args ma if !ok { return args, errors.New("no retriever") } - retriever, ok := v3.(schema.Retriever) + retriever, ok := v3.(langchainschema.Retriever) if !ok { return args, errors.New("retriever not schema.Retriever") } + v4, ok := args["_history"] + if !ok { + return args, errors.New("no history") + } + history, ok := v4.(langchainschema.ChatMessageHistory) + if !ok { + return args, errors.New("prompt not prompts.FormatPrompter") + } + + ns := base.GetAppNamespace(ctx) + instance := &v1alpha1.RetrievalQAChain{} + obj, err := cli.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "retrievalqachains"}). + Namespace(l.Ref.GetNamespace(ns)).Get(ctx, l.Ref.Name, metav1.GetOptions{}) + if err != nil { + return args, fmt.Errorf("cant find the chain in cluster: %w", err) + } + err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), instance) + if err != nil { + return args, err + } + options := getChainOptions(instance.Spec.CommonChainConfig) llmChain := chains.NewLLMChain(llm, prompt) var baseChain chains.Chain - if _, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok { - baseChain = appretriever.NewStuffDocuments(llmChain) + if knowledgeBaseRetriever, ok := v3.(*appretriever.KnowledgeBaseRetriever); ok { + baseChain = appretriever.NewStuffDocuments(llmChain, knowledgeBaseRetriever.DocNullReturn) } else { baseChain = chains.NewStuffDocuments(llmChain) } - chain := chains.NewRetrievalQA(baseChain, retriever) - l.RetrievalQA = chain + chain := chains.NewConversationalRetrievalQA(baseChain, chains.LoadCondenseQuestionGenerator(llm), retriever, getMemory(llm, instance.Spec.Memory, history)) + l.ConversationalRetrievalQA = chain args["query"] = args["question"] var out string - var err error - if needStream, ok := args["need_stream"].(bool); ok && needStream { - option := chains.WithStreamingFunc(stream(args)) - out, err = chains.Predict(ctx, l.RetrievalQA, args, option) + if needStream, ok := args["_need_stream"].(bool); ok && needStream { + options = append(options, chains.WithStreamingFunc(stream(args))) + out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) } else { - out, err = chains.Predict(ctx, l.RetrievalQA, args) + if len(options) > 0 { + out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args, options...) + } else { + out, err = chains.Predict(ctx, l.ConversationalRetrievalQA, args) + } } klog.Infof("out:%v, err:%s", out, err) if err == nil { - args["answer"] = out + args["_answer"] = out } return args, err } diff --git a/pkg/application/chain/stream.go b/pkg/application/chain/stream.go deleted file mode 100644 index 3dd004c3d..000000000 --- a/pkg/application/chain/stream.go +++ /dev/null @@ -1,40 +0,0 @@ -/* -Copyright 2023 KubeAGI. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package chain - -import ( - "context" - "errors" - - "k8s.io/klog/v2" -) - -func stream(res map[string]any) func(ctx context.Context, chunk []byte) error { - return func(ctx context.Context, chunk []byte) error { - if _, ok := res["answer_stream"]; !ok { - res["answer_stream"] = make(chan string) - } - streamChan, ok := res["answer_stream"].(chan string) - if !ok { - klog.Errorln("answer_stream is not chan string") - return errors.New("answer_stream is not chan string") - } - klog.V(5).Infoln("stream out:", string(chunk)) - streamChan <- string(chunk) - return nil - } -} diff --git a/pkg/application/llm/llm.go b/pkg/application/llm/llm.go index 7a2fce57b..757b22889 100644 --- a/pkg/application/llm/llm.go +++ b/pkg/application/llm/llm.go @@ -18,6 +18,7 @@ package llm import ( "context" + "fmt" langchainllms "github.com/tmc/langchaingo/llms" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -47,7 +48,7 @@ func (z *LLM) Init(ctx context.Context, cli dynamic.Interface, args map[string]a obj, err := cli.Resource(schema.GroupVersionResource{Group: v1alpha1.GroupVersion.Group, Version: v1alpha1.GroupVersion.Version, Resource: "llms"}). Namespace(z.Ref.GetNamespace(ns)).Get(ctx, z.Ref.Name, metav1.GetOptions{}) if err != nil { - return err + return fmt.Errorf("cant find the llm in cluster: %w", err) } err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), instance) if err != nil { diff --git a/pkg/application/prompt/prompt.go b/pkg/application/prompt/prompt.go index c7dfe8c7a..01b48ae20 100644 --- a/pkg/application/prompt/prompt.go +++ b/pkg/application/prompt/prompt.go @@ -53,7 +53,7 @@ func (p *Prompt) Run(ctx context.Context, cli dynamic.Interface, args map[string return args, err } template := prompts.NewChatPromptTemplate([]prompts.MessageFormatter{ - // prompts.NewSystemMessagePromptTemplate(instance.Spec.SystemMessage, []string{}), // It's not working now, and it's counterproductive. + prompts.NewSystemMessagePromptTemplate(instance.Spec.SystemMessage, []string{}), // It's not working now, and it's counterproductive. prompts.NewHumanMessagePromptTemplate(instance.Spec.UserMessage, []string{"question"}), }) // todo format diff --git a/pkg/application/retriever/knowledgebaseretriever.go b/pkg/application/retriever/knowledgebaseretriever.go index a163c34aa..6241ee4fe 100644 --- a/pkg/application/retriever/knowledgebaseretriever.go +++ b/pkg/application/retriever/knowledgebaseretriever.go @@ -20,6 +20,7 @@ import ( "context" "fmt" + "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/chains" langchaingoschema "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/vectorstores" @@ -39,6 +40,7 @@ import ( type KnowledgeBaseRetriever struct { langchaingoschema.Retriever base.BaseNode + DocNullReturn string } func NewKnowledgeBaseRetriever(ctx context.Context, baseNode base.BaseNode, cli dynamic.Interface) (*KnowledgeBaseRetriever, error) { @@ -47,7 +49,7 @@ func NewKnowledgeBaseRetriever(ctx context.Context, baseNode base.BaseNode, cli obj, err := cli.Resource(schema.GroupVersionResource{Group: apiretriever.GroupVersion.Group, Version: apiretriever.GroupVersion.Version, Resource: "knowledgebaseretrievers"}). Namespace(baseNode.Ref.GetNamespace(ns)).Get(ctx, baseNode.Ref.Name, metav1.GetOptions{}) if err != nil { - return nil, err + return nil, fmt.Errorf("cant find the retriever in cluster: %w", err) } err = runtime.DefaultUnstructuredConverter.FromUnstructured(obj.UnstructuredContent(), instance) if err != nil { @@ -107,11 +109,10 @@ func NewKnowledgeBaseRetriever(ctx context.Context, baseNode base.BaseNode, cli if err != nil { return nil, err } - // TODO: allow to configure how many relevant documents should be returned - // and the threshold of similiarity(0-1), use 5 and 0.3 by default for now return &KnowledgeBaseRetriever{ - vectorstores.ToRetriever(s, 5, vectorstores.WithScoreThreshold(0.3)), + vectorstores.ToRetriever(s, instance.Spec.NumDocuments, vectorstores.WithScoreThreshold(instance.Spec.ScoreThreshold)), baseNode, + instance.Spec.DocNullReturn, }, nil default: return nil, fmt.Errorf("unknown vectorstore type: %s", vectorStore.Spec.Type()) @@ -126,14 +127,26 @@ func (l *KnowledgeBaseRetriever) Run(ctx context.Context, _ dynamic.Interface, a // KnowledgeBaseStuffDocuments is similar to chains.StuffDocuments but with new joinDocuments method type KnowledgeBaseStuffDocuments struct { chains.StuffDocuments + isDocNullReturn bool + DocNullReturn string + callbacks.SimpleHandler } var _ chains.Chain = KnowledgeBaseStuffDocuments{} +var _ callbacks.Handler = KnowledgeBaseStuffDocuments{} -func (c KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Document) string { +func (c *KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Document) string { var text string docLen := len(docs) for k, doc := range docs { + klog.Infof("KnowledgeBaseRetriever: related doc[%d] raw text: %s, raw score: %v\n", k, doc.PageContent, doc.Score) + for key, v := range doc.Metadata { + if str, ok := v.([]byte); ok { + klog.Infof("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %s\n", k, key, string(str)) + } else { + klog.Infof("KnowledgeBaseRetriever: related doc[%d] metadata[%s]: %#v\n", k, key, v) + } + } answer := doc.Metadata["a"] answerBytes, _ := answer.([]byte) text += doc.PageContent @@ -144,13 +157,17 @@ func (c KnowledgeBaseStuffDocuments) joinDocuments(docs []langchaingoschema.Docu text += c.Separator } } - klog.Infof("get related text: %s\n", text) + klog.Infof("KnowledgeBaseRetriever: finally get related text: %s\n", text) + if len(text) == 0 { + c.isDocNullReturn = true + } return text } -func NewStuffDocuments(llmChain *chains.LLMChain) KnowledgeBaseStuffDocuments { +func NewStuffDocuments(llmChain *chains.LLMChain, docNullReturn string) KnowledgeBaseStuffDocuments { return KnowledgeBaseStuffDocuments{ StuffDocuments: chains.NewStuffDocuments(llmChain), + DocNullReturn: docNullReturn, } } @@ -180,3 +197,11 @@ func (c KnowledgeBaseStuffDocuments) GetInputKeys() []string { func (c KnowledgeBaseStuffDocuments) GetOutputKeys() []string { return c.StuffDocuments.GetOutputKeys() } + +func (c KnowledgeBaseStuffDocuments) HandleChainEnd(_ context.Context, outputValues map[string]any) { + if !c.isDocNullReturn { + return + } + klog.Infof("raw llmChain output: %s, but there is no doc return, so set output to %s\n", outputValues[c.LLMChain.OutputKey], c.DocNullReturn) + outputValues[c.LLMChain.OutputKey] = c.DocNullReturn +} diff --git a/pkg/langchainwrap/llm.go b/pkg/langchainwrap/llm.go index e3ac81cc3..5e230f49b 100644 --- a/pkg/langchainwrap/llm.go +++ b/pkg/langchainwrap/llm.go @@ -49,9 +49,9 @@ func GetLangchainLLM(ctx context.Context, llm *v1alpha1.LLM, c client.Client, cl switch llm.Spec.Type { case llms.ZhiPuAI: z := zhipuai.NewZhiPuAI(apiKey) - return zhipuai.ZhiPuAILLM{ZhiPuAI: *z}, nil + return &zhipuai.ZhiPuAILLM{ZhiPuAI: *z, RetryTimes: 3}, nil case llms.OpenAI: - return openai.New(openai.WithToken(apiKey), openai.WithBaseURL(llm.Spec.Enpoint.URL)) + return openai.NewChat(openai.WithToken(apiKey), openai.WithBaseURL(llm.Spec.Enpoint.URL)) } case v1alpha1.ProviderTypeWorker: gateway, err := config.GetGateway(ctx, c, cli) @@ -86,7 +86,7 @@ func GetLangchainLLM(ctx context.Context, llm *v1alpha1.LLM, c client.Client, cl return nil, fmt.Errorf("worker.spec.model not defined") } modelName := worker.MakeRegistrationModelName() - return openai.New(openai.WithModel(modelName), openai.WithBaseURL(gateway.APIServer), openai.WithToken("fake")) + return openai.NewChat(openai.WithModel(modelName), openai.WithBaseURL(gateway.APIServer), openai.WithToken("fake")) } return nil, fmt.Errorf("unknown provider type") } diff --git a/pkg/llms/zhipuai/langchainllm.go b/pkg/llms/zhipuai/langchainllm.go index f220804df..bd072f7cc 100644 --- a/pkg/llms/zhipuai/langchainllm.go +++ b/pkg/llms/zhipuai/langchainllm.go @@ -21,7 +21,11 @@ import ( "context" "encoding/json" "errors" + "fmt" + "math/rand" + "reflect" "strings" + "time" "github.com/r3labs/sse/v2" langchainllm "github.com/tmc/langchaingo/llms" @@ -30,80 +34,134 @@ import ( ) var ( - ErrEmptyResponse = errors.New("no response") - ErrEmptyPrompt = errors.New("empty prompt") - ErrIncompleteEmbedding = errors.New("no all input got emmbedded") + ErrEmptyResponse = errors.New("no response") + ErrEmptyPrompt = errors.New("empty prompt") +) + +var ( + _ langchainllm.LanguageModel = (*ZhiPuAILLM)(nil) + _ langchainllm.ChatLLM = (*ZhiPuAILLM)(nil) ) type ZhiPuAILLM struct { ZhiPuAI + RetryTimes int +} + +func (z *ZhiPuAILLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...langchainllm.CallOption) (langchainllm.LLMResult, error) { + return langchainllm.GenerateChatPrompt(ctx, z, promptValues, options...) +} + +func (z *ZhiPuAILLM) GetNumTokens(text string) int { + return langchainllm.CountTokens("gpt2", text) } -func (z ZhiPuAILLM) Call(ctx context.Context, prompt string, options ...langchainllm.CallOption) (string, error) { - r, err := z.Generate(ctx, []string{prompt}, options...) +var _ langchainllm.ChatLLM = (*ZhiPuAILLM)(nil) + +func (z *ZhiPuAILLM) Call(ctx context.Context, messages []schema.ChatMessage, options ...langchainllm.CallOption) (*schema.AIChatMessage, error) { + r, err := z.Generate(ctx, [][]schema.ChatMessage{messages}, options...) if err != nil { - return "", err + return nil, fmt.Errorf("failed to generate: %w", err) } if len(r) == 0 { - return "", ErrEmptyResponse + return nil, ErrEmptyResponse } - return r[0].Text, nil + return r[0].Message, nil } -func (z ZhiPuAILLM) Generate(ctx context.Context, prompts []string, options ...langchainllm.CallOption) ([]*langchainllm.Generation, error) { +func (z *ZhiPuAILLM) Generate(ctx context.Context, messageSets [][]schema.ChatMessage, options ...langchainllm.CallOption) ([]*langchainllm.Generation, error) { opts := langchainllm.CallOptions{} for _, opt := range options { opt(&opts) } + generations := make([]*langchainllm.Generation, 0, len(messageSets)) params := DefaultModelParams() - if len(prompts) == 0 { - return nil, ErrEmptyPrompt + if opts.TopP > 0 && opts.TopP < 1 { + params.TopP = float32(opts.TopP) + } + if opts.Temperature > 0 && opts.Temperature < 1 { + params.Temperature = float32(opts.Temperature) } - params.Prompt = []Prompt{ - {Role: User, Content: prompts[0]}, + if opts.Model != "" { + params.Model = opts.Model + } + if len(messageSets) == 0 { + return nil, ErrEmptyPrompt } - klog.Infoln("prompt:", prompts[0]) - client := NewZhiPuAI(z.apiKey) - needStream := opts.StreamingFunc != nil - if needStream { - res := bytes.NewBuffer(nil) - err := client.SSEInvoke(params, func(event *sse.Event) { - if string(event.Event) == "finish" { - return + for _, messageSet := range messageSets { + for _, m := range messageSet { + typ := m.GetType() + switch typ { + case schema.ChatMessageTypeAI: + params.Prompt = append(params.Prompt, Prompt{Role: Assistant, Content: m.GetContent()}) + case schema.ChatMessageTypeHuman, schema.ChatMessageTypeGeneric: + params.Prompt = append(params.Prompt, Prompt{Role: User, Content: m.GetContent()}) + default: + klog.Infof("zhipuai: message type %s not supported, just skip\n", typ) } - _, _ = res.Write(event.Data) - _ = opts.StreamingFunc(ctx, event.Data) - }) - if err != nil { + } + klog.Infof("all history prompts: %#v\n", params.Prompt) + client := NewZhiPuAI(z.apiKey) + needStream := opts.StreamingFunc != nil + if needStream { + res := bytes.NewBuffer(nil) + err := client.SSEInvoke(params, func(event *sse.Event) { + if string(event.Event) == "finish" { + return + } + _, _ = res.Write(event.Data) + _ = opts.StreamingFunc(ctx, event.Data) + }) + if err != nil { + return nil, err + } + return []*langchainllm.Generation{ + { + Text: res.String(), + }, + }, nil + } + var resp *Response + var err error + i := 0 + for { + i++ + resp, err = client.Invoke(params) + if err != nil { + return nil, err + } + if resp == nil { + return nil, ErrEmptyResponse + } + if resp.Data == nil { + klog.Errorf("zhipullm get empty response: msg:%s code:%d\n", resp.Msg, resp.Code) + if i <= z.RetryTimes && (resp.Code == CodeConcurrencyHigh || resp.Code == CodefrequencyHigh || resp.Code == CodeTimesHigh) { + r := rand.Intn(5) + klog.Infof("zhipullm triggers retry[%d], sleep %d seconds, then recall...\n", i, r) + time.Sleep(time.Duration(r) * time.Second) + continue + } + return nil, ErrEmptyResponse + } + if len(resp.Data.Choices) == 0 { + return nil, ErrEmptyResponse + } + break + } + generationInfo := make(map[string]any, reflect.ValueOf(resp.Data.Usage).NumField()) + generationInfo["TotalTokens"] = resp.Data.Usage.TotalTokens + var s string + if err := json.Unmarshal([]byte(resp.Data.Choices[0].Content), &s); err != nil { return nil, err } - return []*langchainllm.Generation{ - { - Text: res.String(), - }, - }, nil - } - resp, err := client.Invoke(params) - if err != nil { - return nil, err - } - var s string - klog.Infoln("resp:", resp.String()) - if err := json.Unmarshal([]byte(resp.Data.Choices[0].Content), &s); err != nil { - return nil, err + msg := &schema.AIChatMessage{ + Content: strings.TrimSpace(s), + } + generations = append(generations, &langchainllm.Generation{ + Message: msg, + Text: msg.Content, + GenerationInfo: generationInfo, + }) } - return []*langchainllm.Generation{ - { - Text: strings.TrimSpace(s), - }, - }, nil -} - -func (z ZhiPuAILLM) GeneratePrompt(ctx context.Context, promptValues []schema.PromptValue, options ...langchainllm.CallOption) (langchainllm.LLMResult, error) { - return langchainllm.GeneratePrompt(ctx, z, promptValues, options...) -} - -func (z ZhiPuAILLM) GetNumTokens(text string) int { - // TODO implement me - panic("implement me") + return generations, nil } diff --git a/pkg/llms/zhipuai/response.go b/pkg/llms/zhipuai/response.go index 617ccfaab..ebee9fb22 100644 --- a/pkg/llms/zhipuai/response.go +++ b/pkg/llms/zhipuai/response.go @@ -107,3 +107,9 @@ type Choice struct { Content string `json:"content"` Role string `json:"role"` } + +const ( + CodeConcurrencyHigh = 1302 // 您当前使用该 API 的并发数过高,请降低并发,或联系客服增加限额 + CodefrequencyHigh = 1303 // 您当前使用该 API 的频率过高,请降低频率,或联系客服增加限额 + CodeTimesHigh = 1305 // 当前 API 请求过多,请稍后重试 +) diff --git a/tests/example-test.sh b/tests/example-test.sh index fe2197d0e..b743a42c6 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -171,6 +171,41 @@ function waitCRDStatusReady() { done } +function getRespInAppChat() { + appname=$1 + namespace=$2 + query=$3 + conversionID=$4 + testStream=$5 + START_TIME=$(date +%s) + while true; do + data=$(jq -n --arg appname "$appname" --arg query "$query" --arg namespace "$namespace" --arg conversionID "$conversionID" '{"query":$query,"response_mode":"blocking","conversion_id":$conversionID,"app_name":$appname, "app_namespace":$namespace}') + resp=$(curl -s -XPOST http://127.0.0.1:8081/chat --data "$data") + ai_data=$(echo $resp | jq -r '.message') + if [ -z "$ai_data" ] || [ "$ai_data" = "null" ]; then + echo $resp + exit 1 + fi + echo "👤: ${query}" + echo "🤖: ${ai_data}" + resp_conversion_id=$(echo $resp | jq -r '.conversion_id') + + if [ $testStream == "true" ]; then + info "just test stream mode" + data=$(jq -n --arg appname "$appname" --arg query "$query" --arg namespace "$namespace" --arg conversionID "$conversionID" '{"query":$query,"response_mode":"streaming","conversion_id":$conversionID,"app_name":$appname, "app_namespace":$namespace}') + curl -s -XPOST http://127.0.0.1:8081/chat --data "$data" + fi + break + CURRENT_TIME=$(date +%s) + ELAPSED_TIME=$((CURRENT_TIME - START_TIME)) + if [ $ELAPSED_TIME -gt $TimeoutSeconds ]; then + error "Timeout reached" + exit 1 + fi + sleep 5 + done +} + info "1. create kind cluster" make kind @@ -251,22 +286,43 @@ else exit 1 fi -info "8 validation:simple app of llmchain can work normally" +info "8 validate simple app can work normally" +info "8.1 app of llmchain" kubectl apply -f config/samples/app_llmchain_englishteacher.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-english-teacher" kubectl port-forward svc/arcadia-apiserver -n arcadia 8081:8081 >/dev/null 2>&1 & portal_pid=$! info "port-forward portal in pid: $portal_pid" sleep 3 -curl -XPOST http://127.0.0.1:8081/chat --data '{"query":"hi, how are you?","response_mode":"blocking","conversion_id":"","app_name":"base-chat-english-teacher", "app_namespace":"arcadia"}' | jq -e '.message' +getRespInAppChat "base-chat-english-teacher" "arcadia" "hi how are you?" "" "true" -info "9 validation:QA app using knowledgebase can work normally" +info "8.2 QA app using knowledgebase" kubectl apply -f config/samples/app_retrievalqachain_knowledgebase.yaml waitCRDStatusReady "Application" "arcadia" "base-chat-with-knowledgebase" sleep 3 -curl -XPOST http://127.0.0.1:8081/chat --data '{"query":"旷工最小计算单位为多少天?","response_mode":"blocking","conversion_id":"","app_name":"base-chat-with-knowledgebase", "app_namespace":"arcadia"}' | jq -e '.message' +getRespInAppChat "base-chat-with-knowledgebase" "arcadia" "旷工最小计算单位为多少天?" "" "true" + +info "8.3 conversion chat app" +kubectl apply -f config/samples/app_llmchain_chat_with_bot.yaml +waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot" +sleep 3 +getRespInAppChat "base-chat-with-bot" "arcadia" "Hi I am Jim" "" "false" +getRespInAppChat "base-chat-with-bot" "arcadia" "What is my name?" ${resp_conversion_id} "false" +if [[ $resp != *"Jim"* ]]; then + echo "Because conversionWindowSize is enabled to be 2, llm should record history, but resp:"$resp "dont contains Jim" + exit 1 +fi +# There is uncertainty in the AI replies, most of the time, it will pass the test, a small percentage of the time, the AI will call names in each reply, causing the test to fail, therefore, temporarily disable the following tests +#getRespInAppChat "base-chat-with-bot" "arcadia" "What is your model?" ${resp_conversion_id} "false" +#getRespInAppChat "base-chat-with-bot" "arcadia" "Does your model based on gpt-3.5?" ${resp_conversion_id} "false" +#getRespInAppChat "base-chat-with-bot" "arcadia" "When was the model you used released?" ${resp_conversion_id} "false" +#getRespInAppChat "base-chat-with-bot" "arcadia" "What is my name?" ${resp_conversion_id} "false" +#if [[ $resp == *"Jim"* ]]; then +# echo "Because conversionWindowSize is enabled to be 2, and current is the 6th conversion, llm should not record My name, but resp:"$resp "still contains Jim" +# exit 1 +#fi -info "10 show apiserver logs for debug" +info "9. show apiserver logs for debug" kubectl logs --tail=100 -n arcadia -l app=arcadia-apiserver >/tmp/apiserver.log cat /tmp/apiserver.log