Skip to content

Commit

Permalink
feat(lcel): Add streaming support in LangChain Expression Language (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored and KennethKnudsen97 committed Apr 22, 2024
1 parent 4318631 commit 6802cb6
Show file tree
Hide file tree
Showing 21 changed files with 631 additions and 57 deletions.
38 changes: 36 additions & 2 deletions packages/langchain/lib/src/core/runnable/base.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import 'dart:async';

import 'package:meta/meta.dart';

import '../base.dart';
import 'binding.dart';
import 'function.dart';
Expand Down Expand Up @@ -101,11 +103,43 @@ abstract class Runnable<RunInput extends Object,
final CallOptions? options,
});

/// Streams the output of invoking the [Runnable] on the given [input].
///
/// - [input] - the input to invoke the [Runnable] on.
/// - [options] - the options to use when invoking the [Runnable].
Stream<RunOutput> stream(
final RunInput input, {
final CallOptions? options,
}) {
return streamFromInputStream(
Stream.value(input).asBroadcastStream(),
options: options,
);
}

/// Streams the output of invoking the [Runnable] on the given [inputStream].
///
/// - [inputStream] - the input stream to invoke the [Runnable] on.
/// - [options] - the options to use when invoking the [Runnable].
@protected
Stream<RunOutput> streamFromInputStream(
final Stream<RunInput> inputStream, {
final CallOptions? options,
}) {
// By default, it just emits the result of calling invoke
// Subclasses should override this method if they support streaming output
return inputStream.asyncMap(
// ignore: discarded_futures
(final input) => invoke(input, options: options),
);
}

/// Pipes the output of this [Runnable] into another [Runnable].
///
/// - [next] - the [Runnable] to pipe the output into.
RunnableSequence<RunInput, NewRunOutput> pipe<NewRunOutput extends Object>(
final Runnable<RunOutput, CallOptions, NewRunOutput> next,
RunnableSequence<RunInput, NewRunOutput> pipe<NewRunOutput extends Object,
NewCallOptions extends BaseLangChainOptions>(
final Runnable<RunOutput, NewCallOptions, NewRunOutput> next,
) {
return RunnableSequence<RunInput, NewRunOutput>(
first: this,
Expand Down
8 changes: 8 additions & 0 deletions packages/langchain/lib/src/core/runnable/binding.dart
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,12 @@ class RunnableBinding<RunInput extends Object,
}) async {
return bound.invoke(input, options: options ?? this.options);
}

@override
Stream<RunOutput> streamFromInputStream(
final Stream<RunInput> inputStream, {
final CallOptions? options,
}) {
return bound.streamFromInputStream(inputStream, options: options ?? this.options);
}
}
2 changes: 1 addition & 1 deletion packages/langchain/lib/src/core/runnable/extensions.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ extension RunnableX<
///
/// - [next] - the [Runnable] to pipe the output into.
RunnableSequence<RunInput, NewRunOutput> operator |(
final Runnable<RunOutput, CallOptions, NewRunOutput> next,
final Runnable<RunOutput, BaseLangChainOptions, NewRunOutput> next,
) {
return pipe(next);
}
Expand Down
16 changes: 16 additions & 0 deletions packages/langchain/lib/src/core/runnable/map.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import 'package:async/async.dart' show StreamGroup;

import '../base.dart';
import 'base.dart';

Expand Down Expand Up @@ -64,4 +66,18 @@ class RunnableMap<RunInput extends Object>

return output;
}

@override
Stream<Map<String, dynamic>> streamFromInputStream(
final Stream<RunInput> inputStream, {
final BaseLangChainOptions? options,
}) {
return StreamGroup.merge(
steps.entries.map((final entry) {
return entry.value.streamFromInputStream(inputStream, options: options).map(
(final output) => {entry.key: output},
);
}),
);
}
}
23 changes: 19 additions & 4 deletions packages/langchain/lib/src/core/runnable/sequence.dart
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,33 @@ class RunnableSequence<RunInput extends Object, RunOutput extends Object>
return last.invoke(nextStepInput, options: options);
}

