diff --git a/api/app-node/chain/v1alpha1/llmchain_types.go b/api/app-node/chain/v1alpha1/llmchain_types.go index 3c318f975..442d66533 100644 --- a/api/app-node/chain/v1alpha1/llmchain_types.go +++ b/api/app-node/chain/v1alpha1/llmchain_types.go @@ -20,6 +20,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" node "github.com/kubeagi/arcadia/api/app-node" + agent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" "github.com/kubeagi/arcadia/api/base/v1alpha1" ) @@ -31,6 +32,7 @@ type LLMChainSpec struct { } type CommonChainConfig struct { + Tools []agent.Tool `json:"tools,omitempty"` // for memory Memory Memory `json:"memory,omitempty"` diff --git a/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go b/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go index efd6edd5f..da51ff254 100644 --- a/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go +++ b/api/app-node/chain/v1alpha1/zz_generated.deepcopy.go @@ -22,6 +22,7 @@ limitations under the License. package v1alpha1 import ( + agentv1alpha1 "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" runtime "k8s.io/apimachinery/pkg/runtime" ) @@ -120,6 +121,13 @@ func (in *APIChainStatus) DeepCopy() *APIChainStatus { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *CommonChainConfig) DeepCopyInto(out *CommonChainConfig) { *out = *in + if in.Tools != nil { + in, out := &in.Tools, &out.Tools + *out = make([]agentv1alpha1.Tool, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } out.Memory = in.Memory if in.StopWords != nil { in, out := &in.StopWords, &out.StopWords diff --git a/apiserver/docs/docs.go b/apiserver/docs/docs.go index 84c965d45..0476f3756 100644 --- a/apiserver/docs/docs.go +++ b/apiserver/docs/docs.go @@ -1174,7 +1174,7 @@ const docTemplate = `{ "example": "旷工最小计算单位为 0.5 天。" }, "content": { - "description": "related content in the source file", + "description": "related content in the source file or in webpage", "type": "string", "example": "旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。" }, @@ -1207,6 +1207,16 @@ const docTemplate = `{ "description": "vector search score", "type": "number", "example": 0.34 + }, + "title": { + "description": "Title of the webpage", + "type": "string", + "example": "开始使用 Microsoft 帐户 – Microsoft" + }, + "url": { + "description": "URL of the webpage", + "type": "string", + "example": "https://www.microsoft.com/zh-cn/welcome" } } }, diff --git a/apiserver/docs/swagger.json b/apiserver/docs/swagger.json index 83ee95fc0..ef909b560 100644 --- a/apiserver/docs/swagger.json +++ b/apiserver/docs/swagger.json @@ -1163,7 +1163,7 @@ "example": "旷工最小计算单位为 0.5 天。" }, "content": { - "description": "related content in the source file", + "description": "related content in the source file or in webpage", "type": "string", "example": "旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。" }, @@ -1196,6 +1196,16 @@ "description": "vector search score", "type": "number", "example": 0.34 + }, + "title": { + "description": "Title of the webpage", + "type": "string", + "example": "开始使用 Microsoft 帐户 – Microsoft" + }, + "url": { + "description": "URL of the webpage", + "type": "string", + "example": "https://www.microsoft.com/zh-cn/welcome" } } }, diff --git a/apiserver/docs/swagger.yaml b/apiserver/docs/swagger.yaml index de9d73469..a6684f511 100644 --- a/apiserver/docs/swagger.yaml +++ b/apiserver/docs/swagger.yaml @@ -183,7 +183,7 @@ definitions: example: 旷工最小计算单位为 0.5 天。 type: string content: - description: related content in the source file + description: related content in the source file or in webpage example: 旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。 type: string file_name: @@ -210,6 +210,14 @@ definitions: description: vector search score example: 0.34 type: number + title: + description: Title of the webpage + example: 开始使用 Microsoft 帐户 – Microsoft + type: string + url: + description: URL of the webpage + example: https://www.microsoft.com/zh-cn/welcome + type: string type: object service.Chunk: properties: diff --git a/apiserver/graph/generated/generated.go b/apiserver/graph/generated/generated.go index bc5b00899..cae6003b6 100644 --- a/apiserver/graph/generated/generated.go +++ b/apiserver/graph/generated/generated.go @@ -90,6 +90,7 @@ type ComplexityRoot struct { ShowRespInfo func(childComplexity int) int ShowRetrievalInfo func(childComplexity int) int Temperature func(childComplexity int) int + Tools func(childComplexity int) int UserPrompt func(childComplexity int) int } @@ -621,6 +622,11 @@ type ComplexityRoot struct { MatchLabels func(childComplexity int) int } + Tool struct { + Name func(childComplexity int) int + Params func(childComplexity int) int + } + TypedObjectReference struct { APIGroup func(childComplexity int) int DisplayName func(childComplexity int) int @@ -1003,6 +1009,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Application.Temperature(childComplexity), true + case "Application.tools": + if e.complexity.Application.Tools == nil { + break + } + + return e.complexity.Application.Tools(childComplexity), true + case "Application.userPrompt": if e.complexity.Application.UserPrompt == nil { break @@ -3741,6 +3754,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Selector.MatchLabels(childComplexity), true + case "Tool.name": + if e.complexity.Tool.Name == nil { + break + } + + return e.complexity.Tool.Name(childComplexity), true + + case "Tool.params": + if e.complexity.Tool.Params == nil { + break + } + + return e.complexity.Tool.Params(childComplexity), true + case "TypedObjectReference.apiGroup": if e.complexity.TypedObjectReference.APIGroup == nil { break @@ -4298,6 +4325,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { ec.unmarshalInputResourceInput, ec.unmarshalInputResourcesInput, ec.unmarshalInputSelectorInput, + ec.unmarshalInputToolInput, ec.unmarshalInputTypedObjectReferenceInput, ec.unmarshalInputUpdateApplicationConfigInput, ec.unmarshalInputUpdateApplicationMetadataInput, @@ -4507,6 +4535,10 @@ type Application { showNextGuide 下一步引导,即是否在chat界面显示下一步引导 """ showNextGuide: Boolean + """ + tools 要使用的工具列表 + """ + tools: [Tool] } """ @@ -4747,6 +4779,10 @@ input UpdateApplicationConfigInput { showNextGuide 下一步引导,即是否在chat界面显示下一步引导 """ showNextGuide: Boolean + """ + tools 要使用的工具列表 + """ + tools: [ToolInput] } `, BuiltIn: false}, {Name: "../schema/dataprocessing.graphqls", Input: `# 数据处理 Mutation @@ -5591,7 +5627,35 @@ type TypedObjectReference { namespace: String } -union PageNode = Datasource | Model | Embedder | KnowledgeBase | Dataset | VersionedDataset | F | Worker | ApplicationMetadata | LLM | ModelService | RayCluster | RAG +""" +ToolInput 应用和Agent中用到的工具 +""" +input ToolInput { + """ + 名称(必填),目前只有bing可选 + """ + name: String! + """ + params 参数,可选 + """ + params: Map +} + +""" +Tool 应用和Agent中用到的工具 +""" +type Tool { + """ + 名称,目前只有bing可选 + """ + name: String + """ + params 参数 + """ + params: Map +} + +union PageNode = Datasource | Model | Embedder | KnowledgeBase | Dataset | VersionedDataset | F | Worker | ApplicationMetadata | LLM | ModelService | RayCluster `, BuiltIn: false}, {Name: "../schema/k8s.graphqls", Input: `type LabelSelectorRequirement { key: String @@ -9000,6 +9064,53 @@ func (ec *executionContext) fieldContext_Application_showNextGuide(ctx context.C return fc, nil } +func (ec *executionContext) _Application_tools(ctx context.Context, field graphql.CollectedField, obj *Application) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Application_tools(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Tools, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.([]*Tool) + fc.Result = res + return ec.marshalOTool2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTool(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Application_tools(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Application", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "name": + return ec.fieldContext_Tool_name(ctx, field) + case "params": + return ec.fieldContext_Tool_params(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type Tool", field.Name) + }, + } + return fc, nil +} + func (ec *executionContext) _ApplicationMetadata_name(ctx context.Context, field graphql.CollectedField, obj *ApplicationMetadata) (ret graphql.Marshaler) { fc, err := ec.fieldContext_ApplicationMetadata_name(ctx, field) if err != nil { @@ -9828,6 +9939,8 @@ func (ec *executionContext) fieldContext_ApplicationMutation_updateApplicationCo return ec.fieldContext_Application_showRetrievalInfo(ctx, field) case "showNextGuide": return ec.fieldContext_Application_showNextGuide(ctx, field) + case "tools": + return ec.fieldContext_Application_tools(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Application", field.Name) }, @@ -9917,6 +10030,8 @@ func (ec *executionContext) fieldContext_ApplicationQuery_getApplication(ctx con return ec.fieldContext_Application_showRetrievalInfo(ctx, field) case "showNextGuide": return ec.fieldContext_Application_showNextGuide(ctx, field) + case "tools": + return ec.fieldContext_Application_tools(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Application", field.Name) }, @@ -24477,6 +24592,8 @@ func (ec *executionContext) fieldContext_RAG_application(ctx context.Context, fi return ec.fieldContext_Application_showRetrievalInfo(ctx, field) case "showNextGuide": return ec.fieldContext_Application_showNextGuide(ctx, field) + case "tools": + return ec.fieldContext_Application_tools(ctx, field) } return nil, fmt.Errorf("no field named %q was found under type Application", field.Name) }, @@ -26154,6 +26271,88 @@ func (ec *executionContext) fieldContext_Selector_matchExpressions(ctx context.C return fc, nil } +func (ec *executionContext) _Tool_name(ctx context.Context, field graphql.CollectedField, obj *Tool) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Tool_name(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Name, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*string) + fc.Result = res + return ec.marshalOString2ᚖstring(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Tool_name(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Tool", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _Tool_params(ctx context.Context, field graphql.CollectedField, obj *Tool) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Tool_params(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.Params, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(map[string]interface{}) + fc.Result = res + return ec.marshalOMap2map(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Tool_params(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Tool", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Map does not have child fields") + }, + } + return fc, nil +} + func (ec *executionContext) _TypedObjectReference_apiGroup(ctx context.Context, field graphql.CollectedField, obj *TypedObjectReference) (ret graphql.Marshaler) { fc, err := ec.fieldContext_TypedObjectReference_apiGroup(ctx, field) if err != nil { @@ -34223,6 +34422,44 @@ func (ec *executionContext) unmarshalInputSelectorInput(ctx context.Context, obj return it, nil } +func (ec *executionContext) unmarshalInputToolInput(ctx context.Context, obj interface{}) (ToolInput, error) { + var it ToolInput + asMap := map[string]interface{}{} + for k, v := range obj.(map[string]interface{}) { + asMap[k] = v + } + + fieldsInOrder := [...]string{"name", "params"} + for _, k := range fieldsInOrder { + v, ok := asMap[k] + if !ok { + continue + } + switch k { + case "name": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("name")) + data, err := ec.unmarshalNString2string(ctx, v) + if err != nil { + return it, err + } + it.Name = data + case "params": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("params")) + data, err := ec.unmarshalOMap2map(ctx, v) + if err != nil { + return it, err + } + it.Params = data + } + } + + return it, nil +} + func (ec *executionContext) unmarshalInputTypedObjectReferenceInput(ctx context.Context, obj interface{}) (TypedObjectReferenceInput, error) { var it TypedObjectReferenceInput asMap := map[string]interface{}{} @@ -34286,7 +34523,7 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte asMap[k] = v } - fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide"} + fieldsInOrder := [...]string{"name", "namespace", "prologue", "model", "llm", "temperature", "maxLength", "maxTokens", "conversionWindowSize", "knowledgebase", "scoreThreshold", "numDocuments", "docNullReturn", "userPrompt", "showRespInfo", "showRetrievalInfo", "showNextGuide", "tools"} for _, k := range fieldsInOrder { v, ok := asMap[k] if !ok { @@ -34446,6 +34683,15 @@ func (ec *executionContext) unmarshalInputUpdateApplicationConfigInput(ctx conte return it, err } it.ShowNextGuide = data + case "tools": + var err error + + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("tools")) + data, err := ec.unmarshalOToolInput2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐToolInput(ctx, v) + if err != nil { + return it, err + } + it.Tools = data } } @@ -35713,13 +35959,6 @@ func (ec *executionContext) _PageNode(ctx context.Context, sel ast.SelectionSet, return graphql.Null } return ec._RayCluster(ctx, sel, obj) - case Rag: - return ec._RAG(ctx, sel, &obj) - case *Rag: - if obj == nil { - return graphql.Null - } - return ec._RAG(ctx, sel, obj) default: panic(fmt.Errorf("unexpected type %T", obj)) } @@ -35775,6 +36014,8 @@ func (ec *executionContext) _Application(ctx context.Context, sel ast.SelectionS out.Values[i] = ec._Application_showRetrievalInfo(ctx, field, obj) case "showNextGuide": out.Values[i] = ec._Application_showNextGuide(ctx, field, obj) + case "tools": + out.Values[i] = ec._Application_tools(ctx, field, obj) default: panic("unknown field " + strconv.Quote(field.Name)) } @@ -40335,7 +40576,7 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr return out } -var rAGImplementors = []string{"RAG", "PageNode"} +var rAGImplementors = []string{"RAG"} func (ec *executionContext) _RAG(ctx context.Context, sel ast.SelectionSet, obj *Rag) graphql.Marshaler { fields := graphql.CollectFields(ec.OperationContext, sel, rAGImplementors) @@ -41120,6 +41361,44 @@ func (ec *executionContext) _Selector(ctx context.Context, sel ast.SelectionSet, return out } +var toolImplementors = []string{"Tool"} + +func (ec *executionContext) _Tool(ctx context.Context, sel ast.SelectionSet, obj *Tool) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, toolImplementors) + + out := graphql.NewFieldSet(fields) + deferred := make(map[string]*graphql.FieldSet) + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("Tool") + case "name": + out.Values[i] = ec._Tool_name(ctx, field, obj) + case "params": + out.Values[i] = ec._Tool_params(ctx, field, obj) + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch(ctx) + if out.Invalids > 0 { + return graphql.Null + } + + atomic.AddInt32(&ec.deferred, int32(len(deferred))) + + for label, dfs := range deferred { + ec.processDeferredGroup(graphql.DeferredGroup{ + Label: label, + Path: graphql.GetPath(ctx), + FieldSet: dfs, + Context: ctx, + }) + } + + return out +} + var typedObjectReferenceImplementors = []string{"TypedObjectReference"} func (ec *executionContext) _TypedObjectReference(ctx context.Context, sel ast.SelectionSet, obj *TypedObjectReference) graphql.Marshaler { @@ -44662,6 +44941,82 @@ func (ec *executionContext) marshalOTime2ᚖtimeᚐTime(ctx context.Context, sel return res } +func (ec *executionContext) marshalOTool2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTool(ctx context.Context, sel ast.SelectionSet, v []*Tool) graphql.Marshaler { + if v == nil { + return graphql.Null + } + ret := make(graphql.Array, len(v)) + var wg sync.WaitGroup + isLen1 := len(v) == 1 + if !isLen1 { + wg.Add(len(v)) + } + for i := range v { + i := i + fc := &graphql.FieldContext{ + Index: &i, + Result: &v[i], + } + ctx := graphql.WithFieldContext(ctx, fc) + f := func(i int) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + if !isLen1 { + defer wg.Done() + } + ret[i] = ec.marshalOTool2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTool(ctx, sel, v[i]) + } + if isLen1 { + f(i) + } else { + go f(i) + } + + } + wg.Wait() + + return ret +} + +func (ec *executionContext) marshalOTool2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTool(ctx context.Context, sel ast.SelectionSet, v *Tool) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._Tool(ctx, sel, v) +} + +func (ec *executionContext) unmarshalOToolInput2ᚕᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐToolInput(ctx context.Context, v interface{}) ([]*ToolInput, error) { + if v == nil { + return nil, nil + } + var vSlice []interface{} + if v != nil { + vSlice = graphql.CoerceList(v) + } + var err error + res := make([]*ToolInput, len(vSlice)) + for i := range vSlice { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithIndex(i)) + res[i], err = ec.unmarshalOToolInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐToolInput(ctx, vSlice[i]) + if err != nil { + return nil, err + } + } + return res, nil +} + +func (ec *executionContext) unmarshalOToolInput2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐToolInput(ctx context.Context, v interface{}) (*ToolInput, error) { + if v == nil { + return nil, nil + } + res, err := ec.unmarshalInputToolInput(ctx, v) + return &res, graphql.ErrorOnPath(ctx, err) +} + func (ec *executionContext) marshalOTypedObjectReference2ᚖgithubᚗcomᚋkubeagiᚋarcadiaᚋapiserverᚋgraphᚋgeneratedᚐTypedObjectReference(ctx context.Context, sel ast.SelectionSet, v *TypedObjectReference) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/apiserver/graph/generated/models_gen.go b/apiserver/graph/generated/models_gen.go index 55a2abb3f..c0888a2ab 100644 --- a/apiserver/graph/generated/models_gen.go +++ b/apiserver/graph/generated/models_gen.go @@ -70,6 +70,8 @@ type Application struct { ShowRetrievalInfo *bool `json:"showRetrievalInfo,omitempty"` // showNextGuide 下一步引导,即是否在chat界面显示下一步引导 ShowNextGuide *bool `json:"showNextGuide,omitempty"` + // tools 要使用的工具列表 + Tools []*Tool `json:"tools,omitempty"` } // Application @@ -1274,8 +1276,6 @@ type Rag struct { PhaseMessage *string `json:"phaseMessage,omitempty"` } -func (Rag) IsPageNode() {} - type RAGDataset struct { Source *TypedObjectReference `json:"source,omitempty"` Files []*F `json:"files,omitempty"` @@ -1370,6 +1370,22 @@ type SelectorInput struct { MatchExpressions []*LabelSelectorRequirementInput `json:"matchExpressions,omitempty"` } +// Tool 应用和Agent中用到的工具 +type Tool struct { + // 名称,目前只有bing可选 + Name *string `json:"name,omitempty"` + // params 参数 + Params map[string]interface{} `json:"params,omitempty"` +} + +// ToolInput 应用和Agent中用到的工具 +type ToolInput struct { + // 名称(必填),目前只有bing可选 + Name string `json:"name"` + // params 参数,可选 + Params map[string]interface{} `json:"params,omitempty"` +} + type TypedObjectReference struct { APIGroup *string `json:"apiGroup,omitempty"` Kind string `json:"kind"` @@ -1422,6 +1438,8 @@ type UpdateApplicationConfigInput struct { ShowRetrievalInfo *bool `json:"showRetrievalInfo,omitempty"` // showNextGuide 下一步引导,即是否在chat界面显示下一步引导 ShowNextGuide *bool `json:"showNextGuide,omitempty"` + // tools 要使用的工具列表 + Tools []*ToolInput `json:"tools,omitempty"` } type UpdateApplicationMetadataInput struct { diff --git a/apiserver/graph/schema/application.gql b/apiserver/graph/schema/application.gql index d36edd869..789a6a0a1 100644 --- a/apiserver/graph/schema/application.gql +++ b/apiserver/graph/schema/application.gql @@ -67,6 +67,10 @@ mutation updateApplicationConfig($input: UpdateApplicationConfigInput!){ showRespInfo showRetrievalInfo showNextGuide + tools { + name + params + } } } } @@ -104,6 +108,10 @@ query getApplication($name: String!, $namespace: String!){ showRespInfo showRetrievalInfo showNextGuide + tools { + name + params + } } } } diff --git a/apiserver/graph/schema/application.graphqls b/apiserver/graph/schema/application.graphqls index bca4bcca0..d60788537 100644 --- a/apiserver/graph/schema/application.graphqls +++ b/apiserver/graph/schema/application.graphqls @@ -95,6 +95,10 @@ type Application { showNextGuide 下一步引导,即是否在chat界面显示下一步引导 """ showNextGuide: Boolean + """ + tools 要使用的工具列表 + """ + tools: [Tool] } """ @@ -335,4 +339,8 @@ input UpdateApplicationConfigInput { showNextGuide 下一步引导,即是否在chat界面显示下一步引导 """ showNextGuide: Boolean + """ + tools 要使用的工具列表 + """ + tools: [ToolInput] } diff --git a/apiserver/graph/schema/entrypoint.graphqls b/apiserver/graph/schema/entrypoint.graphqls index d73ea98bc..211eee3a7 100644 --- a/apiserver/graph/schema/entrypoint.graphqls +++ b/apiserver/graph/schema/entrypoint.graphqls @@ -76,4 +76,32 @@ type TypedObjectReference { namespace: String } -union PageNode = Datasource | Model | Embedder | KnowledgeBase | Dataset | VersionedDataset | F | Worker | ApplicationMetadata | LLM | ModelService | RayCluster | RAG +""" +ToolInput 应用和Agent中用到的工具 +""" +input ToolInput { + """ + 名称(必填),目前只有bing可选 + """ + name: String! + """ + params 参数,可选 + """ + params: Map +} + +""" +Tool 应用和Agent中用到的工具 +""" +type Tool { + """ + 名称,目前只有bing可选 + """ + name: String + """ + params 参数 + """ + params: Map +} + +union PageNode = Datasource | Model | Embedder | KnowledgeBase | Dataset | VersionedDataset | F | Worker | ApplicationMetadata | LLM | ModelService | RayCluster diff --git a/apiserver/pkg/application/application.go b/apiserver/pkg/application/application.go index afa7d88f1..58c8e3432 100644 --- a/apiserver/pkg/application/application.go +++ b/apiserver/pkg/application/application.go @@ -30,6 +30,7 @@ import ( "k8s.io/client-go/dynamic" "k8s.io/utils/pointer" + agent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" apichain "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" apiprompt "github.com/kubeagi/arcadia/api/app-node/prompt/v1alpha1" apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1" @@ -90,6 +91,12 @@ func cr2app(prompt *apiprompt.Prompt, chainConfig *apichain.CommonChainConfig, r gApp.MaxLength = pointer.Int(chainConfig.MaxLength) gApp.MaxTokens = pointer.Int(chainConfig.MaxTokens) gApp.ConversionWindowSize = pointer.Int(chainConfig.Memory.ConversionWindowSize) + for _, v := range chainConfig.Tools { + gApp.Tools = append(gApp.Tools, &generated.Tool{ + Name: pointer.String(v.Name), + Params: utils.MapStr2Any(v.Params), + }) + } } for _, node := range app.Spec.Nodes { if node.Ref == nil { @@ -382,6 +389,12 @@ func UpdateApplicationConfig(ctx context.Context, c dynamic.Interface, input gen qachain.Spec.MaxTokens = pointer.IntDeref(input.MaxTokens, qachain.Spec.MaxTokens) qachain.Spec.Temperature = pointer.Float64Deref(input.Temperature, qachain.Spec.Temperature) qachain.Spec.Memory.ConversionWindowSize = pointer.IntDeref(input.ConversionWindowSize, qachain.Spec.Memory.ConversionWindowSize) + for _, v := range input.Tools { + qachain.Spec.Tools = append(qachain.Spec.Tools, agent.Tool{ + Name: v.Name, + Params: utils.MapAny2Str(v.Params), + }) + } }, qachain); err != nil { return nil, err } @@ -418,6 +431,12 @@ func UpdateApplicationConfig(ctx context.Context, c dynamic.Interface, input gen llmchain.Spec.MaxTokens = pointer.IntDeref(input.MaxTokens, llmchain.Spec.MaxTokens) llmchain.Spec.Temperature = pointer.Float64Deref(input.Temperature, llmchain.Spec.Temperature) llmchain.Spec.Memory.ConversionWindowSize = pointer.IntDeref(input.ConversionWindowSize, llmchain.Spec.Memory.ConversionWindowSize) + for _, v := range input.Tools { + llmchain.Spec.Tools = append(llmchain.Spec.Tools, agent.Tool{ + Name: v.Name, + Params: utils.MapAny2Str(v.Params), + }) + } }, llmchain); err != nil { return nil, err } diff --git a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml index 51a1fc033..d3ddb5c8f 100644 --- a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml +++ b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml @@ -102,6 +102,20 @@ spec: maximum: 1 minimum: 0 type: number + tools: + items: + description: Tool/Capability that this agent will use + properties: + name: + description: Name of the tool + type: string + params: + additionalProperties: + type: string + description: Map of key/value that will be passed to the tool + type: object + type: object + type: array topK: description: TopK is the number of tokens to consider for top-k sampling in a llm call. diff --git a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml index 094f4ea93..9fe511543 100644 --- a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml +++ b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml @@ -98,6 +98,20 @@ spec: maximum: 1 minimum: 0 type: number + tools: + items: + description: Tool/Capability that this agent will use + properties: + name: + description: Name of the tool + type: string + params: + additionalProperties: + type: string + description: Map of key/value that will be passed to the tool + type: object + type: object + type: array topK: description: TopK is the number of tokens to consider for top-k sampling in a llm call. diff --git a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml index 05d449dfb..148fe177a 100644 --- a/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml +++ b/config/crd/bases/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml @@ -101,6 +101,20 @@ spec: maximum: 1 minimum: 0 type: number + tools: + items: + description: Tool/Capability that this agent will use + properties: + name: + description: Name of the tool + type: string + params: + additionalProperties: + type: string + description: Map of key/value that will be passed to the tool + type: object + type: object + type: array topK: description: TopK is the number of tokens to consider for top-k sampling in a llm call. diff --git a/config/samples/app_llmchain_chat_with_bot_bing.yaml b/config/samples/app_llmchain_chat_with_bot_bing.yaml new file mode 100644 index 000000000..3241666fb --- /dev/null +++ b/config/samples/app_llmchain_chat_with_bot_bing.yaml @@ -0,0 +1,86 @@ +apiVersion: arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Application +metadata: + name: base-chat-with-bot-bing + namespace: arcadia +spec: + displayName: "对话机器人" + description: "和AI对话,品赛博人生" + prologue: "Hello, I am KubeAGI Bot🤖, Tell me something?" + nodes: + - name: Input + displayName: "用户输入" + description: "用户输入节点,必须" + ref: + kind: Input + name: Input + nextNodeName: ["prompt-node"] + - name: prompt-node + displayName: "prompt" + description: "设定prompt,template中可以使用{{xx}}来替换变量" + ref: + apiGroup: prompt.arcadia.kubeagi.k8s.com.cn + kind: Prompt + name: base-chat-with-bot-bing + nextNodeName: ["chain-node"] + - name: llm-node + displayName: "zhipu大模型服务" + description: "设定大模型的访问信息" + ref: + apiGroup: arcadia.kubeagi.k8s.com.cn + kind: LLM + name: app-shared-llm-service + nextNodeName: ["chain-node"] + - name: chain-node + displayName: "llm chain" + description: "chain是langchain的核心概念,llmChain用于连接prompt和llm" + ref: + apiGroup: chain.arcadia.kubeagi.k8s.com.cn + kind: LLMChain + name: base-chat-with-bot-bing + nextNodeName: ["Output"] + - name: Output + displayName: "最终输出" + description: "最终输出节点,必须" + ref: + kind: Output + name: Output +--- +apiVersion: prompt.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: Prompt +metadata: + name: base-chat-with-bot-bing + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"Input","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"length":1}]' +spec: + displayName: "设定对话的prompt" + description: "设定对话的prompt" + userMessage: | + The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know. + Current Context: + {{.context}} + + Current conversation: + {{.history}} + + Human: {{.question}} + AI: +--- +apiVersion: chain.arcadia.kubeagi.k8s.com.cn/v1alpha1 +kind: LLMChain +metadata: + name: base-chat-with-bot-bing + namespace: arcadia + annotations: + arcadia.kubeagi.k8s.com.cn/input-rules: '[{"kind":"LLM","group":"arcadia.kubeagi.k8s.com.cn","length":1},{"kind":"prompt","group":"prompt.arcadia.kubeagi.k8s.com.cn","length":1}]' + arcadia.kubeagi.k8s.com.cn/output-rules: '[{"kind":"Output","length":1}]' +spec: + displayName: "llm chain" + description: "llm chain" + memory: + conversionWindowSize: 2 + model: chatglm_turbo # notice: default model chatglm_lite gets poor results in most cases, openai's gpt-3.5-turbo is also good enough + tools: + - name: bing diff --git a/deploy/charts/arcadia/Chart.yaml b/deploy/charts/arcadia/Chart.yaml index 3eca815d8..6f63fa007 100644 --- a/deploy/charts/arcadia/Chart.yaml +++ b/deploy/charts/arcadia/Chart.yaml @@ -2,7 +2,7 @@ apiVersion: v2 name: arcadia description: A Helm chart(KubeBB Component) for KubeAGI Arcadia type: application -version: 0.2.21 +version: 0.2.22 appVersion: "0.1.0" keywords: diff --git a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml index 51a1fc033..d3ddb5c8f 100644 --- a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml +++ b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_apichains.yaml @@ -102,6 +102,20 @@ spec: maximum: 1 minimum: 0 type: number + tools: + items: + description: Tool/Capability that this agent will use + properties: + name: + description: Name of the tool + type: string + params: + additionalProperties: + type: string + description: Map of key/value that will be passed to the tool + type: object + type: object + type: array topK: description: TopK is the number of tokens to consider for top-k sampling in a llm call. diff --git a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml index 094f4ea93..9fe511543 100644 --- a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml +++ b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_llmchains.yaml @@ -98,6 +98,20 @@ spec: maximum: 1 minimum: 0 type: number + tools: + items: + description: Tool/Capability that this agent will use + properties: + name: + description: Name of the tool + type: string + params: + additionalProperties: + type: string + description: Map of key/value that will be passed to the tool + type: object + type: object + type: array topK: description: TopK is the number of tokens to consider for top-k sampling in a llm call. diff --git a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml index 05d449dfb..148fe177a 100644 --- a/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml +++ b/deploy/charts/arcadia/crds/chain.arcadia.kubeagi.k8s.com.cn_retrievalqachains.yaml @@ -101,6 +101,20 @@ spec: maximum: 1 minimum: 0 type: number + tools: + items: + description: Tool/Capability that this agent will use + properties: + name: + description: Name of the tool + type: string + params: + additionalProperties: + type: string + description: Map of key/value that will be passed to the tool + type: object + type: object + type: array topK: description: TopK is the number of tokens to consider for top-k sampling in a llm call. diff --git a/deploy/charts/arcadia/templates/apiserver.yaml b/deploy/charts/arcadia/templates/apiserver.yaml index 3da470f3c..a8715ad7c 100644 --- a/deploy/charts/arcadia/templates/apiserver.yaml +++ b/deploy/charts/arcadia/templates/apiserver.yaml @@ -32,6 +32,8 @@ spec: fieldPath: metadata.namespace - name: DEFAULT_CONFIG value: {{ .Release.Name }}-config + - name: BING_KEY + value: {{ .Values.apiserver.bingKey }} command: - "./apiserver" args: diff --git a/deploy/charts/arcadia/values.yaml b/deploy/charts/arcadia/values.yaml index 3ca6a2651..81cebe79a 100644 --- a/deploy/charts/arcadia/values.yaml +++ b/deploy/charts/arcadia/values.yaml @@ -25,6 +25,7 @@ controller: # @section graphql and bff server # related project: https://github.com/kubeagi/arcadia/tree/main/apiserver apiserver: + bingKey: c30e4d7f3ec24c31a489f883616844b5 image: kubeagi/arcadia:v0.1.0-20240110-0dd9a1f enableplayground: false port: 8081 diff --git a/pkg/appruntime/chain/common.go b/pkg/appruntime/chain/common.go index 6ebe26aa6..e4b356b18 100644 --- a/pkg/appruntime/chain/common.go +++ b/pkg/appruntime/chain/common.go @@ -19,6 +19,8 @@ package chain import ( "context" "fmt" + "strings" + "sync" "github.com/tmc/langchaingo/chains" "github.com/tmc/langchaingo/llms" @@ -26,7 +28,10 @@ import ( langchaingoschema "github.com/tmc/langchaingo/schema" "k8s.io/klog/v2" + agent "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" "github.com/kubeagi/arcadia/api/app-node/chain/v1alpha1" + "github.com/kubeagi/arcadia/pkg/appruntime/retriever" + "github.com/kubeagi/arcadia/pkg/appruntime/tools/bingsearch" ) func stream(res map[string]any) func(ctx context.Context, chunk []byte) error { @@ -98,3 +103,64 @@ func getMemory(llm llms.LLM, config v1alpha1.Memory, history langchaingoschema.C } return memory.NewSimple() } + +func runTools(ctx context.Context, args map[string]any, tools []agent.Tool) map[string]any { + if len(tools) == 0 { + return args + } + input, ok := args["question"].(string) + if !ok { + return args + } + result := make([]string, len(tools)) + resultRef := make([][]retriever.Reference, len(tools)) + for i := range resultRef { + resultRef[i] = make([]retriever.Reference, 0) + } + var wg sync.WaitGroup + wg.Add(len(tools)) + for i, tool := range tools { + i, tool := i, tool + go func(int, agent.Tool) { + defer wg.Done() + switch tool.Name { // nolint:gocritic + case "bing": + klog.V(3).Infof("tools call bing search: %s", input) + client := bingsearch.NewBingClient(tool.Params[bingsearch.ParamAPIKey]) + data, _, err := client.GetWebPages(ctx, input) + if err != nil { + klog.Errorf("failed to call bing search tool: %w", err) + return + } + ref := make([]retriever.Reference, len(data)) + for j := range data { + ref[j] = retriever.Reference{ + Title: data[j].Title, + Content: data[j].Description, + URL: data[j].URL, + } + } + resultRef[i] = ref + result[i] = bingsearch.FormatResults(data) + klog.V(3).Infof("tools call bing search done: %s", input) + } + }(i, tool) + } + wg.Wait() + res := make([]string, 0, len(result)) + for i := range result { + if s := strings.TrimSpace(result[i]); s != "" { + res = append(res, s) + } + } + toolOut := strings.Join(res, "\n") + old, exist := args["context"] + if exist { + toolOut = old.(string) + "\n" + toolOut + } + args["context"] = toolOut + for i := range resultRef { + args = retriever.AddReferencesToArgs(args, resultRef[i]) + } + return args +} diff --git a/pkg/appruntime/chain/llmchain.go b/pkg/appruntime/chain/llmchain.go index 27d87840e..6bd36cd29 100644 --- a/pkg/appruntime/chain/llmchain.go +++ b/pkg/appruntime/chain/llmchain.go @@ -91,6 +91,7 @@ func (l *LLMChain) Run(ctx context.Context, cli dynamic.Interface, args map[stri klog.Infoln("get answer from upstream:", args["_answer"]) args["context"] = args["_answer"] } + args = runTools(ctx, args, instance.Spec.Tools) chain := chains.NewLLMChain(llm, prompt) if history != nil { chain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") diff --git a/pkg/appruntime/chain/retrievalqachain.go b/pkg/appruntime/chain/retrievalqachain.go index a61aacda4..d2f8880d9 100644 --- a/pkg/appruntime/chain/retrievalqachain.go +++ b/pkg/appruntime/chain/retrievalqachain.go @@ -95,6 +95,7 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli dynamic.Interface, args } options := getChainOptions(instance.Spec.CommonChainConfig) + args = runTools(ctx, args, instance.Spec.Tools) llmChain := chains.NewLLMChain(llm, prompt) if history != nil { llmChain.Memory = getMemory(llm, instance.Spec.Memory, history, "", "") @@ -122,7 +123,7 @@ func (l *RetrievalQAChain) Run(ctx context.Context, cli dynamic.Interface, args } } if stuffDocuments != nil && len(stuffDocuments.References) > 0 { - args["_references"] = stuffDocuments.References + args = appretriever.AddReferencesToArgs(args, stuffDocuments.References) } klog.FromContext(ctx).V(5).Info("use retrievalqachain, blocking out:" + out) if err == nil { diff --git a/pkg/appruntime/retriever/knowledgebaseretriever.go b/pkg/appruntime/retriever/knowledgebaseretriever.go index a0df132d3..d584bd538 100644 --- a/pkg/appruntime/retriever/knowledgebaseretriever.go +++ b/pkg/appruntime/retriever/knowledgebaseretriever.go @@ -56,8 +56,12 @@ type Reference struct { FileName string `json:"file_name" example:"员工考勤管理制度-2023.pdf"` // page number in the source file PageNumber int `json:"page_number" example:"1"` - // related content in the source file + // related content in the source file or in webpage Content string `json:"content" example:"旷工最小计算单位为0.5天,不足0.5天以0.5天计算,超过0.5天不满1天以1天计算,以此类推。"` + // Title of the webpage + Title string `json:"title,omitempty" example:"开始使用 Microsoft 帐户 – Microsoft"` + // URL of the webpage + URL string `json:"url,omitempty" example:"https://www.microsoft.com/zh-cn/welcome"` } func (reference Reference) String() string { @@ -68,6 +72,20 @@ func (reference Reference) String() string { return string(bytes) } +func AddReferencesToArgs(args map[string]any, refs []Reference) map[string]any { + if len(refs) == 0 { + return args + } + old, exist := args["_references"] + if exist { + oldRefs := old.([]Reference) + args["_references"] = append(oldRefs, refs...) + return args + } + args["_references"] = refs + return args +} + type KnowledgeBaseRetriever struct { langchaingoschema.Retriever base.BaseNode diff --git a/pkg/appruntime/tools/bingsearch/bing.go b/pkg/appruntime/tools/bingsearch/bing.go new file mode 100644 index 000000000..bcc4d2f62 --- /dev/null +++ b/pkg/appruntime/tools/bingsearch/bing.go @@ -0,0 +1,76 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bingsearch + +import ( + "context" + + "github.com/tmc/langchaingo/callbacks" + "github.com/tmc/langchaingo/tools" + + "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" +) + +const ( + ToolName = "Bing Search API" + ParamAPIKey = "apiKey" +) + +type Tool struct { + client *BingClient + CallbacksHandler callbacks.Handler +} + +var _ tools.Tool = Tool{} + +// New creates a new bing search tool to search on internet +func New(tool *v1alpha1.Tool) (*Tool, error) { + return &Tool{ + client: NewBingClient(tool.Params[ParamAPIKey]), + }, nil +} + +func (t Tool) Name() string { + return ToolName +} + +func (t Tool) Description() string { + return "Invoke API to get the realtime bing search data." +} + +func (t Tool) Call(ctx context.Context, input string) (string, error) { + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolStart(ctx, input) + } + result, err := t.client.Search(ctx, input) + if err != nil { + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolError(ctx, err) + } + return "", err + } + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolEnd(ctx, input) + } + return result, nil +} + +type WebPage struct { + Title string + Description string + URL string +} diff --git a/pkg/appruntime/tools/bingsearch/bing_test.go b/pkg/appruntime/tools/bingsearch/bing_test.go new file mode 100644 index 000000000..4741e164c --- /dev/null +++ b/pkg/appruntime/tools/bingsearch/bing_test.go @@ -0,0 +1,51 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bingsearch + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" +) + +func TestBingSearch(t *testing.T) { + t.Parallel() + apikey := os.Getenv("BING_KEY") + if apikey == "" { + t.Skip("Must set BING_SEARCH_V7_SUBSCRIPTION_KEY to run TestBingSearch") + } + rightTool := &v1alpha1.Tool{ + Params: map[string]string{ + "apiKey": apikey, + }, + } + tool, _ := New(rightTool) + resp, err := tool.Call(context.Background(), "langchain") + require.NoError(t, err) + t.Logf("get format resp:\n%s", resp) + + wrongTool := rightTool + wrongTool.Params["apiKey"] = "xxxxx" + tool, _ = New(wrongTool) + _, err = tool.Call(context.Background(), "langchain") + t.Logf("should get err:\n%s", err) + require.Error(t, err) +} diff --git a/pkg/appruntime/tools/bingsearch/client.go b/pkg/appruntime/tools/bingsearch/client.go new file mode 100644 index 000000000..5aef50296 --- /dev/null +++ b/pkg/appruntime/tools/bingsearch/client.go @@ -0,0 +1,103 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bingsearch + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + + "k8s.io/klog/v2" +) + +const ( + Endpoint = "https://api.bing.microsoft.com/v7.0/search?mkt=zh-CN&q=" + AuthHeaderKey = "Ocp-Apim-Subscription-Key" +) + +type BingClient struct { + apiKey string +} + +func NewBingClient(apiKey string) *BingClient { + if apiKey == "" { + apiKey = os.Getenv("BING_KEY") + } + return &BingClient{ + apiKey: apiKey, + } +} + +func (client *BingClient) Search(ctx context.Context, query string) (string, error) { + p, data, err := client.GetWebPages(ctx, query) + if len(p) > 0 { + return FormatResults(p), nil + } + return data, err +} +func (client *BingClient) GetWebPages(ctx context.Context, query string) (p []WebPage, data string, err error) { + queryURL := Endpoint + url.QueryEscape(query) + + request, err := http.NewRequestWithContext(ctx, http.MethodGet, queryURL, nil) + if err != nil { + return nil, "", fmt.Errorf("creating bingSearch request failed: %w", err) + } + request.Header.Add(AuthHeaderKey, client.apiKey) + + response, err := http.DefaultClient.Do(request) + if err != nil { + return nil, "", fmt.Errorf("bingSearch[%s] get error: %w", queryURL, err) + } + + defer response.Body.Close() + code := response.StatusCode + resp := &RespData{} + if err := json.NewDecoder(response.Body).Decode(&resp); err != nil { + return nil, "", fmt.Errorf("bingSearch parse json resp get err:%w, http status code:%d", err, code) + } + if resp.ErrorResp != nil { + return nil, "", fmt.Errorf("bingSearch get error resp from bing server: http status code:%d message:%s, code:%s", code, resp.ErrorResp.Message, resp.ErrorResp.Code) + } + if len(resp.WebPages.Value) > 0 { + p = make([]WebPage, len(resp.WebPages.Value)) + for i, v := range resp.WebPages.Value { + v := v + p[i] = WebPage{ + Title: v.Name, + Description: v.Snippet, + URL: v.URL, + } + } + } + bytes, err := json.Marshal(resp) + if err != nil { + return nil, "", fmt.Errorf("bingSearch json marshal resp, get err:%w", err) + } + klog.V(3).Infof("bingSearch get webpages: %#v", p) + klog.V(5).Infof("bingSearch get resp: %s", string(bytes)) + return p, string(bytes), nil +} + +func FormatResults(vals []WebPage) (res string) { + for _, val := range vals { + res += fmt.Sprintf("Title: %s\nDescription: %s\nURL: %s\n\n", val.Title, val.Description, val.URL) + } + return res +} diff --git a/pkg/appruntime/tools/bingsearch/resp.go b/pkg/appruntime/tools/bingsearch/resp.go new file mode 100644 index 000000000..850a46ed9 --- /dev/null +++ b/pkg/appruntime/tools/bingsearch/resp.go @@ -0,0 +1,165 @@ +/* +Copyright 2024 KubeAGI. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package bingsearch + +import "time" + +type RespData struct { + Type string `json:"_type"` + QueryContext QueryContext `json:"queryContext"` + WebPages WebPages `json:"webPages"` + Entities Entities `json:"entities"` + Videos Videos `json:"videos"` + RankingResponse RankingResponse `json:"rankingResponse"` + ErrorResp *ErrorResp `json:"error,omitempty"` +} + +type QueryContext struct { + OriginalQuery string `json:"originalQuery"` +} + +type WebPages struct { + WebSearchURL string `json:"webSearchUrl"` + TotalEstimatedMatches int `json:"totalEstimatedMatches"` + Value []WebPagesValue `json:"value"` + SomeResultsRemoved bool `json:"someResultsRemoved"` +} + +type WebPagesValue struct { + ID string `json:"id"` + Name string `json:"name"` + URL string `json:"url"` + DatePublished string `json:"datePublished,omitempty"` + DatePublishedDisplayText string `json:"datePublishedDisplayText,omitempty"` + IsFamilyFriendly bool `json:"isFamilyFriendly"` + DisplayURL string `json:"displayUrl"` + Snippet string `json:"snippet"` + DeepLinks []DeepLink `json:"deepLinks,omitempty"` + DateLastCrawled time.Time `json:"dateLastCrawled"` + CachedPageURL string `json:"cachedPageUrl,omitempty"` + Language string `json:"language"` + IsNavigational bool `json:"isNavigational"` + ThumbnailURL string `json:"thumbnailUrl,omitempty"` + PrimaryImageOfPage PrimaryImageOfPage `json:"primaryImageOfPage,omitempty"` + SearchTags []SearchTag `json:"searchTags,omitempty"` +} + +type DeepLink struct { + Name string `json:"name"` + URL string `json:"url"` + Snippet string `json:"snippet"` +} + +type PrimaryImageOfPage struct { + ThumbnailURL string `json:"thumbnailUrl"` + Width int `json:"width"` + Height int `json:"height"` + ImageID string `json:"imageId"` +} + +type SearchTag struct { + Name string `json:"name"` + Content string `json:"content"` +} + +type Entities struct { + Value []EntitiesValue `json:"value"` +} + +type EntitiesValue struct { + ID string `json:"id"` + EntityPresentationInfo EntityPresentationInfo `json:"entityPresentationInfo"` + BingID string `json:"bingId"` +} + +type EntityPresentationInfo struct { + EntityScenario string `json:"entityScenario"` +} + +type RankingResponse struct { + Mainline RankingResponseMainline `json:"mainline"` + Sidebar RankingResponseSidebar `json:"sidebar"` +} + +type RankingResponseMainline struct { + Items []RankingResponseItem `json:"items"` +} + +type RankingResponseSidebar struct { + Items []RankingResponseItem `json:"items"` +} + +type RankingResponseItem struct { + AnswerType string `json:"answerType"` + ResultIndex int `json:"resultIndex"` + Value RankingResponseItemValue `json:"value"` +} + +type RankingResponseItemValue struct { + ID string `json:"id"` +} + +type ErrorResp struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type Videos struct { + ID string `json:"id"` + ReadLink string `json:"readLink"` + WebSearchURL string `json:"webSearchUrl"` + IsFamilyFriendly bool `json:"isFamilyFriendly"` + Value []VideoValue `json:"value"` + Scenario string `json:"scenario"` +} + +type VideoValue struct { + WebSearchURL string `json:"webSearchUrl"` + Name string `json:"name"` + Description string `json:"description"` + ThumbnailURL string `json:"thumbnailUrl"` + DatePublished string `json:"datePublished"` + Publisher []Publisher `json:"publisher"` + Creator Creator `json:"creator,omitempty"` + ContentURL string `json:"contentUrl"` + HostPageURL string `json:"hostPageUrl"` + EncodingFormat string `json:"encodingFormat"` + HostPageDisplayURL string `json:"hostPageDisplayUrl"` + Width int `json:"width"` + Height int `json:"height"` + Duration string `json:"duration,omitempty"` + MotionThumbnailURL string `json:"motionThumbnailUrl,omitempty"` + EmbedHTML string `json:"embedHtml"` + AllowHTTPSEmbed bool `json:"allowHttpsEmbed"` + ViewCount int `json:"viewCount"` + Thumbnail Thumbnail `json:"thumbnail"` + AllowMobileEmbed bool `json:"allowMobileEmbed"` + IsSuperfresh bool `json:"isSuperfresh"` +} + +type Thumbnail struct { + Width int `json:"width"` + Height int `json:"height"` +} + +type Publisher struct { + Name string `json:"name"` +} + +type Creator struct { + Name string `json:"name"` +} diff --git a/pkg/appruntime/tools/weather/weatherapi.go b/pkg/appruntime/tools/weather/weatherapi.go index 9f0f8ce60..ac907c9d7 100644 --- a/pkg/appruntime/tools/weather/weatherapi.go +++ b/pkg/appruntime/tools/weather/weatherapi.go @@ -19,6 +19,7 @@ import ( "context" "strings" + "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/tools" "github.com/kubeagi/arcadia/api/app-node/agent/v1alpha1" @@ -30,7 +31,8 @@ const ( ) type Tool struct { - client *internal.Client + client *internal.Client + CallbacksHandler callbacks.Handler } var _ tools.Tool = Tool{} @@ -51,9 +53,19 @@ func (t Tool) Description() string { } func (t Tool) Call(ctx context.Context, input string) (string, error) { + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolStart(ctx, input) + } result, err := t.client.GetData(ctx, input) if err != nil { + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolError(ctx, err) + } return "", err } - return strings.Join(strings.Fields(result), " "), nil + result = strings.Join(strings.Fields(result), " ") + if t.CallbacksHandler != nil { + t.CallbacksHandler.HandleToolEnd(ctx, result) + } + return result, nil } diff --git a/tests/example-test.sh b/tests/example-test.sh index 1257f72f4..e8447a4e9 100755 --- a/tests/example-test.sh +++ b/tests/example-test.sh @@ -420,6 +420,17 @@ curl -s -XPOST http://127.0.0.1:8081/chat --data '{"query":"年会不能停的 # echo "Because conversationWindowSize is enabled to be 2, llm should record history, but resp:"$resp "dont contains Jim" # exit 1 #fi +if [[ $GITHUB_ACTIONS != "true" ]]; then + info "8.6 bingsearch test" + kubectl apply -f config/samples/app_llmchain_chat_with_bot_bing.yaml + waitCRDStatusReady "Application" "arcadia" "base-chat-with-bot-bing" + sleep 3 + getRespInAppChat "base-chat-with-bot-bing" "arcadia" "介绍一下微软的产品" "" "false" + if [ -z "$references" ] || [ "$references" = "null" ]; then + echo $resp + exit 1 + fi +fi info "9. show apiserver logs for debug" kubectl logs --tail=100 -n arcadia -l app=arcadia-apiserver >/tmp/apiserver.log