Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase test coverage, split up some files #23

Merged
merged 8 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions cmd/list/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/cli/go-gh/v2/pkg/tableprinter"
"github.com/github/gh-models/internal/azuremodels"
"github.com/github/gh-models/internal/ux"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This package was able to go away because its two functions made more sense elsewhere, either as a method on a type rather than a standalone function (IsChatModel) or because the function dealt with types defined in another package (SortModels).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I agree that's a sensible move.

"github.com/github/gh-models/pkg/command"
"github.com/mgutz/ansi"
"github.com/spf13/cobra"
Expand All @@ -33,7 +32,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command {
// For now, filter to just chat models.
// Once other tasks are supported (like embeddings), update the list to show all models, with the task as a column.
models = filterToChatModels(models)
ux.SortModels(models)
azuremodels.SortModels(models)

if cfg.IsTerminalOutput {
cfg.WriteToOut("\n")
Expand Down Expand Up @@ -67,7 +66,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command {
func filterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary {
var chatModels []*azuremodels.ModelSummary
for _, model := range models {
if ux.IsChatModel(model) {
if model.IsChatModel() {
chatModels = append(chatModels, model)
}
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/briandowns/spinner"
"github.com/github/gh-models/internal/azuremodels"
"github.com/github/gh-models/internal/sse"
"github.com/github/gh-models/internal/ux"
"github.com/github/gh-models/pkg/command"
"github.com/github/gh-models/pkg/util"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -381,7 +380,7 @@ func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) {
return nil, err
}

ux.SortModels(models)
azuremodels.SortModels(models)
return models, nil
}

Expand All @@ -397,7 +396,7 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm
}

for _, model := range models {
if !ux.IsChatModel(model) {
if !model.IsChatModel() {
continue
}
prompt.Options = append(prompt.Options, model.FriendlyName)
Expand Down
5 changes: 2 additions & 3 deletions cmd/view/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/AlecAivazis/survey/v2"
"github.com/github/gh-models/internal/azuremodels"
"github.com/github/gh-models/internal/ux"
"github.com/github/gh-models/pkg/command"
"github.com/spf13/cobra"
)
Expand All @@ -25,7 +24,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
return err
}

ux.SortModels(models)
azuremodels.SortModels(models)

modelName := ""
switch {
Expand All @@ -37,7 +36,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command {
}

for _, model := range models {
if !ux.IsChatModel(model) {
if !model.IsChatModel() {
continue
}
prompt.Options = append(prompt.Options, model.FriendlyName)
Expand Down
24 changes: 24 additions & 0 deletions internal/azuremodels/model_details.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package azuremodels

import "fmt"

// ModelDetails includes detailed information about a model.
type ModelDetails struct {
Description string `json:"description"`
Evaluation string `json:"evaluation"`
License string `json:"license"`
LicenseDescription string `json:"license_description"`
Notes string `json:"notes"`
Tags []string `json:"tags"`
SupportedInputModalities []string `json:"supported_input_modalities"`
SupportedOutputModalities []string `json:"supported_output_modalities"`
SupportedLanguages []string `json:"supported_languages"`
MaxOutputTokens int `json:"max_output_tokens"`
MaxInputTokens int `json:"max_input_tokens"`
RateLimitTier string `json:"rateLimitTier"`
}

// ContextLimits returns a summary of the context limits for the model.
func (m *ModelDetails) ContextLimits() string {
return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens)
}
15 changes: 15 additions & 0 deletions internal/azuremodels/model_details_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package azuremodels

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestModelDetails(t *testing.T) {
t.Run("ContextLimits", func(t *testing.T) {
details := &ModelDetails{MaxInputTokens: 123, MaxOutputTokens: 456}
result := details.ContextLimits()
require.Equal(t, "up to 123 input tokens and 456 output tokens", result)
})
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,39 @@
package ux
package azuremodels

import (
"slices"
"sort"
"strings"

"github.com/github/gh-models/internal/azuremodels"
)

// ModelSummary includes basic information about a model.
type ModelSummary struct {
ID string `json:"id"`
Name string `json:"name"`
FriendlyName string `json:"friendly_name"`
Task string `json:"task"`
Publisher string `json:"publisher"`
Summary string `json:"summary"`
Version string `json:"version"`
RegistryName string `json:"registry_name"`
}

// IsChatModel returns true if the model is for chat completions.
func (m *ModelSummary) IsChatModel() bool {
return m.Task == "chat-completion"
}

// HasName checks if the model has the given name.
func (m *ModelSummary) HasName(name string) bool {
return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name)
}

var (
featuredModelNames = []string{}
)

// SortModels sorts the given models in place, with featured models first, and then by friendly name.
func SortModels(models []*azuremodels.ModelSummary) {
func SortModels(models []*ModelSummary) {
sort.Slice(models, func(i, j int) bool {
// Sort featured models first, by name
isFeaturedI := slices.Contains(featuredModelNames, models[i].Name)
Expand Down
44 changes: 44 additions & 0 deletions internal/azuremodels/model_summary_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package azuremodels

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestModelSummary(t *testing.T) {
t.Run("IsChatModel", func(t *testing.T) {
embeddingModel := &ModelSummary{Task: "embeddings"}
chatCompletionModel := &ModelSummary{Task: "chat-completion"}
otherModel := &ModelSummary{Task: "something-else"}

require.False(t, embeddingModel.IsChatModel())
require.True(t, chatCompletionModel.IsChatModel())
require.False(t, otherModel.IsChatModel())
})

t.Run("HasName", func(t *testing.T) {
model := &ModelSummary{Name: "foo123", FriendlyName: "Foo 123"}

require.True(t, model.HasName(model.Name))
require.True(t, model.HasName("FOO123"))
require.True(t, model.HasName(model.FriendlyName))
require.True(t, model.HasName("fOo 123"))
require.False(t, model.HasName("completely different value"))
require.False(t, model.HasName("foo"))
})

t.Run("SortModels sorts given slice in-place by friendly name, case-insensitive", func(t *testing.T) {
modelA := &ModelSummary{Name: "z", FriendlyName: "AARDVARK"}
modelB := &ModelSummary{Name: "y", FriendlyName: "betta"}
modelC := &ModelSummary{Name: "x", FriendlyName: "Cat"}
models := []*ModelSummary{modelB, modelA, modelC}

SortModels(models)

require.Equal(t, 3, len(models))
require.Equal(t, "AARDVARK", models[0].FriendlyName)
require.Equal(t, "betta", models[1].FriendlyName)
require.Equal(t, "Cat", models[2].FriendlyName)
})
}
40 changes: 0 additions & 40 deletions internal/azuremodels/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package azuremodels

import (
"encoding/json"
"fmt"
"strings"

"github.com/github/gh-models/internal/sse"
)
Expand Down Expand Up @@ -81,23 +79,6 @@ type modelCatalogSearchSummary struct {
Summary string `json:"summary"`
}

// ModelSummary includes basic information about a model.
type ModelSummary struct {
ID string `json:"id"`
Name string `json:"name"`
FriendlyName string `json:"friendly_name"`
Task string `json:"task"`
Publisher string `json:"publisher"`
Summary string `json:"summary"`
Version string `json:"version"`
RegistryName string `json:"registry_name"`
}

// HasName checks if the model has the given name.
func (m *ModelSummary) HasName(name string) bool {
return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name)
}

type modelCatalogTextLimits struct {
MaxOutputTokens int `json:"maxOutputTokens"`
InputContextWindow int `json:"inputContextWindow"`
Expand Down Expand Up @@ -136,24 +117,3 @@ type modelCatalogDetailsResponse struct {
PlaygroundLimits *modelCatalogPlaygroundLimits `json:"playgroundLimits"`
ModelLimits *modelCatalogLimits `json:"modelLimits"`
}

// ModelDetails includes detailed information about a model.
type ModelDetails struct {
Description string `json:"description"`
Evaluation string `json:"evaluation"`
License string `json:"license"`
LicenseDescription string `json:"license_description"`
Notes string `json:"notes"`
Tags []string `json:"tags"`
SupportedInputModalities []string `json:"supported_input_modalities"`
SupportedOutputModalities []string `json:"supported_output_modalities"`
SupportedLanguages []string `json:"supported_languages"`
MaxOutputTokens int `json:"max_output_tokens"`
MaxInputTokens int `json:"max_input_tokens"`
RateLimitTier string `json:"rateLimitTier"`
}

// ContextLimits returns a summary of the context limits for the model.
func (m *ModelDetails) ContextLimits() string {
return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens)
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type is EventReader so I wanted the file name to match.

File renamed without changes.
81 changes: 81 additions & 0 deletions internal/sse/event_reader_test.go
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package sse

import (
"io"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

type sampleContent struct {
Name string `json:"name"`
NestedData []*struct {
Count int `json:"count"`
Value string `json:"value"`
} `json:"nested_data"`
}

type badReader struct{}

func (br badReader) Read(p []byte) (n int, err error) {
return 0, io.ErrClosedPipe
}

func TestEventReader(t *testing.T) {
t.Run("invalid type", func(t *testing.T) {
data := []string{
"invaliddata: {\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"object\":\"chat.completion.chunk\",\"created\":1688594090,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"finish_reason\":null}]}\n\n",
}

text := strings.NewReader(strings.Join(data, "\n"))
eventReader := NewEventReader[sampleContent](io.NopCloser(text))

firstEvent, err := eventReader.Read()
require.Empty(t, firstEvent)
require.EqualError(t, err, "unexpected event type: invaliddata")
})

t.Run("bad reader", func(t *testing.T) {
eventReader := NewEventReader[sampleContent](io.NopCloser(badReader{}))
defer eventReader.Close()

firstEvent, err := eventReader.Read()
require.Empty(t, firstEvent)
require.ErrorIs(t, io.ErrClosedPipe, err)
})

t.Run("stream is closed before done", func(t *testing.T) {
buf := strings.NewReader("data: {}")

eventReader := NewEventReader[sampleContent](io.NopCloser(buf))

evt, err := eventReader.Read()
require.Empty(t, evt)
require.NoError(t, err)

evt, err = eventReader.Read()
require.Empty(t, evt)
require.EqualError(t, err, "incomplete stream")
})

t.Run("spaces around areas", func(t *testing.T) {
buf := strings.NewReader(
// spaces between data
"data: {\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"nested_data\":[{\"count\":0,\"value\":\"with-spaces\"}]}\n" +
// no spaces
"data:{\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"nested_data\":[{\"count\":0,\"value\":\"without-spaces\"}]}\n",
)

eventReader := NewEventReader[sampleContent](io.NopCloser(buf))

evt, err := eventReader.Read()
require.NoError(t, err)
require.Equal(t, "with-spaces", evt.NestedData[0].Value)

evt, err = eventReader.Read()
require.NoError(t, err)
require.NotEmpty(t, evt)
require.Equal(t, "without-spaces", evt.NestedData[0].Value)
})
}
9 changes: 0 additions & 9 deletions internal/ux/filtering.go

This file was deleted.

24 changes: 0 additions & 24 deletions internal/ux/sorting_test.go

This file was deleted.

Loading