Skip to content

Commit

Permalink
feat: Add support for Structured Outputs in ChatOpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz committed Aug 17, 2024
1 parent c757407 commit c15448b
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 55 deletions.
4 changes: 1 addition & 3 deletions docs/expression_language/primitives/mapper.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ In the following example, the model streams the output in chunks and the output
final model = ChatOpenAI(
apiKey: openAiApiKey,
defaultOptions: ChatOpenAIOptions(
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand Down
4 changes: 1 addition & 3 deletions docs/expression_language/streaming.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ Let’s see such a parser in action to understand what this means.
final model = ChatOpenAI(
apiKey: openAiApiKey,
defaultOptions: const ChatOpenAIOptions(
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ final llm = ChatOpenAI(
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4-turbo',
temperature: 0,
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final chain = llm.pipe(JsonOutputParser());
Expand Down
4 changes: 1 addition & 3 deletions docs/modules/model_io/output_parsers/json.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ final model = ChatOpenAI(
apiKey: openAiApiKey,
defaultOptions: ChatOpenAIOptions(
model: 'gpt-4-turbo',
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ Future<void> _inputStreams() async {
final model = ChatOpenAI(
apiKey: openAiApiKey,
defaultOptions: const ChatOpenAIOptions(
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand Down Expand Up @@ -125,9 +123,7 @@ Future<void> _inputStreamMapper() async {
final model = ChatOpenAI(
apiKey: openAiApiKey,
defaultOptions: const ChatOpenAIOptions(
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ Future<void> _mapInputStream() async {
final model = ChatOpenAI(
apiKey: openAiApiKey,
defaultOptions: const ChatOpenAIOptions(
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ Future<void> _chatOpenAIJsonMode() async {
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4-turbo',
temperature: 0,
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final chain = llm.pipe(JsonOutputParser());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ Future<void> _invoke() async {
apiKey: openAiApiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4-turbo',
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);
final parser = JsonOutputParser<ChatResult>();
Expand All @@ -51,9 +49,7 @@ Future<void> _streaming() async {
apiKey: openAiApiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4-turbo',
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
),
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);

Expand Down
22 changes: 13 additions & 9 deletions packages/langchain_openai/lib/src/chat_models/mappers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,19 @@ extension CreateChatCompletionStreamResponseMapper
}

extension ChatOpenAIResponseFormatMapper on ChatOpenAIResponseFormat {
ChatCompletionResponseFormat toChatCompletionResponseFormat() {
return ChatCompletionResponseFormat(
type: switch (type) {
ChatOpenAIResponseFormatType.text =>
ChatCompletionResponseFormatType.text,
ChatOpenAIResponseFormatType.jsonObject =>
ChatCompletionResponseFormatType.jsonObject,
},
);
ResponseFormat toChatCompletionResponseFormat() {
return switch (this) {
ChatOpenAIResponseFormatText() => const ResponseFormat.text(),
ChatOpenAIResponseFormatJsonObject() => const ResponseFormat.jsonObject(),
final ChatOpenAIResponseFormatJsonSchema res => ResponseFormat.jsonSchema(
jsonSchema: JsonSchemaObject(
name: res.jsonSchema.name,
description: res.jsonSchema.description,
schema: res.jsonSchema.schema,
strict: res.jsonSchema.strict,
),
),
};
}
}

Expand Down
113 changes: 100 additions & 13 deletions packages/langchain_openai/lib/src/chat_models/types.dart
Original file line number Diff line number Diff line change
Expand Up @@ -261,24 +261,111 @@ class ChatOpenAIOptions extends ChatModelOptions {
/// {@template chat_openai_response_format}
/// An object specifying the format that the model must output.
/// {@endtemplate}
class ChatOpenAIResponseFormat {
/// {@macro chat_openai_response_format}
const ChatOpenAIResponseFormat({
required this.type,
sealed class ChatOpenAIResponseFormat {
const ChatOpenAIResponseFormat();

/// The model will respond with text.
static const text = ChatOpenAIResponseFormatText();

/// The model will respond with a valid JSON object.
static const jsonObject = ChatOpenAIResponseFormatJsonObject();

/// The model will respond with a valid JSON object that adheres to the
/// specified schema.
factory ChatOpenAIResponseFormat.jsonSchema(
final ChatOpenAIJsonSchema jsonSchema,
) =>
ChatOpenAIResponseFormatJsonSchema(jsonSchema: jsonSchema);
}

/// {@template chat_openai_response_format_text}
/// The model will respond with text.
/// {@endtemplate}
class ChatOpenAIResponseFormatText extends ChatOpenAIResponseFormat {
/// {@macro chat_openai_response_format_text}
const ChatOpenAIResponseFormatText();
}

/// {@template chat_openai_response_format_json_object}
/// The model will respond with a valid JSON object.
/// {@endtemplate}
class ChatOpenAIResponseFormatJsonObject extends ChatOpenAIResponseFormat {
/// {@macro chat_openai_response_format_json_object}
const ChatOpenAIResponseFormatJsonObject();
}

/// {@template chat_openai_response_format_json_schema}
/// The model will respond with a valid JSON object that adheres to the
/// specified schema.
/// {@endtemplate}
@immutable
class ChatOpenAIResponseFormatJsonSchema extends ChatOpenAIResponseFormat {
/// {@macro chat_openai_response_format_json_schema}
const ChatOpenAIResponseFormatJsonSchema({
required this.jsonSchema,
});

/// The format type.
final ChatOpenAIResponseFormatType type;
/// The JSON schema that the model must adhere to.
final ChatOpenAIJsonSchema jsonSchema;

@override
bool operator ==(covariant ChatOpenAIResponseFormatJsonSchema other) {
return identical(this, other) ||
runtimeType == other.runtimeType && jsonSchema == other.jsonSchema;
}

@override
int get hashCode => jsonSchema.hashCode;
}

/// Types of response formats.
enum ChatOpenAIResponseFormatType {
/// Standard text mode.
text,
/// {@template chat_openai_json_schema}
/// Specifies the schema for the response format.
/// {@endtemplate}
@immutable
class ChatOpenAIJsonSchema {
/// {@macro chat_openai_json_schema}
const ChatOpenAIJsonSchema({
required this.name,
required this.schema,
this.description,
this.strict = false,
});

/// The name of the response format. Must be a-z, A-Z, 0-9, or contain
/// underscores and dashes, with a maximum length of 64.
final String name;

/// A description of what the response format is for, used by the model to
/// determine how to respond in the format.
final String? description;

/// The schema for the response format, described as a JSON Schema object.
final Map<String, dynamic> schema;

/// Whether to enable strict schema adherence when generating the output.
/// If set to true, the model will always follow the exact schema defined in
/// the `schema` field. Only a subset of JSON Schema is supported when
/// `strict` is `true`. To learn more, read the
/// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
final bool strict;

/// [ChatOpenAIResponseFormatType.jsonObject] enables JSON mode, which
/// guarantees the message the model generates is valid JSON.
jsonObject,
@override
bool operator ==(covariant ChatOpenAIJsonSchema other) {
return identical(this, other) ||
runtimeType == other.runtimeType &&
name == other.name &&
description == other.description &&
const MapEquality<String, dynamic>().equals(schema, other.schema) &&
strict == other.strict;
}

@override
int get hashCode {
return name.hashCode ^
description.hashCode ^
const MapEquality<String, dynamic>().hash(schema) ^
strict.hashCode;
}
}

/// Specifies the latency tier to use for processing the request.
Expand Down
63 changes: 60 additions & 3 deletions packages/langchain_openai/test/chat_models/chat_openai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,68 @@ void main() {
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: const ChatOpenAIOptions(
model: 'gpt-4-1106-preview',
model: defaultModel,
temperature: 0,
seed: 9999,
responseFormat: ChatOpenAIResponseFormat.jsonObject,
),
);

final res = await llm.invoke(prompt);
final outputMsg = res.output;
final outputJson = json.decode(outputMsg.content) as Map<String, dynamic>;
expect(outputJson['companies'], isNotNull);
final companies = outputJson['companies'] as List<dynamic>;
expect(companies, hasLength(2));
final firstCompany = companies.first as Map<String, dynamic>;
expect(firstCompany['name'], 'Google');
expect(firstCompany['origin'], 'USA');
final secondCompany = companies.last as Map<String, dynamic>;
expect(secondCompany['name'], 'Deepmind');
expect(secondCompany['origin'], 'UK');
});

test('Test Structured Output', () async {
final prompt = PromptValue.chat([
ChatMessage.system(
'Extract the data of any companies mentioned in the '
'following statement. Return a JSON list.',
),
ChatMessage.humanText(
'Google was founded in the USA, while Deepmind was founded in the UK',
),
]);
final llm = ChatOpenAI(
apiKey: openaiApiKey,
defaultOptions: ChatOpenAIOptions(
model: defaultModel,
temperature: 0,
seed: 9999,
responseFormat: ChatOpenAIResponseFormat(
type: ChatOpenAIResponseFormatType.jsonObject,
responseFormat: ChatOpenAIResponseFormat.jsonSchema(
const ChatOpenAIJsonSchema(
name: 'Companies',
description: 'A list of companies',
strict: true,
schema: {
'type': 'object',
'properties': {
'companies': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'name': {'type': 'string'},
'origin': {'type': 'string'},
},
'additionalProperties': false,
'required': ['name', 'origin'],
},
},
},
'additionalProperties': false,
'required': ['companies'],
},
),
),
),
);
Expand Down

0 comments on commit c15448b

Please sign in to comment.