Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sys.model.provider.credential #792

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17
github.com/gptscript-ai/tui v0.0.0-20240804004233-efc5673dc76e
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 h1:BTfJ6ls31Roq42lznlZnuPzRf0wrT8jT+tWcvq7wDXY=
Expand Down
42 changes: 34 additions & 8 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import (
)

var SafeTools = map[string]struct{}{
"sys.abort": {},
"sys.chat.finish": {},
"sys.chat.history": {},
"sys.chat.current": {},
"sys.echo": {},
"sys.prompt": {},
"sys.time.now": {},
"sys.context": {},
"sys.abort": {},
"sys.chat.finish": {},
"sys.chat.history": {},
"sys.chat.current": {},
"sys.echo": {},
"sys.prompt": {},
"sys.time.now": {},
"sys.context": {},
"sys.model.provider.credential": {},
}

var tools = map[string]types.Tool{
Expand Down Expand Up @@ -248,6 +249,15 @@ var tools = map[string]types.Tool{
BuiltinFunc: SysContext,
},
},
"sys.model.provider.credential": {
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Description: "A credential tool to set the OPENAI_API_KEY and OPENAI_BASE_URL to give access to the default model provider",
Arguments: types.ObjectSchema(),
},
BuiltinFunc: SysModelProviderCredential,
},
},
}

func ListTools() (result []types.Tool) {
Expand Down Expand Up @@ -678,6 +688,22 @@ func invalidArgument(input string, err error) string {
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
}

func SysModelProviderCredential(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)
auth, url, err := engineContext.Engine.Model.ProxyInfo()
if err != nil {
return "", err
}
data, err := json.Marshal(map[string]any{
"env": map[string]string{
"OPENAI_API_KEY": auth,
"OPENAI_BASE_URL": url,
},
"ephemeral": true,
})
return string(data), err
}

func SysContext(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)

Expand Down
1 change: 1 addition & 0 deletions pkg/credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Credential struct {
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
Ephemeral bool `json:"ephemeral,omitempty"`
ExpiresAt *time.Time `json:"expiresAt"`
RefreshToken string `json:"refreshToken"`
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
}
}()

return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
return tool.BuiltinFunc(ctx.WrappedContext(e), e.Env, input, progress)
}

var instructions []string
Expand Down
8 changes: 6 additions & 2 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

type Model interface {
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
ProxyInfo() (string, string, error)
}

type RuntimeManager interface {
Expand Down Expand Up @@ -79,6 +80,7 @@ type Context struct {
Parent *Context
LastReturn *Return
CurrentReturn *Return
Engine *Engine
Program *types.Program
// Input is saved only so that we can render display text, don't use otherwise
Input string
Expand Down Expand Up @@ -250,8 +252,10 @@ func FromContext(ctx context.Context) (*Context, bool) {
return c, ok
}

func (c *Context) WrappedContext() context.Context {
return context.WithValue(c.Ctx, engineContext{}, c)
func (c *Context) WrappedContext(e *Engine) context.Context {
cp := *c
cp.Engine = e
return context.WithValue(c.Ctx, engineContext{}, &cp)
}

func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
Expand Down
104 changes: 104 additions & 0 deletions pkg/llm/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package llm

import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"path"
"strings"

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/openai"
)

func (r *Registry) ProxyInfo() (string, string, error) {
r.proxyLock.Lock()
defer r.proxyLock.Unlock()

if r.proxyURL != "" {
return r.proxyToken, r.proxyURL, nil
}

l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return "", "", err
}

go func() {
_ = http.Serve(l, r)
r.proxyLock.Lock()
defer r.proxyLock.Unlock()
_ = l.Close()
r.proxyURL = ""
}()

r.proxyURL = "http://" + l.Addr().String()
return r.proxyToken, r.proxyURL, nil
}

func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if r.proxyToken != strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

inBytes, err := io.ReadAll(req.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

var (
model string
data = map[string]any{}
)

if json.Unmarshal(inBytes, &data) == nil {
model, _ = data["model"].(string)
}

if model == "" {
model = builtin.GetDefaultModel()
}

c, err := r.getClient(req.Context(), model)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

oai, ok := c.(*openai.Client)
if !ok {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}

auth, targetURL := oai.ProxyInfo()
if targetURL == "" {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
}

newURL, err := url.Parse(targetURL)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

newURL.Path = path.Join(newURL.Path, req.URL.Path)

rp := httputil.ReverseProxy{
Director: func(proxyReq *http.Request) {
proxyReq.Body = io.NopCloser(bytes.NewReader(inBytes))
proxyReq.URL = newURL
proxyReq.Header.Del("Authorization")
proxyReq.Header.Add("Authorization", "Bearer "+auth)
proxyReq.Host = newURL.Hostname()
},
}
rp.ServeHTTP(w, req)
}
57 changes: 55 additions & 2 deletions pkg/llm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"errors"
"fmt"
"sort"
"sync"

"github.com/google/uuid"
"github.com/gptscript-ai/gptscript/pkg/env"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/remote"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -18,11 +21,16 @@ type Client interface {
}

