diff --git a/llms/googleai/googleai_llm_test.go b/llms/googleai/googleai_llm_test.go index 529d9bb98..1b46d12a6 100644 --- a/llms/googleai/googleai_llm_test.go +++ b/llms/googleai/googleai_llm_test.go @@ -48,6 +48,33 @@ func TestMultiContentText(t *testing.T) { assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content)) } +func TestMultiContentTextChatSequence(t *testing.T) { + t.Parallel() + llm := newClient(t) + + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: []llms.ContentPart{llms.TextContent{Text: "Name some countries"}}, + }, + { + Role: schema.ChatMessageTypeAI, + Parts: []llms.ContentPart{llms.TextContent{Text: "Spain and Lesotho"}}, + }, + { + Role: schema.ChatMessageTypeHuman, + Parts: []llms.ContentPart{llms.TextContent{Text: "Which if these is larger?"}}, + }, + } + + rsp, err := llm.GenerateContent(context.Background(), content, llms.WithModel("gemini-pro")) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "spain.*larger", strings.ToLower(c1.Content)) +} + func TestMultiContentImage(t *testing.T) { t.Parallel() llm := newClient(t) diff --git a/llms/openai/multicontent_test.go b/llms/openai/multicontent_test.go index ecf7b82c4..5d8118601 100644 --- a/llms/openai/multicontent_test.go +++ b/llms/openai/multicontent_test.go @@ -49,6 +49,33 @@ func TestMultiContentText(t *testing.T) { assert.Regexp(t, "dog|canid", strings.ToLower(c1.Content)) } +func TestMultiContentTextChatSequence(t *testing.T) { + t.Parallel() + llm := newChatClient(t) + + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: []llms.ContentPart{llms.TextContent{Text: "Name some countries"}}, + }, + { + Role: schema.ChatMessageTypeAI, + Parts: []llms.ContentPart{llms.TextContent{Text: "Spain and Lesotho"}}, + }, + { + Role: schema.ChatMessageTypeHuman, + Parts: []llms.ContentPart{llms.TextContent{Text: "Which if these is larger?"}}, + }, + } + + rsp, err := llm.GenerateContent(context.Background(), content) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "spain.*larger", strings.ToLower(c1.Content)) +} + func TestMultiContentImage(t *testing.T) { t.Parallel()