Skip to content

Commit

Permalink
Minor fix for the Artifact filter model (#3334)
Browse files Browse the repository at this point in the history
* fixing the artifact filter model

* creating a new taggable filter

* fixed import

* imports

* fixed the artifacts

* Update src/zenml/models/v2/base/scoped.py

Co-authored-by: Michael Schuster <schustmi@users.noreply.github.com>

---------

Co-authored-by: Michael Schuster <schustmi@users.noreply.github.com>
  • Loading branch information
bcdurak and schustmi authored Feb 5, 2025
1 parent 64e12ba commit b8a5a8a
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 30 deletions.
6 changes: 4 additions & 2 deletions src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BaseDatedResponseBody,
)
from zenml.models.v2.base.scoped import (
TaggableFilter,
UserScopedRequest,
UserScopedFilter,
UserScopedResponse,
Expand All @@ -39,7 +40,7 @@
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
WorkspaceScopedFilter,
)
from zenml.models.v2.base.filter import (
BaseFilter,
Expand Down Expand Up @@ -497,12 +498,13 @@
"WorkspaceScopedResponseBody",
"WorkspaceScopedResponseMetadata",
"WorkspaceScopedResponseResources",
"WorkspaceScopedTaggableFilter",
"WorkspaceScopedFilter",
"BaseFilter",
"StrFilter",
"BoolFilter",
"NumericFilter",
"UUIDFilter",
"TaggableFilter",
"Page",
# V2 Core
"ActionFilter",
Expand Down
8 changes: 4 additions & 4 deletions src/zenml/models/v2/base/scoped.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,19 +464,19 @@ def apply_sorting(
return super().apply_sorting(query=query, table=table)


class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter):
"""Model to enable advanced scoping with workspace and tagging."""
class TaggableFilter(BaseFilter):
"""Model to enable filtering and sorting by tags."""

tag: Optional[str] = Field(
description="Tag to apply to the filter query.", default=None
)

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*BaseFilter.FILTER_EXCLUDE_FIELDS,
"tag",
]
CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*BaseFilter.CUSTOM_SORTING_OPTIONS,
"tags",
]

Expand Down
6 changes: 3 additions & 3 deletions src/zenml/models/v2/core/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
BaseResponseMetadata,
BaseResponseResources,
)
from zenml.models.v2.base.scoped import WorkspaceScopedTaggableFilter
from zenml.models.v2.base.scoped import TaggableFilter
from zenml.models.v2.core.tag import TagResponse

if TYPE_CHECKING:
Expand Down Expand Up @@ -183,14 +183,14 @@ def versions(self) -> Dict[str, "ArtifactVersionResponse"]:
# ------------------ Filter Model ------------------


class ArtifactFilter(WorkspaceScopedTaggableFilter):
class ArtifactFilter(TaggableFilter):
"""Model to enable advanced filtering of artifacts."""

name: Optional[str] = None
has_custom_name: Optional[bool] = None

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]

Expand Down
17 changes: 14 additions & 3 deletions src/zenml/models/v2/core/artifact_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@
from zenml.metadata.metadata_types import MetadataType
from zenml.models.v2.base.filter import FilterGenerator, StrFilter
from zenml.models.v2.base.scoped import (
TaggableFilter,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
)
from zenml.models.v2.core.artifact import ArtifactResponse
from zenml.models.v2.core.tag import TagResponse
Expand Down Expand Up @@ -469,11 +470,12 @@ def visualize(self, title: Optional[str] = None) -> None:
# ------------------ Filter Model ------------------


class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
class ArtifactVersionFilter(WorkspaceScopedFilter, TaggableFilter):
"""Model to enable advanced filtering of artifact versions."""

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
"name",
"only_unused",
"has_custom_name",
Expand All @@ -482,6 +484,15 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter):
"model_version_id",
"run_metadata",
]
CUSTOM_SORTING_OPTIONS = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
*TaggableFilter.CLI_EXCLUDE_FIELDS,
]

artifact_id: Optional[Union[UUID, str]] = Field(
default=None,
description="ID of the artifact to which this version belongs.",
Expand Down
16 changes: 13 additions & 3 deletions src/zenml/models/v2/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
TEXT_FIELD_MAX_LENGTH,
)
from zenml.models.v2.base.scoped import (
TaggableFilter,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
)
from zenml.utils.pagination_utils import depaginate

Expand Down Expand Up @@ -322,18 +323,27 @@ def versions(self) -> List["Model"]:
# ------------------ Filter Model ------------------


class ModelFilter(WorkspaceScopedTaggableFilter):
class ModelFilter(WorkspaceScopedFilter, TaggableFilter):
"""Model to enable advanced filtering of all Workspaces."""

