Skip to content

Commit

Permalink
Merge pull request #224 from markstur/add_embedding_task
Browse files Browse the repository at this point in the history
Add embedding task
  • Loading branch information
gkumbhat authored Nov 20, 2023
2 parents 798a3c0 + 2b8b134 commit 316ead6
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ __pycache__/

# Distribution / packaging
data/classification_data/
models/
tmp_models/
models_to_upload/
upload_models/
Expand Down
3 changes: 2 additions & 1 deletion caikit_nlp/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"""

# Local
from . import generation
from . import embedding_vectors, generation
from .embedding_vectors import *
from .generation import *
163 changes: 163 additions & 0 deletions caikit_nlp/data_model/embedding_vectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data structures for embedding vector representations
"""
# Standard
from dataclasses import dataclass, field
from typing import List, Union
import json

# Third Party
from google.protobuf import json_format
import numpy as np

# First Party
from caikit.core import DataObjectBase, dataobject
from caikit.core.exceptions import error_handler
import alog

log = alog.use_channel("DATAM")
error = error_handler.get(log)


@dataobject(package="caikit_data_model.caikit_nlp")
@dataclass
class PyFloatSequence(DataObjectBase):
values: List[float] = field(default_factory=list)


@dataobject(package="caikit_data_model.caikit_nlp")
@dataclass
class NpFloat32Sequence(DataObjectBase):
values: List[np.float32]

@classmethod
def from_proto(cls, proto):
values = np.asarray(proto.values, dtype=np.float32)
return cls(values)


@dataobject(package="caikit_data_model.caikit_nlp")
@dataclass
class NpFloat64Sequence(DataObjectBase):
values: List[np.float64]

@classmethod
def from_proto(cls, proto):
values = np.asarray(proto.values, dtype=np.float64)
return cls(values)


@dataobject(package="caikit_data_model.caikit_nlp")
@dataclass
class Vector1D(DataObjectBase):
"""Data representation for a 1 dimension vector of float-type data."""

data: Union[
PyFloatSequence,
NpFloat32Sequence,
NpFloat64Sequence,
]

def __post_init__(self):
error.value_check(
"<NLP92989048E>",
hasattr(self.data, "values"),
ValueError("Vector1D requires a float sequence data object with values."),
)

@classmethod
def from_vector(cls, vector):
if vector.dtype == np.float32:
data = NpFloat32Sequence(vector)
elif vector.dtype == np.float64:
data = NpFloat64Sequence(vector)
else:
data = PyFloatSequence(vector)
return cls(data=data)

@classmethod
def from_json(cls, json_str):
"""JSON does not have different float types. Move data into data_pyfloatsequence"""

json_obj = json.loads(json_str) if isinstance(json_str, str) else json_str
data = json_obj.pop("data")
if data is not None:
json_obj["data_pyfloatsequence"] = data

json_str = json.dumps(json_obj)
try:
# Parse given JSON into google.protobufs.pyext.cpp_message.GeneratedProtocolMessageType
parsed_proto = json_format.Parse(
json_str, cls.get_proto_class()(), ignore_unknown_fields=False
)

# Use from_proto to return the DataBase object from the parsed proto
return cls.from_proto(parsed_proto)

except json_format.ParseError as ex:
error("<NLP39795399E>", ValueError(ex))

def to_dict(self) -> dict:
"""to_dict is needed to make things serializable"""
values = self.data.values if self.data.values is not None else []
return {
"data": {
# coerce numpy.ndarray and numpy.float32 into JSON serializable list of floats
"values": values.tolist()
if isinstance(values, np.ndarray)
else values
}
}

@classmethod
def from_proto(cls, proto):
"""Wrap the data in an appropriate float sequence, wrapped by this class"""
woo = proto.WhichOneof("data")
if woo is None:
return cls(PyFloatSequence())

woo_data = getattr(proto, woo)
if woo == "data_npfloat64sequence":
ret = cls(NpFloat64Sequence.from_proto(woo_data))
elif woo == "data_npfloat32sequence":
ret = cls(NpFloat32Sequence.from_proto(woo_data))
else:
ret = cls(PyFloatSequence.from_proto(woo_data))
return ret

