Skip to content

Commit

Permalink
chore: refactor to move runtime data to context
Browse files Browse the repository at this point in the history
  • Loading branch information
kardolus committed Jan 18, 2025
1 parent 8dbc215 commit 51310fd
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 73 deletions.
55 changes: 37 additions & 18 deletions api/client/client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
Expand All @@ -9,6 +10,7 @@ import (
"github.com/kardolus/chatgpt-cli/api"
"github.com/kardolus/chatgpt-cli/api/http"
"github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/internal"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -156,23 +158,31 @@ func (c *Client) ListModels() ([]string, error) {
// and the method will split it into messages, preserving punctuation and special
// characters.
func (c *Client) ProvideContext(context string) {
if len(c.Config.Binary) > 0 {
return
}

c.initHistory()
historyEntries := c.createHistoryEntriesFromString(context)
c.History = append(c.History, historyEntries...)
}

// Query sends a query to the API, returning the response as a string along with the token usage.
// It takes an input string, constructs a request body, and makes a POST API call.
//
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
// The context allows for request scoping, timeouts, and cancellation handling.
//
// Returns the API response string, the number of tokens used, and an error if any issues occur.
// If the response contains choices, it decodes the JSON and returns the content of the first choice.
func (c *Client) Query(input string) (string, int, error) {
//
// Parameters:
// - ctx: A context.Context that controls request cancellation and deadlines.
// - input: The query string to send to the API.
//
// Returns:
// - string: The content of the first response choice from the API.
// - int: The total number of tokens used in the request.
// - error: An error if the request fails or the response is invalid.
func (c *Client) Query(ctx context.Context, input string) (string, int, error) {
c.prepareQuery(input)

body, err := c.createBody(false)
body, err := c.createBody(ctx, false)
if err != nil {
return "", 0, err
}
Expand Down Expand Up @@ -207,14 +217,23 @@ func (c *Client) Query(input string) (string, int, error) {
}

// Stream sends a query to the API and processes the response as a stream.
// It takes an input string as a parameter and returns an error if there's
// any issue during the process. The method creates a request body with the
// input and then makes an API call using the Post method. The actual
// processing of the streamed response is done in the Post method.
func (c *Client) Stream(input string) error {
//
// It takes a context `ctx` and an input string, constructs a request body, and makes a POST API call.
// The context allows for request scoping, timeouts, and cancellation handling.
//
// The method creates a request body with the input and calls the API using the `Post` method.
// The actual processing of the streamed response is handled inside the `Post` method.
//
// Parameters:
// - ctx: A context.Context that controls request cancellation and deadlines.
// - input: The query string to send to the API.
//
// Returns:
// - error: An error if the request fails or the response is invalid.
func (c *Client) Stream(ctx context.Context, input string) error {
c.prepareQuery(input)

body, err := c.createBody(true)
body, err := c.createBody(ctx, true)
if err != nil {
return err
}
Expand All @@ -235,7 +254,7 @@ func (c *Client) Stream(input string) error {
return nil
}

func (c *Client) createBody(stream bool) ([]byte, error) {
func (c *Client) createBody(ctx context.Context, stream bool) ([]byte, error) {
var messages []api.Message

for index, item := range c.History {
Expand All @@ -257,17 +276,17 @@ func (c *Client) createBody(stream bool) ([]byte, error) {
Stream: stream,
}

if len(c.Config.Binary) > 0 {
content, err := c.createImageContentFromBinary(c.Config.Binary)
if data, ok := ctx.Value(internal.BinaryDataKey).([]byte); ok {
content, err := c.createImageContentFromBinary(data)
if err != nil {
return nil, err
}
body.Messages = append(body.Messages, api.Message{
Role: UserRole,
Content: []api.ImageContent{content},
})
} else if c.Config.Image != "" {
content, err := c.createImageContentFromURLOrFile(c.Config.Image)
} else if path, ok := ctx.Value(internal.ImagePathKey).(string); ok {
content, err := c.createImageContentFromURLOrFile(path)
if err != nil {
return nil, err
}
Expand Down
70 changes: 35 additions & 35 deletions api/client/client_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client_test

import (
"context"
"encoding/json"
"errors"
"github.com/golang/mock/gomock"
Expand All @@ -10,6 +11,7 @@ import (
"github.com/kardolus/chatgpt-cli/api/http"
config2 "github.com/kardolus/chatgpt-cli/config"
"github.com/kardolus/chatgpt-cli/history"
"github.com/kardolus/chatgpt-cli/internal"
"github.com/kardolus/chatgpt-cli/test"
"os"
"strings"
Expand Down Expand Up @@ -175,7 +177,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {

mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)

_, _, err = subject.Query(query)
_, _, err = subject.Query(context.TODO(), query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(tt.expectedError))
})
Expand Down Expand Up @@ -235,7 +237,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
}))
}

result, usage, err := subject.Query(query)
result, usage, err := subject.Query(context.TODO(), query)
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal(answer))
Expect(usage).To(Equal(tokens))
Expand Down Expand Up @@ -378,7 +380,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)

_, _, _ = subject.Query("test query")
_, _, _ = subject.Query(context.TODO(), "test query")
})
it("should include all messages when the model does not start with o1Prefix", func() {
const systemRole = "System role for this test"
Expand All @@ -405,7 +407,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockTimer.EXPECT().Now().Return(time.Now()).AnyTimes()
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)

_, _, _ = subject.Query("test query")
_, _, _ = subject.Query(context.TODO(), "test query")
})
})

