From 3f3cf0fdb05000d262292e9404730bc0f53f811a Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Mon, 6 Nov 2023 22:09:17 +0100 Subject: [PATCH] feat(llms): Add streaming support to OpenAI (#196) --- .../lib/src/llms/models/mappers.dart | 3 +- .../langchain_openai/lib/src/llms/openai.dart | 39 +++++++++++++++++++ .../test/llms/openai_test.dart | 37 ++++++++++++++---- 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/packages/langchain_openai/lib/src/llms/models/mappers.dart b/packages/langchain_openai/lib/src/llms/models/mappers.dart index 2bb1ae76..53564f40 100644 --- a/packages/langchain_openai/lib/src/llms/models/mappers.dart +++ b/packages/langchain_openai/lib/src/llms/models/mappers.dart @@ -2,7 +2,7 @@ import 'package:langchain/langchain.dart'; import 'package:openai_dart/openai_dart.dart'; extension CreateCompletionResponseMapper on CreateCompletionResponse { - LLMResult toLLMResult() { + LLMResult toLLMResult({final bool streaming = false}) { return LLMResult( generations: choices .map((final choice) => choice.toLLMGeneration()) @@ -13,6 +13,7 @@ extension CreateCompletionResponseMapper on CreateCompletionResponse { 'created': created, 'model': model, }, + streaming: streaming, ); } } diff --git a/packages/langchain_openai/lib/src/llms/openai.dart b/packages/langchain_openai/lib/src/llms/openai.dart index 5a0d9a1e..10099d2f 100644 --- a/packages/langchain_openai/lib/src/llms/openai.dart +++ b/packages/langchain_openai/lib/src/llms/openai.dart @@ -257,6 +257,45 @@ class OpenAI extends BaseLLM { return completion.toLLMResult(); } + @override + Stream stream( + final PromptValue input, { + final OpenAIOptions? options, + }) { + return _client + .createCompletionStream( + request: CreateCompletionRequest( + model: CompletionModel.string(model), + prompt: CompletionPrompt.string(input.toString()), + bestOf: bestOf, + frequencyPenalty: frequencyPenalty, + logitBias: logitBias, + logprobs: logprobs, + maxTokens: maxTokens, + n: n, + presencePenalty: presencePenalty, + stop: options?.stop != null + ? CompletionStop.arrayString(options!.stop!) + : null, + suffix: suffix, + temperature: temperature, + topP: topP, + user: options?.user ?? user, + ), + ) + .map((final completion) => completion.toLLMResult(streaming: true)); + } + + @override + Stream streamFromInputStream( + final Stream inputStream, { + final OpenAIOptions? options, + }) { + return inputStream.asyncExpand((final input) { + return stream(input, options: options); + }); + } + /// Tokenizes the given prompt using tiktoken with the encoding used by the /// [model]. If an encoding model is specified in [encoding] field, that /// encoding is used instead. diff --git a/packages/langchain_openai/test/llms/openai_test.dart b/packages/langchain_openai/test/llms/openai_test.dart index 74d1f796..aa65f092 100644 --- a/packages/langchain_openai/test/llms/openai_test.dart +++ b/packages/langchain_openai/test/llms/openai_test.dart @@ -67,8 +67,8 @@ void main() { }); test('Test OpenAI wrapper with multiple completions', () async { - final chat = OpenAI(apiKey: openaiApiKey, n: 5, bestOf: 5); - final res = await chat.generate('Hello, how are you?'); + final llm = OpenAI(apiKey: openaiApiKey, n: 5, bestOf: 5); + final res = await llm.generate('Hello, how are you?'); expect(res.generations.length, 5); for (final generation in res.generations) { expect(generation.output, isNotEmpty); @@ -76,27 +76,48 @@ void main() { }); test('Test tokenize', () async { - final chat = OpenAI(apiKey: openaiApiKey); + final llm = OpenAI(apiKey: openaiApiKey); const text = 'Hello, how are you?'; - final tokens = await chat.tokenize(PromptValue.string(text)); + final tokens = await llm.tokenize(PromptValue.string(text)); expect(tokens, [15496, 11, 703, 389, 345, 30]); }); test('Test different encoding than the model', () async { - final chat = OpenAI(apiKey: openaiApiKey, encoding: 'cl100k_base'); + final llm = OpenAI(apiKey: openaiApiKey, encoding: 'cl100k_base'); const text = 'Hello, how are you?'; - final tokens = await chat.tokenize(PromptValue.string(text)); + final tokens = await llm.tokenize(PromptValue.string(text)); expect(tokens, [9906, 11, 1268, 527, 499, 30]); }); test('Test countTokens', () async { - final chat = OpenAI(apiKey: openaiApiKey); + final llm = OpenAI(apiKey: openaiApiKey); const text = 'Hello, how are you?'; - final numTokens = await chat.countTokens(PromptValue.string(text)); + final numTokens = await llm.countTokens(PromptValue.string(text)); expect(numTokens, 6); }); + + test('Test streaming', () async { + final promptTemplate = PromptTemplate.fromTemplate( + 'List the numbers from 1 to {max_num} in order without any spaces or commas', + ); + final llm = OpenAI(apiKey: openaiApiKey); + const stringOutputParser = StringOutputParser(); + + final chain = promptTemplate.pipe(llm).pipe(stringOutputParser); + + final stream = chain.stream({'max_num': '9'}); + + String content = ''; + int count = 0; + await for (final res in stream) { + content += res; + count++; + } + expect(count, greaterThan(1)); + expect(content, contains('123456789')); + }); }); }