Skip to content

Commit

Permalink
Per-review feedback: save() code is cleaner and GUIDS are guids.
Browse files Browse the repository at this point in the history
* Removed some unnecessary things in save()
* Lowercase module GUID for consistency

Signed-off-by: markstur <mark.sturdevant@ibm.com>
  • Loading branch information
markstur committed Nov 14, 2023
1 parent 2d61326 commit 2b8b134
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


@module(
"EEB12558-B4FA-4F34-A9FD-3F5890E9CD3F",
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
"EmbeddingModule",
"0.0.1",
EmbeddingTask,
Expand Down Expand Up @@ -121,22 +121,18 @@ def save(self, model_path: str, *args, **kwargs):
model_config_path.strip()
) # No leading/trailing spaces sneaky weirdness

# Only allow new dirs because there are not enough controls to safely update in-place
os.makedirs(model_config_path, exist_ok=False)

saver = ModuleSaver(
module=self,
model_path=model_config_path,
)

# Get and update config (artifacts_path)
artifacts_path = saver.config.get(self._ARTIFACTS_PATH_KEY)
if not artifacts_path:
artifacts_path = self._ARTIFACTS_PATH_DEFAULT
saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path})
artifacts_path = self._ARTIFACTS_PATH_DEFAULT
saver.update_config({self._ARTIFACTS_PATH_KEY: artifacts_path})

# Save the model
self.model.save(
os.path.join(model_config_path, artifacts_path), create_model_card=True
)
self.model.save(os.path.join(model_config_path, artifacts_path))

# Save the config
ModuleConfig(saver.config).save(model_config_path)

0 comments on commit 2b8b134

Please sign in to comment.