Skip to content

Commit

Permalink
feat: able to get model list provided by Embedder and LLM
Browse files Browse the repository at this point in the history
Signed-off-by: bjwswang <bjwswang@gmail.com>
  • Loading branch information
bjwswang committed Dec 15, 2023
1 parent 511dddd commit 6ec48be
Show file tree
Hide file tree
Showing 24 changed files with 474 additions and 52 deletions.
36 changes: 31 additions & 5 deletions api/base/v1alpha1/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (

"k8s.io/client-go/dynamic"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeagi/arcadia/pkg/embeddings"
)

func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client, cli dynamic.Interface) (string, error) {
Expand All @@ -30,9 +32,33 @@ func (e Embedder) AuthAPIKey(ctx context.Context, c client.Client, cli dynamic.I
return e.Spec.Enpoint.AuthAPIKey(ctx, e.GetNamespace(), c, cli)
}

type EmbeddingType string
// GetModelList returns a model list provided by this LLM based on different provider
func (e Embedder) GetModelList() []string {
switch e.Spec.Provider.GetType() {
case ProviderTypeWorker:
return e.GetWorkerModels()
case ProviderType3rdParty:
return e.Get3rdPartyModels()
}
return []string{}
}

const (
OpenAI EmbeddingType = "openai"
ZhiPuAI EmbeddingType = "zhipuai"
)
// GetWorkerModels returns a model list which provided by this worker provider
func (e Embedder) GetWorkerModels() []string {
return []string{string(e.GetUID())}
}

// Get3rdPartyModels returns a model list which provided by the 3rd party provider
func (e Embedder) Get3rdPartyModels() []string {
if e.Spec.Provider.GetType() != ProviderType3rdParty {
return []string{}
}
switch e.Spec.Type {
case embeddings.ZhiPuAI:
return embeddings.ZhiPuAIModels
case embeddings.OpenAI:
return embeddings.OpenAIModels
}

return []string{}
}
4 changes: 3 additions & 1 deletion api/base/v1alpha1/embedder_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package v1alpha1

import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeagi/arcadia/pkg/embeddings"
)

// EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
Expand All @@ -28,7 +30,7 @@ type EmbedderSpec struct {
CommonSpec `json:",inline"`

// ServiceType indicates the source type of embedding service
Type EmbeddingType `json:"type"`
Type embeddings.EmbeddingType `json:"type"`

// Provider defines the provider info which provide this embedder service
Provider `json:"provider,omitempty"`
Expand Down
32 changes: 32 additions & 0 deletions api/base/v1alpha1/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (

"k8s.io/client-go/dynamic"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kubeagi/arcadia/pkg/llms"
)

func (llm LLM) AuthAPIKey(ctx context.Context, c client.Client, cli dynamic.Interface) (string, error) {
Expand All @@ -39,3 +41,33 @@ func (llmStatus LLMStatus) LLMReady() (string, bool) {
}
return "", true
}

// GetModelList returns a model list provided by this LLM based on different provider
func (llm LLM) GetModelList() []string {
switch llm.Spec.Provider.GetType() {
case ProviderTypeWorker:
return llm.GetWorkerModels()
case ProviderType3rdParty:
return llm.Get3rdPartyModels()
}
return []string{}
}

// GetWorkerModels returns a model list which provided by this worker provider
func (llm LLM) GetWorkerModels() []string {
return []string{string(llm.GetUID())}
}

// Get3rdPartyModels returns a model list which provided by the 3rd party provider
func (llm LLM) Get3rdPartyModels() []string {
if llm.Spec.Provider.GetType() != ProviderType3rdParty {
return []string{}
}
switch llm.Spec.Type {
case llms.ZhiPuAI:
return llms.ZhiPuAIModels
case llms.OpenAI:
return llms.OpenAIModels
}
return []string{}
}
3 changes: 2 additions & 1 deletion api/base/v1alpha1/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/llms"
)

