diff --git a/cmd/list/list.go b/cmd/list/list.go index 7543793..1db503f 100644 --- a/cmd/list/list.go +++ b/cmd/list/list.go @@ -6,7 +6,6 @@ import ( "github.com/cli/go-gh/v2/pkg/tableprinter" "github.com/github/gh-models/internal/azuremodels" - "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/command" "github.com/mgutz/ansi" "github.com/spf13/cobra" @@ -33,7 +32,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command { // For now, filter to just chat models. // Once other tasks are supported (like embeddings), update the list to show all models, with the task as a column. models = filterToChatModels(models) - ux.SortModels(models) + azuremodels.SortModels(models) if cfg.IsTerminalOutput { cfg.WriteToOut("\n") @@ -67,7 +66,7 @@ func NewListCommand(cfg *command.Config) *cobra.Command { func filterToChatModels(models []*azuremodels.ModelSummary) []*azuremodels.ModelSummary { var chatModels []*azuremodels.ModelSummary for _, model := range models { - if ux.IsChatModel(model) { + if model.IsChatModel() { chatModels = append(chatModels, model) } } diff --git a/cmd/run/run.go b/cmd/run/run.go index 7b54235..ad96626 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -16,7 +16,6 @@ import ( "github.com/briandowns/spinner" "github.com/github/gh-models/internal/azuremodels" "github.com/github/gh-models/internal/sse" - "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/util" "github.com/spf13/cobra" @@ -381,7 +380,7 @@ func (h *runCommandHandler) loadModels() ([]*azuremodels.ModelSummary, error) { return nil, err } - ux.SortModels(models) + azuremodels.SortModels(models) return models, nil } @@ -397,7 +396,7 @@ func (h *runCommandHandler) getModelNameFromArgs(models []*azuremodels.ModelSumm } for _, model := range models { - if !ux.IsChatModel(model) { + if !model.IsChatModel() { continue } prompt.Options = append(prompt.Options, model.FriendlyName) diff --git a/cmd/view/view.go b/cmd/view/view.go index 37e34e0..1a547d6 100644 --- a/cmd/view/view.go +++ b/cmd/view/view.go @@ -6,7 +6,6 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/github/gh-models/internal/azuremodels" - "github.com/github/gh-models/internal/ux" "github.com/github/gh-models/pkg/command" "github.com/spf13/cobra" ) @@ -25,7 +24,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command { return err } - ux.SortModels(models) + azuremodels.SortModels(models) modelName := "" switch { @@ -37,7 +36,7 @@ func NewViewCommand(cfg *command.Config) *cobra.Command { } for _, model := range models { - if !ux.IsChatModel(model) { + if !model.IsChatModel() { continue } prompt.Options = append(prompt.Options, model.FriendlyName) diff --git a/internal/azuremodels/model_details.go b/internal/azuremodels/model_details.go new file mode 100644 index 0000000..2683648 --- /dev/null +++ b/internal/azuremodels/model_details.go @@ -0,0 +1,24 @@ +package azuremodels + +import "fmt" + +// ModelDetails includes detailed information about a model. +type ModelDetails struct { + Description string `json:"description"` + Evaluation string `json:"evaluation"` + License string `json:"license"` + LicenseDescription string `json:"license_description"` + Notes string `json:"notes"` + Tags []string `json:"tags"` + SupportedInputModalities []string `json:"supported_input_modalities"` + SupportedOutputModalities []string `json:"supported_output_modalities"` + SupportedLanguages []string `json:"supported_languages"` + MaxOutputTokens int `json:"max_output_tokens"` + MaxInputTokens int `json:"max_input_tokens"` + RateLimitTier string `json:"rateLimitTier"` +} + +// ContextLimits returns a summary of the context limits for the model. +func (m *ModelDetails) ContextLimits() string { + return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens) +} diff --git a/internal/azuremodels/model_details_test.go b/internal/azuremodels/model_details_test.go new file mode 100644 index 0000000..8a41f06 --- /dev/null +++ b/internal/azuremodels/model_details_test.go @@ -0,0 +1,15 @@ +package azuremodels + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestModelDetails(t *testing.T) { + t.Run("ContextLimits", func(t *testing.T) { + details := &ModelDetails{MaxInputTokens: 123, MaxOutputTokens: 456} + result := details.ContextLimits() + require.Equal(t, "up to 123 input tokens and 456 output tokens", result) + }) +} diff --git a/internal/ux/sorting.go b/internal/azuremodels/model_summary.go similarity index 50% rename from internal/ux/sorting.go rename to internal/azuremodels/model_summary.go index 59b0e4e..9b7798a 100644 --- a/internal/ux/sorting.go +++ b/internal/azuremodels/model_summary.go @@ -1,19 +1,39 @@ -package ux +package azuremodels import ( "slices" "sort" "strings" - - "github.com/github/gh-models/internal/azuremodels" ) +// ModelSummary includes basic information about a model. +type ModelSummary struct { + ID string `json:"id"` + Name string `json:"name"` + FriendlyName string `json:"friendly_name"` + Task string `json:"task"` + Publisher string `json:"publisher"` + Summary string `json:"summary"` + Version string `json:"version"` + RegistryName string `json:"registry_name"` +} + +// IsChatModel returns true if the model is for chat completions. +func (m *ModelSummary) IsChatModel() bool { + return m.Task == "chat-completion" +} + +// HasName checks if the model has the given name. +func (m *ModelSummary) HasName(name string) bool { + return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name) +} + var ( featuredModelNames = []string{} ) // SortModels sorts the given models in place, with featured models first, and then by friendly name. -func SortModels(models []*azuremodels.ModelSummary) { +func SortModels(models []*ModelSummary) { sort.Slice(models, func(i, j int) bool { // Sort featured models first, by name isFeaturedI := slices.Contains(featuredModelNames, models[i].Name) diff --git a/internal/azuremodels/model_summary_test.go b/internal/azuremodels/model_summary_test.go new file mode 100644 index 0000000..82358f0 --- /dev/null +++ b/internal/azuremodels/model_summary_test.go @@ -0,0 +1,44 @@ +package azuremodels + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestModelSummary(t *testing.T) { + t.Run("IsChatModel", func(t *testing.T) { + embeddingModel := &ModelSummary{Task: "embeddings"} + chatCompletionModel := &ModelSummary{Task: "chat-completion"} + otherModel := &ModelSummary{Task: "something-else"} + + require.False(t, embeddingModel.IsChatModel()) + require.True(t, chatCompletionModel.IsChatModel()) + require.False(t, otherModel.IsChatModel()) + }) + + t.Run("HasName", func(t *testing.T) { + model := &ModelSummary{Name: "foo123", FriendlyName: "Foo 123"} + + require.True(t, model.HasName(model.Name)) + require.True(t, model.HasName("FOO123")) + require.True(t, model.HasName(model.FriendlyName)) + require.True(t, model.HasName("fOo 123")) + require.False(t, model.HasName("completely different value")) + require.False(t, model.HasName("foo")) + }) + + t.Run("SortModels sorts given slice in-place by friendly name, case-insensitive", func(t *testing.T) { + modelA := &ModelSummary{Name: "z", FriendlyName: "AARDVARK"} + modelB := &ModelSummary{Name: "y", FriendlyName: "betta"} + modelC := &ModelSummary{Name: "x", FriendlyName: "Cat"} + models := []*ModelSummary{modelB, modelA, modelC} + + SortModels(models) + + require.Equal(t, 3, len(models)) + require.Equal(t, "AARDVARK", models[0].FriendlyName) + require.Equal(t, "betta", models[1].FriendlyName) + require.Equal(t, "Cat", models[2].FriendlyName) + }) +} diff --git a/internal/azuremodels/types.go b/internal/azuremodels/types.go index 2cd4d2a..29d4a7d 100644 --- a/internal/azuremodels/types.go +++ b/internal/azuremodels/types.go @@ -2,8 +2,6 @@ package azuremodels import ( "encoding/json" - "fmt" - "strings" "github.com/github/gh-models/internal/sse" ) @@ -81,23 +79,6 @@ type modelCatalogSearchSummary struct { Summary string `json:"summary"` } -// ModelSummary includes basic information about a model. -type ModelSummary struct { - ID string `json:"id"` - Name string `json:"name"` - FriendlyName string `json:"friendly_name"` - Task string `json:"task"` - Publisher string `json:"publisher"` - Summary string `json:"summary"` - Version string `json:"version"` - RegistryName string `json:"registry_name"` -} - -// HasName checks if the model has the given name. -func (m *ModelSummary) HasName(name string) bool { - return strings.EqualFold(m.FriendlyName, name) || strings.EqualFold(m.Name, name) -} - type modelCatalogTextLimits struct { MaxOutputTokens int `json:"maxOutputTokens"` InputContextWindow int `json:"inputContextWindow"` @@ -136,24 +117,3 @@ type modelCatalogDetailsResponse struct { PlaygroundLimits *modelCatalogPlaygroundLimits `json:"playgroundLimits"` ModelLimits *modelCatalogLimits `json:"modelLimits"` } - -// ModelDetails includes detailed information about a model. -type ModelDetails struct { - Description string `json:"description"` - Evaluation string `json:"evaluation"` - License string `json:"license"` - LicenseDescription string `json:"license_description"` - Notes string `json:"notes"` - Tags []string `json:"tags"` - SupportedInputModalities []string `json:"supported_input_modalities"` - SupportedOutputModalities []string `json:"supported_output_modalities"` - SupportedLanguages []string `json:"supported_languages"` - MaxOutputTokens int `json:"max_output_tokens"` - MaxInputTokens int `json:"max_input_tokens"` - RateLimitTier string `json:"rateLimitTier"` -} - -// ContextLimits returns a summary of the context limits for the model. -func (m *ModelDetails) ContextLimits() string { - return fmt.Sprintf("up to %d input tokens and %d output tokens", m.MaxInputTokens, m.MaxOutputTokens) -} diff --git a/internal/sse/eventreader.go b/internal/sse/event_reader.go similarity index 100% rename from internal/sse/eventreader.go rename to internal/sse/event_reader.go diff --git a/internal/sse/event_reader_test.go b/internal/sse/event_reader_test.go new file mode 100644 index 0000000..e8a5041 --- /dev/null +++ b/internal/sse/event_reader_test.go @@ -0,0 +1,81 @@ +package sse + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +type sampleContent struct { + Name string `json:"name"` + NestedData []*struct { + Count int `json:"count"` + Value string `json:"value"` + } `json:"nested_data"` +} + +type badReader struct{} + +func (br badReader) Read(p []byte) (n int, err error) { + return 0, io.ErrClosedPipe +} + +func TestEventReader(t *testing.T) { + t.Run("invalid type", func(t *testing.T) { + data := []string{ + "invaliddata: {\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"object\":\"chat.completion.chunk\",\"created\":1688594090,\"model\":\"gpt-4-0613\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"finish_reason\":null}]}\n\n", + } + + text := strings.NewReader(strings.Join(data, "\n")) + eventReader := NewEventReader[sampleContent](io.NopCloser(text)) + + firstEvent, err := eventReader.Read() + require.Empty(t, firstEvent) + require.EqualError(t, err, "unexpected event type: invaliddata") + }) + + t.Run("bad reader", func(t *testing.T) { + eventReader := NewEventReader[sampleContent](io.NopCloser(badReader{})) + defer eventReader.Close() + + firstEvent, err := eventReader.Read() + require.Empty(t, firstEvent) + require.ErrorIs(t, io.ErrClosedPipe, err) + }) + + t.Run("stream is closed before done", func(t *testing.T) { + buf := strings.NewReader("data: {}") + + eventReader := NewEventReader[sampleContent](io.NopCloser(buf)) + + evt, err := eventReader.Read() + require.Empty(t, evt) + require.NoError(t, err) + + evt, err = eventReader.Read() + require.Empty(t, evt) + require.EqualError(t, err, "incomplete stream") + }) + + t.Run("spaces around areas", func(t *testing.T) { + buf := strings.NewReader( + // spaces between data + "data: {\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"nested_data\":[{\"count\":0,\"value\":\"with-spaces\"}]}\n" + + // no spaces + "data:{\"name\":\"chatcmpl-7Z4kUpXX6HN85cWY28IXM4EwemLU3\",\"nested_data\":[{\"count\":0,\"value\":\"without-spaces\"}]}\n", + ) + + eventReader := NewEventReader[sampleContent](io.NopCloser(buf)) + + evt, err := eventReader.Read() + require.NoError(t, err) + require.Equal(t, "with-spaces", evt.NestedData[0].Value) + + evt, err = eventReader.Read() + require.NoError(t, err) + require.NotEmpty(t, evt) + require.Equal(t, "without-spaces", evt.NestedData[0].Value) + }) +} diff --git a/internal/ux/filtering.go b/internal/ux/filtering.go deleted file mode 100644 index f456c85..0000000 --- a/internal/ux/filtering.go +++ /dev/null @@ -1,9 +0,0 @@ -// Package ux provides utility functions around presentation and user experience. -package ux - -import "github.com/github/gh-models/internal/azuremodels" - -// IsChatModel returns true if the given model is for chat completions. -func IsChatModel(model *azuremodels.ModelSummary) bool { - return model.Task == "chat-completion" -} diff --git a/internal/ux/sorting_test.go b/internal/ux/sorting_test.go deleted file mode 100644 index 3516944..0000000 --- a/internal/ux/sorting_test.go +++ /dev/null @@ -1,24 +0,0 @@ -package ux - -import ( - "testing" - - "github.com/github/gh-models/internal/azuremodels" - "github.com/stretchr/testify/require" -) - -func TestSorting(t *testing.T) { - t.Run("SortModels sorts given slice in-place by friendly name, case-insensitive", func(t *testing.T) { - modelA := &azuremodels.ModelSummary{Name: "z", FriendlyName: "AARDVARK"} - modelB := &azuremodels.ModelSummary{Name: "y", FriendlyName: "betta"} - modelC := &azuremodels.ModelSummary{Name: "x", FriendlyName: "Cat"} - models := []*azuremodels.ModelSummary{modelB, modelA, modelC} - - SortModels(models) - - require.Equal(t, 3, len(models)) - require.Equal(t, "AARDVARK", models[0].FriendlyName) - require.Equal(t, "betta", models[1].FriendlyName) - require.Equal(t, "Cat", models[2].FriendlyName) - }) -}