Skip to content

Commit

Permalink
Merge pull request #8 from wesen/task/add-image-support
Browse files Browse the repository at this point in the history
Add support for image content in conversations
  • Loading branch information
wesen authored Jul 9, 2024
2 parents d06369b + 567e29a commit 0060012
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 2 deletions.
108 changes: 106 additions & 2 deletions pkg/conversation/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"encoding/json"
"fmt"
"github.com/google/uuid"
"io"
"os"
"path/filepath"
"strings"
"time"
)
Expand All @@ -17,6 +20,7 @@ const (
// See also the comment to refactor this in openai/helpers.go, where tool use information is actually stored in the metadata of the message
ContentTypeToolUse ContentType = "tool-use"
ContentTypeToolResult ContentType = "tool-result"
ContentTypeImage ContentType = "image"
)

// MessageContent is an interface for different types of node content.
Expand All @@ -36,8 +40,9 @@ const (
)

type ChatMessageContent struct {
Role Role `json:"role"`
Text string `json:"text"`
Role Role `json:"role"`
Text string `json:"text"`
Images []*ImageContent `json:"images"`
}

func (c *ChatMessageContent) ContentType() ContentType {
Expand Down Expand Up @@ -99,6 +104,105 @@ func (t *ToolResultContent) View() string {

var _ MessageContent = (*ToolResultContent)(nil)

type ImageDetail string

const (
ImageDetailLow ImageDetail = "low"
ImageDetailHigh ImageDetail = "high"
ImageDetailAuto ImageDetail = "auto"
)

type ImageContent struct {
ImageURL string `json:"imageURL"`
ImageContent []byte `json:"imageContent"`
ImageName string `json:"imageName"`
MediaType string `json:"mediaType"`
Detail ImageDetail `json:"detail"`
}

func NewImageContentFromFile(path string) (*ImageContent, error) {
// Check if the path is a URL
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return newImageContentFromURL(path)
}
return newImageContentFromLocalFile(path)
}

func newImageContentFromURL(url string) (*ImageContent, error) {
imageName := filepath.Base(url)

return &ImageContent{
ImageURL: url,
ImageName: imageName,
Detail: ImageDetailAuto,
}, nil
}

func newImageContentFromLocalFile(path string) (*ImageContent, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open file: %v", err)
}
defer func(file *os.File) {
_ = file.Close()
}(file)

content, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("failed to read file content: %v", err)
}

fileInfo, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("failed to get file info: %v", err)
}

if fileInfo.Size() > 20*1024*1024 {
return nil, fmt.Errorf("image size exceeds 20MB limit")
}

mediaType := getMediaTypeFromExtension(filepath.Ext(path))
if mediaType == "" {
return nil, fmt.Errorf("unsupported image format: %s", filepath.Ext(path))
}

return &ImageContent{
ImageContent: content,
ImageName: fileInfo.Name(),
MediaType: mediaType,
Detail: ImageDetailAuto,
}, nil
}

func getMediaTypeFromExtension(ext string) string {
switch strings.ToLower(ext) {
case ".png":
return "image/png"
case ".jpg", ".jpeg":
return "image/jpeg"
case ".webp":
return "image/webp"
case ".gif":
return "image/gif"
default:
return ""
}
}

func (i *ImageContent) ContentType() ContentType {
return ContentTypeImage
}

func (i *ImageContent) String() string {
return fmt.Sprintf("ImageContent{ImageURL: %s, ImageName: %s, Detail: %s}", i.ImageURL, i.ImageName, i.Detail)
}

func (i *ImageContent) View() string {
return i.String()
}

var _ MessageContent = (*ImageContent)(nil)

// Message represents a single message node in the conversation tree.
type Message struct {
ParentID NodeID `json:"parentID"`
Expand Down
6 changes: 6 additions & 0 deletions pkg/conversation/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ func (mn *Message) UnmarshalJSON(data []byte) error {
return err
}
mn.Content = content
case ContentTypeImage:
var content *ImageContent
if err := json.Unmarshal(mna.Content, &content); err != nil {
return err
}
mn.Content = content
default:
return errors.New("unknown content type")
}
Expand Down

0 comments on commit 0060012

Please sign in to comment.