def fill_proto(self, proto):
"""Fill in the data in an appropriate data_<float type sequence>"""
values = self.data.values
if values is not None and len(values) > 0:
sample = values[0]
error.type_check(
"<NLP47515960E>", float, np.float32, np.float64, sample=sample
)
if isinstance(sample, np.float64):
proto.data_npfloat64sequence.values.extend(values)
elif isinstance(sample, np.float32):
proto.data_npfloat32sequence.values.extend(values)
else:
proto.data_pyfloatsequence.values.extend(values)

return proto


@dataobject(package="caikit_data_model.caikit_nlp")
@dataclass
class EmbeddingResult(DataObjectBase):
"""Result from text embedding task"""

result: Vector1D
2 changes: 1 addition & 1 deletion caikit_nlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Local
from . import text_classification, text_generation, token_classification
from . import text_classification, text_embedding, text_generation, token_classification
17 changes: 17 additions & 0 deletions caikit_nlp/modules/text_embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .embedding import EmbeddingModule
from .embedding_tasks import EmbeddingTask
138 changes: 138 additions & 0 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import os

# Third Party
from sentence_transformers import SentenceTransformer

# First Party
from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.core.exceptions import error_handler
import alog

# Local
from .embedding_tasks import EmbeddingTask
from caikit_nlp.data_model.embedding_vectors import EmbeddingResult, Vector1D

logger = alog.use_channel("TXT_EMB")
error = error_handler.get(logger)


@module(
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
"EmbeddingModule",
"0.0.1",
EmbeddingTask,
)
class EmbeddingModule(ModuleBase):

_ARTIFACTS_PATH_KEY = "artifacts_path"
_ARTIFACTS_PATH_DEFAULT = "artifacts"

def __init__(
self,
model: SentenceTransformer,
):
super().__init__()
self.model = model

@classmethod
def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
"""Load model
Args:
model_path: str
Path to the config dir under the model_id (where the config.yml lives)
Returns:
EmbeddingModule
Instance of this class built from the model.
"""

config = ModuleConfig.load(model_path)
artifacts_path = config.get(cls._ARTIFACTS_PATH_KEY)

error.value_check(
"<NLP07391618E>",
artifacts_path,
ValueError(f"Model config missing '{cls._ARTIFACTS_PATH_KEY}'"),
)

artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path))
error.dir_check("<NLP34197772E>", artifacts_path)

return cls.bootstrap(model_name_or_path=artifacts_path)

def run(
self, input: str, **kwargs # pylint: disable=redefined-builtin
) -> EmbeddingResult:
"""Run inference on model.
Args:
input: str
Input text to be processed
Returns:
EmbeddingResult: the result vector nicely wrapped up
"""
error.type_check("<NLP27491611E>", str, input=input)

return EmbeddingResult(Vector1D.from_vector(self.model.encode(input)))

@classmethod
def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule":
"""Bootstrap a sentence-transformers model
Args:
model_name_or_path: str
Model name (Hugging Face hub) or path to model to load.
"""
return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path))

def save(self, model_path: str, *args, **kwargs):
"""Save model using config in model_path
Args:
model_path: str
Path to model config
"""

model_config_path = model_path # because the param name is misleading

error.type_check("<NLP82314992E>", str, model_path=model_config_path)
error.value_check(
"<NLP40145207E>",
model_config_path is not None and model_config_path.strip(),
f"model_path '{model_config_path}' is invalid",
)

model_config_path = os.path.abspath(
model_config_path.strip()
) # No leading/trailing spaces sneaky weirdness

# Only allow new dirs because there are not enough controls to safely update in-place
os.makedirs(model_config_path, exist_ok=False)

saver = ModuleSaver(
module=self,
model_path=model_config_path,
)
artifacts_path = self._ARTIFACTS_PATH_DEFAULT
saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path})

# Save the model
self.model.save(os.path.join(model_config_path, artifacts_path))

# Save the config
ModuleConfig(saver.config).save(model_config_path)
29 changes: 29 additions & 0 deletions caikit_nlp/modules/text_embedding/embedding_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard

# First Party
from caikit.core import TaskBase, task

# Local
from ...data_model import EmbeddingResult


@task(
required_parameters={"input": str},
output_type=EmbeddingResult,
)
class EmbeddingTask(TaskBase):
pass
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"pandas>=1.5.0",
"scikit-learn>=1.1",
"scipy>=1.8.1",
"sentence-transformers~=2.2.2",
"tokenizers>=0.13.3",
"torch>=2.0.1",
"tqdm>=4.65.0",
Expand Down
Loading

0 comments on commit 316ead6

Please sign in to comment.