Skip to content

Commit

Permalink
feat(llms): Add streaming support to OpenAI (davidmigloz#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored and KennethKnudsen97 committed Apr 22, 2024
1 parent d4e5cb3 commit 3f3cf0f
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 9 deletions.
3 changes: 2 additions & 1 deletion packages/langchain_openai/lib/src/llms/models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import 'package:langchain/langchain.dart';
import 'package:openai_dart/openai_dart.dart';

extension CreateCompletionResponseMapper on CreateCompletionResponse {
LLMResult toLLMResult() {
LLMResult toLLMResult({final bool streaming = false}) {
return LLMResult(
generations: choices
.map((final choice) => choice.toLLMGeneration())
Expand All @@ -13,6 +13,7 @@ extension CreateCompletionResponseMapper on CreateCompletionResponse {
'created': created,
'model': model,
},
streaming: streaming,
);
}
}
Expand Down
39 changes: 39 additions & 0 deletions packages/langchain_openai/lib/src/llms/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,45 @@ class OpenAI extends BaseLLM<OpenAIOptions> {
return completion.toLLMResult();
}

@override
Stream<LLMResult> stream(
final PromptValue input, {
final OpenAIOptions? options,
}) {
return _client
.createCompletionStream(
request: CreateCompletionRequest(
model: CompletionModel.string(model),
prompt: CompletionPrompt.string(input.toString()),
bestOf: bestOf,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
logprobs: logprobs,
maxTokens: maxTokens,
n: n,
presencePenalty: presencePenalty,
stop: options?.stop != null
? CompletionStop.arrayString(options!.stop!)
: null,
suffix: suffix,
temperature: temperature,
topP: topP,
user: options?.user ?? user,
),
)
.map((final completion) => completion.toLLMResult(streaming: true));
}

@override
Stream<LLMResult> streamFromInputStream(
final Stream<PromptValue> inputStream, {
final OpenAIOptions? options,
}) {
return inputStream.asyncExpand((final input) {
return stream(input, options: options);
});
}

/// Tokenizes the given prompt using tiktoken with the encoding used by the
/// [model]. If an encoding model is specified in [encoding] field, that
/// encoding is used instead.
Expand Down
37 changes: 29 additions & 8 deletions packages/langchain_openai/test/llms/openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -67,36 +67,57 @@ void main() {
});

test('Test OpenAI wrapper with multiple completions', () async {
final chat = OpenAI(apiKey: openaiApiKey, n: 5, bestOf: 5);
final res = await chat.generate('Hello, how are you?');
final llm = OpenAI(apiKey: openaiApiKey, n: 5, bestOf: 5);
final res = await llm.generate('Hello, how are you?');
expect(res.generations.length, 5);
for (final generation in res.generations) {
expect(generation.output, isNotEmpty);
}
});

test('Test tokenize', () async {
final chat = OpenAI(apiKey: openaiApiKey);
final llm = OpenAI(apiKey: openaiApiKey);
const text = 'Hello, how are you?';

final tokens = await chat.tokenize(PromptValue.string(text));
final tokens = await llm.tokenize(PromptValue.string(text));
expect(tokens, [15496, 11, 703, 389, 345, 30]);
});

test('Test different encoding than the model', () async {
final chat = OpenAI(apiKey: openaiApiKey, encoding: 'cl100k_base');
final llm = OpenAI(apiKey: openaiApiKey, encoding: 'cl100k_base');
const text = 'Hello, how are you?';

final tokens = await chat.tokenize(PromptValue.string(text));
final tokens = await llm.tokenize(PromptValue.string(text));
expect(tokens, [9906, 11, 1268, 527, 499, 30]);
});

test('Test countTokens', () async {
final chat = OpenAI(apiKey: openaiApiKey);
final llm = OpenAI(apiKey: openaiApiKey);
const text = 'Hello, how are you?';

final numTokens = await chat.countTokens(PromptValue.string(text));
final numTokens = await llm.countTokens(PromptValue.string(text));
expect(numTokens, 6);
});

test('Test streaming', () async {
final promptTemplate = PromptTemplate.fromTemplate(
'List the numbers from 1 to {max_num} in order without any spaces or commas',
);
final llm = OpenAI(apiKey: openaiApiKey);
const stringOutputParser = StringOutputParser<String>();

final chain = promptTemplate.pipe(llm).pipe(stringOutputParser);

final stream = chain.stream({'max_num': '9'});

String content = '';
int count = 0;
await for (final res in stream) {
content += res;
count++;
}
expect(count, greaterThan(1));
expect(content, contains('123456789'));
});
});
}

0 comments on commit 3f3cf0f

Please sign in to comment.