From 89cff8726f180fe0511980ff1ae1ae162b914f64 Mon Sep 17 00:00:00 2001
From: Darren Shepherd <darren@acorn.io>
Date: Mon, 12 Aug 2024 22:13:23 -0700
Subject: [PATCH] feat: add sys.model.provider.credential

---
 go.mod                        |   2 +-
 go.sum                        |   4 +-
 pkg/builtin/builtin.go        |  42 +++++++++++---
 pkg/credentials/credential.go |   1 +
 pkg/engine/cmd.go             |   2 +-
 pkg/engine/engine.go          |   8 ++-
 pkg/llm/proxy.go              | 104 ++++++++++++++++++++++++++++++++++
 pkg/llm/registry.go           |  57 ++++++++++++++++++-
 pkg/openai/client.go          |   7 +++
 pkg/runner/runner.go          |  27 ++++++---
 pkg/tests/tester/runner.go    |   4 ++
 pkg/types/tool.go             |  11 +++-
 pkg/types/toolstring.go       |   2 +-
 13 files changed, 245 insertions(+), 26 deletions(-)
 create mode 100644 pkg/llm/proxy.go

diff --git a/go.mod b/go.mod
index 4cf0ba73..3cfbc98e 100644
--- a/go.mod
+++ b/go.mod
@@ -16,7 +16,7 @@ require (
 	github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
 	github.com/google/uuid v1.6.0
 	github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
-	github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
+	github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
 	github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
 	github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17
 	github.com/gptscript-ai/tui v0.0.0-20240804004233-efc5673dc76e
diff --git a/go.sum b/go.sum
index 1b518130..85a3f76e 100644
--- a/go.sum
+++ b/go.sum
@@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
 github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
 github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
 github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
-github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
-github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
+github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
+github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
 github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
 github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
 github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 h1:BTfJ6ls31Roq42lznlZnuPzRf0wrT8jT+tWcvq7wDXY=
diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go
index f6811549..23db5152 100644
--- a/pkg/builtin/builtin.go
+++ b/pkg/builtin/builtin.go
@@ -26,14 +26,15 @@ import (
 )
 
 var SafeTools = map[string]struct{}{
-	"sys.abort":        {},
-	"sys.chat.finish":  {},
-	"sys.chat.history": {},
-	"sys.chat.current": {},
-	"sys.echo":         {},
-	"sys.prompt":       {},
-	"sys.time.now":     {},
-	"sys.context":      {},
+	"sys.abort":                     {},
+	"sys.chat.finish":               {},
+	"sys.chat.history":              {},
+	"sys.chat.current":              {},
+	"sys.echo":                      {},
+	"sys.prompt":                    {},
+	"sys.time.now":                  {},
+	"sys.context":                   {},
+	"sys.model.provider.credential": {},
 }
 
 var tools = map[string]types.Tool{
@@ -248,6 +249,15 @@ var tools = map[string]types.Tool{
 			BuiltinFunc: SysContext,
 		},
 	},
+	"sys.model.provider.credential": {
+		ToolDef: types.ToolDef{
+			Parameters: types.Parameters{
+				Description: "A credential tool to set the OPENAI_API_KEY and OPENAI_BASE_URL to give access to the default model provider",
+				Arguments:   types.ObjectSchema(),
+			},
+			BuiltinFunc: SysModelProviderCredential,
+		},
+	},
 }
 
 func ListTools() (result []types.Tool) {
@@ -678,6 +688,22 @@ func invalidArgument(input string, err error) string {
 	return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
 }
 
+func SysModelProviderCredential(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
+	engineContext, _ := engine.FromContext(ctx)
+	auth, url, err := engineContext.Engine.Model.ProxyInfo()
+	if err != nil {
+		return "", err
+	}
+	data, err := json.Marshal(map[string]any{
+		"env": map[string]string{
+			"OPENAI_API_KEY":  auth,
+			"OPENAI_BASE_URL": url,
+		},
+		"ephemeral": true,
+	})
+	return string(data), err
+}
+
 func SysContext(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
 	engineContext, _ := engine.FromContext(ctx)
 
diff --git a/pkg/credentials/credential.go b/pkg/credentials/credential.go
index 3d1e2192..f589a065 100644
--- a/pkg/credentials/credential.go
+++ b/pkg/credentials/credential.go
@@ -24,6 +24,7 @@ type Credential struct {
 	ToolName     string            `json:"toolName"`
 	Type         CredentialType    `json:"type"`
 	Env          map[string]string `json:"env"`
+	Ephemeral    bool              `json:"ephemeral,omitempty"`
 	ExpiresAt    *time.Time        `json:"expiresAt"`
 	RefreshToken string            `json:"refreshToken"`
 }
diff --git a/pkg/engine/cmd.go b/pkg/engine/cmd.go
index 14b41183..960bcfe8 100644
--- a/pkg/engine/cmd.go
+++ b/pkg/engine/cmd.go
@@ -109,7 +109,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
 			}
 		}()
 
-		return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
+		return tool.BuiltinFunc(ctx.WrappedContext(e), e.Env, input, progress)
 	}
 
 	var instructions []string
diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go
index d3daa674..20ca43a9 100644
--- a/pkg/engine/engine.go
+++ b/pkg/engine/engine.go
@@ -16,6 +16,7 @@ import (
 
 type Model interface {
 	Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
+	ProxyInfo() (string, string, error)
 }
 
 type RuntimeManager interface {
@@ -79,6 +80,7 @@ type Context struct {
 	Parent        *Context
 	LastReturn    *Return
 	CurrentReturn *Return
+	Engine        *Engine
 	Program       *types.Program
 	// Input is saved only so that we can render display text, don't use otherwise
 	Input string
@@ -250,8 +252,10 @@ func FromContext(ctx context.Context) (*Context, bool) {
 	return c, ok
 }
 
-func (c *Context) WrappedContext() context.Context {
-	return context.WithValue(c.Ctx, engineContext{}, c)
+func (c *Context) WrappedContext(e *Engine) context.Context {
+	cp := *c
+	cp.Engine = e
+	return context.WithValue(c.Ctx, engineContext{}, &cp)
 }
 
 func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
diff --git a/pkg/llm/proxy.go b/pkg/llm/proxy.go
new file mode 100644
index 00000000..7c3091b3
--- /dev/null
+++ b/pkg/llm/proxy.go
@@ -0,0 +1,104 @@
+package llm
+
+import (
+	"bytes"
+	"encoding/json"
+	"io"
+	"net"
+	"net/http"
+	"net/http/httputil"
+	"net/url"
+	"path"
+	"strings"
+
+	"github.com/gptscript-ai/gptscript/pkg/builtin"
+	"github.com/gptscript-ai/gptscript/pkg/openai"
+)
+
+func (r *Registry) ProxyInfo() (string, string, error) {
+	r.proxyLock.Lock()
+	defer r.proxyLock.Unlock()
+
+	if r.proxyURL != "" {
+		return r.proxyToken, r.proxyURL, nil
+	}
+
+	l, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		return "", "", err
+	}
+
+	go func() {
+		_ = http.Serve(l, r)
+		r.proxyLock.Lock()
+		defer r.proxyLock.Unlock()
+		_ = l.Close()
+		r.proxyURL = ""
+	}()
+
+	r.proxyURL = "http://" + l.Addr().String()
+	return r.proxyToken, r.proxyURL, nil
+}
+
+func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+	if r.proxyToken != strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") {
+		http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
+		return
+	}
+
+	inBytes, err := io.ReadAll(req.Body)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusBadRequest)
+		return
+	}
+
+	var (
+		model string
+		data  = map[string]any{}
+	)
+
+	if json.Unmarshal(inBytes, &data) == nil {
+		model, _ = data["model"].(string)
+	}
+
+	if model == "" {
+		model = builtin.GetDefaultModel()
+	}
+
+	c, err := r.getClient(req.Context(), model)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	oai, ok := c.(*openai.Client)
+	if !ok {
+		http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
+		return
+	}
+
+	auth, targetURL := oai.ProxyInfo()
+	if targetURL == "" {
+		http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
+		return
+	}
+
+	newURL, err := url.Parse(targetURL)
+	if err != nil {
+		http.Error(w, err.Error(), http.StatusInternalServerError)
+		return
+	}
+
+	newURL.Path = path.Join(newURL.Path, req.URL.Path)
+
+	rp := httputil.ReverseProxy{
+		Director: func(proxyReq *http.Request) {
+			proxyReq.Body = io.NopCloser(bytes.NewReader(inBytes))
+			proxyReq.URL = newURL
+			proxyReq.Header.Del("Authorization")
+			proxyReq.Header.Add("Authorization", "Bearer "+auth)
+			proxyReq.Host = newURL.Hostname()
+		},
+	}
+	rp.ServeHTTP(w, req)
+}
diff --git a/pkg/llm/registry.go b/pkg/llm/registry.go
index c568b43c..8129c788 100644
--- a/pkg/llm/registry.go
+++ b/pkg/llm/registry.go
@@ -5,7 +5,10 @@ import (
 	"errors"
 	"fmt"
 	"sort"
+	"sync"
 
+	"github.com/google/uuid"
+	"github.com/gptscript-ai/gptscript/pkg/env"
 	"github.com/gptscript-ai/gptscript/pkg/openai"
 	"github.com/gptscript-ai/gptscript/pkg/remote"
 	"github.com/gptscript-ai/gptscript/pkg/types"
@@ -18,11 +21,16 @@ type Client interface {
 }
 
 type Registry struct {
-	clients []Client
+	proxyToken string
+	proxyURL   string
+	proxyLock  sync.Mutex
+	clients    []Client
 }
 
 func NewRegistry() *Registry {
-	return &Registry{}
+	return &Registry{
+		proxyToken: env.VarOrDefault("GPTSCRIPT_INTERNAL_PROXY_TOKEN", uuid.New().String()),
+	}
 }
 
 func (r *Registry) AddClient(client Client) error {
@@ -44,6 +52,10 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result
 
 func (r *Registry) fastPath(modelName string) Client {
 	// This is optimization hack to avoid doing List Models
+	if len(r.clients) == 1 {
+		return r.clients[0]
+	}
+
 	if len(r.clients) != 2 {
 		return nil
 	}
@@ -66,6 +78,47 @@ func (r *Registry) fastPath(modelName string) Client {
 	return r.clients[0]
 }
 
+func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
+	if c := r.fastPath(modelName); c != nil {
+		return c, nil
+	}
+
+	var errs []error
+	var oaiClient *openai.Client
+	for _, client := range r.clients {
+		ok, err := client.Supports(ctx, modelName)
+		if err != nil {
+			// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
+			if errors.Is(err, openai.InvalidAuthError{}) {
+				oaiClient = client.(*openai.Client)
+			}
+
+			errs = append(errs, err)
+		} else if ok {
+			return client, nil
+		}
+	}
+
+	if len(errs) > 0 && oaiClient != nil {
+		// Prompt the user to enter their OpenAI API key and try again.
+		if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
+			return nil, err
+		}
+		ok, err := oaiClient.Supports(ctx, modelName)
+		if err != nil {
+			return nil, err
+		} else if ok {
+			return oaiClient, nil
+		}
+	}
+
+	if len(errs) == 0 {
+		return nil, fmt.Errorf("failed to find a model provider for model [%s]", modelName)
+	}
+
+	return nil, errors.Join(errs...)
+}
+
 func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
 	if messageRequest.Model == "" {
 		return nil, fmt.Errorf("model is required")
diff --git a/pkg/openai/client.go b/pkg/openai/client.go
index 53252895..42a1a39e 100644
--- a/pkg/openai/client.go
+++ b/pkg/openai/client.go
@@ -130,6 +130,13 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
 	}, nil
 }
 
