Skip to content

Commit

Permalink
Update test_st_trace.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TaiSakuma committed Dec 5, 2024
1 parent 7d00568 commit eb64885
Showing 1 changed file with 64 additions and 32 deletions.
96 changes: 64 additions & 32 deletions src/nextline_rdb/models/strategies/tests/test_st_trace.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,71 @@
from typing import Optional, TypedDict

from hypothesis import Phase, given, settings
from hypothesis import strategies as st
from sqlalchemy import select

from nextline_rdb.db import DB
from nextline_test_utils import safe_compare as sc
from nextline_test_utils.strategies import st_none_or
from nextline_test_utils.strategies.misc import st_graphql_ints
from nextline_test_utils.strategies import st_graphql_ints, st_none_or

from ... import Model, Trace
from ... import Model, Run
from .. import st_model_run, st_model_trace, st_thread_task_no
from .funcs import assert_model_persistence


class StModelTraceKwargs(TypedDict, total=False):
run: Optional[Run]
trace_no: Optional[int]
thread_task_no: Optional[tuple[int, int | None]]
generate_trace_calls: bool
generate_prompts: bool


@st.composite
def st_st_model_trace_kwargs(draw: st.DrawFn) -> StModelTraceKwargs:
kwargs = StModelTraceKwargs()

if draw(st.booleans()):
kwargs['run'] = draw(st_none_or(st_model_run(generate_traces=False)))

if draw(st.booleans()):
kwargs['trace_no'] = draw(st_none_or(st_graphql_ints(min_value=1)))

if draw(st.booleans()):
kwargs['thread_task_no'] = draw(st_none_or(st_thread_task_no()))

if kwargs.get('run') is None and draw(st.booleans()):
kwargs['generate_trace_calls'] = draw(st.booleans())

if kwargs.get('generate_trace_calls', False) and draw(st.booleans()):
kwargs['generate_prompts'] = draw(st.booleans())

return kwargs


@given(kwargs=st_st_model_trace_kwargs())
def test_st_model_trace_kwargs(kwargs: StModelTraceKwargs) -> None:
if kwargs.get('run') is not None:
assert not kwargs.get('generate_trace_calls')
assert not kwargs.get('generate_prompts')

if not kwargs.get('generate_trace_calls', False):
assert not kwargs.get('generate_prompts')


@settings(phases=(Phase.generate,)) # Avoid shrinking
@given(st.data())
async def test_options(data: st.DataObject) -> None:
run = data.draw(st_none_or(st_model_run(generate_traces=False)))
trace_no = data.draw(st_none_or(st_graphql_ints(min_value=1)))
thread_task_no = data.draw(st_none_or(st_thread_task_no()))
generate_trace_calls = False if run else data.draw(st.booleans())
generate_prompts = (
False if run or not generate_trace_calls else data.draw(st.booleans())
)

trace = data.draw(
st_model_trace(
run=run,
trace_no=trace_no,
thread_task_no=thread_task_no,
generate_trace_calls=generate_trace_calls,
generate_prompts=generate_prompts,
)
)
@given(data=st.data())
def test_options(data: st.DataObject) -> None:
# Generate options of the strategy to be tested
kwargs = data.draw(st_st_model_trace_kwargs())

# Call the strategy to be tested
trace = data.draw(st_model_trace(**kwargs))

# Assert the generated values
run = kwargs.get('run')
trace_no = kwargs.get('trace_no')
thread_task_no = kwargs.get('thread_task_no')
generate_trace_calls = kwargs.get('generate_trace_calls', False)
generate_prompts = kwargs.get('generate_prompts', False)

assert run is None or run is trace.run
run = trace.run
Expand All @@ -45,11 +80,8 @@ async def test_options(data: st.DataObject) -> None:
assert not generate_trace_calls or trace.trace_calls
assert not generate_prompts or trace.prompts

async with DB(use_migration=False, model_base_class=Model) as db:
async with db.session.begin() as session:
session.add(trace)
async with db.session() as session:
stmt = select(Trace)
trace_ = await session.scalar(stmt)
session.expunge_all()
assert repr(trace) == repr(trace_)

@settings(phases=(Phase.generate,)) # Avoid shrinking
@given(instance=st_model_trace())
async def test_db(instance: Model) -> None:
await assert_model_persistence([instance])

0 comments on commit eb64885

Please sign in to comment.