From 8158572b15c0525b9caa9bc71fbbbee6ab4458fe Mon Sep 17 00:00:00 2001 From: David Miguel Lozano Date: Sat, 20 Jan 2024 12:10:52 +0100 Subject: [PATCH] refactor: Remove tiktoken in favour of countTokens API on VertexAI (#307) --- .../chat_models/vertex_ai/chat_vertex_ai.dart | 50 +++++++++++++++---- .../lib/src/llms/vertex_ai.dart | 37 ++++++++++---- packages/langchain_google/pubspec.yaml | 1 - .../test/chat_models/chat_vertex_ai_test.dart | 13 +---- .../test/llms/vertex_ai_test.dart | 11 ---- 5 files changed, 66 insertions(+), 46 deletions(-) diff --git a/packages/langchain_google/lib/src/chat_models/vertex_ai/chat_vertex_ai.dart b/packages/langchain_google/lib/src/chat_models/vertex_ai/chat_vertex_ai.dart index cb261b9a..97874217 100644 --- a/packages/langchain_google/lib/src/chat_models/vertex_ai/chat_vertex_ai.dart +++ b/packages/langchain_google/lib/src/chat_models/vertex_ai/chat_vertex_ai.dart @@ -1,6 +1,5 @@ import 'package:http/http.dart' as http; import 'package:langchain/langchain.dart'; -import 'package:langchain_tiktoken/langchain_tiktoken.dart'; import 'package:uuid/uuid.dart'; import 'package:vertex_ai/vertex_ai.dart'; @@ -190,20 +189,49 @@ class ChatVertexAI extends BaseChatModel { return result.toChatResult(id, model); } - /// Tokenizes the given prompt using tiktoken. - /// - /// Currently Google does not provide a tokenizer for Vertex AI models. - /// So we use tiktoken and cl100k_base encoding to get an approximation - /// for counting tokens. Mind that the actual tokens will be totally - /// different from the ones used by the Vertex AI model. - /// - /// - [promptValue] The prompt to tokenize. @override Future> tokenize( final PromptValue promptValue, { final ChatVertexAIOptions? options, }) async { - final encoding = getEncoding('cl100k_base'); - return encoding.encode(promptValue.toString()); + throw UnsupportedError( + 'ChatVertexAI does not support tokenize, only countTokens', + ); + } + + @override + Future countTokens( + final PromptValue promptValue, { + final ChatVertexAIOptions? options, + }) async { + final messages = promptValue.toChatMessages(); + String? context; + final vertexMessages = []; + for (final message in messages) { + if (message is SystemChatMessage) { + context = message.content; + continue; + } else { + vertexMessages.add(message.toVertexAIChatMessage()); + } + } + final examples = (options?.examples ?? defaultOptions.examples) + ?.map((final e) => e.toVertexAIChatExample()) + .toList(growable: false); + final model = + options?.model ?? defaultOptions.model ?? throwNullModelError(); + + final res = await client.chat.countTokens( + context: context, + examples: examples, + messages: vertexMessages, + publisher: options?.publisher ?? + ArgumentError.checkNotNull( + defaultOptions.publisher, + 'VertexAIOptions.publisher', + ), + model: model, + ); + return res.totalTokens; } } diff --git a/packages/langchain_google/lib/src/llms/vertex_ai.dart b/packages/langchain_google/lib/src/llms/vertex_ai.dart index 8c63a212..8f1fb6cb 100644 --- a/packages/langchain_google/lib/src/llms/vertex_ai.dart +++ b/packages/langchain_google/lib/src/llms/vertex_ai.dart @@ -1,6 +1,5 @@ import 'package:http/http.dart' as http; import 'package:langchain/langchain.dart'; -import 'package:langchain_tiktoken/langchain_tiktoken.dart'; import 'package:vertex_ai/vertex_ai.dart'; import 'models/mappers.dart'; @@ -80,6 +79,10 @@ import 'models/models.dart'; /// - `text-bison-32k` /// * Max input and output tokens combined: 32k /// * Training data: Up to Aug 2023 +/// - `text-unicorn` +/// * Max input token: 8192 +/// * Max output tokens: 1024 +/// * Training data: Up to Feb 2023 /// /// The previous list of models may not be exhaustive or up-to-date. Check out /// the [Vertex AI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models) @@ -170,20 +173,32 @@ class VertexAI extends BaseLLM { return result.toLLMResult(model); } - /// Tokenizes the given prompt using tiktoken. - /// - /// Currently Google does not provide a tokenizer for Vertex AI models. - /// So we use tiktoken and cl100k_base encoding to get an approximation - /// for counting tokens. Mind that the actual tokens will be totally - /// different from the ones used by the Vertex AI model. - /// - /// - [promptValue] The prompt to tokenize. @override Future> tokenize( final PromptValue promptValue, { final VertexAIOptions? options, }) async { - final encoding = getEncoding('cl100k_base'); - return encoding.encode(promptValue.toString()); + throw UnsupportedError( + 'VertexAI does not support tokenize, only countTokens', + ); + } + + @override + Future countTokens( + final PromptValue promptValue, { + final VertexAIOptions? options, + }) async { + final model = + options?.model ?? defaultOptions.model ?? throwNullModelError(); + final res = await client.text.countTokens( + prompt: promptValue.toString(), + publisher: options?.publisher ?? + ArgumentError.checkNotNull( + defaultOptions.publisher, + 'VertexAIOptions.publisher', + ), + model: model, + ); + return res.totalTokens; } } diff --git a/packages/langchain_google/pubspec.yaml b/packages/langchain_google/pubspec.yaml index 652c2875..9f1abe8d 100644 --- a/packages/langchain_google/pubspec.yaml +++ b/packages/langchain_google/pubspec.yaml @@ -24,7 +24,6 @@ dependencies: googleapis_auth: ^1.4.1 http: ^1.1.0 langchain: ^0.3.2 - langchain_tiktoken: ^1.0.1 meta: ^1.9.1 uuid: ^4.0.0 vertex_ai: ^0.0.8 diff --git a/packages/langchain_google/test/chat_models/chat_vertex_ai_test.dart b/packages/langchain_google/test/chat_models/chat_vertex_ai_test.dart index f8e328b7..cc0bb432 100644 --- a/packages/langchain_google/test/chat_models/chat_vertex_ai_test.dart +++ b/packages/langchain_google/test/chat_models/chat_vertex_ai_test.dart @@ -181,17 +181,6 @@ void main() async { expect(res2.generations.length, 5); }); - test('Test tokenize', () async { - final chat = ChatVertexAI( - httpClient: authHttpClient, - project: Platform.environment['VERTEX_AI_PROJECT_ID']!, - ); - const text = 'Hello, how are you?'; - - final tokens = await chat.tokenize(PromptValue.string(text)); - expect(tokens, [9906, 11, 1268, 527, 499, 30]); - }); - test('Test countTokens string', () async { final chat = ChatVertexAI( httpClient: authHttpClient, @@ -226,7 +215,7 @@ void main() async { ]; final numTokens = await chat.countTokens(PromptValue.chat(messages)); - expect(numTokens, 41); + expect(numTokens, 37); }); }); } diff --git a/packages/langchain_google/test/llms/vertex_ai_test.dart b/packages/langchain_google/test/llms/vertex_ai_test.dart index b44b1234..e72fc7f8 100644 --- a/packages/langchain_google/test/llms/vertex_ai_test.dart +++ b/packages/langchain_google/test/llms/vertex_ai_test.dart @@ -146,17 +146,6 @@ Future main() async { expect(res2.generations.length, 5); }); - test('Test tokenize', () async { - final llm = VertexAI( - httpClient: authHttpClient, - project: Platform.environment['VERTEX_AI_PROJECT_ID']!, - ); - const text = 'Hello, how are you?'; - - final tokens = await llm.tokenize(PromptValue.string(text)); - expect(tokens, [9906, 11, 1268, 527, 499, 30]); - }); - test('Test countTokens', () async { final llm = VertexAI( httpClient: authHttpClient,