@override
Stream<RunOutput> streamFromInputStream(
final Stream<RunInput> inputStream, {
final BaseLangChainOptions? options,
}) {
var nextStepStream = first.streamFromInputStream(inputStream);

for (final step in middle) {
nextStepStream = step.streamFromInputStream(nextStepStream);
}

return last.streamFromInputStream(nextStepStream);
}

/// Pipes the output of this [RunnableSequence] into another [Runnable].
///
/// - [next] - the [Runnable] to pipe the output into.
@override
RunnableSequence<RunInput, NewRunOutput> pipe<NewRunOutput extends Object>(
final Runnable<RunOutput, BaseLangChainOptions, NewRunOutput> next,
RunnableSequence<RunInput, NewRunOutput> pipe<NewRunOutput extends Object, NewCallOptions extends BaseLangChainOptions>(
final Runnable<RunOutput, NewCallOptions, NewRunOutput> next,
) {
if (next is RunnableSequence<RunOutput, NewRunOutput>) {
final nextSeq = next as RunnableSequence<RunOutput, NewRunOutput>;
return RunnableSequence(
first: first,
middle: [...middle, last, next.first, ...next.middle],
last: next.last,
middle: [...middle, last, nextSeq.first, ...nextSeq.middle],
last: nextSeq.last,
);
} else {
return RunnableSequence(
Expand Down
21 changes: 20 additions & 1 deletion packages/langchain/lib/src/model_io/chat_models/fake.dart
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class FakeChatModel extends SimpleChatModel {

/// {@template fake_echo_llm}
/// Fake Chat Model for testing.
/// It just returns the content of the last message of the prompt.
/// It just returns the content of the last message of the prompt
/// or streams it char by char.
/// {@endtemplate}
class FakeEchoChatModel extends SimpleChatModel {
/// {@macro fake_echo_llm}
Expand All @@ -57,6 +58,24 @@ class FakeEchoChatModel extends SimpleChatModel {
return Future<String>.value(messages.last.content);
}

@override
Stream<ChatResult> streamFromInputStream(
final Stream<PromptValue> inputStream, {
final ChatModelOptions? options,
}) {
return inputStream.asyncExpand(
(final input) {
final prompt = input.toChatMessages().first.content.split('');
return Stream.fromIterable(prompt).map(
(final char) => ChatResult(
generations: [ChatGeneration(ChatMessage.ai(char))],
streaming: true,
),
);
},
);
}

@override
Future<List<int>> tokenize(final PromptValue promptValue) async {
return promptValue
Expand Down
105 changes: 82 additions & 23 deletions packages/langchain/lib/src/model_io/chat_models/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,8 @@ class ChatModelOptions extends LanguageModelOptions {
const ChatModelOptions();
}

/// {@template chat_result}
/// Result returned by the Chat Model.
/// {@endtemplate}
@immutable
class ChatResult extends LanguageModelResult<ChatMessage> {
/// {@macro chat_result}
const ChatResult({
required super.generations,
super.usage,
super.modelOutput,
});

@override
String toString() {
return '''
ChatResult{
generations: $generations,
usage: $usage,
modelOutput: $modelOutput},
''';
}
}
typedef ChatResult = LanguageModelResult<ChatMessage>;

/// {@template chat_generation}
/// Output of a single generation.
Expand All @@ -48,6 +28,19 @@ class ChatGeneration extends LanguageModelGeneration<ChatMessage> {
@override
String get outputAsString => output.content;

@override
LanguageModelGeneration<ChatMessage> concat(
final LanguageModelGeneration<ChatMessage> other,
) {
return ChatGeneration(
output.concat(other.output),
generationInfo: {
...?generationInfo,
...?other.generationInfo,
},
);
}

@override
String toString() {
return '''
Expand Down Expand Up @@ -118,6 +111,9 @@ sealed class ChatMessage {

/// The content of the message.
final String content;

/// Merges this message with another by concatenating the content.
ChatMessage concat(final ChatMessage other);
}

/// {@template system_chat_message}
Expand All @@ -140,6 +136,11 @@ class SystemChatMessage extends ChatMessage {
@override
int get hashCode => content.hashCode;

@override
SystemChatMessage concat(final ChatMessage other) {
return SystemChatMessage(content: content + other.content);
}

@override
String toString() {
return '''
Expand Down Expand Up @@ -184,6 +185,14 @@ class HumanChatMessage extends ChatMessage {
@override
int get hashCode => content.hashCode ^ example.hashCode;

@override
HumanChatMessage concat(final ChatMessage other) {
return HumanChatMessage(
content: content + other.content,
example: example,
);
}

@override
String toString() {
return '''
Expand Down Expand Up @@ -237,6 +246,32 @@ class AIChatMessage extends ChatMessage {
@override
int get hashCode => content.hashCode ^ example.hashCode;

@override
AIChatMessage concat(final ChatMessage other) {
if (other is AIChatMessage) {
return AIChatMessage(
content: content + other.content,
functionCall: functionCall != null || other.functionCall != null
? AIChatMessageFunctionCall(
name: (functionCall?.name ?? '') +
(other.functionCall?.name ?? ''),
argumentsRaw: (functionCall?.argumentsRaw ?? '') +
(other.functionCall?.argumentsRaw ?? ''),
arguments: {
...?functionCall?.arguments,
...?other.functionCall?.arguments,
},
)
: null,
example: example,
);
}
return AIChatMessage(
content: content + other.content,
example: example,
);
}

@override
String toString() {
return '''
Expand Down Expand Up @@ -275,12 +310,16 @@ class AIChatMessageFunctionCall {
/// {@macro ai_chat_message_function_call}
const AIChatMessageFunctionCall({
required this.name,
required this.argumentsRaw,
required this.arguments,
});

/// The name of the function that the model wants to call.
final String name;

/// The raw arguments JSON string (needed to parse streaming responses).
final String argumentsRaw;

/// The arguments that the model wants to pass to the function.
final Map<String, dynamic> arguments;

Expand All @@ -296,17 +335,21 @@ class AIChatMessageFunctionCall {
bool operator ==(covariant final AIChatMessageFunctionCall other) {
final mapEquals = const MapEquality<String, dynamic>().equals;
return identical(this, other) ||
name == other.name && mapEquals(arguments, other.arguments);
name == other.name &&
argumentsRaw == other.argumentsRaw &&
mapEquals(arguments, other.arguments);
}

@override
int get hashCode => name.hashCode ^ arguments.hashCode;
int get hashCode =>
name.hashCode ^ argumentsRaw.hashCode ^ arguments.hashCode;

@override
String toString() {
return '''
AIChatMessageFunctionCall{
name: $name,
argumentsRaw: $argumentsRaw,
arguments: $arguments,
}
''';
Expand Down Expand Up @@ -348,6 +391,14 @@ class FunctionChatMessage extends ChatMessage {
@override
int get hashCode => content.hashCode;

@override
FunctionChatMessage concat(final ChatMessage other) {
return FunctionChatMessage(
content: content + other.content,
name: name + (other is FunctionChatMessage ? other.name : ''),
);
}

@override
String toString() {
return '''
Expand Down Expand Up @@ -390,6 +441,14 @@ class CustomChatMessage extends ChatMessage {
@override
int get hashCode => content.hashCode ^ role.hashCode;

@override
CustomChatMessage concat(final ChatMessage other) {
return CustomChatMessage(
role: role,
content: content + other.content,
);
}

@override
String toString() {
return '''
Expand Down
Loading

0 comments on commit 6802cb6

Please sign in to comment.