Skip to content

Commit

Permalink
add: added relevance score to rerank task
Browse files Browse the repository at this point in the history
  • Loading branch information
namwoam committed Jul 4, 2024
1 parent 1c911ba commit d4f640c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
10 changes: 5 additions & 5 deletions ai/cohere/v0/component_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func TestComponent_Tasks(t *testing.T) {
wantResp rerankOutput
}{
input: map[string]any{"documents": []string{"a", "b", "c", "d"}},
wantResp: rerankOutput{Ranking: []string{"d", "c", "b", "a"}, Usage: rerankUsage{Search: 5}},
wantResp: rerankOutput{Ranking: []string{"d", "c", "b", "a"}, Usage: rerankUsage{Search: 5}, Relevance: []float64{10, 9, 8, 7}},
}
c.Run("ok - task rerank", func(c *qt.C) {
setup, err := structpb.NewStruct(map[string]any{
Expand Down Expand Up @@ -178,10 +178,10 @@ func (m *MockCohereClient) generateRerank(request cohereSDK.RerankRequest) (cohe
{Text: request.Documents[0].String},
}
result := []*cohereSDK.RerankResponseResultsItem{
{Document: &documents[0]},
{Document: &documents[1]},
{Document: &documents[2]},
{Document: &documents[3]},
{Document: &documents[0], RelevanceScore: 10},
{Document: &documents[1], RelevanceScore: 9},
{Document: &documents[2], RelevanceScore: 8},
{Document: &documents[3], RelevanceScore: 7},
}
searchCnt := float64(5)
bill := cohereSDK.ApiMetaBilledUnits{SearchUnits: &searchCnt}
Expand Down
13 changes: 9 additions & 4 deletions ai/cohere/v0/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ type rerankInput struct {
}

type rerankOutput struct {
Ranking []string `json:"ranking"`
Usage rerankUsage `json:"usage"`
Ranking []string `json:"ranking"`
Usage rerankUsage `json:"usage"`
Relevance []float64 `json:"relevance"`
}

type rerankUsage struct {
Expand All @@ -40,6 +41,7 @@ func (e *execution) taskRerank(in *structpb.Struct) (*structpb.Struct, error) {
}
documents = append(documents, &document)
}

returnDocument := true
rankFields := []string{"text"}
req := cohereSDK.RerankRequest{
Expand All @@ -54,7 +56,9 @@ func (e *execution) taskRerank(in *structpb.Struct) (*structpb.Struct, error) {
return nil, err
}
newRanking := []string{}
relevance := []float64{}
for _, rankResult := range resp.Results {
relevance = append(relevance, rankResult.RelevanceScore)
newRanking = append(newRanking, rankResult.Document.Text)
}

Expand All @@ -67,8 +71,9 @@ func (e *execution) taskRerank(in *structpb.Struct) (*structpb.Struct, error) {
}

outputStruct := rerankOutput{
Ranking: newRanking,
Usage: rerankUsage{Search: int(*bills.SearchUnits)},
Ranking: newRanking,
Usage: rerankUsage{Search: int(*bills.SearchUnits)},
Relevance: relevance,
}

outputJSON, err := json.Marshal(outputStruct)
Expand Down

0 comments on commit d4f640c

Please sign in to comment.