Skip to content

Commit

Permalink
feat: add TogetherAI integration and split Composer with Filter
Browse files Browse the repository at this point in the history
Add TogetherAI integration and split Composer with Filter
Closes #47
  • Loading branch information
samgozman authored Jan 1, 2024
2 parents 256a373 + 2063996 commit 40885c7
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 33 deletions.
1 change: 1 addition & 0 deletions .env_sample
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
TELEGRAM_CHANNEL_ID=
TELEGRAM_BOT_TOKEN=
OPENAI_TOKEN=
TOGETHER_AI_TOKEN=
# DSN in gorm format
POSTGRES_DSN="host=postgres user=postgres password=postgres dbname=finfeed port=5432 sslmode=disable"
SENTRY_DSN=https://public@sentry.example.com/1
2 changes: 1 addition & 1 deletion app.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (a *App) start() {
panic(err)
}

composer := NewComposer(a.cnf.env.OpenAiToken)
composer := NewComposer(a.cnf.env.OpenAiToken, a.cnf.env.TogetherAIToken)

marketJournalist := NewJournalist("MarketNews", []NewsProvider{
NewRssProvider("benzinga:large-cap", "https://www.benzinga.com/news/large-cap/feed"),
Expand Down
100 changes: 100 additions & 0 deletions composer/clients.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package composer

import (
"bytes"
"context"
"encoding/json"
"github.com/sashabaranov/go-openai"
"io"
"net/http"
)

// OpenAiClientInterface is an interface for OpenAI API client
type OpenAiClientInterface interface {
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (response openai.ChatCompletionResponse, error error)
}

// TogetherAIClientInterface is an interface for TogetherAI API client
type TogetherAIClientInterface interface {
CreateChatCompletion(ctx context.Context, options TogetherAIRequest) (TogetherAIResponse, error)
}

// TogetherAIRequest is a struct that contains options for TogetherAI API requests
type TogetherAIRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
TopK int `json:"top_k"`
RepetitionPenalty float64 `json:"repetition_penalty"`
}

// TogetherAIResponse is a struct that contains response from TogetherAI API
type TogetherAIResponse struct {
ID string `json:"id"`
Choices []struct {
Text string `json:"text"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
Created int64 `json:"created"`
Model string `json:"model"`
Object string `json:"object"`
}

// TogetherAI client to interact with TogetherAI API (replacement for OpenAI API in some cases)
type TogetherAI struct {
APIKey string
URL string
}

// CreateChatCompletion creates a new chat completion request to TogetherAI API
func (t *TogetherAI) CreateChatCompletion(ctx context.Context, options TogetherAIRequest) (TogetherAIResponse, error) {
var response TogetherAIResponse

bodyJSON, err := json.Marshal(options)
if err != nil {
return response, err
}

req, err := http.NewRequest("POST", t.URL, bytes.NewBuffer(bodyJSON))
if err != nil {
return response, err
}

req.Header.Set("Authorization", "Bearer "+t.APIKey)
req.Header.Set("Content-Type", "application/json")
req.WithContext(ctx)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return response, err
}

defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
return
}
}(resp.Body)

err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return response, err
}

return response, nil
}

// NewTogetherAI creates new TogetherAI client
func NewTogetherAI(apiKey string) *TogetherAI {
return &TogetherAI{
APIKey: apiKey,
URL: "https://api.together.xyz/completions",
}
}
66 changes: 58 additions & 8 deletions composer/composer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@ import (
"github.com/sashabaranov/go-openai"
)

type OpenAiClientInterface interface {
CreateChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (response openai.ChatCompletionResponse, error error)
}

// Composer is used to compose (rephrase) news and events, find some meta information about them,
// filter out some unnecessary stuff, summarise them and so on.
type Composer struct {
OpenAiClient OpenAiClientInterface
Config *PromptConfig
OpenAiClient OpenAiClientInterface
TogetherAIClient TogetherAIClientInterface
Config *PromptConfig
}

func NewComposer(oaiToken string) *Composer {
return &Composer{OpenAiClient: openai.NewClient(oaiToken), Config: DefaultPromptConfig()}
// NewComposer creates a new Composer instance with OpenAI and TogetherAI clients and default config
func NewComposer(oaiToken, tgrAiToken string) *Composer {
return &Composer{
OpenAiClient: openai.NewClient(oaiToken),
TogetherAIClient: NewTogetherAI(tgrAiToken),
Config: DefaultPromptConfig(),
}
}

// Compose creates a new AI-composed news from the given news list.
// It will also find some meta information about the news and events (markets, tickers, hashtags).
func (c *Composer) Compose(ctx context.Context, news journalist.NewsList) ([]*ComposedNews, error) {
// RemoveDuplicates out news that are not from today
var todayNews journalist.NewsList = lo.Filter(news, func(n *journalist.News, _ int) bool {
Expand Down Expand Up @@ -145,6 +151,50 @@ func (c *Composer) Summarise(ctx context.Context, headlines []*Headline, headlin
return h, nil
}

// Filter removes unnecessary news from the given news list using TogetherAI API.
func (c *Composer) Filter(ctx context.Context, news journalist.NewsList) (journalist.NewsList, error) {
if len(news) == 0 {
return nil, nil
}

// TODO: This can be optimised by using ToContentJSON() method.
// But it will require to map the response back to the original news list.
// Also prompt can be optimised to return only IDs of the news to reduce tokens count.
jsonNews, err := news.ToJSON()
if err != nil {
return nil, newErr(err, "Filter", "json.Marshal news").WithValue(fmt.Sprintf("%+v", news))
}

resp, err := c.TogetherAIClient.CreateChatCompletion(
ctx,
TogetherAIRequest{
Model: "mistralai/Mistral-7B-Instruct-v0.2",
Prompt: c.Config.FilterPromptInstruct(jsonNews),
MaxTokens: 2048,
Temperature: 0.7,
TopP: 0.7,
TopK: 50,
RepetitionPenalty: 1,
},
)
if err != nil {
return nil, newErr(err, "Filter", "TogetherAIClient.CreateChatCompletion")
}

matches, err := openaiJSONStringFixer(resp.Choices[0].Text)
if err != nil {
return nil, newErr(err, "Filter", "openaiJSONStringFixer")
}

var filtered journalist.NewsList
err = json.Unmarshal([]byte(matches), &filtered)
if err != nil {
return nil, newErr(err, "Filter", "json.Unmarshal").WithValue(resp.Choices[0].Text)
}

return filtered, nil
}

// Headline is the base data structure for the data to summarise
type Headline struct {
ID string `json:"id"`
Expand Down
137 changes: 135 additions & 2 deletions composer/composer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ func (m *MockOpenAiClient) CreateChatCompletion(ctx context.Context, req openai.
return args.Get(0).(openai.ChatCompletionResponse), args.Error(1)
}

type MockTogetherAIClient struct {
mock.Mock
}

func (m *MockTogetherAIClient) CreateChatCompletion(ctx context.Context, options TogetherAIRequest) (TogetherAIResponse, error) {
args := m.Called(ctx, options)
return args.Get(0).(TogetherAIResponse), args.Error(1)
}

func TestComposer_Compose(t *testing.T) {
news := journalist.NewsList{
{
Expand All @@ -46,7 +55,7 @@ func TestComposer_Compose(t *testing.T) {
Title: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Description: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Link: "https://www.cnbc.com/",
Date: time.Now().Add(-24 * time.Hour * 2).UTC(), // Should be filtered out
Date: time.Now().UTC(),
ProviderName: "cnbc",
},
}
Expand All @@ -68,7 +77,7 @@ func TestComposer_Compose(t *testing.T) {
ctx: context.Background(),
news: news,
},
expectedFilteredNews: journalist.NewsList{news[0], news[1]},
expectedFilteredNews: journalist.NewsList{news[0], news[1], news[2]},
want: []*ComposedNews{
{
ID: "1",
Expand All @@ -84,6 +93,13 @@ func TestComposer_Compose(t *testing.T) {
Markets: []string{},
Hashtags: []string{"interestrates"},
},
{
ID: "3",
Text: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Tickers: []string{},
Markets: []string{},
Hashtags: []string{},
},
},
wantErr: false,
},
Expand Down Expand Up @@ -291,3 +307,120 @@ func TestComposer_Summarise(t *testing.T) {
})
}
}

func TestComposer_Filter(t *testing.T) {
type args struct {
ctx context.Context
news journalist.NewsList
}
tests := []struct {
name string
args args
want journalist.NewsList
wantErr bool
}{
{
name: "Should pass and return correct filtered news",
args: args{
ctx: context.Background(),
news: journalist.NewsList{
{
ID: "1",
Title: "Ray Dalio says U.S. reaching an inflection point where the debt problem quickly gets even worse",
Description: "Soaring U.S. government debt is reaching a point where it will begin creating larger problems, the hedge fund titan said Friday.",
Link: "https://www.cnbc.com/",
Date: time.Now().UTC(),
ProviderName: "cnbc",
},
{
ID: "2",
Title: "The market thinks the Fed is going to start cutting rates aggressively. Investors could be in for a letdown",
Description: "Markets may be at least a tad optimistic, particularly considering the cautious approach central bank officials have taken.",
Link: "https://www.cnbc.com/",
Date: time.Now().UTC(),
ProviderName: "cnbc",
},
{
ID: "3",
Title: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Description: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Link: "https://www.cnbc.com/",
Date: time.Now().UTC(),
ProviderName: "cnbc",
},
},
},
want: journalist.NewsList{
{
ID: "1",
Title: "Ray Dalio says U.S. reaching an inflection point where the debt problem quickly gets even worse",
Description: "Soaring U.S. government debt is reaching a point where it will begin creating larger problems, the hedge fund titan said Friday.",
Link: "https://www.cnbc.com/",
Date: time.Now().UTC(),
ProviderName: "cnbc",
},
{
ID: "3",
Title: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Description: "Wholesale prices fell 0.5% in October for biggest monthly drop since April 2020",
Link: "https://www.cnbc.com/",
Date: time.Now().UTC(),
ProviderName: "cnbc",
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := new(MockTogetherAIClient)
defConf := DefaultPromptConfig()

// Set expectations for the mock client
if tt.wantErr {
mockError := errors.New("some error")
mockClient.On("CreateChatCompletion", mock.Anything, mock.Anything).Return(TogetherAIResponse{}, mockError)
} else {
jsonNews, _ := tt.args.news.ToJSON()
expectedJsonNews, _ := tt.want.ToJSON()

mockClient.On("CreateChatCompletion",
mock.Anything,
TogetherAIRequest{
Model: "mistralai/Mistral-7B-Instruct-v0.2",
Prompt: defConf.FilterPromptInstruct(jsonNews),
MaxTokens: 2048,
Temperature: 0.7,
TopP: 0.7,
TopK: 50,
RepetitionPenalty: 1,
},
).Return(TogetherAIResponse{
Choices: []struct {
Text string `json:"text"`
}{
{
Text: expectedJsonNews,
},
},
}, nil)
}

c := &Composer{
TogetherAIClient: mockClient,
Config: DefaultPromptConfig(),
}
got, err := c.Filter(tt.args.ctx, tt.args.news)
if (err != nil) != tt.wantErr {
t.Errorf("Filter() error = %v, wantErr %v", err, tt.wantErr)
return
}

for i, n := range got {
if !reflect.DeepEqual(n, tt.want[i]) {
t.Errorf("Filter() = %v, want %v", n, tt.want[i])
}
}
})
}
}
Loading

0 comments on commit 40885c7

Please sign in to comment.