Skip to content

Commit

Permalink
feat: add sys.model.provider.credential
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Aug 13, 2024
1 parent 50503da commit 6d26908
Show file tree
Hide file tree
Showing 13 changed files with 245 additions and 26 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17
github.com/gptscript-ai/tui v0.0.0-20240804004233-efc5673dc76e
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 h1:BTfJ6ls31Roq42lznlZnuPzRf0wrT8jT+tWcvq7wDXY=
Expand Down
42 changes: 34 additions & 8 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import (
)

var SafeTools = map[string]struct{}{
"sys.abort": {},
"sys.chat.finish": {},
"sys.chat.history": {},
"sys.chat.current": {},
"sys.echo": {},
"sys.prompt": {},
"sys.time.now": {},
"sys.context": {},
"sys.abort": {},
"sys.chat.finish": {},
"sys.chat.history": {},
"sys.chat.current": {},
"sys.echo": {},
"sys.prompt": {},
"sys.time.now": {},
"sys.context": {},
"sys.model.provider.credential": {},
}

var tools = map[string]types.Tool{
Expand Down Expand Up @@ -248,6 +249,15 @@ var tools = map[string]types.Tool{
BuiltinFunc: SysContext,
},
},
"sys.model.provider.credential": {
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Description: "A credential tool to set the OPENAI_API_KEY and OPENAI_BASE_URL to give access to the default model provider",
Arguments: types.ObjectSchema(),
},
BuiltinFunc: SysModelProviderCredential,
},
},
}

func ListTools() (result []types.Tool) {
Expand Down Expand Up @@ -678,6 +688,22 @@ func invalidArgument(input string, err error) string {
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
}

func SysModelProviderCredential(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)
auth, url, err := engineContext.Engine.Model.ProxyInfo()
if err != nil {
return "", err
}
data, err := json.Marshal(map[string]any{
"env": map[string]string{
"OPENAI_API_KEY": auth,
"OPENAI_BASE_URL": url,
},
"ephemeral": true,
})
return string(data), err
}

func SysContext(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)

Expand Down
1 change: 1 addition & 0 deletions pkg/credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Credential struct {
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
Ephemeral bool `json:"ephemeral,omitempty"`
ExpiresAt *time.Time `json:"expiresAt"`
RefreshToken string `json:"refreshToken"`
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
}
}()

return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
return tool.BuiltinFunc(ctx.WrappedContext(e), e.Env, input, progress)
}

var instructions []string
Expand Down
13 changes: 11 additions & 2 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

type Model interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
ProxyInfo() (string, string, error)
}