type Registry struct {
clients []Client
proxyToken string
proxyURL string
proxyLock sync.Mutex
clients []Client
}

func NewRegistry() *Registry {
return &Registry{}
return &Registry{
proxyToken: env.VarOrDefault("GPTSCRIPT_INTERNAL_PROXY_TOKEN", uuid.New().String()),
}
}

func (r *Registry) AddClient(client Client) error {
Expand All @@ -44,6 +52,10 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result

func (r *Registry) fastPath(modelName string) Client {
// This is optimization hack to avoid doing List Models
if len(r.clients) == 1 {
return r.clients[0]
}

if len(r.clients) != 2 {
return nil
}
Expand All @@ -66,6 +78,47 @@ func (r *Registry) fastPath(modelName string) Client {
return r.clients[0]
}

func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
if c := r.fastPath(modelName); c != nil {
return c, nil
}

var errs []error
var oaiClient *openai.Client
for _, client := range r.clients {
ok, err := client.Supports(ctx, modelName)
if err != nil {
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
if errors.Is(err, openai.InvalidAuthError{}) {
oaiClient = client.(*openai.Client)
}

errs = append(errs, err)
} else if ok {
return client, nil
}
}

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, modelName)
if err != nil {
return nil, err
} else if ok {
return oaiClient, nil
}
}

if len(errs) == 0 {
return nil, fmt.Errorf("failed to find a model provider for model [%s]", modelName)
}

return nil, errors.Join(errs...)
}

func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if messageRequest.Model == "" {
return nil, fmt.Errorf("model is required")
Expand Down
7 changes: 7 additions & 0 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
}, nil
}

func (c *Client) ProxyInfo() (token, urlBase string) {
if c.invalidAuth {
return "", ""
}
return c.c.GetAPIKeyAndBaseURL()
}

func (c *Client) ValidAuth() error {
if c.invalidAuth {
return InvalidAuthError{}
Expand Down
27 changes: 19 additions & 8 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,11 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
}

if callCtx.Program.ToolSet[ref.ToolID].IsNoop() {
// ignore empty tools
continue
}

credName := toolName
if credentialAlias != "" {
credName = credentialAlias
Expand Down Expand Up @@ -944,6 +949,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
}

if *res.Result == "" {
continue
}

if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
}
Expand All @@ -958,15 +967,17 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
}
}

// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
if !c.Ephemeral {
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
if isEmpty {
log.Warnf("Not saving empty credential for tool %s", toolName)
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
} else {
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
}
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/tests/tester/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ type Result struct {
Err error
}

func (c *Client) ProxyInfo() (string, string, error) {
return "test-auth", "test-url", nil
}

func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
msgData, err := json.MarshalIndent(messageRequest, "", " ")
require.NoError(c.t, err)
Expand Down
Loading