+func (c *Client) ProxyInfo() (token, urlBase string) {
+	if c.invalidAuth {
+		return "", ""
+	}
+	return c.c.GetAPIKeyAndBaseURL()
+}
+
 func (c *Client) ValidAuth() error {
 	if c.invalidAuth {
 		return InvalidAuthError{}
diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go
index a8d88fee..f92b0705 100644
--- a/pkg/runner/runner.go
+++ b/pkg/runner/runner.go
@@ -872,6 +872,11 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
 			return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
 		}
 
+		if callCtx.Program.ToolSet[ref.ToolID].IsNoop() {
+			// ignore empty tools
+			continue
+		}
+
 		credName := toolName
 		if credentialAlias != "" {
 			credName = credentialAlias
@@ -944,6 +949,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
 				return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
 			}
 
+			if *res.Result == "" {
+				continue
+			}
+
 			if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
 				return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
 			}
@@ -958,15 +967,17 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
 				}
 			}
 
-			// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
-			if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
-				if isEmpty {
-					log.Warnf("Not saving empty credential for tool %s", toolName)
-				} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
-					return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
+			if !c.Ephemeral {
+				// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
+				if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
+					if isEmpty {
+						log.Warnf("Not saving empty credential for tool %s", toolName)
+					} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
+						return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
+					}
+				} else {
+					log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
 				}
-			} else {
-				log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
 			}
 		}
 
diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go
index a36c5e91..66337ff5 100644
--- a/pkg/tests/tester/runner.go
+++ b/pkg/tests/tester/runner.go
@@ -31,6 +31,10 @@ type Result struct {
 	Err     error
 }
 
+func (c *Client) ProxyInfo() (string, string, error) {
+	return "test-auth", "test-url", nil
+}
+
 func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
 	msgData, err := json.MarshalIndent(messageRequest, "", "  ")
 	require.NoError(c.t, err)
diff --git a/pkg/types/tool.go b/pkg/types/tool.go
index b59a1953..57ce3fbf 100644
--- a/pkg/types/tool.go
+++ b/pkg/types/tool.go
@@ -753,7 +753,16 @@ func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]Too
 
 	result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential))
 
-	toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
+	toolRefs, err := result.List()
+	if err != nil {
+		return nil, err
+	}
+	for _, toolRef := range toolRefs {
+		referencedTool := prg.ToolSet[toolRef.ToolID]
+		result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials))
+	}
+
+	toolRefs, err = t.getCompletionToolRefs(prg, agentGroup)
 	if err != nil {
 		return nil, err
 	}
diff --git a/pkg/types/toolstring.go b/pkg/types/toolstring.go
index 64f53638..2be6d0fc 100644
--- a/pkg/types/toolstring.go
+++ b/pkg/types/toolstring.go
@@ -74,7 +74,7 @@ func ToSysDisplayString(id string, args map[string]string) (string, error) {
 		return fmt.Sprintf("Removing `%s`", args["location"]), nil
 	case "sys.write":
 		return fmt.Sprintf("Writing `%s`", args["filename"]), nil
-	case "sys.context", "sys.stat", "sys.getenv", "sys.abort", "sys.chat.current", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now":
+	case "sys.context", "sys.stat", "sys.getenv", "sys.abort", "sys.chat.current", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now", "sys.model.provider.credential":
 		return "", nil
 	default:
 		return "", fmt.Errorf("unknown tool for display string: %s", id)