From 89cff8726f180fe0511980ff1ae1ae162b914f64 Mon Sep 17 00:00:00 2001 From: Darren Shepherd <darren@acorn.io> Date: Mon, 12 Aug 2024 22:13:23 -0700 Subject: [PATCH] feat: add sys.model.provider.credential --- go.mod | 2 +- go.sum | 4 +- pkg/builtin/builtin.go | 42 +++++++++++--- pkg/credentials/credential.go | 1 + pkg/engine/cmd.go | 2 +- pkg/engine/engine.go | 8 ++- pkg/llm/proxy.go | 104 ++++++++++++++++++++++++++++++++++ pkg/llm/registry.go | 57 ++++++++++++++++++- pkg/openai/client.go | 7 +++ pkg/runner/runner.go | 27 ++++++--- pkg/tests/tester/runner.go | 4 ++ pkg/types/tool.go | 11 +++- pkg/types/toolstring.go | 2 +- 13 files changed, 245 insertions(+), 26 deletions(-) create mode 100644 pkg/llm/proxy.go diff --git a/go.mod b/go.mod index 4cf0ba73..3cfbc98e 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.6.0 github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 - github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 + github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 github.com/gptscript-ai/tui v0.0.0-20240804004233-efc5673dc76e diff --git a/go.sum b/go.sum index 1b518130..85a3f76e 100644 --- a/go.sum +++ b/go.sum @@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA= github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk= -github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo= -github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= +github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc= +github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 h1:BTfJ6ls31Roq42lznlZnuPzRf0wrT8jT+tWcvq7wDXY= diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index f6811549..23db5152 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -26,14 +26,15 @@ import ( ) var SafeTools = map[string]struct{}{ - "sys.abort": {}, - "sys.chat.finish": {}, - "sys.chat.history": {}, - "sys.chat.current": {}, - "sys.echo": {}, - "sys.prompt": {}, - "sys.time.now": {}, - "sys.context": {}, + "sys.abort": {}, + "sys.chat.finish": {}, + "sys.chat.history": {}, + "sys.chat.current": {}, + "sys.echo": {}, + "sys.prompt": {}, + "sys.time.now": {}, + "sys.context": {}, + "sys.model.provider.credential": {}, } var tools = map[string]types.Tool{ @@ -248,6 +249,15 @@ var tools = map[string]types.Tool{ BuiltinFunc: SysContext, }, }, + "sys.model.provider.credential": { + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + Description: "A credential tool to set the OPENAI_API_KEY and OPENAI_BASE_URL to give access to the default model provider", + Arguments: types.ObjectSchema(), + }, + BuiltinFunc: SysModelProviderCredential, + }, + }, } func ListTools() (result []types.Tool) { @@ -678,6 +688,22 @@ func invalidArgument(input string, err error) string { return fmt.Sprintf("Failed to parse arguments %s: %v", input, err) } +func SysModelProviderCredential(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) { + engineContext, _ := engine.FromContext(ctx) + auth, url, err := engineContext.Engine.Model.ProxyInfo() + if err != nil { + return "", err + } + data, err := json.Marshal(map[string]any{ + "env": map[string]string{ + "OPENAI_API_KEY": auth, + "OPENAI_BASE_URL": url, + }, + "ephemeral": true, + }) + return string(data), err +} + func SysContext(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) { engineContext, _ := engine.FromContext(ctx) diff --git a/pkg/credentials/credential.go b/pkg/credentials/credential.go index 3d1e2192..f589a065 100644 --- a/pkg/credentials/credential.go +++ b/pkg/credentials/credential.go @@ -24,6 +24,7 @@ type Credential struct { ToolName string `json:"toolName"` Type CredentialType `json:"type"` Env map[string]string `json:"env"` + Ephemeral bool `json:"ephemeral,omitempty"` ExpiresAt *time.Time `json:"expiresAt"` RefreshToken string `json:"refreshToken"` } diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go index 14b41183..960bcfe8 100644 --- a/pkg/engine/cmd.go +++ b/pkg/engine/cmd.go @@ -109,7 +109,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate } }() - return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress) + return tool.BuiltinFunc(ctx.WrappedContext(e), e.Env, input, progress) } var instructions []string diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index d3daa674..20ca43a9 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -16,6 +16,7 @@ import ( type Model interface { Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) + ProxyInfo() (string, string, error) } type RuntimeManager interface { @@ -79,6 +80,7 @@ type Context struct { Parent *Context LastReturn *Return CurrentReturn *Return + Engine *Engine Program *types.Program // Input is saved only so that we can render display text, don't use otherwise Input string @@ -250,8 +252,10 @@ func FromContext(ctx context.Context) (*Context, bool) { return c, ok } -func (c *Context) WrappedContext() context.Context { - return context.WithValue(c.Ctx, engineContext{}, c) +func (c *Context) WrappedContext(e *Engine) context.Context { + cp := *c + cp.Engine = e + return context.WithValue(c.Ctx, engineContext{}, &cp) } func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) { diff --git a/pkg/llm/proxy.go b/pkg/llm/proxy.go new file mode 100644 index 00000000..7c3091b3 --- /dev/null +++ b/pkg/llm/proxy.go @@ -0,0 +1,104 @@ +package llm + +import ( + "bytes" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httputil" + "net/url" + "path" + "strings" + + "github.com/gptscript-ai/gptscript/pkg/builtin" + "github.com/gptscript-ai/gptscript/pkg/openai" +) + +func (r *Registry) ProxyInfo() (string, string, error) { + r.proxyLock.Lock() + defer r.proxyLock.Unlock() + + if r.proxyURL != "" { + return r.proxyToken, r.proxyURL, nil + } + + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return "", "", err + } + + go func() { + _ = http.Serve(l, r) + r.proxyLock.Lock() + defer r.proxyLock.Unlock() + _ = l.Close() + r.proxyURL = "" + }() + + r.proxyURL = "http://" + l.Addr().String() + return r.proxyToken, r.proxyURL, nil +} + +func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if r.proxyToken != strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + inBytes, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var ( + model string + data = map[string]any{} + ) + + if json.Unmarshal(inBytes, &data) == nil { + model, _ = data["model"].(string) + } + + if model == "" { + model = builtin.GetDefaultModel() + } + + c, err := r.getClient(req.Context(), model) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + oai, ok := c.(*openai.Client) + if !ok { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + + auth, targetURL := oai.ProxyInfo() + if targetURL == "" { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + return + } + + newURL, err := url.Parse(targetURL) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + newURL.Path = path.Join(newURL.Path, req.URL.Path) + + rp := httputil.ReverseProxy{ + Director: func(proxyReq *http.Request) { + proxyReq.Body = io.NopCloser(bytes.NewReader(inBytes)) + proxyReq.URL = newURL + proxyReq.Header.Del("Authorization") + proxyReq.Header.Add("Authorization", "Bearer "+auth) + proxyReq.Host = newURL.Hostname() + }, + } + rp.ServeHTTP(w, req) +} diff --git a/pkg/llm/registry.go b/pkg/llm/registry.go index c568b43c..8129c788 100644 --- a/pkg/llm/registry.go +++ b/pkg/llm/registry.go @@ -5,7 +5,10 @@ import ( "errors" "fmt" "sort" + "sync" + "github.com/google/uuid" + "github.com/gptscript-ai/gptscript/pkg/env" "github.com/gptscript-ai/gptscript/pkg/openai" "github.com/gptscript-ai/gptscript/pkg/remote" "github.com/gptscript-ai/gptscript/pkg/types" @@ -18,11 +21,16 @@ type Client interface { } type Registry struct { - clients []Client + proxyToken string + proxyURL string + proxyLock sync.Mutex + clients []Client } func NewRegistry() *Registry { - return &Registry{} + return &Registry{ + proxyToken: env.VarOrDefault("GPTSCRIPT_INTERNAL_PROXY_TOKEN", uuid.New().String()), + } } func (r *Registry) AddClient(client Client) error { @@ -44,6 +52,10 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result func (r *Registry) fastPath(modelName string) Client { // This is optimization hack to avoid doing List Models + if len(r.clients) == 1 { + return r.clients[0] + } + if len(r.clients) != 2 { return nil } @@ -66,6 +78,47 @@ func (r *Registry) fastPath(modelName string) Client { return r.clients[0] } +func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) { + if c := r.fastPath(modelName); c != nil { + return c, nil + } + + var errs []error + var oaiClient *openai.Client + for _, client := range r.clients { + ok, err := client.Supports(ctx, modelName) + if err != nil { + // If we got an OpenAI invalid auth error back, store the OpenAI client for later. + if errors.Is(err, openai.InvalidAuthError{}) { + oaiClient = client.(*openai.Client) + } + + errs = append(errs, err) + } else if ok { + return client, nil + } + } + + if len(errs) > 0 && oaiClient != nil { + // Prompt the user to enter their OpenAI API key and try again. + if err := oaiClient.RetrieveAPIKey(ctx); err != nil { + return nil, err + } + ok, err := oaiClient.Supports(ctx, modelName) + if err != nil { + return nil, err + } else if ok { + return oaiClient, nil + } + } + + if len(errs) == 0 { + return nil, fmt.Errorf("failed to find a model provider for model [%s]", modelName) + } + + return nil, errors.Join(errs...) +} + func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { if messageRequest.Model == "" { return nil, fmt.Errorf("model is required") diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 53252895..42a1a39e 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -130,6 +130,13 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts }, nil } +func (c *Client) ProxyInfo() (token, urlBase string) { + if c.invalidAuth { + return "", "" + } + return c.c.GetAPIKeyAndBaseURL() +} + func (c *Client) ValidAuth() error { if c.invalidAuth { return InvalidAuthError{} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index a8d88fee..f92b0705 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -872,6 +872,11 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err) } + if callCtx.Program.ToolSet[ref.ToolID].IsNoop() { + // ignore empty tools + continue + } + credName := toolName if credentialAlias != "" { credName = credentialAlias @@ -944,6 +949,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference) } + if *res.Result == "" { + continue + } + if err := json.Unmarshal([]byte(*res.Result), &c); err != nil { return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err) } @@ -958,15 +967,17 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } } - // Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty. - if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" { - if isEmpty { - log.Warnf("Not saving empty credential for tool %s", toolName) - } else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil { - return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err) + if !c.Ephemeral { + // Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty. + if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" { + if isEmpty { + log.Warnf("Not saving empty credential for tool %s", toolName) + } else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil { + return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err) + } + } else { + log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName) } - } else { - log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName) } } diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index a36c5e91..66337ff5 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -31,6 +31,10 @@ type Result struct { Err error } +func (c *Client) ProxyInfo() (string, string, error) { + return "test-auth", "test-url", nil +} + func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) { msgData, err := json.MarshalIndent(messageRequest, "", " ") require.NoError(c.t, err) diff --git a/pkg/types/tool.go b/pkg/types/tool.go index b59a1953..57ce3fbf 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -753,7 +753,16 @@ func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]Too result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential)) - toolRefs, err := t.getCompletionToolRefs(prg, agentGroup) + toolRefs, err := result.List() + if err != nil { + return nil, err + } + for _, toolRef := range toolRefs { + referencedTool := prg.ToolSet[toolRef.ToolID] + result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials)) + } + + toolRefs, err = t.getCompletionToolRefs(prg, agentGroup) if err != nil { return nil, err } diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go index 64f53638..2be6d0fc 100644 --- a/pkg/types/toolstring.go +++ b/pkg/types/toolstring.go @@ -74,7 +74,7 @@ func ToSysDisplayString(id string, args map[string]string) (string, error) { return fmt.Sprintf("Removing `%s`", args["location"]), nil case "sys.write": return fmt.Sprintf("Writing `%s`", args["filename"]), nil - case "sys.context", "sys.stat", "sys.getenv", "sys.abort", "sys.chat.current", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now": + case "sys.context", "sys.stat", "sys.getenv", "sys.abort", "sys.chat.current", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now", "sys.model.provider.credential": return "", nil default: return "", fmt.Errorf("unknown tool for display string: %s", id)