From 7f77239601fb216a01ec9d25680ec4d3dc4b97c7 Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Sat, 15 Jun 2024 10:27:18 +0200 Subject: [PATCH] refactor: Simplify how tools are passed to the internal Firebase client (#459) --- .../vertex_ai/chat_firebase_vertex_ai.dart | 40 +++++-------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/packages/langchain_firebase/lib/src/chat_models/vertex_ai/chat_firebase_vertex_ai.dart b/packages/langchain_firebase/lib/src/chat_models/vertex_ai/chat_firebase_vertex_ai.dart index 3d58f8ea..47661d68 100644 --- a/packages/langchain_firebase/lib/src/chat_models/vertex_ai/chat_firebase_vertex_ai.dart +++ b/packages/langchain_firebase/lib/src/chat_models/vertex_ai/chat_firebase_vertex_ai.dart @@ -4,7 +4,6 @@ import 'package:firebase_core/firebase_core.dart'; import 'package:firebase_vertexai/firebase_vertexai.dart'; import 'package:langchain_core/chat_models.dart'; import 'package:langchain_core/prompts.dart'; -import 'package:langchain_core/tools.dart'; import 'package:uuid/uuid.dart'; import 'mappers.dart'; @@ -193,24 +192,20 @@ class ChatFirebaseVertexAI extends BaseChatModel { /// The current system instruction set in [_firebaseClient]; String? _currentSystemInstruction; - /// The current tools set in [_firebaseClient]; - List? _currentTools; - - /// The current tool choice set in [_firebaseClient]; - ChatToolChoice? _currentToolChoice; - @override Future invoke( final PromptValue input, { final ChatFirebaseVertexAIOptions? options, }) async { final id = _uuid.v4(); - final (model, prompt, safetySettings, generationConfig) = + final (model, prompt, safetySettings, generationConfig, tools, toolConfig) = _generateCompletionRequest(input.toChatMessages(), options: options); final completion = await _firebaseClient.generateContent( prompt, safetySettings: safetySettings, generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, ); return completion.toChatResult(id, model); } @@ -221,13 +216,15 @@ class ChatFirebaseVertexAI extends BaseChatModel { final ChatFirebaseVertexAIOptions? options, }) { final id = _uuid.v4(); - final (model, prompt, safetySettings, generationConfig) = + final (model, prompt, safetySettings, generationConfig, tools, toolConfig) = _generateCompletionRequest(input.toChatMessages(), options: options); return _firebaseClient .generateContentStream( prompt, safetySettings: safetySettings, generationConfig: generationConfig, + tools: tools, + toolConfig: toolConfig, ) .map((final completion) => completion.toChatResult(id, model)); } @@ -238,6 +235,8 @@ class ChatFirebaseVertexAI extends BaseChatModel { Iterable prompt, List? safetySettings, GenerationConfig? generationConfig, + List? tools, + ToolConfig? toolConfig, ) _generateCompletionRequest( final List messages, { final ChatFirebaseVertexAIOptions? options, @@ -260,6 +259,8 @@ class ChatFirebaseVertexAI extends BaseChatModel { topP: options?.topP ?? defaultOptions.topP, topK: options?.topK ?? defaultOptions.topK, ), + (options?.tools ?? defaultOptions.tools)?.toToolList(), + (options?.toolChoice ?? defaultOptions.toolChoice)?.toToolConfig(), ); } @@ -288,8 +289,6 @@ class ChatFirebaseVertexAI extends BaseChatModel { GenerativeModel _createFirebaseClient( final String model, { final String? systemInstruction, - final List? tools, - final ChatToolChoice? toolChoice, }) { return FirebaseVertexAI.instanceFor( app: app, @@ -300,8 +299,6 @@ class ChatFirebaseVertexAI extends BaseChatModel { model: model, systemInstruction: systemInstruction != null ? Content.system(systemInstruction) : null, - tools: tools?.toToolList(), - toolConfig: toolChoice?.toToolConfig(), ); } @@ -309,14 +306,10 @@ class ChatFirebaseVertexAI extends BaseChatModel { void _recreateFirebaseClient( final String model, final String? systemInstruction, - final List? tools, - final ChatToolChoice? toolChoice, ) { _firebaseClient = _createFirebaseClient( model, systemInstruction: systemInstruction, - tools: tools, - toolChoice: toolChoice, ); } @@ -332,9 +325,6 @@ class ChatFirebaseVertexAI extends BaseChatModel { ? messages.firstOrNull?.contentAsString : null; - final tools = options?.tools ?? defaultOptions.tools; - final toolChoice = options?.toolChoice ?? defaultOptions.toolChoice; - bool recreate = false; if (model != _currentModel) { _currentModel = model; @@ -344,17 +334,9 @@ class ChatFirebaseVertexAI extends BaseChatModel { _currentSystemInstruction = systemInstruction; recreate = true; } - if (!const ListEquality().equals(tools, _currentTools)) { - _currentTools = tools; - recreate = true; - } - if (toolChoice != _currentToolChoice) { - _currentToolChoice = toolChoice; - recreate = true; - } if (recreate) { - _recreateFirebaseClient(model, systemInstruction, tools, toolChoice); + _recreateFirebaseClient(model, systemInstruction); } } }