Expand All @@ -424,9 +426,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {

it("should update a callout as expected when a valid image URL is provided", func() {
subject := factory.buildClientWithoutConfig()
subject.Config.Image = website

subject.Config.Role = systemRole

ctx := context.Background()
ctx = context.WithValue(ctx, internal.ImagePathKey, website)

expectedBody, err := createBody([]api.Message{
{Role: client.SystemRole, Content: systemRole},
{Role: client.UserRole, Content: query},
Expand All @@ -444,58 +449,66 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)

_, _, _ = subject.Query(query)
_, _, _ = subject.Query(ctx, query)
})
it("throws an error when the image mime type cannot be obtained due to an open-error", func() {
subject := factory.buildClientWithoutConfig()
subject.Config.Image = image
subject.Config.Role = systemRole

ctx := context.Background()
ctx = context.WithValue(ctx, internal.ImagePathKey, image)

mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
mockReader.EXPECT().Open(image).Return(nil, errors.New(errorMessage))

_, _, err := subject.Query(query)
_, _, err := subject.Query(ctx, query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMessage))
})
it("throws an error when the image mime type cannot be obtained due to a read-error", func() {
imageFile := &os.File{}

subject := factory.buildClientWithoutConfig()
subject.Config.Image = image
subject.Config.Role = systemRole

ctx := context.Background()
ctx = context.WithValue(ctx, internal.ImagePathKey, image)

mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
mockReader.EXPECT().Open(image).Return(imageFile, nil)
mockReader.EXPECT().ReadBufferFromFile(imageFile).Return(nil, errors.New(errorMessage))

_, _, err := subject.Query(query)
_, _, err := subject.Query(ctx, query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMessage))
})
it("throws an error when the image base64 encoded content cannot be obtained due to a read-error", func() {
imageFile := &os.File{}

subject := factory.buildClientWithoutConfig()
subject.Config.Image = image
subject.Config.Role = systemRole

ctx := context.Background()
ctx = context.WithValue(ctx, internal.ImagePathKey, image)

mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
mockReader.EXPECT().Open(image).Return(imageFile, nil)
mockReader.EXPECT().ReadBufferFromFile(imageFile).Return(nil, nil)
mockReader.EXPECT().ReadFile(image).Return(nil, errors.New(errorMessage))

_, _, err := subject.Query(query)
_, _, err := subject.Query(ctx, query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMessage))
})
it("should update a callout as expected when a valid local image is provided", func() {
imageFile := &os.File{}

subject := factory.buildClientWithoutConfig()
subject.Config.Image = image
subject.Config.Role = systemRole

ctx := context.Background()
ctx = context.WithValue(ctx, internal.ImagePathKey, image)

mockReader.EXPECT().Open(image).Return(imageFile, nil)
mockReader.EXPECT().ReadBufferFromFile(imageFile).Return(nil, nil)
mockReader.EXPECT().ReadFile(image).Return(nil, nil)
Expand All @@ -517,7 +530,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
mockTimer.EXPECT().Now().Return(time.Now()).Times(2)
mockCaller.EXPECT().Post(subject.Config.URL+subject.Config.CompletionsPath, expectedBody, false).Return(nil, nil)

_, _, _ = subject.Query(query)
_, _, _ = subject.Query(ctx, query)
})
})
})
Expand All @@ -541,7 +554,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {

mockTimer.EXPECT().Now().Return(time.Time{}).Times(2)

err := subject.Stream(query)
err := subject.Stream(context.TODO(), query)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(Equal(errorMsg))
})
Expand Down Expand Up @@ -574,7 +587,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
},
}))

err := subject.Stream(query)
err := subject.Stream(context.TODO(), query)
Expect(err).NotTo(HaveOccurred())
}

Expand Down Expand Up @@ -671,12 +684,12 @@ func testClient(t *testing.T, when spec.G, it spec.S) {
it("updates the history with the provided context", func() {
subject := factory.buildClientWithoutConfig()

context := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
chatContext := "This is a story about a dog named Kya. Kya loves to play fetch and swim in the lake."
mockHistoryStore.EXPECT().Read().Return(nil, nil).Times(1)

mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()

subject.ProvideContext(context)
subject.ProvideContext(chatContext)

Expect(len(subject.History)).To(Equal(2)) // The system message and the provided context

Expand All @@ -686,7 +699,7 @@ func testClient(t *testing.T, when spec.G, it spec.S) {

contextMessage := subject.History[1]
Expect(contextMessage.Role).To(Equal(client.UserRole))
Expect(contextMessage.Content).To(Equal(context))
Expect(contextMessage.Content).To(Equal(chatContext))
})
it("behaves as expected with a non empty initial history", func() {
subject := factory.buildClientWithoutConfig()
Expand All @@ -707,27 +720,14 @@ func testClient(t *testing.T, when spec.G, it spec.S) {

mockTimer.EXPECT().Now().Return(time.Time{}).AnyTimes()

context := "test context"
subject.ProvideContext(context)
chatContext := "test context"
subject.ProvideContext(chatContext)

Expect(len(subject.History)).To(Equal(3))

contextMessage := subject.History[2]
Expect(contextMessage.Role).To(Equal(client.UserRole))
Expect(contextMessage.Content).To(Equal(context))
})
it("does not update history if Config.Binary is provided", func() {
subject := factory.buildClientWithoutConfig()

subject.Config.Binary = []byte("binary data")

mockHistoryStore.EXPECT().Read().Times(0) // No read should be called, early return happens
mockTimer.EXPECT().Now().Times(0) // No need to mock time since we should not enter the function body

initialHistoryLength := len(subject.History)
subject.ProvideContext("some context")

Expect(len(subject.History)).To(Equal(initialHistoryLength))
Expect(contextMessage.Content).To(Equal(chatContext))
})
})
}
Expand Down
Loading

0 comments on commit 51310fd

Please sign in to comment.