Skip to content

Commit

Permalink
refactor: Simplify how tools are passed to the internal Firebase clie…
Browse files Browse the repository at this point in the history
…nt (#459)
  • Loading branch information
davidmigloz authored Jun 15, 2024
1 parent d3c96c5 commit 7f77239
Showing 1 changed file with 11 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import 'package:firebase_core/firebase_core.dart';
import 'package:firebase_vertexai/firebase_vertexai.dart';
import 'package:langchain_core/chat_models.dart';
import 'package:langchain_core/prompts.dart';
import 'package:langchain_core/tools.dart';
import 'package:uuid/uuid.dart';

import 'mappers.dart';
Expand Down Expand Up @@ -193,24 +192,20 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
/// The current system instruction set in [_firebaseClient];
String? _currentSystemInstruction;

/// The current tools set in [_firebaseClient];
List<ToolSpec>? _currentTools;

/// The current tool choice set in [_firebaseClient];
ChatToolChoice? _currentToolChoice;

@override
Future<ChatResult> invoke(
final PromptValue input, {
final ChatFirebaseVertexAIOptions? options,
}) async {
final id = _uuid.v4();
final (model, prompt, safetySettings, generationConfig) =
final (model, prompt, safetySettings, generationConfig, tools, toolConfig) =
_generateCompletionRequest(input.toChatMessages(), options: options);
final completion = await _firebaseClient.generateContent(
prompt,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
);
return completion.toChatResult(id, model);
}
Expand All @@ -221,13 +216,15 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
final ChatFirebaseVertexAIOptions? options,
}) {
final id = _uuid.v4();
final (model, prompt, safetySettings, generationConfig) =
final (model, prompt, safetySettings, generationConfig, tools, toolConfig) =
_generateCompletionRequest(input.toChatMessages(), options: options);
return _firebaseClient
.generateContentStream(
prompt,
safetySettings: safetySettings,
generationConfig: generationConfig,
tools: tools,
toolConfig: toolConfig,
)
.map((final completion) => completion.toChatResult(id, model));
}
Expand All @@ -238,6 +235,8 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
Iterable<Content> prompt,
List<SafetySetting>? safetySettings,
GenerationConfig? generationConfig,
List<Tool>? tools,
ToolConfig? toolConfig,
) _generateCompletionRequest(
final List<ChatMessage> messages, {
final ChatFirebaseVertexAIOptions? options,
Expand All @@ -260,6 +259,8 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
topP: options?.topP ?? defaultOptions.topP,
topK: options?.topK ?? defaultOptions.topK,
),
(options?.tools ?? defaultOptions.tools)?.toToolList(),
(options?.toolChoice ?? defaultOptions.toolChoice)?.toToolConfig(),
);
}

Expand Down Expand Up @@ -288,8 +289,6 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
GenerativeModel _createFirebaseClient(
final String model, {
final String? systemInstruction,
final List<ToolSpec>? tools,
final ChatToolChoice? toolChoice,
}) {
return FirebaseVertexAI.instanceFor(
app: app,
Expand All @@ -300,23 +299,17 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
model: model,
systemInstruction:
systemInstruction != null ? Content.system(systemInstruction) : null,
tools: tools?.toToolList(),
toolConfig: toolChoice?.toToolConfig(),
);
}

/// Recreate the [GenerativeModel] instance.
void _recreateFirebaseClient(
final String model,
final String? systemInstruction,
final List<ToolSpec>? tools,
final ChatToolChoice? toolChoice,
) {
_firebaseClient = _createFirebaseClient(
model,
systemInstruction: systemInstruction,
tools: tools,
toolChoice: toolChoice,
);
}

Expand All @@ -332,9 +325,6 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
? messages.firstOrNull?.contentAsString
: null;

final tools = options?.tools ?? defaultOptions.tools;
final toolChoice = options?.toolChoice ?? defaultOptions.toolChoice;

bool recreate = false;
if (model != _currentModel) {
_currentModel = model;
Expand All @@ -344,17 +334,9 @@ class ChatFirebaseVertexAI extends BaseChatModel<ChatFirebaseVertexAIOptions> {
_currentSystemInstruction = systemInstruction;
recreate = true;
}
if (!const ListEquality<ToolSpec>().equals(tools, _currentTools)) {
_currentTools = tools;
recreate = true;
}
if (toolChoice != _currentToolChoice) {
_currentToolChoice = toolChoice;
recreate = true;
}

if (recreate) {
_recreateFirebaseClient(model, systemInstruction, tools, toolChoice);
_recreateFirebaseClient(model, systemInstruction);
}
}
}

0 comments on commit 7f77239

Please sign in to comment.