Skip to content

Commit

Permalink
Merge pull request #204 from cvxgrp/203-address-nans
Browse files Browse the repository at this point in the history
input data
  • Loading branch information
tschm authored Nov 20, 2023
2 parents 471d52b + acbe45b commit 8088713
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 27 deletions.
22 changes: 15 additions & 7 deletions cvx/simulator/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ def builder(
min_cap_fraction: float | None = None,
max_trade_fraction: float | None = None,
min_trade_fraction: float | None = None,
input_data: dict[str, Any] = field(default_factory=dict),
**kwargs,
# input_data: dict[str, Any] = field(default_factory=dict),
# **kwargs,
) -> _Builder:
"""The builder function creates an instance of the _Builder class, which
is used to construct a portfolio of assets. The function takes in a pandas
Expand All @@ -187,9 +188,12 @@ def builder(
index=prices.index, columns=prices.columns, data=0.0, dtype=float
)

# print(input_data)
# if input_data is None:

builder = _Builder(
stocks=stocks,
prices=prices.ffill(),
prices=prices,
initial_cash=float(initial_cash),
trading_cost_model=trading_cost_model,
market_cap=market_cap,
Expand All @@ -198,8 +202,7 @@ def builder(
min_cap_fraction=min_cap_fraction,
max_trade_fraction=max_trade_fraction,
min_trade_fraction=min_trade_fraction,
input_data=input_data,
parameter=dict(kwargs),
input_data=dict(kwargs),
)

if weights is not None:
Expand All @@ -209,6 +212,10 @@ def builder(
return builder


def empty():
return dict()


@dataclass(frozen=True)
class _Builder:
prices: pd.DataFrame
Expand All @@ -222,7 +229,7 @@ class _Builder:
min_cap_fraction: float | None = None
max_trade_fraction: float | None = None
min_trade_fraction: float | None = None
input_data: dict[str, Any] = field(default_factory=dict)
input_data: dict[str, Any] = field(default_factory=empty)
parameter: dict[str, Any] = field(default_factory=dict)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -332,9 +339,10 @@ def __iter__(self) -> Generator[tuple[pd.DatetimeIndex, _State], None, None]:
for t in self.index:
# valuation of the current position
self._state.prices = self.prices.loc[t]
self._state.input_data = {key: data.loc[t] for key, data in self.input_data.items()}


self._state.input_data = {
key: data.loc[t] for key, data in self.input_data.items()
}

yield self.index[self.index <= t], self._state

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def prices(resource_dir):
"""prices fixture"""
return pd.read_csv(
resource_dir / "price.csv", index_col=0, parse_dates=True, header=0
)
).ffill()


@pytest.fixture()
Expand Down
20 changes: 1 addition & 19 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,25 +271,7 @@ def test_cov(prices):
assert np.all(np.isfinite(mat))


def test_parameter(prices):
b = _builder(
prices=prices,
initial_cash=50000,
a=2,
b=3,
c="wurst",
d=[1, 2, 3],
e={"a": 1, "b": 2},
)
assert b.parameter["a"] == 2
assert b.parameter["b"] == 3
assert b.parameter["c"] == "wurst"
assert b.parameter["d"] == [1, 2, 3]
assert b.parameter["e"] == {"a": 1, "b": 2}


def test_input_data(prices):
b = _builder(prices=prices, initial_cash=50000, a=2, input_data={"volume": prices.ffill()})
b = _builder(prices=prices, initial_cash=50000, volume=prices.ffill())
for t, state in b:
assert b.parameter["a"] == 2
pd.testing.assert_series_equal(state.prices, state.input_data["volume"])

0 comments on commit 8088713

Please sign in to comment.