name: Optional[str] = Field(
default=None,
description="Name of the Model",
)

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
]
CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_BY_LATEST_VERSION_KEY,
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
*TaggableFilter.CLI_EXCLUDE_FIELDS,
]

def apply_sorting(
self,
Expand Down
16 changes: 13 additions & 3 deletions src/zenml/models/v2/core/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
from zenml.models.v2.base.filter import AnyQuery
from zenml.models.v2.base.page import Page
from zenml.models.v2.base.scoped import (
TaggableFilter,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
)
from zenml.models.v2.core.service import ServiceResponse
from zenml.models.v2.core.tag import TagResponse
Expand Down Expand Up @@ -576,13 +577,22 @@ def set_stage(
# ------------------ Filter Model ------------------


class ModelVersionFilter(WorkspaceScopedTaggableFilter):
class ModelVersionFilter(WorkspaceScopedFilter, TaggableFilter):
"""Filter model for model versions."""

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
"run_metadata",
]
CUSTOM_SORTING_OPTIONS = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
*TaggableFilter.CLI_EXCLUDE_FIELDS,
]

name: Optional[str] = Field(
default=None,
Expand Down
15 changes: 11 additions & 4 deletions src/zenml/models/v2/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
from zenml.enums import ExecutionStatus
from zenml.models.v2.base.base import BaseUpdate
from zenml.models.v2.base.scoped import (
TaggableFilter,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
)
from zenml.models.v2.core.tag import TagResponse

Expand Down Expand Up @@ -256,17 +257,23 @@ def tags(self) -> List[TagResponse]:
# ------------------ Filter Model ------------------


class PipelineFilter(WorkspaceScopedTaggableFilter):
class PipelineFilter(WorkspaceScopedFilter, TaggableFilter):
"""Pipeline filter model."""

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
SORT_PIPELINES_BY_LATEST_RUN_KEY,
]
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
"latest_run_status",
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
*TaggableFilter.CLI_EXCLUDE_FIELDS,
]

name: Optional[str] = Field(
default=None,
Expand Down
17 changes: 12 additions & 5 deletions src/zenml/models/v2/core/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@
from zenml.enums import ExecutionStatus
from zenml.metadata.metadata_types import MetadataType
from zenml.models.v2.base.scoped import (
TaggableFilter,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
)
from zenml.models.v2.core.model_version import ModelVersionResponse
from zenml.models.v2.core.tag import TagResponse
Expand Down Expand Up @@ -589,20 +590,21 @@ def tags(self) -> List[TagResponse]:
# ------------------ Filter Model ------------------


class PipelineRunFilter(WorkspaceScopedTaggableFilter):
class PipelineRunFilter(WorkspaceScopedFilter, TaggableFilter):
"""Model to enable advanced filtering of all Workspaces."""

CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS,
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
"tag",
"stack",
"pipeline",
"model",
"model_version",
]

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
"unlisted",
"code_repository_id",
"build_id",
Expand All @@ -618,6 +620,11 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter):
"templatable",
"run_metadata",
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
*TaggableFilter.CLI_EXCLUDE_FIELDS,
]

name: Optional[str] = Field(
default=None,
description="Name of the Pipeline Run",
Expand Down
16 changes: 13 additions & 3 deletions src/zenml/models/v2/core/run_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
from zenml.enums import ExecutionStatus
from zenml.models.v2.base.base import BaseUpdate
from zenml.models.v2.base.scoped import (
TaggableFilter,
WorkspaceScopedFilter,
WorkspaceScopedRequest,
WorkspaceScopedResponse,
WorkspaceScopedResponseBody,
WorkspaceScopedResponseMetadata,
WorkspaceScopedResponseResources,
WorkspaceScopedTaggableFilter,
)
from zenml.models.v2.core.code_reference import (
CodeReferenceResponse,
Expand Down Expand Up @@ -307,11 +308,12 @@ def tags(self) -> List[TagResponse]:
# ------------------ Filter Model ------------------


class RunTemplateFilter(WorkspaceScopedTaggableFilter):
class RunTemplateFilter(WorkspaceScopedFilter, TaggableFilter):
"""Model for filtering of run templates."""

FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
*WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS,
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
*TaggableFilter.FILTER_EXCLUDE_FIELDS,
"code_repository_id",
"stack_id",
"build_id",
Expand All @@ -320,6 +322,14 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter):
"pipeline",
"stack",
]
CUSTOM_SORTING_OPTIONS = [
*WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS,
*TaggableFilter.CUSTOM_SORTING_OPTIONS,
]
CLI_EXCLUDE_FIELDS = [
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
*TaggableFilter.CLI_EXCLUDE_FIELDS,
]

name: Optional[str] = Field(
default=None,
Expand Down

0 comments on commit b8a5a8a

Please sign in to comment.