Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrades to support remote zoo models #633

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 111 additions & 48 deletions eta/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,12 +1038,17 @@ class Model(Serializable):

Attributes:
base_name: the base name of the model (no version info)
base_filename: the base filename of the model (if any, no version info)
base_filename: the base filename or directory of the model (if any)
(no version info)
subdir: the model's subdirectory (if any)
manager: the ModelManager instance that describes the remote storage
location of the models_dir (if any)
version: the version of the model (if any)
author (optional): the author of the model
version: (optional) the model version
url (optional): the URL where the model is hosted
source (optional): the source of the model
license (optional): the license under which the model is distributed
description: the description of the model (if any)
source: the source of the model (if any)
size_bytes: the size of the model on disk (if any)
default_deployment_config_dict: a dictionary representation of an
`eta.core.learning.ModelConfig` describing the recommended settings
Expand All @@ -1061,10 +1066,14 @@ def __init__(
self,
base_name,
base_filename=None,
subdir=None,
manager=None,
author=None,
version=None,
description=None,
url=None,
source=None,
license=None,
description=None,
size_bytes=None,
default_deployment_config_dict=None,
requirements=None,
Expand All @@ -1076,10 +1085,14 @@ def __init__(
Args:
base_name: the base name of the model
base_filename (optional): the base filename for the model
subdir: the model's subdirectory (if any)
manager (optional): the ModelManager for the model
author (optional): the author of the model
version: (optional) the model version
url (optional): the URL where the model is hosted
source (optional): the source of the model
license (optional): the license under which the model is distributed
description: (optional) the description of the model
source: (optional) the source of the model
size_bytes: (optional) the size of the model on disk
default_deployment_config_dict: (optional) a dictionary
representation of an `eta.core.learning.ModelConfig` describing
Expand All @@ -1090,10 +1103,14 @@ def __init__(
"""
self.base_name = base_name
self.base_filename = base_filename
self.subdir = subdir
self.manager = manager
self.author = author
self.version = version or None
self.description = description
self.url = url
self.source = source
self.license = license
self.description = description
self.size_bytes = size_bytes
self.default_deployment_config_dict = default_deployment_config_dict
self.requirements = requirements
Expand All @@ -1112,14 +1129,19 @@ def name(self):
@property
def filename(self):
"""The version-aware filename of the model."""
if not self.has_version:
return self.base_filename

if self.base_filename is None:
return None

base, ext = os.path.splitext(self.base_filename)
return base + "-v" + self.version + ext
if self.has_version:
base, ext = os.path.splitext(self.base_filename)
filename = base + "-v" + self.version + ext
else:
filename = self.base_filename

if self.subdir is not None:
filename = os.path.join(self.subdir, filename)

return filename

@property
def has_manager(self):
Expand Down Expand Up @@ -1383,17 +1405,11 @@ def parse_name(name):
Returns:
base_name: the base name of the model
version: the version of the model, or None if no version was found

Raises:
ModelError: if the model name was invalid
"""
chunks = name.split("@")
chunks = name.rsplit("@", 1)
if len(chunks) == 1:
return name, None

if chunks[1] == "" or len(chunks) > 2:
raise ModelError("Invalid model name '%s'" % name)

return chunks[0], chunks[1]

@staticmethod
Expand All @@ -1406,7 +1422,7 @@ def has_version_str(name):
Returns:
True/False
"""
return bool(Model.parse_name(name)[1])
return Model.parse_name(name)[1] is not None

def attributes(self):
"""Returns a list of class attributes to be serialized.
Expand All @@ -1417,9 +1433,12 @@ def attributes(self):
return [
"base_name",
"base_filename",
"author",
"version",
"description",
"url",
"source",
"license",
"description",
"size_bytes",
"manager",
"default_deployment_config_dict",
Expand All @@ -1429,11 +1448,12 @@ def attributes(self):
]

@classmethod
def from_dict(cls, d):
def from_dict(cls, d, subdir=None):
"""Constructs a Model from a JSON dictionary.

Args:
d: a JSON dictionary
subdir (optional): a subdirectory for the model

Returns:
a Model instance
Expand All @@ -1453,10 +1473,14 @@ def from_dict(cls, d):
return cls(
d["base_name"],
base_filename=d.get("base_filename", None),
subdir=subdir,
manager=manager,
author=d.get("author", None),
version=d.get("version", None),
description=d.get("description", None),
url=d.get("url", None),
source=d.get("source", None),
license=d.get("license", None),
description=d.get("description", None),
size_bytes=d.get("size_bytes", None),
default_deployment_config_dict=d.get(
"default_deployment_config_dict", None
Expand All @@ -1472,70 +1496,103 @@ class ModelsManifest(Serializable):

_MODEL_CLS = Model

def __init__(self, models=None):
def __init__(self, models=None, name=None, url=None):
"""Creates a ModelsManifest instance.

Args:
models: a list of Model instances
name (optional): a name for the manifest
url (optional): the source location of the manifest
"""
self.models = models or []
if models is None:
models = []

if name is not None:
subdir = os.path.join(*name.split("/"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it is a no-op? You're splitting then joining it back together again. Do you mean to take out the file name? If so, why not use os.path.dirname(name)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's for windows. Converting org-name/model-name to org-name\model-name if necessary

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm I see. A little confusing that name is assigned to subdir then.

for model in models:
model.subdir = subdir
else:
subdir = None

self.models = models
self.name = name
self.url = url
self._subdir = subdir

def __iter__(self):
return iter(self.models)

def add_model(self, model):
@property
def subdir(self):
return self._subdir

def add_model(self, model, error_level=0):
"""Adds the given model to the manifest.

Args:
model: a Model instance
error_level: the error level to use, defined as:

Raises:
ModelError: if the model conflicts with an existing model in the
manifest
0: raise error if the model cannot be added
1: log warning if the model cannot be added
2: ignore models that cannot be added
"""
if self.has_model_with_name(model.name):
raise ModelError(
error_msg = (
"Manifest already contains model called '%s'" % model.name
)
etau.handle_error(ModelError(error_msg), error_level)
return

if model.filename is not None and self.has_model_with_filename(
model.filename
):
raise ModelError(
if self.has_model_with_filename(model):
error_msg = (
"Manifest already contains model with filename '%s'"
% (model.filename)
% model.filename
)
etau.handle_error(ModelError(error_msg), error_level)
return

if self.has_model_with_name(model.base_name):
raise ModelError(
error_msg = (
"Manifest already contains a versionless model called '%s', "
"so a versioned model is not allowed" % model.base_name
)
"so a versioned model is not allowed"
) % model.base_name
etau.handle_error(ModelError(error_msg), error_level)
return

self.models.append(model)

def remove_model(self, name):
def remove_model(self, name, error_level=0):
"""Removes the model with the given name from the ModelsManifest.

Args:
name: the name of the model
error_level: the error level to use, defined as:

Raises:
ModelError: if the model was not found
0: raise error if the model cannot be added
1: log warning if the model cannot be added
2: ignore models that cannot be added
"""
if not self.has_model_with_name(name):
raise ModelError("Manifest does not contain model '%s'" % name)
error_msg = "Manifest does not contain model '%s'" % name
etau.handle_error(ModelError(error_msg), error_level)
return

self.models = [model for model in self.models if model.name != name]

def merge(self, models_manifest):
def merge(self, models_manifest, error_level=0):
"""Merges the models manifest into this one.

Args:
models_manifest: a ModelsManifest
error_level: the error level to use, defined as:

0: raise error if a model cannot be added
1: log warning if a model cannot be added
2: ignore models that cannot be added
"""
for model in models_manifest:
self.add_model(model)
self.add_model(model, error_level=error_level)

def get_model_with_name(self, name):
"""Gets the model with the given name.
Expand Down Expand Up @@ -1593,17 +1650,20 @@ def has_model_with_name(self, name):
"""
return any(name == model.name for model in self.models)

def has_model_with_filename(self, filename):
"""Determines whether this manifest contains a model with the given
def has_model_with_filename(self, model):
"""Determines whether this manifest contains a model with a conflicting
filename.

Args:
filename: the filename
model: a Model instance

Returns:
True/False
"""
return any(filename == model.filename for model in self.models)
if model.filename is None:
return False

return any(model.filename == m.filename for m in self.models)

@staticmethod
def make_manifest_path(models_dir):
Expand Down Expand Up @@ -1664,7 +1724,10 @@ def from_dict(cls, d):
Returns:
a ModelsManifest
"""
return cls(models=[cls._MODEL_CLS.from_dict(md) for md in d["models"]])
models = [cls._MODEL_CLS.from_dict(md) for md in d.get("models", [])]
name = d.get("name", None)
url = d.get("url", None)
return cls(models=models, name=name, url=url)


class ModelManager(Configurable, Serializable):
Expand Down