diff --git a/llms/ollama/internal/ollamaclient/ollamaclient.go b/llms/ollama/internal/ollamaclient/ollamaclient.go index 42cd02efa..d91e00eb8 100644 --- a/llms/ollama/internal/ollamaclient/ollamaclient.go +++ b/llms/ollama/internal/ollamaclient/ollamaclient.go @@ -16,8 +16,8 @@ import ( ) type Client struct { - base *url.URL - http http.Client + base *url.URL + httpClient *http.Client } func checkError(resp *http.Response, body []byte) error { @@ -36,7 +36,7 @@ func checkError(resp *http.Response, body []byte) error { return apiError } -func NewClient(ourl *url.URL) (*Client, error) { +func NewClient(ourl *url.URL, ohttp *http.Client) (*Client, error) { if ourl == nil { scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://") if !ok { @@ -57,14 +57,17 @@ func NewClient(ourl *url.URL) (*Client, error) { } } - client := Client{ - base: ourl, + if ohttp == nil { + ohttp = &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + } } - client.http = http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - }, + client := Client{ + base: ourl, + httpClient: ohttp, } return &client, nil @@ -93,7 +96,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData request.Header.Set("User-Agent", fmt.Sprintf("langchaingo/ (%s %s) Go/%s", runtime.GOARCH, runtime.GOOS, runtime.Version())) - respObj, err := c.http.Do(request) + respObj, err := c.httpClient.Do(request) if err != nil { return err } @@ -140,7 +143,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f request.Header.Set("User-Agent", fmt.Sprintf("langchaingo (%s %s) Go/%s", runtime.GOARCH, runtime.GOOS, runtime.Version())) - response, err := c.http.Do(request) + response, err := c.httpClient.Do(request) if err != nil { return err } diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index 9734afbb7..f31a41abd 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -31,7 +31,7 @@ func New(opts ...Option) (*LLM, error) { opt(&o) } - client, err := ollamaclient.NewClient(o.ollamaServerURL) + client, err := ollamaclient.NewClient(o.ollamaServerURL, o.httpClient) if err != nil { return nil, err } diff --git a/llms/ollama/options.go b/llms/ollama/options.go index acb8c44e3..ff093d27b 100644 --- a/llms/ollama/options.go +++ b/llms/ollama/options.go @@ -2,6 +2,7 @@ package ollama import ( "log" + "net/http" "net/url" "github.com/tmc/langchaingo/llms/ollama/internal/ollamaclient" @@ -9,6 +10,7 @@ import ( type options struct { ollamaServerURL *url.URL + httpClient *http.Client model string ollamaOptions ollamaclient.Options customModelTemplate string @@ -52,6 +54,13 @@ func WithServerURL(rawURL string) Option { } } +// WithHTTPClient Set custom http client. +func WithHTTPClient(client *http.Client) Option { + return func(opts *options) { + opts.httpClient = client + } +} + // WithBackendUseNUMA Use NUMA optimization on certain systems. func WithRunnerUseNUMA(numa bool) Option { return func(opts *options) {