Skip to content

Commit

Permalink
feat: Add temperature, top_p and response format to Assistants API (#384
Browse files Browse the repository at this point in the history
)
  • Loading branch information
davidmigloz committed Apr 17, 2024
1 parent 8b9979e commit 1d18290
Show file tree
Hide file tree
Showing 11 changed files with 2,236 additions and 307 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ class CreateAssistantRequest with _$CreateAssistantRequest {

/// Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.
@JsonKey(includeIfNull: false) Map<String, dynamic>? metadata,

/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
@JsonKey(includeIfNull: false) @Default(1.0) double? temperature,

/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or temperature but not both.
@JsonKey(name: 'top_p', includeIfNull: false) @Default(1.0) double? topP,

/// Specifies the format that the model must output. Compatible with [GPT-4 Turbo](https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.
///
/// Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.
///
/// **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.
@_CreateAssistantRequestResponseFormatConverter()
@JsonKey(name: 'response_format', includeIfNull: false)
CreateAssistantRequestResponseFormat? responseFormat,
}) = _CreateAssistantRequest;

/// Object construction from a JSON representation
Expand All @@ -49,13 +66,22 @@ class CreateAssistantRequest with _$CreateAssistantRequest {
'instructions',
'tools',
'file_ids',
'metadata'
'metadata',
'temperature',
'top_p',
'response_format'
];

/// Validation constants
static const nameMaxLengthValue = 256;
static const descriptionMaxLengthValue = 512;
static const instructionsMaxLengthValue = 256000;
static const temperatureDefaultValue = 1.0;
static const temperatureMinValue = 0.0;
static const temperatureMaxValue = 2.0;
static const topPDefaultValue = 1.0;
static const topPMinValue = 0.0;
static const topPMaxValue = 1.0;

/// Perform validations on the schema property values
String? validateSchema() {
Expand All @@ -70,6 +96,18 @@ class CreateAssistantRequest with _$CreateAssistantRequest {
instructions!.length > instructionsMaxLengthValue) {
return "The length of 'instructions' cannot be > $instructionsMaxLengthValue characters";
}
if (temperature != null && temperature! < temperatureMinValue) {
return "The value of 'temperature' cannot be < $temperatureMinValue";
}
if (temperature != null && temperature! > temperatureMaxValue) {
return "The value of 'temperature' cannot be > $temperatureMaxValue";
}
if (topP != null && topP! < topPMinValue) {
return "The value of 'topP' cannot be < $topPMinValue";
}
if (topP != null && topP! > topPMaxValue) {
return "The value of 'topP' cannot be > $topPMaxValue";
}
return null;
}

Expand All @@ -83,6 +121,9 @@ class CreateAssistantRequest with _$CreateAssistantRequest {
'tools': tools,
'file_ids': fileIds,
'metadata': metadata,
'temperature': temperature,
'top_p': topP,
'response_format': responseFormat,
};
}
}
Expand Down Expand Up @@ -185,3 +226,91 @@ class _AssistantModelConverter
};
}
}

// ==========================================
// ENUM: CreateAssistantResponseFormatMode
// ==========================================

/// `auto` is the default value
enum CreateAssistantResponseFormatMode {
@JsonValue('none')
none,
@JsonValue('auto')
auto,
}

// ==========================================
// CLASS: CreateAssistantRequestResponseFormat
// ==========================================

/// Specifies the format that the model must output. Compatible with [GPT-4 Turbo](https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.
///
/// Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.
///
/// **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.
@freezed
sealed class CreateAssistantRequestResponseFormat
with _$CreateAssistantRequestResponseFormat {
const CreateAssistantRequestResponseFormat._();

/// `auto` is the default value
const factory CreateAssistantRequestResponseFormat.mode(
CreateAssistantResponseFormatMode value,
) = CreateAssistantRequestResponseFormatEnumeration;

/// No Description
const factory CreateAssistantRequestResponseFormat.format(
AssistantsResponseFormat value,
) = CreateAssistantRequestResponseFormatAssistantsResponseFormat;

/// Object construction from a JSON representation
factory CreateAssistantRequestResponseFormat.fromJson(
Map<String, dynamic> json) =>
_$CreateAssistantRequestResponseFormatFromJson(json);
}

