From 567e29aac83f4c4139313534f6497fb874eead76 Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Tue, 9 Jul 2024 15:06:16 -0400 Subject: [PATCH] :sparkles: Add images to conversation messages --- pkg/conversation/message.go | 108 +++++++++++++++++++++++++++++++++++- pkg/conversation/tree.go | 6 ++ 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/pkg/conversation/message.go b/pkg/conversation/message.go index c094a05..d540a25 100644 --- a/pkg/conversation/message.go +++ b/pkg/conversation/message.go @@ -4,6 +4,9 @@ import ( "encoding/json" "fmt" "github.com/google/uuid" + "io" + "os" + "path/filepath" "strings" "time" ) @@ -17,6 +20,7 @@ const ( // See also the comment to refactor this in openai/helpers.go, where tool use information is actually stored in the metadata of the message ContentTypeToolUse ContentType = "tool-use" ContentTypeToolResult ContentType = "tool-result" + ContentTypeImage ContentType = "image" ) // MessageContent is an interface for different types of node content. @@ -36,8 +40,9 @@ const ( ) type ChatMessageContent struct { - Role Role `json:"role"` - Text string `json:"text"` + Role Role `json:"role"` + Text string `json:"text"` + Images []*ImageContent `json:"images"` } func (c *ChatMessageContent) ContentType() ContentType { @@ -99,6 +104,105 @@ func (t *ToolResultContent) View() string { var _ MessageContent = (*ToolResultContent)(nil) +type ImageDetail string + +const ( + ImageDetailLow ImageDetail = "low" + ImageDetailHigh ImageDetail = "high" + ImageDetailAuto ImageDetail = "auto" +) + +type ImageContent struct { + ImageURL string `json:"imageURL"` + ImageContent []byte `json:"imageContent"` + ImageName string `json:"imageName"` + MediaType string `json:"mediaType"` + Detail ImageDetail `json:"detail"` +} + +func NewImageContentFromFile(path string) (*ImageContent, error) { + // Check if the path is a URL + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return newImageContentFromURL(path) + } + return newImageContentFromLocalFile(path) +} + +func newImageContentFromURL(url string) (*ImageContent, error) { + imageName := filepath.Base(url) + + return &ImageContent{ + ImageURL: url, + ImageName: imageName, + Detail: ImageDetailAuto, + }, nil +} + +func newImageContentFromLocalFile(path string) (*ImageContent, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("failed to open file: %v", err) + } + defer func(file *os.File) { + _ = file.Close() + }(file) + + content, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("failed to read file content: %v", err) + } + + fileInfo, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("failed to get file info: %v", err) + } + + if fileInfo.Size() > 20*1024*1024 { + return nil, fmt.Errorf("image size exceeds 20MB limit") + } + + mediaType := getMediaTypeFromExtension(filepath.Ext(path)) + if mediaType == "" { + return nil, fmt.Errorf("unsupported image format: %s", filepath.Ext(path)) + } + + return &ImageContent{ + ImageContent: content, + ImageName: fileInfo.Name(), + MediaType: mediaType, + Detail: ImageDetailAuto, + }, nil +} + +func getMediaTypeFromExtension(ext string) string { + switch strings.ToLower(ext) { + case ".png": + return "image/png" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".webp": + return "image/webp" + case ".gif": + return "image/gif" + default: + return "" + } +} + +func (i *ImageContent) ContentType() ContentType { + return ContentTypeImage +} + +func (i *ImageContent) String() string { + return fmt.Sprintf("ImageContent{ImageURL: %s, ImageName: %s, Detail: %s}", i.ImageURL, i.ImageName, i.Detail) +} + +func (i *ImageContent) View() string { + return i.String() +} + +var _ MessageContent = (*ImageContent)(nil) + // Message represents a single message node in the conversation tree. type Message struct { ParentID NodeID `json:"parentID"` diff --git a/pkg/conversation/tree.go b/pkg/conversation/tree.go index fa68670..55a4afd 100644 --- a/pkg/conversation/tree.go +++ b/pkg/conversation/tree.go @@ -70,6 +70,12 @@ func (mn *Message) UnmarshalJSON(data []byte) error { return err } mn.Content = content + case ContentTypeImage: + var content *ImageContent + if err := json.Unmarshal(mna.Content, &content); err != nil { + return err + } + mn.Content = content default: return errors.New("unknown content type") }