type RuntimeManager interface {
Expand Down Expand Up @@ -79,6 +80,7 @@ type Context struct {
Parent *Context
LastReturn *Return
CurrentReturn *Return
Engine *Engine
Program *types.Program
// Input is saved only so that we can render display text, don't use otherwise
Input string
Expand Down Expand Up @@ -250,8 +252,10 @@ func FromContext(ctx context.Context) (*Context, bool) {
return c, ok
}

func (c *Context) WrappedContext() context.Context {
return context.WithValue(c.Ctx, engineContext{}, c)
func (c *Context) WrappedContext(e *Engine) context.Context {
cp := *c
cp.Engine = e
return context.WithValue(c.Ctx, engineContext{}, &cp)
}

func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
Expand Down Expand Up @@ -280,6 +284,11 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
return &Return{
Result: &s,
}, nil
} else if tool.IsNoop() {
var empty string
return &Return{
Result: &empty,
}, nil
}

if ctx.ToolCategory == CredentialToolCategory {
Expand Down
104 changes: 104 additions & 0 deletions pkg/llm/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package llm

import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"path"
"strings"

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/openai"
)

func (r *Registry) ProxyInfo() (string, string, error) {
r.proxyLock.Lock()
defer r.proxyLock.Unlock()

if r.proxyURL != "" {
return r.proxyToken, r.proxyURL, nil
}

l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", "", err
}

go func() {
_ = http.Serve(l, r)
r.proxyLock.Lock()
defer r.proxyLock.Unlock()
_ = l.Close()
r.proxyURL = ""
}()

r.proxyURL = "http://" + l.Addr().String()
return r.proxyToken, r.proxyURL, nil
}

func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.proxyToken != strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

inBytes, err := io.ReadAll(req.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

var (
model string
data = map[string]any{}
)

if json.Unmarshal(inBytes, &data) == nil {
model, _ = data["model"].(string)
}

if model == "" {
model = builtin.GetDefaultModel()
}

c, err := r.getClient(req.Context(), model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

oai, ok := c.(*openai.Client)
if !ok {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}

auth, targetURL := oai.ProxyInfo()
if targetURL == "" {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}

newURL, err := url.Parse(targetURL)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

newURL.Path = path.Join(newURL.Path, req.URL.Path)

rp := httputil.ReverseProxy{
Director: func(proxyReq *http.Request) {
proxyReq.Body = io.NopCloser(bytes.NewReader(inBytes))
proxyReq.URL = newURL
proxyReq.Header.Del("Authorization")
proxyReq.Header.Add("Authorization", "Bearer "+auth)
proxyReq.Host = newURL.Hostname()
},
}
rp.ServeHTTP(w, req)
}
57 changes: 55 additions & 2 deletions pkg/llm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"errors"
"fmt"
"sort"
"sync"

"github.com/google/uuid"
"github.com/gptscript-ai/gptscript/pkg/env"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/remote"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -18,11 +21,16 @@ type Client interface {
}

type Registry struct {
clients []Client
proxyToken string
proxyURL string
proxyLock sync.Mutex
clients []Client
}

func NewRegistry() *Registry {
return &Registry{}
return &Registry{
proxyToken: env.VarOrDefault("GPTSCRIPT_INTERNAL_PROXY_TOKEN", uuid.New().String()),
}
}

func (r *Registry) AddClient(client Client) error {
Expand All @@ -44,6 +52,10 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result

func (r *Registry) fastPath(modelName string) Client {
// This is optimization hack to avoid doing List Models
if len(r.clients) == 1 {
return r.clients[0]
}

if len(r.clients) != 2 {
return nil
}
Expand All @@ -66,6 +78,47 @@ func (r *Registry) fastPath(modelName string) Client {
return r.clients[0]
}

func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
if c := r.fastPath(modelName); c != nil {
return c, nil
}

var errs []error
var oaiClient *openai.Client
for _, client := range r.clients {
ok, err := client.Supports(ctx, modelName)
if err != nil {
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
if errors.Is(err, openai.InvalidAuthError{}) {
oaiClient = client.(*openai.Client)
}

errs = append(errs, err)
} else if ok {
return client, nil
}
}

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, modelName)
if err != nil {
return nil, err
} else if ok {
return oaiClient, nil
}
}

if len(errs) == 0 {
return nil, fmt.Errorf("failed to find a model provider for model [%s]", modelName)
}

return nil, errors.Join(errs...)
}

func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if messageRequest.Model == "" {
return nil, fmt.Errorf("model is required")
Expand Down
7 changes: 7 additions & 0 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
}, nil
}

func (c *Client) ProxyInfo() (token, urlBase string) {
if c.invalidAuth {
return "", ""
}
return c.c.GetAPIKeyAndBaseURL()
}

func (c *Client) ValidAuth() error {
if c.invalidAuth {
return InvalidAuthError{}
Expand Down
22 changes: 14 additions & 8 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
}

if *res.Result == "" {
continue
}

if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
}
Expand All @@ -958,15 +962,17 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}
}

// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
if !c.Ephemeral {
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/tests/tester/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ type Result struct {
Err error
}

func (c *Client) ProxyInfo() (string, string, error) {
return "test-auth", "test-url", nil
}

func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
msgData, err := json.MarshalIndent(messageRequest, "", " ")
require.NoError(c.t, err)
Expand Down
Loading

0 comments on commit 6d26908

Please sign in to comment.