/// Custom JSON converter for [CreateAssistantRequestResponseFormat]
class _CreateAssistantRequestResponseFormatConverter
implements JsonConverter<CreateAssistantRequestResponseFormat?, Object?> {
const _CreateAssistantRequestResponseFormatConverter();

@override
CreateAssistantRequestResponseFormat? fromJson(Object? data) {
if (data == null) {
return null;
}
if (data is String &&
_$CreateAssistantResponseFormatModeEnumMap.values.contains(data)) {
return CreateAssistantRequestResponseFormatEnumeration(
_$CreateAssistantResponseFormatModeEnumMap.keys.elementAt(
_$CreateAssistantResponseFormatModeEnumMap.values
.toList()
.indexOf(data),
),
);
}
if (data is Map<String, dynamic>) {
try {
return CreateAssistantRequestResponseFormatAssistantsResponseFormat(
AssistantsResponseFormat.fromJson(data),
);
} catch (e) {}
}
throw Exception(
'Unexpected value for CreateAssistantRequestResponseFormat: $data',
);
}

@override
Object? toJson(CreateAssistantRequestResponseFormat? data) {
return switch (data) {
CreateAssistantRequestResponseFormatEnumeration(value: final v) =>
_$CreateAssistantResponseFormatModeEnumMap[v]!,
CreateAssistantRequestResponseFormatAssistantsResponseFormat(
value: final v
) =>
v.toJson(),
null => null,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class CreateFineTuningJobRequest with _$CreateFineTuningJobRequest {

/// The ID of an uploaded file that contains training data.
///
/// See [upload file](https://platform.openai.com/docs/api-reference/files/upload) for how to upload a file.
/// See [upload file](https://platform.openai.com/docs/api-reference/files/create) for how to upload a file.
///
/// Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class CreateRunRequest with _$CreateRunRequest {
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
@JsonKey(includeIfNull: false) @Default(1.0) double? temperature,

/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or temperature but not both.
@JsonKey(name: 'top_p', includeIfNull: false) @Default(1.0) double? topP,

/// The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `complete`. See `incomplete_details` for more info.
@JsonKey(name: 'max_prompt_tokens', includeIfNull: false)
int? maxPromptTokens,
Expand Down Expand Up @@ -90,6 +95,7 @@ class CreateRunRequest with _$CreateRunRequest {
'tools',
'metadata',
'temperature',
'top_p',
'max_prompt_tokens',
'max_completion_tokens',
'truncation_strategy',
Expand All @@ -102,6 +108,9 @@ class CreateRunRequest with _$CreateRunRequest {
static const temperatureDefaultValue = 1.0;
static const temperatureMinValue = 0.0;
static const temperatureMaxValue = 2.0;
static const topPDefaultValue = 1.0;
static const topPMinValue = 0.0;
static const topPMaxValue = 1.0;
static const maxPromptTokensMinValue = 256;
static const maxCompletionTokensMinValue = 256;

Expand All @@ -113,6 +122,12 @@ class CreateRunRequest with _$CreateRunRequest {
if (temperature != null && temperature! > temperatureMaxValue) {
return "The value of 'temperature' cannot be > $temperatureMaxValue";
}
if (topP != null && topP! < topPMinValue) {
return "The value of 'topP' cannot be < $topPMinValue";
}
if (topP != null && topP! > topPMaxValue) {
return "The value of 'topP' cannot be > $topPMaxValue";
}
if (maxPromptTokens != null && maxPromptTokens! < maxPromptTokensMinValue) {
return "The value of 'maxPromptTokens' cannot be < $maxPromptTokensMinValue";
}
Expand All @@ -134,6 +149,7 @@ class CreateRunRequest with _$CreateRunRequest {
'tools': tools,
'metadata': metadata,
'temperature': temperature,
'top_p': topP,
'max_prompt_tokens': maxPromptTokens,
'max_completion_tokens': maxCompletionTokens,
'truncation_strategy': truncationStrategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ class CreateThreadAndRunRequest with _$CreateThreadAndRunRequest {
/// What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.
@JsonKey(includeIfNull: false) @Default(1.0) double? temperature,

/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or temperature but not both.
@JsonKey(name: 'top_p', includeIfNull: false) @Default(1.0) double? topP,

/// The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `complete`. See `incomplete_details` for more info.
@JsonKey(name: 'max_prompt_tokens', includeIfNull: false)
int? maxPromptTokens,
Expand Down Expand Up @@ -84,6 +89,7 @@ class CreateThreadAndRunRequest with _$CreateThreadAndRunRequest {
'tools',
'metadata',
'temperature',
'top_p',
'max_prompt_tokens',
'max_completion_tokens',
'truncation_strategy',
Expand All @@ -96,6 +102,9 @@ class CreateThreadAndRunRequest with _$CreateThreadAndRunRequest {
static const temperatureDefaultValue = 1.0;
static const temperatureMinValue = 0.0;
static const temperatureMaxValue = 2.0;
static const topPDefaultValue = 1.0;
static const topPMinValue = 0.0;
static const topPMaxValue = 1.0;
static const maxPromptTokensMinValue = 256;
static const maxCompletionTokensMinValue = 256;

Expand All @@ -107,6 +116,12 @@ class CreateThreadAndRunRequest with _$CreateThreadAndRunRequest {
if (temperature != null && temperature! > temperatureMaxValue) {
return "The value of 'temperature' cannot be > $temperatureMaxValue";
}
if (topP != null && topP! < topPMinValue) {
return "The value of 'topP' cannot be < $topPMinValue";
}
if (topP != null && topP! > topPMaxValue) {
return "The value of 'topP' cannot be > $topPMaxValue";
}
if (maxPromptTokens != null && maxPromptTokens! < maxPromptTokensMinValue) {
return "The value of 'maxPromptTokens' cannot be < $maxPromptTokensMinValue";
}
Expand All @@ -127,6 +142,7 @@ class CreateThreadAndRunRequest with _$CreateThreadAndRunRequest {
'tools': tools,
'metadata': metadata,
'temperature': temperature,
'top_p': topP,
'max_prompt_tokens': maxPromptTokens,
'max_completion_tokens': maxCompletionTokens,
'truncation_strategy': truncationStrategy,
Expand Down
Loading

0 comments on commit 1d18290

Please sign in to comment.