Skip to content

Commit

Permalink
NaNs in prices
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Schmelzer committed Nov 26, 2023
1 parent 6c5a7f1 commit 30e758e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
26 changes: 24 additions & 2 deletions cvx/simulator/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,12 @@ def set_weights(self, time: datetime, weights: pd.Series) -> None:
:param weights: series of weights
"""
assert isinstance(weights, pd.Series), "weights must be a pandas Series"
self[time] = (self._state.nav * weights) / self._state.prices
valid = self._state.prices.dropna().index
# check that you have weights exactly for those indices
if not set(weights.dropna().index) == set(valid):
raise ValueError("weights must have same index as prices")

self[time] = (self._state.nav * weights[valid]) / self._state.prices[valid]

def set_cashposition(self, time: datetime, cashposition: pd.Series) -> None:
"""
Expand All @@ -305,7 +310,13 @@ def set_cashposition(self, time: datetime, cashposition: pd.Series) -> None:
assert isinstance(
cashposition, pd.Series
), "cashposition must be a pandas Series"
self[time] = cashposition / self._state.prices

valid = self._state.prices.dropna().index
# check that you have weights exactly for those indices
if not set(cashposition.dropna().index) == set(valid):
raise ValueError("cashposition must have same index as prices")

self[time] = cashposition[valid] / self._state.prices[valid]

def set_position(self, time: datetime, position: pd.Series) -> None:
"""
Expand All @@ -315,6 +326,12 @@ def set_position(self, time: datetime, position: pd.Series) -> None:
:param position: series of number of stocks
"""
assert isinstance(position, pd.Series), "position must be a pandas Series"

valid = self._state.prices.dropna().index
# check that you have weights exactly for those indices
if not set(position.dropna().index) == set(valid):
raise ValueError("position must have same index as prices")

self[time] = position

def __iter__(self) -> Generator[tuple[pd.DatetimeIndex, _State], None, None]:
Expand Down Expand Up @@ -365,6 +382,11 @@ def __setitem__(self, time: datetime, position: pd.Series) -> None:
assert isinstance(position, pd.Series)
assert set(position.index).issubset(set(self.assets))

valid = self._state.prices.dropna().index
# check that you have weights exactly for those indices
if not set(position.dropna().index) == set(valid):
raise ValueError("position must have same index as prices")

if self.market_cap is not None:
# compute capitalization of desired position
cap = position * self._state.prices
Expand Down
6 changes: 6 additions & 0 deletions tests/resources/priceNaN.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
,A,B,C,D,E,F,G
2013-01-01,1673.78,23311.98,,3735.12,1462.42,2711.25,2518.99
2013-01-02,1686.9,23311.98,,3735.12,1462.42,2711.25,2518.99
2013-01-03,1663.95,23398.6,63312.46,3714.99,1459.37,,2509.51
2013-01-04,1655.65,23331.09,62523.06,3689.34,1466.47,,2516.81
2013-01-07,1646.95,23329.75,61932.54,3699.14,1461.89,,2523.77
32 changes: 32 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,35 @@ def test_input_data(prices):
for t, state in b:
print(state.volume)
pd.testing.assert_series_equal(state.prices, state.input_data["volume"])


def test_weights_on_wrong_days(resource_dir):
prices = pd.read_csv(
resource_dir / "priceNaN.csv", index_col=0, parse_dates=True, header=0
)

b = _builder(prices=prices, initial_cash=50000)
t = prices.index

for t, state in b:
with pytest.raises(ValueError):
b.set_weights(
t[-1], pd.Series(index={"A", "B", "C"}, data=[0.5, 0.25, 0.25])
)

with pytest.raises(ValueError):
b.set_cashposition(t[-1], pd.Series(index={"A", "B", "C"}, data=[5, 5, 5]))

with pytest.raises(ValueError):
b.set_position(t[-1], pd.Series(index={"A", "B", "C"}, data=[5, 5, 5]))

for t, state in b:
b.set_weights(
t[-1],
pd.Series(
index=prices.loc[t[-1]].dropna().index,
data=np.random.rand(6),
dtype=float,
),
)
# assert False

0 comments on commit 30e758e

Please sign in to comment.