-
Notifications
You must be signed in to change notification settings - Fork 190
/
Copy pathembeddings.py
544 lines (492 loc) · 20.6 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
import logging
import re
import string
import threading
import warnings
from concurrent.futures import ThreadPoolExecutor, wait
from enum import Enum, auto
from typing import Any, Dict, List, Literal, Optional, Tuple, Type
from google.api_core.exceptions import (
Aborted,
DeadlineExceeded,
InternalServerError,
InvalidArgument,
ResourceExhausted,
ServiceUnavailable,
)
from google.cloud.aiplatform import telemetry
from langchain_core._api.deprecation import deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from vertexai.generative_models._generative_models import ( # type: ignore[import-untyped]
SafetySettingsType as SafetySettingsType,
)
from vertexai.language_models import ( # type: ignore
TextEmbeddingInput,
TextEmbeddingModel,
)
from vertexai.vision_models import ( # type: ignore
Image,
MultiModalEmbeddingModel,
MultiModalEmbeddingResponse,
)
from langchain_google_vertexai._base import _VertexAICommon
from langchain_google_vertexai._image_utils import ImageBytesLoader
from langchain_google_vertexai._utils import get_user_agent
logger = logging.getLogger(__name__)
_MAX_TOKENS_PER_BATCH = 20000
_MAX_BATCH_SIZE = 250
_MIN_BATCH_SIZE = 5
class GoogleEmbeddingModelType(str, Enum):
TEXT = auto()
MULTIMODAL = auto()
@classmethod
def _missing_(cls, value: Any) -> Optional["GoogleEmbeddingModelType"]:
if value.lower().startswith("text"):
return GoogleEmbeddingModelType.TEXT
if "multimodalembedding" in value.lower():
return GoogleEmbeddingModelType.MULTIMODAL
return None
class GoogleEmbeddingModelVersion(str, Enum):
EMBEDDINGS_JUNE_2023 = auto()
EMBEDDINGS_NOV_2023 = auto()
EMBEDDINGS_DEC_2023 = auto()
EMBEDDINGS_MAY_2024 = auto()
@classmethod
def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
if "textembedding-gecko@001" in value.lower():
return GoogleEmbeddingModelVersion.EMBEDDINGS_JUNE_2023
if (
"textembedding-gecko@002" in value.lower()
or "textembedding-gecko-multilingual@001" in value.lower()
):
return GoogleEmbeddingModelVersion.EMBEDDINGS_NOV_2023
if "textembedding-gecko@003" in value.lower():
return GoogleEmbeddingModelVersion.EMBEDDINGS_DEC_2023
if (
"text-embedding-004" in value.lower()
or "text-multilingual-embedding-002" in value.lower()
or "text-embedding-preview-0409" in value.lower()
or "text-multilingual-embedding-preview-0409" in value.lower()
):
return GoogleEmbeddingModelVersion.EMBEDDINGS_MAY_2024
return GoogleEmbeddingModelVersion.EMBEDDINGS_JUNE_2023
@property
def task_type_supported(self) -> bool:
"""
Checks if the model generation supports task type.
"""
return self != GoogleEmbeddingModelVersion.EMBEDDINGS_JUNE_2023
@property
def output_dimensionality_supported(self) -> bool:
"""
Checks if the model generation supports output dimensionality.
"""
return self == GoogleEmbeddingModelVersion.EMBEDDINGS_MAY_2024
class VertexAIEmbeddings(_VertexAICommon, Embeddings):
"""Google Cloud VertexAI embedding models."""
# Instance context
instance: Dict[str, Any] = {} #: :meta private:
model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
)
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates that the python package exists in environment."""
values = {
"project": self.project,
"location": self.location,
"credentials": self.credentials,
"api_transport": self.api_transport,
"api_endpoint": self.api_endpoint,
"default_metadata": self.default_metadata,
}
self._init_vertexai(values)
_, user_agent = get_user_agent(f"{self.__class__.__name__}_{self.model_name}")
with telemetry.tool_context_manager(user_agent):
if (
GoogleEmbeddingModelType(self.model_name)
== GoogleEmbeddingModelType.MULTIMODAL
):
self.client = MultiModalEmbeddingModel.from_pretrained(self.model_name)
else:
self.client = TextEmbeddingModel.from_pretrained(self.model_name)
return self
def __init__(
self,
model_name: Optional[str] = None,
project: Optional[str] = None,
location: str = "us-central1",
request_parallelism: int = 5,
max_retries: int = 6,
credentials: Optional[Any] = None,
**kwargs: Any,
):
"""Initialize the sentence_transformer."""
if model_name:
kwargs["model_name"] = model_name
super().__init__(
project=project,
location=location,
credentials=credentials,
request_parallelism=request_parallelism,
max_retries=max_retries,
**kwargs,
)
self.instance["max_batch_size"] = kwargs.get("max_batch_size", _MAX_BATCH_SIZE)
self.instance["batch_size"] = self.instance["max_batch_size"]
self.instance["min_batch_size"] = kwargs.get("min_batch_size", _MIN_BATCH_SIZE)
self.instance["min_good_batch_size"] = self.instance["min_batch_size"]
self.instance["lock"] = threading.Lock()
self.instance["batch_size_validated"] = False
self.instance["task_executor"] = ThreadPoolExecutor(
max_workers=request_parallelism
)
retry_errors: List[Type[BaseException]] = [
ResourceExhausted,
ServiceUnavailable,
Aborted,
DeadlineExceeded,
InternalServerError,
]
retry_decorator = create_base_retry_decorator(
error_types=retry_errors, max_retries=self.max_retries
)
self.instance["get_embeddings_with_retry"] = retry_decorator(
self.client.get_embeddings
)
@property
def model_type(self) -> str:
return GoogleEmbeddingModelType(self.model_name)
@property
def model_version(self) -> GoogleEmbeddingModelVersion:
return GoogleEmbeddingModelVersion(self.model_name)
@staticmethod
def _split_by_punctuation(text: str) -> List[str]:
"""Splits a string by punctuation and whitespace characters."""
split_by = string.punctuation + "\t\n "
pattern = f"([{split_by}])"
# Using re.split to split the text based on the pattern
return [segment for segment in re.split(pattern, text) if segment]
@staticmethod
def _prepare_batches(texts: List[str], batch_size: int) -> List[List[str]]:
"""Splits texts in batches based on current maximum batch size
and maximum tokens per request.
"""
text_index = 0
texts_len = len(texts)
batch_token_len = 0
batches: List[List[str]] = []
current_batch: List[str] = []
if texts_len == 0:
return []
while text_index < texts_len:
current_text = texts[text_index]
# Number of tokens per a text is conservatively estimated
# as 2 times number of words, punctuation and whitespace characters.
# Using `count_tokens` API will make batching too expensive.
# Utilizing a tokenizer, would add a dependency that would not
# necessarily be reused by the application using this class.
current_text_token_cnt = (
len(VertexAIEmbeddings._split_by_punctuation(current_text)) * 2
)
end_of_batch = False
if current_text_token_cnt > _MAX_TOKENS_PER_BATCH:
# Current text is too big even for a single batch.
# Such request will fail, but we still make a batch
# so that the app can get the error from the API.
if len(current_batch) > 0:
# Adding current batch if not empty.
batches.append(current_batch)
current_batch = [current_text]
text_index += 1
end_of_batch = True
elif (
batch_token_len + current_text_token_cnt > _MAX_TOKENS_PER_BATCH
or len(current_batch) == batch_size
):
end_of_batch = True
else:
if text_index == texts_len - 1:
# Last element - even though the batch may be not big,
# we still need to make it.
end_of_batch = True
batch_token_len += current_text_token_cnt
current_batch.append(current_text)
text_index += 1
if end_of_batch:
batches.append(current_batch)
current_batch = []
batch_token_len = 0
return batches
def _get_embeddings_with_retry(
self,
texts: List[str],
embeddings_type: Optional[str] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
with telemetry.tool_context_manager(self._user_agent):
if self.model_type == GoogleEmbeddingModelType.MULTIMODAL:
return self._get_multimodal_embeddings_with_retry(texts, dimensions)
return self._get_text_embeddings_with_retry(
texts, embeddings_type=embeddings_type, output_dimensionality=dimensions
)
def _get_multimodal_embeddings_with_retry(
self, texts: List[str], dimensions: Optional[int] = None
) -> List[List[float]]:
tasks = []
for text in texts:
tasks.append(
self.instance["task_executor"].submit(
self.instance["get_embeddings_with_retry"],
contextual_text=text,
dimension=dimensions,
)
)
if len(tasks) > 0:
wait(tasks)
embeddings = [task.result().text_embedding for task in tasks]
return embeddings
def _get_text_embeddings_with_retry(
self,
texts: List[str],
embeddings_type: Optional[str] = None,
output_dimensionality: Optional[int] = None,
) -> List[List[float]]:
"""Makes a Vertex AI model request with retry logic."""
if embeddings_type and self.model_version.task_type_supported:
requests = [
TextEmbeddingInput(text=t, task_type=embeddings_type) for t in texts
]
else:
requests = texts
kwargs = {}
if output_dimensionality and self.model_version.output_dimensionality_supported:
kwargs["output_dimensionality"] = output_dimensionality
embeddings = self.instance["get_embeddings_with_retry"](requests, **kwargs)
return [embedding.values for embedding in embeddings]
def _prepare_and_validate_batches(
self, texts: List[str], embeddings_type: Optional[str] = None
) -> Tuple[List[List[float]], List[List[str]]]:
"""Prepares text batches with one-time validation of batch size.
Batch size varies between GCP regions and individual project quotas.
# Returns embeddings of the first text batch that went through,
# and text batches for the rest of the texts.
"""
batches = VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# If batch size if less or equal to one that went through before,
# then keep batches as they are.
if len(batches[0]) <= self.instance["min_good_batch_size"]:
return [], batches
with self.instance["lock"]:
# If largest possible batch size was validated
# while waiting for the lock, then check for rebuilding
# our batches, and return.
if self.instance["batch_size_validated"]:
if len(batches[0]) <= self.instance["batch_size"]:
return [], batches
else:
return [], VertexAIEmbeddings._prepare_batches(
texts, self.instance["batch_size"]
)
# Figure out the largest possible batch size by trying to push
# batches and lowering their size in half after every failure.
first_batch = batches[0]
first_result = []
had_failure = False
while True:
try:
first_result = self._get_embeddings_with_retry(
first_batch, embeddings_type
)
break
except InvalidArgument:
had_failure = True
first_batch_len = len(first_batch)
if first_batch_len == self.instance["min_batch_size"]:
raise
first_batch_len = max(
self.instance["min_batch_size"], int(first_batch_len / 2)
)
first_batch = first_batch[:first_batch_len]
first_batch_len = len(first_batch)
self.instance["min_good_batch_size"] = max(
self.instance["min_good_batch_size"], first_batch_len
)
# If had a failure and recovered
# or went through with the max size, then it's a legit batch size.
if had_failure or first_batch_len == self.instance["max_batch_size"]:
self.instance["batch_size"] = first_batch_len
self.instance["batch_size_validated"] = True
# If batch size was updated,
# rebuild batches with the new batch size
# (texts that went through are excluded here).
if first_batch_len != self.instance["max_batch_size"]:
batches = VertexAIEmbeddings._prepare_batches(
texts[first_batch_len:], self.instance["batch_size"]
)
else:
batches = batches[1:]
else:
# Still figuring out max batch size.
batches = batches[1:]
# Returning embeddings of the first text batch that went through,
# and text batches for the rest of texts.
return first_result, batches
def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Embed a list of strings.
Args:
texts: List[str] The list of strings to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
embeddings_task_type: [str] optional embeddings task type,
one of the following
RETRIEVAL_QUERY - Text is a query
in a search/retrieval setting.
RETRIEVAL_DOCUMENT - Text is a document
in a search/retrieval setting.
SEMANTIC_SIMILARITY - Embeddings will be used
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
The following are only supported on preview models:
QUESTION_ANSWERING
FACT_VERIFICATION
dimensions: [int] optional. Output embeddings dimensions.
Only supported on preview models.
Returns:
List of embeddings, one for each text.
"""
if len(texts) == 0:
return []
embeddings: List[List[float]] = []
first_batch_result: List[List[float]] = []
if batch_size > 0:
# Fixed batch size.
batches = VertexAIEmbeddings._prepare_batches(texts, batch_size)
else:
# Dynamic batch size, starting from 250 at the first call.
first_batch_result, batches = self._prepare_and_validate_batches(
texts, embeddings_task_type
)
# First batch result may have some embeddings already.
# In such case, batches have texts that were not processed yet.
embeddings.extend(first_batch_result)
tasks = []
for batch in batches:
tasks.append(
self.instance["task_executor"].submit(
self._get_embeddings_with_retry,
texts=batch,
embeddings_type=embeddings_task_type,
dimensions=dimensions,
)
)
if len(tasks) > 0:
wait(tasks)
for t in tasks:
embeddings.extend(t.result())
return embeddings
def embed_documents(
self, texts: List[str], batch_size: int = 0
) -> List[List[float]]:
"""Embed a list of documents.
Args:
texts: List[str] The list of texts to embed.
batch_size: [int] The batch size of embeddings to send to the model.
If zero, then the largest batch size will be detected dynamically
at the first request, starting from 250, down to 5.
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
def embed_query(self, text: str) -> List[float]:
"""Embed a text.
Args:
text: The text to embed.
Returns:
Embedding for the text.
"""
return self.embed([text], 1, "RETRIEVAL_QUERY")[0]
@deprecated(
since="2.0.1", removal="3.0.0", alternative="VertexAIEmbeddings.embed_images()"
)
def embed_image(
self,
image_path: str,
contextual_text: Optional[str] = None,
dimensions: Optional[int] = None,
) -> List[float]:
"""Embed an image.
Args:
image_path: Path to image (Google Cloud Storage or web) to generate
embeddings for.
contextual_text: Text to generate embeddings for.
Returns:
Embedding for the image.
"""
warnings.warn(
"The `embed_image()` API will be deprecated and replaced by \
`embed_images()`. Change your usage to \
`embed_images([image_path1, image_path2])` and note\
that the result returned will be a list of image embeddings."
)
if self.model_type != GoogleEmbeddingModelType.MULTIMODAL:
raise NotImplementedError("Only supported for multimodal models")
image_loader = ImageBytesLoader()
bytes_image = image_loader.load_bytes(image_path)
image = Image(bytes_image)
result: MultiModalEmbeddingResponse = self.instance[
"get_embeddings_with_retry"
](image=image, contextual_text=contextual_text, dimension=dimensions)
return result.image_embedding
def embed_images(
self,
uris: List[str],
contextual_text: Optional[str] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Embed a list of images.
Args:
uris: Paths to image (local, Google Cloud Storage or web) to generate
embeddings for.
contextual_text: Text to generate embeddings for.
Returns:
Embedding for the image.
"""
if self.model_type != GoogleEmbeddingModelType.MULTIMODAL:
raise NotImplementedError("Only supported for multimodal models")
image_loader = ImageBytesLoader()
embeddings = []
for image_path in uris:
bytes_image = image_loader.load_bytes(image_path)
image = Image(bytes_image)
result: MultiModalEmbeddingResponse = self.instance[
"get_embeddings_with_retry"
](image=image, contextual_text=contextual_text, dimension=dimensions)
embeddings.append(result.image_embedding)
return embeddings
VertexAIEmbeddings.model_rebuild()