Skip to content

Commit

Permalink
feat: fix bug for local model provider (casibase#750)
Browse files Browse the repository at this point in the history
* fix: bug for local model provider

* fix: bug for local model provider

* fix: bug for local model provider

* fix: add flushData for local model provider

* Update local.go

---------

Co-authored-by: Eric Luo <hsluoyz@qq.com>
  • Loading branch information
MartinRepo and hsluoyz authored Mar 4, 2024
1 parent d434386 commit 79924de
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions model/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, histor
var flushData func(string, io.Writer) error
if p.typ == "Local" {
client = getLocalClientFromUrl(p.secretKey, p.providerUrl)
flushData = flushDataOpenai
} else if p.typ == "Azure" {
client = getAzureClientFromToken(p.deploymentName, p.secretKey, p.providerUrl, p.apiVersion)
flushData = flushDataAzure
Expand Down Expand Up @@ -335,6 +336,7 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, histor
defer respStream.Close()

isLeadingReturn := true
var response strings.Builder
for {
completion, streamErr := respStream.Recv()
if streamErr != nil {
Expand All @@ -358,15 +360,13 @@ func (p *LocalModelProvider) QueryText(question string, writer io.Writer, histor
return nil, err
}

modelResult.PromptTokenCount += completion.Usage.PromptTokens
modelResult.ResponseTokenCount += completion.Usage.CompletionTokens
modelResult.TotalTokenCount += completion.Usage.TotalTokens
err = p.calculatePrice(modelResult)
_, err = response.WriteString(data)
if err != nil {
return nil, err
}
}

modelResult, err = getDefaultModelResult(model, question, response.String())
return modelResult, nil
} else {
return nil, fmt.Errorf("QueryText() error: unknown model type: %s", p.subType)
Expand Down
2 changes: 1 addition & 1 deletion model/openai_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func getOpenAiModelType(model string) string {
completionModels := []string{
"text-davinci-003", "text-davinci-002", "text-curie-001",
"text-babbage-001", "text-ada-001", "text-davinci-001",
"davinci-instruct-beta", "davinci", "curie-instruct-beta", "curie", "ada", "babbage",
"davinci-instruct-beta", "davinci", "curie-instruct-beta", "curie", "ada", "babbage", "custom-model",
}

for _, chatModel := range chatModels {
Expand Down

0 comments on commit 79924de

Please sign in to comment.