Skip to content

Commit

Permalink
Fix TestUpdateMessages()
Browse files Browse the repository at this point in the history
  • Loading branch information
hsluoyz committed Mar 2, 2024
1 parent b10a758 commit 538386c
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 8 deletions.
2 changes: 1 addition & 1 deletion controllers/message_answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (c *ApiController) GetMessageAnswer() {
}

writer := &RefinedWriter{*c.Ctx.ResponseWriter, *NewCleaner(6), []byte{}}
history, err := object.GetRecentRawMessages(chat.Name, store.MemoryLimit)
history, err := object.GetRecentRawMessages(chat.Name, message.CreatedTime, store.MemoryLimit)
if err != nil {
c.ResponseErrorStream(err.Error())
return
Expand Down
4 changes: 2 additions & 2 deletions object/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,14 @@ func (message *Message) GetId() string {
return fmt.Sprintf("%s/%s", message.Owner, message.Name)
}

func GetRecentRawMessages(chat string, memoryLimit int) ([]*model.RawMessage, error) {
func GetRecentRawMessages(chat string, createdTime string, memoryLimit int) ([]*model.RawMessage, error) {
res := []*model.RawMessage{}
if memoryLimit == 0 {
return res, nil
}

messages := []*Message{}
err := adapter.engine.Desc("created_time").Limit(2*memoryLimit, 2).Find(&messages, &Message{Chat: chat})
err := adapter.engine.Where("created_time <= ?", createdTime).Desc("created_time").Limit(2*memoryLimit, 2).Find(&messages, &Message{Chat: chat})
if err != nil {
return nil, err
}
Expand Down
84 changes: 79 additions & 5 deletions object/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,36 @@
package object

import (
"fmt"
"testing"

"github.com/casibase/casibase/embedding"
"github.com/casibase/casibase/model"
)

func TestUpdateMessages(t *testing.T) {
InitConfig()

messages, err := GetGlobalMessages()
store, err := GetDefaultStore("admin")
if err != nil {
panic(err)
}

for _, message := range messages {
if message.Author != "AI" && message.Text != "" && (message.TokenCount == 0 || message.Price == 0) {
var defaultEmbeddingResult *embedding.EmbeddingResult
defaultEmbeddingResult, err = embedding.GetDefaultEmbeddingResult("text-embedding-ada-002", message.Text)
allMessages, err := GetGlobalMessages()
if err != nil {
panic(err)
}

modelSubType := "gpt-4-vision-preview"
maxTokens := model.GetOpenAiMaxTokens(modelSubType)

for i, message := range allMessages {
if message.Text == "" || (message.TokenCount != 0 && message.Price != 0) {
continue
}

if message.Author != "AI" {
defaultEmbeddingResult, err := embedding.GetDefaultEmbeddingResult("text-embedding-ada-002", message.Text)
if err != nil {
panic(err)
}
Expand All @@ -43,6 +56,67 @@ func TestUpdateMessages(t *testing.T) {
message.Price = defaultEmbeddingResult.Price
message.Currency = defaultEmbeddingResult.Currency

_, err = UpdateMessage(message.GetId(), message)
if err != nil {
panic(err)
}
} else {
question := store.Welcome
if message.ReplyTo != "Welcome" {
questionMessage, err := GetMessage(message.ReplyTo)
if err != nil {
panic(err)
}

question = questionMessage.Text
}

history, err := GetRecentRawMessages(message.Chat, message.CreatedTime, store.MemoryLimit)
if err != nil {
panic(err)
}

prompt := store.Prompt
knowledge := []*model.RawMessage{}

rawMessages, err := model.OpenaiGenerateMessages(prompt, question, history, knowledge, modelSubType, maxTokens)
if err != nil {
panic(err)
}

messages, err := model.OpenaiRawMessagesToGpt4VisionMessages(rawMessages)
if err != nil {
panic(err)
}

// https://github.com/sashabaranov/go-openai/pull/223#issuecomment-1494372875
promptTokenCount, err := model.OpenaiNumTokensFromMessages(messages, modelSubType)
if err != nil {
panic(err)
}

responseTokenCount, err := model.GetTokenSize(modelSubType, message.Text)
if err != nil {
panic(err)
}

modelResult := &model.ModelResult{}
modelResult.PromptTokenCount = promptTokenCount
modelResult.ResponseTokenCount = responseTokenCount
modelResult.TotalTokenCount = modelResult.PromptTokenCount + modelResult.ResponseTokenCount

p, err := model.NewLocalModelProvider("", modelSubType, "", 0, 0, 0, 0, "")
err = p.CalculatePrice(modelResult)
if err != nil {
panic(err)
}

message.TokenCount = modelResult.TotalTokenCount
message.Price = modelResult.TotalPrice
message.Currency = modelResult.Currency

fmt.Printf("[%d/%d] message: %s, user: %s, author: %s, tokenCount: %d, price: %f\n", i+1, len(allMessages), message.Name, message.User, message.Author, message.TokenCount, message.Price)

_, err = UpdateMessage(message.GetId(), message)
if err != nil {
panic(err)
Expand Down

0 comments on commit 538386c

Please sign in to comment.