Expand Down Expand Up @@ -138,7 +139,7 @@ func (worker Worker) BuildEmbedder() *Embedder {
DisplayName: worker.Spec.Model.Name,
Description: "Embedder created by Worker(OpenAI compatible)",
},
Type: OpenAI,
Type: embeddings.OpenAI,
Provider: Provider{
Worker: &TypedObjectReference{
Kind: "Worker",
Expand Down
5 changes: 3 additions & 2 deletions controllers/embedder_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/predicate"

arcadiav1alpha1 "github.com/kubeagi/arcadia/api/base/v1alpha1"
"github.com/kubeagi/arcadia/pkg/embeddings"
"github.com/kubeagi/arcadia/pkg/llms/openai"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
)
Expand Down Expand Up @@ -161,14 +162,14 @@ func (r *EmbedderReconciler) check3rdPartyEmbedder(ctx context.Context, logger l
}

switch instance.Spec.Type {
case arcadiav1alpha1.ZhiPuAI:
case embeddings.ZhiPuAI:
embedClient := zhipuai.NewZhiPuAI(apiKey)
res, err := embedClient.Validate()
if err != nil {
return r.UpdateStatus(ctx, instance, nil, err)
}
msg = res.String()
case arcadiav1alpha1.OpenAI:
case embeddings.OpenAI:
embedClient := openai.NewOpenAI(apiKey, instance.Spec.Enpoint.URL)
res, err := embedClient.Validate()
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions examples/chat_with_document/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/valyala/fasthttp"

zhipuaiembeddings "github.com/kubeagi/arcadia/pkg/embeddings/zhipuai"
"github.com/kubeagi/arcadia/pkg/llms"
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"github.com/tmc/langchaingo/vectorstores/chroma"
)
Expand Down Expand Up @@ -112,7 +113,7 @@ func QueryHandler(c *fiber.Ctx) error {

params := zhipuai.ModelParams{
Method: zhipuai.ZhiPuAIInvoke,
Model: zhipuai.ZhiPuAIPro,
Model: llms.ZhiPuAIPro,
Temperature: 0.5,
TopP: 0.7,
Prompt: prompt,
Expand Down Expand Up @@ -219,7 +220,7 @@ func StreamQueryHandler(c *fiber.Ctx) error {

params := zhipuai.ModelParams{
Method: zhipuai.ZhiPuAISSEInvoke,
Model: zhipuai.ZhiPuAIPro,
Model: llms.ZhiPuAIPro,
Temperature: 0.5,
TopP: 0.7,
Prompt: prompt,
Expand Down
2 changes: 1 addition & 1 deletion examples/rbac/inquiry.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func Inquiry() *cobra.Command {
}

params := zhipuai.DefaultModelParams()
params.Model = zhipuai.Model(model)
params.Model = model
params.Method = zhipuai.Method(method)
params.Prompt = []zhipuai.Prompt{
{Role: zhipuai.User, Content: output.String()},
Expand Down
4 changes: 2 additions & 2 deletions examples/rbac/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.
package main

import (
"github.com/kubeagi/arcadia/pkg/llms/zhipuai"
"github.com/kubeagi/arcadia/pkg/llms"
"github.com/spf13/cobra"
)

Expand All @@ -34,7 +34,7 @@ func NewCmd() *cobra.Command {
}

cmd.PersistentFlags().StringVar(&apiKey, "apiKey", "", "apiKey to access LLM service")
cmd.PersistentFlags().StringVar(&model, "model", string(zhipuai.ZhiPuAILite), "which model to use: chatglm_lite/chatglm_std/chatglm_pro")
cmd.PersistentFlags().StringVar(&model, "model", string(llms.ZhiPuAILite), "which model to use: chatglm_lite/chatglm_std/chatglm_pro")
cmd.PersistentFlags().StringVar(&method, "method", "sse-invoke", "Invoke method used when access LLM service(invoke/sse-invoke)")

cmd.MarkPersistentFlagRequired("apiKey")
Expand Down
Loading

0 comments on commit 6ec48be

Please sign in to comment.