From fe9a465bdc83bac7cda0f84f4b5a0998662c1895 Mon Sep 17 00:00:00 2001
From: Abirdcfly <fp544037857@gmail.com>
Date: Wed, 20 Mar 2024 17:33:22 +0800
Subject: [PATCH] fix: when knowledgebase has only chunk cant generate
 PromptStarter

Signed-off-by: Abirdcfly <fp544037857@gmail.com>
---
 apiserver/pkg/chat/chat_server.go | 148 +++++++++++++++---------------
 1 file changed, 73 insertions(+), 75 deletions(-)

diff --git a/apiserver/pkg/chat/chat_server.go b/apiserver/pkg/chat/chat_server.go
index 8e0d2fb62..8dbe9ce71 100644
--- a/apiserver/pkg/chat/chat_server.go
+++ b/apiserver/pkg/chat/chat_server.go
@@ -21,8 +21,6 @@ import (
 	"context"
 	"errors"
 	"fmt"
-	"io"
-	"path/filepath"
 	"strings"
 	"sync"
 	"time"
@@ -31,10 +29,12 @@ import (
 	langchainllms "github.com/tmc/langchaingo/llms"
 	"github.com/tmc/langchaingo/memory"
 	"github.com/tmc/langchaingo/prompts"
+	langchainschema "github.com/tmc/langchaingo/schema"
 	"k8s.io/apimachinery/pkg/types"
 	"k8s.io/klog/v2"
 	runtimeclient "sigs.k8s.io/controller-runtime/pkg/client"
 
+	apiretriever "github.com/kubeagi/arcadia/api/app-node/retriever/v1alpha1"
 	"github.com/kubeagi/arcadia/api/base/v1alpha1"
 	"github.com/kubeagi/arcadia/apiserver/pkg/auth"
 	"github.com/kubeagi/arcadia/apiserver/pkg/chat/storage"
@@ -47,7 +47,7 @@ import (
 	"github.com/kubeagi/arcadia/pkg/appruntime/retriever"
 	pkgconfig "github.com/kubeagi/arcadia/pkg/config"
 	"github.com/kubeagi/arcadia/pkg/datasource"
-	pkgdocumentloaders "github.com/kubeagi/arcadia/pkg/documentloaders"
+	"github.com/kubeagi/arcadia/pkg/documentloaders"
 )
 
 type ChatServer struct {
@@ -286,94 +286,84 @@ func (cs *ChatServer) ListPromptStarters(ctx context.Context, req APPMetadata, l
 		}
 	}
 	promptStarters = make([]string, 0, limit)
-	remains := limit
+	content := bytes.Buffer{}
+	// if there is a knowledgebase, use it to generate prompt starter
 	if kb != nil {
-		system, err := pkgconfig.GetSystemDatasource(ctx, c)
+		outArg, finish, err := retriever.GenerateKnowledgebaseRetriever(ctx, c, kb.Name, kb.Namespace, apiretriever.CommonRetrieverConfig{NumDocuments: limit}, map[string]any{"question": ""})
 		if err != nil {
 			return nil, err
 		}
-		endpoint := system.Spec.Endpoint.DeepCopy()
-		if endpoint != nil && endpoint.AuthSecret != nil {
-			endpoint.AuthSecret.WithNameSpace(system.Namespace)
+		if finish != nil {
+			defer finish()
 		}
-		ds, err := datasource.NewLocal(ctx, c, endpoint)
-		if err != nil {
-			return nil, err
-		}
-	Outer:
-		for _, detail := range kb.Status.FileGroupDetail {
-			if detail.Source == nil || detail.Source.Name == "" {
-				continue
-			}
-			versionedDataset := &v1alpha1.VersionedDataset{}
-			if err := c.Get(ctx, types.NamespacedName{Namespace: detail.Source.GetNamespace(kb.Namespace), Name: detail.Source.Name}, versionedDataset); err != nil {
-				klog.Infof("failed to get versionedDataset: %s, try next one", err)
-				continue
-			}
-			if !versionedDataset.Status.IsReady() {
-				klog.Infof("versionedDataset is not ready, try next one")
-				continue
-			}
-			info := &v1alpha1.OSS{Bucket: versionedDataset.Namespace}
-			for _, fileDetails := range detail.FileDetails {
-				info.Object = filepath.Join("dataset", versionedDataset.Spec.Dataset.Name, versionedDataset.Spec.Version, fileDetails.Path)
-				file, err := ds.ReadFile(ctx, info)
-				if err != nil {
-					klog.Infof("failed to read file: %s, try next one", err)
-					continue
-				}
-				defer file.Close()
-				data, err := io.ReadAll(file)
+		v, ok := outArg[base.LangchaingoRetrieverKeyInArg]
+		if ok {
+			r, ok := v.(langchainschema.Retriever)
+			if ok {
+				doc, err := r.GetRelevantDocuments(ctx, "")
 				if err != nil {
-					klog.Infof("failed to read file: %s, try next one", err)
-					continue
-				}
-				dataReader := bytes.NewReader(data)
-				doc, err := pkgdocumentloaders.NewQACSV(dataReader, "").Load(ctx)
-				if err != nil {
-					klog.Infof("failed to load doc file: %s, try next one", err)
-					continue
-				}
-				for i := 0; i < remains && i < len(doc); i++ {
-					content := strings.TrimPrefix(doc[i].PageContent, "q: ")
-					promptStarters = append(promptStarters, content)
+					return nil, err
 				}
-				remains = limit - len(promptStarters)
-				if remains == 0 {
-					break Outer
+				for _, d := range doc {
+					hasAnswer := false
+					// has answer, means qa.csv, just return the question
+					v, ok := d.Metadata[documentloaders.AnswerCol]
+					if ok {
+						answer, ok := v.(string)
+						if ok && answer != "" {
+							question := strings.TrimSuffix(d.PageContent, "\na: "+answer)
+							promptStarters = append(promptStarters, strings.TrimPrefix(question, "q: "))
+							hasAnswer = true
+						}
+					}
+					if !hasAnswer {
+						content.WriteString(d.PageContent + "\n")
+						// if content is too long, may cause llm error
+						if content.Len() > 500 {
+							break
+						}
+					}
 				}
 			}
 		}
+	}
+	if len(promptStarters) == limit {
+		klog.V(3).Infoln("app has knowlegebase with qa.csv, just read some question")
+		return promptStarters, nil
+	}
+	if model == nil {
+		return nil, fmt.Errorf("can't find model in app")
+	}
+	var p prompts.PromptTemplate
+	predictArg := map[string]any{"limit": limit}
+	if content.Len() > 0 {
+		klog.V(3).Infoln("app has knowlegebase with chunk information, let llm generate some question")
+		p = prompts.NewPromptTemplate(PromptForGeneratePromptStartersByChunk, []string{"limit", "information"})
+		predictArg["information"] = content.String()
 	} else {
 		klog.V(3).Infoln("app has no knowlegebase, let llm generate some question")
-		if model != nil {
-			p := prompts.NewPromptTemplate(PromptForGeneratePromptStarters, []string{"limit", "displayName", "description"})
-			var c *chains.LLMChain
-			if len(chainOptions) > 0 {
-				c = chains.NewLLMChain(model, p, chainOptions...)
-			} else {
-				c = chains.NewLLMChain(model, p)
-			}
-			result, err := chains.Predict(ctx, c,
-				map[string]any{
-					"limit":       limit,
-					"displayName": app.Spec.DisplayName,
-					"description": app.Spec.Description,
-				},
-			)
-			if err != nil {
-				return nil, err
-			}
-			res := strings.Split(result, "\n")
-			for _, r := range res {
-				promptStarters = append(promptStarters, strings.TrimSpace(r))
-			}
-		}
+		p = prompts.NewPromptTemplate(PromptForGeneratePromptStartersByAppInfo, []string{"limit", "displayName", "description"})
+		predictArg["displayName"] = app.Spec.DisplayName
+		predictArg["description"] = app.Spec.Description
+	}
+	var llmchain *chains.LLMChain
+	if len(chainOptions) > 0 {
+		llmchain = chains.NewLLMChain(model, p, chainOptions...)
+	} else {
+		llmchain = chains.NewLLMChain(model, p)
+	}
+	result, err := chains.Predict(ctx, llmchain, predictArg)
+	if err != nil {
+		return nil, err
+	}
+	res := strings.Split(result, "\n")
+	for _, r := range res {
+		promptStarters = append(promptStarters, strings.TrimSpace(r))
 	}
 	return promptStarters, nil
 }
 
-const PromptForGeneratePromptStarters = `You are the friendly and curious questioner, please ask {{.limit}} questions based on the name and description of this app below.
+const PromptForGeneratePromptStartersByAppInfo = `You are the friendly and curious questioner, please ask {{.limit}} questions based on the name and description of this app below.
 
 Requires language consistent with the name and description of the application, no restating of my words, questions only, one question per line, no subheadings.
 
@@ -383,6 +373,14 @@ The description of the application is: {{.description}}
 
 The question you asked is:`
 
+const PromptForGeneratePromptStartersByChunk = `You are the friendly and curious questioner, please ask {{.limit}} questions based on the information below.
+
+Requires language consistent with the information, no restating of my words, questions only, one question per line, no subheadings.
+---
+{{.information}}
+---
+The question you asked is:`
+
 func (cs *ChatServer) getApp(ctx context.Context, appName, appNamespace string) (*v1alpha1.Application, runtimeclient.Client, error) {
 	token := auth.ForOIDCToken(ctx)
 	c, err := client.GetClient(token)