diff --git a/packages/langchain/lib/src/core/runnable/base.dart b/packages/langchain/lib/src/core/runnable/base.dart index 071fa46a..14e994e7 100644 --- a/packages/langchain/lib/src/core/runnable/base.dart +++ b/packages/langchain/lib/src/core/runnable/base.dart @@ -1,5 +1,7 @@ import 'dart:async'; +import 'package:meta/meta.dart'; + import '../base.dart'; import 'binding.dart'; import 'function.dart'; @@ -101,11 +103,43 @@ abstract class Runnable stream( + final RunInput input, { + final CallOptions? options, + }) { + return streamFromInputStream( + Stream.value(input).asBroadcastStream(), + options: options, + ); + } + + /// Streams the output of invoking the [Runnable] on the given [inputStream]. + /// + /// - [inputStream] - the input stream to invoke the [Runnable] on. + /// - [options] - the options to use when invoking the [Runnable]. + @protected + Stream streamFromInputStream( + final Stream inputStream, { + final CallOptions? options, + }) { + // By default, it just emits the result of calling invoke + // Subclasses should override this method if they support streaming output + return inputStream.asyncMap( + // ignore: discarded_futures + (final input) => invoke(input, options: options), + ); + } + /// Pipes the output of this [Runnable] into another [Runnable]. /// /// - [next] - the [Runnable] to pipe the output into. - RunnableSequence pipe( - final Runnable next, + RunnableSequence pipe( + final Runnable next, ) { return RunnableSequence( first: this, diff --git a/packages/langchain/lib/src/core/runnable/binding.dart b/packages/langchain/lib/src/core/runnable/binding.dart index b07efbbe..18cca773 100644 --- a/packages/langchain/lib/src/core/runnable/binding.dart +++ b/packages/langchain/lib/src/core/runnable/binding.dart @@ -62,4 +62,12 @@ class RunnableBinding streamFromInputStream( + final Stream inputStream, { + final CallOptions? options, + }) { + return bound.streamFromInputStream(inputStream, options: options ?? this.options); + } } diff --git a/packages/langchain/lib/src/core/runnable/extensions.dart b/packages/langchain/lib/src/core/runnable/extensions.dart index 378b903b..90eaec09 100644 --- a/packages/langchain/lib/src/core/runnable/extensions.dart +++ b/packages/langchain/lib/src/core/runnable/extensions.dart @@ -19,7 +19,7 @@ extension RunnableX< /// /// - [next] - the [Runnable] to pipe the output into. RunnableSequence operator |( - final Runnable next, + final Runnable next, ) { return pipe(next); } diff --git a/packages/langchain/lib/src/core/runnable/map.dart b/packages/langchain/lib/src/core/runnable/map.dart index b7d3d0cc..5910485b 100644 --- a/packages/langchain/lib/src/core/runnable/map.dart +++ b/packages/langchain/lib/src/core/runnable/map.dart @@ -1,3 +1,5 @@ +import 'package:async/async.dart' show StreamGroup; + import '../base.dart'; import 'base.dart'; @@ -64,4 +66,18 @@ class RunnableMap return output; } + + @override + Stream> streamFromInputStream( + final Stream inputStream, { + final BaseLangChainOptions? options, + }) { + return StreamGroup.merge( + steps.entries.map((final entry) { + return entry.value.streamFromInputStream(inputStream, options: options).map( + (final output) => {entry.key: output}, + ); + }), + ); + } } diff --git a/packages/langchain/lib/src/core/runnable/sequence.dart b/packages/langchain/lib/src/core/runnable/sequence.dart index fefd888c..30e1490b 100644 --- a/packages/langchain/lib/src/core/runnable/sequence.dart +++ b/packages/langchain/lib/src/core/runnable/sequence.dart @@ -109,18 +109,33 @@ class RunnableSequence return last.invoke(nextStepInput, options: options); } + @override + Stream streamFromInputStream( + final Stream inputStream, { + final BaseLangChainOptions? options, + }) { + var nextStepStream = first.streamFromInputStream(inputStream); + + for (final step in middle) { + nextStepStream = step.streamFromInputStream(nextStepStream); + } + + return last.streamFromInputStream(nextStepStream); + } + /// Pipes the output of this [RunnableSequence] into another [Runnable]. /// /// - [next] - the [Runnable] to pipe the output into. @override - RunnableSequence pipe( - final Runnable next, + RunnableSequence pipe( + final Runnable next, ) { if (next is RunnableSequence) { + final nextSeq = next as RunnableSequence; return RunnableSequence( first: first, - middle: [...middle, last, next.first, ...next.middle], - last: next.last, + middle: [...middle, last, nextSeq.first, ...nextSeq.middle], + last: nextSeq.last, ); } else { return RunnableSequence( diff --git a/packages/langchain/lib/src/model_io/chat_models/fake.dart b/packages/langchain/lib/src/model_io/chat_models/fake.dart index 5b130e1e..50b841ba 100644 --- a/packages/langchain/lib/src/model_io/chat_models/fake.dart +++ b/packages/langchain/lib/src/model_io/chat_models/fake.dart @@ -40,7 +40,8 @@ class FakeChatModel extends SimpleChatModel { /// {@template fake_echo_llm} /// Fake Chat Model for testing. -/// It just returns the content of the last message of the prompt. +/// It just returns the content of the last message of the prompt +/// or streams it char by char. /// {@endtemplate} class FakeEchoChatModel extends SimpleChatModel { /// {@macro fake_echo_llm} @@ -57,6 +58,24 @@ class FakeEchoChatModel extends SimpleChatModel { return Future.value(messages.last.content); } + @override + Stream streamFromInputStream( + final Stream inputStream, { + final ChatModelOptions? options, + }) { + return inputStream.asyncExpand( + (final input) { + final prompt = input.toChatMessages().first.content.split(''); + return Stream.fromIterable(prompt).map( + (final char) => ChatResult( + generations: [ChatGeneration(ChatMessage.ai(char))], + streaming: true, + ), + ); + }, + ); + } + @override Future> tokenize(final PromptValue promptValue) async { return promptValue diff --git a/packages/langchain/lib/src/model_io/chat_models/models/models.dart b/packages/langchain/lib/src/model_io/chat_models/models/models.dart index 328e8dd8..783daa47 100644 --- a/packages/langchain/lib/src/model_io/chat_models/models/models.dart +++ b/packages/langchain/lib/src/model_io/chat_models/models/models.dart @@ -11,28 +11,8 @@ class ChatModelOptions extends LanguageModelOptions { const ChatModelOptions(); } -/// {@template chat_result} /// Result returned by the Chat Model. -/// {@endtemplate} -@immutable -class ChatResult extends LanguageModelResult { - /// {@macro chat_result} - const ChatResult({ - required super.generations, - super.usage, - super.modelOutput, - }); - - @override - String toString() { - return ''' -ChatResult{ - generations: $generations, - usage: $usage, - modelOutput: $modelOutput}, -'''; - } -} +typedef ChatResult = LanguageModelResult; /// {@template chat_generation} /// Output of a single generation. @@ -48,6 +28,19 @@ class ChatGeneration extends LanguageModelGeneration { @override String get outputAsString => output.content; + @override + LanguageModelGeneration concat( + final LanguageModelGeneration other, + ) { + return ChatGeneration( + output.concat(other.output), + generationInfo: { + ...?generationInfo, + ...?other.generationInfo, + }, + ); + } + @override String toString() { return ''' @@ -118,6 +111,9 @@ sealed class ChatMessage { /// The content of the message. final String content; + + /// Merges this message with another by concatenating the content. + ChatMessage concat(final ChatMessage other); } /// {@template system_chat_message} @@ -140,6 +136,11 @@ class SystemChatMessage extends ChatMessage { @override int get hashCode => content.hashCode; + @override + SystemChatMessage concat(final ChatMessage other) { + return SystemChatMessage(content: content + other.content); + } + @override String toString() { return ''' @@ -184,6 +185,14 @@ class HumanChatMessage extends ChatMessage { @override int get hashCode => content.hashCode ^ example.hashCode; + @override + HumanChatMessage concat(final ChatMessage other) { + return HumanChatMessage( + content: content + other.content, + example: example, + ); + } + @override String toString() { return ''' @@ -237,6 +246,32 @@ class AIChatMessage extends ChatMessage { @override int get hashCode => content.hashCode ^ example.hashCode; + @override + AIChatMessage concat(final ChatMessage other) { + if (other is AIChatMessage) { + return AIChatMessage( + content: content + other.content, + functionCall: functionCall != null || other.functionCall != null + ? AIChatMessageFunctionCall( + name: (functionCall?.name ?? '') + + (other.functionCall?.name ?? ''), + argumentsRaw: (functionCall?.argumentsRaw ?? '') + + (other.functionCall?.argumentsRaw ?? ''), + arguments: { + ...?functionCall?.arguments, + ...?other.functionCall?.arguments, + }, + ) + : null, + example: example, + ); + } + return AIChatMessage( + content: content + other.content, + example: example, + ); + } + @override String toString() { return ''' @@ -275,12 +310,16 @@ class AIChatMessageFunctionCall { /// {@macro ai_chat_message_function_call} const AIChatMessageFunctionCall({ required this.name, + required this.argumentsRaw, required this.arguments, }); /// The name of the function that the model wants to call. final String name; + /// The raw arguments JSON string (needed to parse streaming responses). + final String argumentsRaw; + /// The arguments that the model wants to pass to the function. final Map arguments; @@ -296,17 +335,21 @@ class AIChatMessageFunctionCall { bool operator ==(covariant final AIChatMessageFunctionCall other) { final mapEquals = const MapEquality().equals; return identical(this, other) || - name == other.name && mapEquals(arguments, other.arguments); + name == other.name && + argumentsRaw == other.argumentsRaw && + mapEquals(arguments, other.arguments); } @override - int get hashCode => name.hashCode ^ arguments.hashCode; + int get hashCode => + name.hashCode ^ argumentsRaw.hashCode ^ arguments.hashCode; @override String toString() { return ''' AIChatMessageFunctionCall{ name: $name, + argumentsRaw: $argumentsRaw, arguments: $arguments, } '''; @@ -348,6 +391,14 @@ class FunctionChatMessage extends ChatMessage { @override int get hashCode => content.hashCode; + @override + FunctionChatMessage concat(final ChatMessage other) { + return FunctionChatMessage( + content: content + other.content, + name: name + (other is FunctionChatMessage ? other.name : ''), + ); + } + @override String toString() { return ''' @@ -390,6 +441,14 @@ class CustomChatMessage extends ChatMessage { @override int get hashCode => content.hashCode ^ role.hashCode; + @override + CustomChatMessage concat(final ChatMessage other) { + return CustomChatMessage( + role: role, + content: content + other.content, + ); + } + @override String toString() { return ''' diff --git a/packages/langchain/lib/src/model_io/language_models/models/models.dart b/packages/langchain/lib/src/model_io/language_models/models/models.dart index 30531656..d1cb4209 100644 --- a/packages/langchain/lib/src/model_io/language_models/models/models.dart +++ b/packages/langchain/lib/src/model_io/language_models/models/models.dart @@ -1,3 +1,4 @@ +import 'package:collection/collection.dart'; import 'package:meta/meta.dart'; import '../../../core/core.dart'; @@ -15,14 +16,19 @@ abstract class LanguageModelOptions extends BaseLangChainOptions { /// Result returned by the model. /// {@endtemplate} @immutable -abstract class LanguageModelResult { +class LanguageModelResult { /// {@macro language_model} const LanguageModelResult({ + this.id, required this.generations, this.usage, this.modelOutput, + this.streaming = false, }); + /// Result id. + final String? id; + /// Generated outputs. final List> generations; @@ -32,10 +38,67 @@ abstract class LanguageModelResult { /// For arbitrary model provider specific output. final Map? modelOutput; + /// Whether the result of the language model is being streamed. + final bool streaming; + /// Returns the first output as a string. String get firstOutputAsString { return generations.firstOrNull?.outputAsString ?? ''; } + + @override + bool operator ==(covariant final LanguageModelResult other) => + identical(this, other) || + runtimeType == other.runtimeType && + id == other.id && + ListEquality>().equals( + generations, + other.generations, + ) && + usage == other.usage && + const MapEquality().equals( + modelOutput, + other.modelOutput, + ) && + streaming == other.streaming; + + @override + int get hashCode => + id.hashCode ^ + ListEquality>().hash(generations) ^ + usage.hashCode ^ + const MapEquality().hash(modelOutput) ^ + streaming.hashCode; + + /// Merges this result with another by concatenating the outputs. + LanguageModelResult concat(final LanguageModelResult other) { + return LanguageModelResult( + id: id, + generations: generations.mapIndexed( + (final index, final generation) { + return generation.concat(other.generations[index]); + }, + ).toList(growable: false), + usage: usage, + modelOutput: { + ...?modelOutput, + ...?other.modelOutput, + }, + streaming: streaming, + ); + } + + @override + String toString() { + return ''' +LanguageModelResult{ + id: $id, + generations: $generations, + usage: $usage, + modelOutput: $modelOutput, + streaming: $streaming +}'''; + } } /// {@template language_model_usage} @@ -72,6 +135,24 @@ class LanguageModelUsage { /// The total number of tokens in the prompt and completion. final int? totalTokens; + @override + bool operator ==(covariant final LanguageModelUsage other) => + identical(this, other) || + runtimeType == other.runtimeType && + promptTokens == other.promptTokens && + promptBillableCharacters == other.promptBillableCharacters && + responseTokens == other.responseTokens && + responseBillableCharacters == other.responseBillableCharacters && + totalTokens == other.totalTokens; + + @override + int get hashCode => + promptTokens.hashCode ^ + promptBillableCharacters.hashCode ^ + responseTokens.hashCode ^ + responseBillableCharacters.hashCode ^ + totalTokens.hashCode; + @override String toString() { return ''' @@ -105,4 +186,31 @@ abstract class LanguageModelGeneration { /// Returns the output as string. String get outputAsString; + + @override + bool operator ==(covariant final LanguageModelGeneration other) => + identical(this, other) || + runtimeType == other.runtimeType && + output == other.output && + const MapEquality().equals( + generationInfo, + other.generationInfo, + ); + + @override + int get hashCode => + output.hashCode ^ + const MapEquality().hash(generationInfo); + + /// Merges this generation with another by concatenating the outputs. + LanguageModelGeneration concat(final LanguageModelGeneration other); + + @override + String toString() { + return ''' +LanguageModelGeneration{ + output: $output, + generationInfo: $generationInfo +}'''; + } } diff --git a/packages/langchain/lib/src/model_io/llms/fake.dart b/packages/langchain/lib/src/model_io/llms/fake.dart index f92f0c80..6a3ddf06 100644 --- a/packages/langchain/lib/src/model_io/llms/fake.dart +++ b/packages/langchain/lib/src/model_io/llms/fake.dart @@ -40,7 +40,7 @@ class FakeListLLM extends SimpleLLM { /// {@template fake_echo_llm} /// Fake LLM for testing. -/// It just returns the prompt. +/// It just returns the prompt or streams it char by char. /// {@endtemplate} class FakeEchoLLM extends SimpleLLM { /// {@macro fake_echo_llm} @@ -57,6 +57,24 @@ class FakeEchoLLM extends SimpleLLM { return Future.value(prompt); } + @override + Stream streamFromInputStream( + final Stream inputStream, { + final LLMOptions? options, + }) { + return inputStream.asyncExpand( + (final prompt) { + final promptChars = prompt.toString().split(''); + return Stream.fromIterable(promptChars).map( + (final item) => LLMResult( + generations: [LLMGeneration(item)], + streaming: true, + ), + ); + }, + ); + } + @override Future> tokenize(final PromptValue promptValue) async { return promptValue diff --git a/packages/langchain/lib/src/model_io/llms/models/models.dart b/packages/langchain/lib/src/model_io/llms/models/models.dart index ac5db587..c4c9e95d 100644 --- a/packages/langchain/lib/src/model_io/llms/models/models.dart +++ b/packages/langchain/lib/src/model_io/llms/models/models.dart @@ -10,28 +10,8 @@ class LLMOptions extends LanguageModelOptions { const LLMOptions(); } -/// {@template llm_result} /// Class that contains all relevant information for an LLM Result. -/// {@endtemplate} -@immutable -class LLMResult extends LanguageModelResult { - /// {@macro llm_result} - const LLMResult({ - required super.generations, - super.usage, - super.modelOutput, - }); - - @override - String toString() { - return ''' -LLMResult{ - generations: $generations, - usage: $usage, - modelOutput: $modelOutput}, -'''; - } -} +typedef LLMResult = LanguageModelResult; /// {@template llm_generation} /// Output of a single generation. @@ -47,6 +27,19 @@ class LLMGeneration extends LanguageModelGeneration { @override String get outputAsString => output; + @override + LanguageModelGeneration concat( + final LanguageModelGeneration other, + ) { + return LLMGeneration( + output + other.output, + generationInfo: { + ...?generationInfo, + ...?other.generationInfo, + }, + ); + } + @override String toString() { return ''' diff --git a/packages/langchain/lib/src/model_io/prompts/models/models.dart b/packages/langchain/lib/src/model_io/prompts/models/models.dart index 222e92ee..89b46732 100644 --- a/packages/langchain/lib/src/model_io/prompts/models/models.dart +++ b/packages/langchain/lib/src/model_io/prompts/models/models.dart @@ -1,3 +1,4 @@ +import 'package:collection/collection.dart'; import 'package:meta/meta.dart'; import '../../../utils/exception.dart'; @@ -57,6 +58,14 @@ class StringPromptValue implements PromptValue { List toChatMessages() { return [ChatMessage.human(value)]; } + + @override + bool operator ==(covariant final StringPromptValue other) => + identical(this, other) || + runtimeType == other.runtimeType && value == other.value; + + @override + int get hashCode => value.hashCode; } /// {@template chat_prompt_value} @@ -91,6 +100,16 @@ class ChatPromptValue implements PromptValue { List toChatMessages() { return messages; } + + @override + bool operator ==(covariant final ChatPromptValue other) { + return identical(this, other) || + runtimeType == other.runtimeType && + const ListEquality().equals(messages, other.messages); + } + + @override + int get hashCode => const ListEquality().hash(messages); } /// Input values used to format a prompt. diff --git a/packages/langchain/pubspec.yaml b/packages/langchain/pubspec.yaml index ff6e1078..1d4cc40c 100644 --- a/packages/langchain/pubspec.yaml +++ b/packages/langchain/pubspec.yaml @@ -16,6 +16,7 @@ environment: sdk: ">=3.0.0 <4.0.0" dependencies: + async: ^2.11.0 beautiful_soup_dart: ^0.3.0 characters: ^1.3.0 collection: ^1.17.1 diff --git a/packages/langchain/test/core/runnable/binding_test.dart b/packages/langchain/test/core/runnable/binding_test.dart index 9111a176..b51b5080 100644 --- a/packages/langchain/test/core/runnable/binding_test.dart +++ b/packages/langchain/test/core/runnable/binding_test.dart @@ -7,7 +7,7 @@ void main() { test('RunnableBinding from Runnable.bind', () async { final prompt = PromptTemplate.fromTemplate('Hello {input}'); const model = _FakeOptionsChatModel(); - const outputParser = StringOutputParser(); + const outputParser = StringOutputParser(); final chain = prompt | model.bind(const _FakeOptionsChatModelOptions('world')) | outputParser; @@ -15,6 +15,24 @@ void main() { final res = await chain.invoke({'input': 'world'}); expect(res, 'Hello '); }); + + test('Streaming RunnableBinding', () async { + final prompt = PromptTemplate.fromTemplate('Hello {input}'); + const model = _FakeOptionsChatModel(); + const outputParser = StringOutputParser(); + + final chain = prompt + .pipe(model.bind(const _FakeOptionsChatModelOptions('world'))) + .pipe(outputParser); + final stream = chain.stream({'input': 'world'}); + + final streamList = await stream.toList(); + expect(streamList.length, 6); + expect(streamList, isA>()); + + final output = streamList.join(); + expect(output, 'Hello '); + }); }); } @@ -35,6 +53,28 @@ class _FakeOptionsChatModel ); } + @override + Stream streamFromInputStream( + final Stream inputStream, { + final _FakeOptionsChatModelOptions? options, + }) { + return inputStream.asyncExpand( + (final input) { + final prompt = input + .toChatMessages() + .first + .content + .replaceAll(options?.stop ?? '', '') + .split(''); + return Stream.fromIterable(prompt).map( + (final char) => ChatResult( + generations: [ChatGeneration(ChatMessage.ai(char))], + ), + ); + }, + ); + } + @override Future> tokenize(final PromptValue promptValue) async { return promptValue diff --git a/packages/langchain/test/core/runnable/function_test.dart b/packages/langchain/test/core/runnable/function_test.dart index f75109b7..79d29fcf 100644 --- a/packages/langchain/test/core/runnable/function_test.dart +++ b/packages/langchain/test/core/runnable/function_test.dart @@ -18,5 +18,19 @@ void main() { final res = await chain.invoke({'input': 'world'}); expect(res, 12); }); + + test('Streaming RunnableFunction', () async { + final function = Runnable.fromFunction( + (final input, final options) => input.length, + ); + final stream = function.stream('world'); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final item = streamList.first; + expect(item, 5); + }); }); } diff --git a/packages/langchain/test/core/runnable/input_getter_test.dart b/packages/langchain/test/core/runnable/input_getter_test.dart index 4ab72a0c..4fa3bb55 100644 --- a/packages/langchain/test/core/runnable/input_getter_test.dart +++ b/packages/langchain/test/core/runnable/input_getter_test.dart @@ -10,14 +10,36 @@ void main() { final res = await chain.invoke({'foo': 'foo1', 'bar': 'bar1'}); expect(res, 'foo1'); }); - }); - group('RunnableMapFromItem tests', () { test('RunnableMapFromItem from Runnable.getMapFromItem', () async { final chain = Runnable.getMapFromItem('foo'); final res = await chain.invoke('foo1'); expect(res, {'foo': 'foo1'}); }); + + test('Streaming RunnableItemFromMap', () async { + final chain = Runnable.getItemFromMap('foo'); + final stream = chain.stream({'foo': 'foo1', 'bar': 'bar1'}); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final item = streamList.first; + expect(item, 'foo1'); + }); + + test('Streaming RunnableMapFromItem', () async { + final chain = Runnable.getMapFromItem('foo'); + final stream = chain.stream('foo1'); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA>()); + + final item = streamList.first; + expect(item, {'foo': 'foo1'}); + }); }); } diff --git a/packages/langchain/test/core/runnable/base_test.dart b/packages/langchain/test/core/runnable/invoke_test.dart similarity index 100% rename from packages/langchain/test/core/runnable/base_test.dart rename to packages/langchain/test/core/runnable/invoke_test.dart diff --git a/packages/langchain/test/core/runnable/map_test.dart b/packages/langchain/test/core/runnable/map_test.dart index 94f3226d..e7f04f63 100644 --- a/packages/langchain/test/core/runnable/map_test.dart +++ b/packages/langchain/test/core/runnable/map_test.dart @@ -1,4 +1,5 @@ // ignore_for_file: unused_element +import 'package:collection/collection.dart'; import 'package:langchain/langchain.dart'; import 'package:test/test.dart'; @@ -20,5 +21,33 @@ void main() { {'left': 'Hello world!', 'right': 'Bye world!'}, ); }); + + test('Streaming RunnableMap', () async { + final prompt1 = PromptTemplate.fromTemplate('Hello {input}!'); + final prompt2 = PromptTemplate.fromTemplate('Bye {input}!'); + const model = FakeEchoLLM(); + const outputParser = StringOutputParser(); + final chain = Runnable.fromMap({ + 'left': prompt1 | model | outputParser, + 'right': prompt2 | model | outputParser, + }); + final stream = chain.stream({'input': 'world'}); + + final streamList = await stream.toList(); + expect(streamList.length, 22); + expect(streamList, isA>>()); + + final left = streamList + .map((final it) => it['left']) // + .whereNotNull() + .join(); + final right = streamList + .map((final it) => it['right']) // + .whereNotNull() + .join(); + + expect(left, 'Hello world!'); + expect(right, 'Bye world!'); + }); }); } diff --git a/packages/langchain/test/core/runnable/passthrough_test.dart b/packages/langchain/test/core/runnable/passthrough_test.dart index b9710b10..6e5369e3 100644 --- a/packages/langchain/test/core/runnable/passthrough_test.dart +++ b/packages/langchain/test/core/runnable/passthrough_test.dart @@ -22,5 +22,17 @@ void main() { }, ); }); + + test('Streaming RunnablePassthrough', () async { + final passthrough = Runnable.passthrough(); + final stream = passthrough.stream('world'); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final item = streamList.first; + expect(item, 'world'); + }); }); } diff --git a/packages/langchain/test/core/runnable/sequence_test.dart b/packages/langchain/test/core/runnable/sequence_test.dart index eabbd19d..d4e8d01d 100644 --- a/packages/langchain/test/core/runnable/sequence_test.dart +++ b/packages/langchain/test/core/runnable/sequence_test.dart @@ -33,5 +33,20 @@ void main() { final res = await chain.invoke({'input': 'world'}); expect(res, 'Hello world!'); }); + + test('Streaming RunnableSequence', () async { + final prompt = PromptTemplate.fromTemplate('Hello {input}!'); + const model = FakeEchoLLM(); + const outputParser = StringOutputParser(); + final chain = prompt.pipe(model).pipe(outputParser); + final stream = chain.stream({'input': 'world'}); + + final streamList = await stream.toList(); + expect(streamList.length, 12); + expect(streamList, isA>()); + + final res = streamList.join(); + expect(res, 'Hello world!'); + }); }); } diff --git a/packages/langchain/test/core/runnable/stream_test.dart b/packages/langchain/test/core/runnable/stream_test.dart new file mode 100644 index 00000000..e75a6f14 --- /dev/null +++ b/packages/langchain/test/core/runnable/stream_test.dart @@ -0,0 +1,151 @@ +// ignore_for_file: unused_element +import 'package:langchain/langchain.dart'; +import 'package:test/test.dart'; + +void main() { + group('Runnable stream tests', () { + test('Test streaming PromptTemplate', () async { + final run = PromptTemplate.fromTemplate('This is a {input}'); + final stream = run.stream({'input': 'test'}); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final item = streamList.first; + expect(item.toString(), 'This is a test'); + }); + + test('Test streaming ChatPromptTemplate', () async { + final run = ChatPromptTemplate.fromPromptMessages([ + SystemChatMessagePromptTemplate.fromTemplate( + 'You are a helpful chatbot', + ), + HumanChatMessagePromptTemplate.fromTemplate('{input}'), + ]); + final stream = run.stream({'input': 'test'}); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final item = streamList.first; + expect( + item.toChatMessages(), + equals([ + ChatMessage.system('You are a helpful chatbot'), + ChatMessage.human('test'), + ]), + ); + }); + + test('Test streaming', () async { + const doc = Document( + id: '1', + pageContent: 'This is a test', + ); + const run = FakeRetriever([doc]); + final stream = run.stream('test'); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA>()); + + final item = streamList.first; + expect(item, [doc]); + }); + + test('Streaming DocumentTransformer', () async { + const run = CharacterTextSplitter( + separator: ' ', + chunkSize: 7, + chunkOverlap: 3, + ); + final stream = run.stream([ + const Document(pageContent: 'foo bar baz 123'), + ]); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA>()); + + final item = streamList.first; + expect( + item, + [ + const Document(pageContent: 'foo bar'), + const Document(pageContent: 'bar baz'), + const Document(pageContent: 'baz 123'), + ], + ); + }); + + test('Streaming LLM', () async { + const run = FakeEchoLLM(); + final stream = run.stream(PromptValue.string('Hello world!')); + + final streamList = await stream.toList(); + expect(streamList.length, 12); + expect(streamList, isA>()); + + final res = streamList.map((final i) => i.firstOutputAsString).join(); + + expect(res, 'Hello world!'); + }); + + test('Streaming ChatModel', () async { + const run = FakeEchoChatModel(); + final stream = run.stream(PromptValue.string('Hello world!')); + + final streamList = await stream.toList(); + expect(streamList.length, 12); + expect(streamList, isA>()); + + final res = streamList.map((final i) => i.firstOutputAsString).join(); + expect(res, 'Hello world!'); + }); + + + test('Streaming Chain', () async { + final model = FakeListLLM(responses: ['Hello world!']); + final prompt = PromptTemplate.fromTemplate('Print {foo}'); + final run = LLMChain(prompt: prompt, llm: model); + final stream = run.stream({'foo': 'Hello world!'}); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA>()); + + final res = streamList.first; + expect(res[LLMChain.defaultOutputKey], 'Hello world!'); + }); + + test('Streaming OutputParser', () async { + const run = StringOutputParser(); + final stream = run.stream( + const LLMResult( + generations: [LLMGeneration('Hello world!')], + ), + ); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final res = streamList.first; + expect(res, 'Hello world!'); + }); + + test('Streaming Tool', () async { + final run = CalculatorTool(); + final stream = run.stream({'input': '1+1'}); + + final streamList = await stream.toList(); + expect(streamList.length, 1); + expect(streamList.first, isA()); + + final res = streamList.first; + expect(res, '2.0'); + }); + }); +} diff --git a/packages/langchain/test/model_io/output_parsers/functions_test.dart b/packages/langchain/test/model_io/output_parsers/functions_test.dart index 1b8f75c6..2bcfd86c 100644 --- a/packages/langchain/test/model_io/output_parsers/functions_test.dart +++ b/packages/langchain/test/model_io/output_parsers/functions_test.dart @@ -10,6 +10,7 @@ void main() { '', functionCall: const AIChatMessageFunctionCall( name: 'test', + argumentsRaw: '{"foo":"bar","bar":"foo"}', arguments: { 'foo': 'bar', 'bar': 'foo',