Skip to content

Commit

Permalink
fixing the filtering bug and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
bcdurak committed Feb 13, 2025
1 parent cf7cbca commit c19750a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
4 changes: 1 addition & 3 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7550,9 +7550,7 @@ def create_tag(
if color is not None:
request_model.color = color

return self.zen_store.create_tag(
tag=request_model
)
return self.zen_store.create_tag(tag=request_model)

def delete_tag(self, tag_name_or_id: Union[str, UUID]) -> None:
"""Deletes a tag.
Expand Down
14 changes: 10 additions & 4 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,15 +554,21 @@ def get_custom_filters(
custom_filters = super().get_custom_filters(table)

if self.tags is not None:
from sqlmodel import exists
from sqlmodel import exists, select

from zenml.zen_stores.schemas import TagSchema
from zenml.zen_stores.schemas import TagResourceSchema, TagSchema

for tag in self.tags:
conditions = self.generate_custom_query_conditions_for_column(
condition = self.generate_custom_query_conditions_for_column(
value=tag, table=TagSchema, column="name"
)
exists_subquery = exists().where(conditions)
exists_subquery = exists(
select(TagResourceSchema)
.join(TagSchema, TagSchema.id == TagResourceSchema.tag_id)
.where(
TagResourceSchema.resource_id == table.id, condition
)
)
custom_filters.append(exists_subquery)

return custom_filters
Expand Down
1 change: 0 additions & 1 deletion src/zenml/zen_stores/schemas/tag_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# permissions and limitations under the License.
"""SQLModel implementation of tag tables."""

import random
from typing import Any, List
from uuid import UUID

Expand Down
11 changes: 8 additions & 3 deletions src/zenml/zen_stores/sql_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10617,7 +10617,7 @@ def _create_model_version(
self._attach_tags_to_resource(
tag_names=model_version.tags,
resource=model_version_schema,
resource_type=TaggableResourceTypes.MODEL_VERSION
resource_type=TaggableResourceTypes.MODEL_VERSION,
)

return self.get_model_version(model_version_id)
Expand Down Expand Up @@ -11105,7 +11105,7 @@ def _attach_tags_to_resource(
self,
tag_names: List[str],
resource: AnySchema,
resource_type: TaggableResourceTypes
resource_type: TaggableResourceTypes,
) -> None:
"""Creates a tag<>resource link if not present.
Expand All @@ -11126,15 +11126,19 @@ def _attach_tags_to_resource(
if resource_type == TaggableResourceTypes.PIPELINE_RUN:
older_runs = self.list_runs(
PipelineRunFilter(
id=f"notequals:{resource.id}",
pipeline_id=resource.pipeline_id,
tags=[tag.name],
)
)
if older_runs.items:
detach_resource_id = older_runs.items[0].id
elif resource_type == TaggableResourceTypes.ARTIFACT_VERSION:
elif (
resource_type == TaggableResourceTypes.ARTIFACT_VERSION
):
older_versions = self.list_artifact_versions(
ArtifactVersionFilter(
id=f"notequals:{resource.id}",
artifact_id=resource.artifact_id,
tags=[tag.name],
)
Expand All @@ -11144,6 +11148,7 @@ def _attach_tags_to_resource(
elif resource_type == TaggableResourceTypes.RUN_TEMPLATE:
older_templates = self.list_run_templates(
RunTemplateFilter(
id=f"notequals:{resource.id}",
tags=[tag.name],
)
)
Expand Down

0 comments on commit c19750a

Please sign in to comment.