Skip to content

Commit

Permalink
Improve queries for pipelines, run templates, models and artifacts (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi authored Feb 5, 2025
1 parent 46d4628 commit dff5ba8
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 33 deletions.
35 changes: 31 additions & 4 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from pydantic import ValidationError
from sqlalchemy import TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship
from sqlalchemy.orm import object_session
from sqlmodel import Field, Relationship, desc, select

from zenml.config.source import Source
from zenml.enums import (
Expand Down Expand Up @@ -90,6 +91,32 @@ class ArtifactSchema(NamedSchema, table=True):
),
)

@property
def latest_version(self) -> Optional["ArtifactVersionSchema"]:
"""Fetch the latest version for this artifact.
Raises:
RuntimeError: If no session for the schema exists.
Returns:
The latest version for this artifact.
"""
if session := object_session(self):
return (
session.execute(
select(ArtifactVersionSchema)
.where(ArtifactVersionSchema.artifact_id == self.id)
.order_by(desc(ArtifactVersionSchema.created))
.limit(1)
)
.scalars()
.one_or_none()
)
else:
raise RuntimeError(
"Missing DB session to fetch latest version for artifact."
)

@classmethod
def from_request(
cls,
Expand Down Expand Up @@ -127,9 +154,9 @@ def to_model(
The created `ArtifactResponse`.
"""
latest_id, latest_name = None, None
if self.versions:
latest_version = max(self.versions, key=lambda x: x.created)
latest_id, latest_name = latest_version.id, latest_version.version
if latest_version := self.latest_version:
latest_id = latest_version.id
latest_name = latest_version.version

# Create the body of the model
body = ArtifactResponseBody(
Expand Down
37 changes: 31 additions & 6 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
Column,
UniqueConstraint,
)
from sqlmodel import Field, Relationship
from sqlalchemy.orm import object_session
from sqlmodel import Field, Relationship, desc, select

from zenml.enums import (
ArtifactType,
Expand Down Expand Up @@ -126,6 +127,32 @@ class ModelSchema(NamedSchema, table=True):
sa_relationship_kwargs={"cascade": "delete"},
)

@property
def latest_version(self) -> Optional["ModelVersionSchema"]:
"""Fetch the latest version for this model.
Raises:
RuntimeError: If no session for the schema exists.
Returns:
The latest version for this model.
"""
if session := object_session(self):
return (
session.execute(
select(ModelVersionSchema)
.where(ModelVersionSchema.model_id == self.id)
.order_by(desc(ModelVersionSchema.number))
.limit(1)
)
.scalars()
.one_or_none()
)
else:
raise RuntimeError(
"Missing DB session to fetch latest version for model."
)

@classmethod
def from_request(cls, model_request: ModelRequest) -> "ModelSchema":
"""Convert an `ModelRequest` to an `ModelSchema`.
Expand Down Expand Up @@ -169,11 +196,9 @@ def to_model(
"""
tags = [tag.to_model() for tag in self.tags]

if self.model_versions:
version_numbers = [mv.number for mv in self.model_versions]
latest_version_idx = version_numbers.index(max(version_numbers))
latest_version_name = self.model_versions[latest_version_idx].name
latest_version_id = self.model_versions[latest_version_idx].id
if latest_version := self.latest_version:
latest_version_name = latest_version.name
latest_version_id = latest_version.id
else:
latest_version_name = None
latest_version_id = None
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
stack: Optional["StackSchema"] = Relationship()
build: Optional["PipelineBuildSchema"] = Relationship()
schedule: Optional["ScheduleSchema"] = Relationship()
pipeline: Optional["PipelineSchema"] = Relationship(back_populates="runs")
pipeline: Optional["PipelineSchema"] = Relationship()
trigger_execution: Optional["TriggerExecutionSchema"] = Relationship()

services: List["ServiceSchema"] = Relationship(
Expand Down
43 changes: 35 additions & 8 deletions src/zenml/zen_stores/schemas/pipeline_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from uuid import UUID

from sqlalchemy import TEXT, Column, UniqueConstraint
from sqlmodel import Field, Relationship
from sqlalchemy.orm import object_session
from sqlmodel import Field, Relationship, desc, select

from zenml.enums import TaggableResourceTypes
from zenml.models import (
Expand Down Expand Up @@ -85,10 +86,6 @@ class PipelineSchema(NamedSchema, table=True):
schedules: List["ScheduleSchema"] = Relationship(
back_populates="pipeline",
)
runs: List["PipelineRunSchema"] = Relationship(
back_populates="pipeline",
sa_relationship_kwargs={"order_by": "PipelineRunSchema.created"},
)
builds: List["PipelineBuildSchema"] = Relationship(
back_populates="pipeline"
)
Expand All @@ -105,6 +102,34 @@ class PipelineSchema(NamedSchema, table=True):
),
)

@property
def latest_run(self) -> Optional["PipelineRunSchema"]:
"""Fetch the latest run for this pipeline.
Raises:
RuntimeError: If no session for the schema exists.
Returns:
The latest run for this pipeline.
"""
from zenml.zen_stores.schemas import PipelineRunSchema

if session := object_session(self):
return (
session.execute(
select(PipelineRunSchema)
.where(PipelineRunSchema.pipeline_id == self.id)
.order_by(desc(PipelineRunSchema.created))
.limit(1)
)
.scalars()
.one_or_none()
)
else:
raise RuntimeError(
"Missing DB session to fetch latest run for pipeline."
)

@classmethod
def from_request(
cls,
Expand Down Expand Up @@ -141,10 +166,12 @@ def to_model(
Returns:
The created PipelineResponse.
"""
latest_run = self.latest_run

body = PipelineResponseBody(
user=self.user.to_model() if self.user else None,
latest_run_id=self.runs[-1].id if self.runs else None,
latest_run_status=self.runs[-1].status if self.runs else None,
latest_run_id=latest_run.id if latest_run else None,
latest_run_status=latest_run.status if latest_run else None,
created=self.created,
updated=self.updated,
)
Expand All @@ -158,7 +185,7 @@ def to_model(

resources = None
if include_resources:
latest_run_user = self.runs[-1].user if self.runs else None
latest_run_user = latest_run.user if latest_run else None

resources = PipelineResponseResources(
latest_run_user=latest_run_user.to_model()
Expand Down
56 changes: 42 additions & 14 deletions src/zenml/zen_stores/schemas/run_template_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

from sqlalchemy import Column, String, UniqueConstraint
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlmodel import Field, Relationship
from sqlalchemy.orm import object_session
from sqlmodel import Field, Relationship, col, desc, select

from zenml.constants import MEDIUMTEXT_MAX_LENGTH
from zenml.enums import TaggableResourceTypes
Expand Down Expand Up @@ -99,17 +100,6 @@ class RunTemplateSchema(BaseSchema, table=True):
}
)

runs: List["PipelineRunSchema"] = Relationship(
sa_relationship_kwargs={
"primaryjoin": "RunTemplateSchema.id==PipelineDeploymentSchema.template_id",
"secondaryjoin": "PipelineDeploymentSchema.id==PipelineRunSchema.deployment_id",
"secondary": "pipeline_deployment",
"cascade": "delete",
"viewonly": True,
"order_by": "PipelineRunSchema.created",
}
)

tags: List["TagSchema"] = Relationship(
sa_relationship_kwargs=dict(
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)",
Expand All @@ -120,6 +110,42 @@ class RunTemplateSchema(BaseSchema, table=True):
),
)

@property
def latest_run(self) -> Optional["PipelineRunSchema"]:
"""Fetch the latest run for this template.
Raises:
RuntimeError: If no session for the schema exists.
Returns:
The latest run for this template.
"""
from zenml.zen_stores.schemas import (
PipelineDeploymentSchema,
PipelineRunSchema,
)

if session := object_session(self):
return (
session.execute(
select(PipelineRunSchema)
.join(
PipelineDeploymentSchema,
col(PipelineDeploymentSchema.id)
== col(PipelineRunSchema.deployment_id),
)
.where(PipelineDeploymentSchema.template_id == self.id)
.order_by(desc(PipelineRunSchema.created))
.limit(1)
)
.scalars()
.one_or_none()
)
else:
raise RuntimeError(
"Missing DB session to fetch latest run for template."
)

@classmethod
def from_request(
cls,
Expand Down Expand Up @@ -184,13 +210,15 @@ def to_model(
):
runnable = True

latest_run = self.latest_run

body = RunTemplateResponseBody(
user=self.user.to_model() if self.user else None,
created=self.created,
updated=self.updated,
runnable=runnable,
latest_run_id=self.runs[-1].id if self.runs else None,
latest_run_status=self.runs[-1].status if self.runs else None,
latest_run_id=latest_run.id if latest_run else None,
latest_run_status=latest_run.status if latest_run else None,
)

metadata = None
Expand Down

0 comments on commit dff5ba8

Please sign in to comment.