From 40bc46124cd65b054393dbf12e9abcebd267864c Mon Sep 17 00:00:00 2001 From: Maggie Hei Date: Thu, 18 Mar 2021 16:08:01 -0400 Subject: [PATCH 1/3] fix the deprecated args in ForestDRLearner and add strong condition on search params from dowhy wrapper --- econml/dowhy.py | 45 +++++++++++++++++++++++++------------- econml/dr/_drlearner.py | 4 ++-- econml/tests/test_dowhy.py | 3 ++- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/econml/dowhy.py b/econml/dowhy.py index a7c112530..d6171e631 100644 --- a/econml/dowhy.py +++ b/econml/dowhy.py @@ -38,14 +38,18 @@ def _get_params(self): # to represent init_signature = inspect.signature(init) parameters = init_signature.parameters.values() + params = [] for p in parameters: if p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD: raise RuntimeError("cate estimators should always specify their parameters in the signature " "of their __init__ (no varargs, no varkwargs). " f"{self._cate_estimator} with constructor {init_signature} doesn't " "follow this convention.") + # if the argument is deprecated, ignore it + if p.default != "deprecated": + params.append(p.name) # Extract and sort argument names excluding 'self' - return sorted([p.name for p in parameters]) + return sorted(params) def fit(self, Y, T, X=None, W=None, Z=None, *, outcome_names=None, treatment_names=None, feature_names=None, confounder_names=None, instrument_names=None, graph=None, estimand_type="nonparametric-ate", @@ -106,30 +110,41 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, outcome_names=None, treatment_nam ------- self """ - - Y, T, X, W, Z = check_input_arrays(Y, T, X, W, Z) - - # create dataframe - n_obs = Y.shape[0] - Y, T, X, W, Z = reshape_arrays_2dim(n_obs, Y, T, X, W, Z) - - # currently dowhy only support single outcome and single treatment - assert Y.shape[1] == 1, "Can only accept single dimensional outcome." - assert T.shape[1] == 1, "Can only accept single dimensional treatment." - # column names if outcome_names is None: outcome_names = get_input_columns(Y, prefix="Y") if treatment_names is None: treatment_names = get_input_columns(T, prefix="T") if feature_names is None: - feature_names = get_input_columns(X, prefix="X") + if X is not None: + feature_names = get_input_columns(X, prefix="X") + else: + feature_names = [] if confounder_names is None: - confounder_names = get_input_columns(W, prefix="W") + if W is not None: + confounder_names = get_input_columns(W, prefix="W") + else: + confounder_names = [] if instrument_names is None: - instrument_names = get_input_columns(Z, prefix="Z") + if Z is not None: + instrument_names = get_input_columns(Z, prefix="Z") + else: + instrument_names = [] column_names = outcome_names + treatment_names + feature_names + confounder_names + instrument_names + + # transfer input to numpy arrays + Y, T, X, W, Z = check_input_arrays(Y, T, X, W, Z) + # transfer input to 2d arrays + n_obs = Y.shape[0] + Y, T, X, W, Z = reshape_arrays_2dim(n_obs, Y, T, X, W, Z) + # create dataframe df = pd.DataFrame(np.hstack((Y, T, X, W, Z)), columns=column_names) + + # currently dowhy only support single outcome and single treatment + assert Y.shape[1] == 1, "Can only accept single dimensional outcome." + assert T.shape[1] == 1, "Can only accept single dimensional treatment." + + # call dowhy self.dowhy_ = CausalModel( data=df, treatment=treatment_names, diff --git a/econml/dr/_drlearner.py b/econml/dr/_drlearner.py index 30435e652..73637a2cc 100644 --- a/econml/dr/_drlearner.py +++ b/econml/dr/_drlearner.py @@ -1523,7 +1523,7 @@ def n_crossfit_splits(self, value): @property def criterion(self): - return self.criterion + return "mse" @criterion.setter def criterion(self, value): @@ -1533,7 +1533,7 @@ def criterion(self, value): @property def max_leaf_nodes(self): - return self.max_leaf_nodes + return None @max_leaf_nodes.setter def max_leaf_nodes(self, value): diff --git a/econml/tests/test_dowhy.py b/econml/tests/test_dowhy.py index f007e4b0f..56b780a69 100644 --- a/econml/tests/test_dowhy.py +++ b/econml/tests/test_dowhy.py @@ -5,7 +5,7 @@ import unittest from econml.dml import LinearDML, CausalForestDML from econml.orf import DROrthoForest -from econml.dr import DRLearner +from econml.dr import DRLearner, ForestDRLearner from econml.metalearners import XLearner from econml.iv.dml import DMLATEIV from sklearn.linear_model import LinearRegression, LogisticRegression, Lasso @@ -33,6 +33,7 @@ def clf(): linear_first_stages=False), "dr": DRLearner(model_propensity=clf(), model_regression=reg(), model_final=reg()), + "forestdr": ForestDRLearner(model_propensity=clf(), model_regression=reg()), "xlearner": XLearner(models=reg(), cate_models=reg(), propensity_model=clf()), "cfdml": CausalForestDML(model_y=reg(), model_t=clf(), discrete_treatment=True), "orf": DROrthoForest(n_trees=10, propensity_model=clf(), model_Y=reg()), From 62c065c2963f3728032bf10e8214906c92619de0 Mon Sep 17 00:00:00 2001 From: Maggie Hei Date: Fri, 19 Mar 2021 11:54:20 -0400 Subject: [PATCH 2/3] add tests --- econml/tests/test_dowhy.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/econml/tests/test_dowhy.py b/econml/tests/test_dowhy.py index 56b780a69..7ba75c96f 100644 --- a/econml/tests/test_dowhy.py +++ b/econml/tests/test_dowhy.py @@ -2,10 +2,11 @@ # Licensed under the MIT License. import numpy as np +import pandas as pd import unittest from econml.dml import LinearDML, CausalForestDML from econml.orf import DROrthoForest -from econml.dr import DRLearner, ForestDRLearner +from econml.dr import DRLearner, ForestDRLearner, LinearDRLearner from econml.metalearners import XLearner from econml.iv.dml import DMLATEIV from sklearn.linear_model import LinearRegression, LogisticRegression, Lasso @@ -65,3 +66,19 @@ def clf(): num_simulations=3) est_dowhy.refute_estimate(method_name="data_subset_refuter", subset_fraction=0.8, num_simulations=3) + + def test_store_dataframe_name(self): + Y, T, X, W, Z = self._get_data() + Y_name = "outcome" + Y = pd.Series(Y, name=Y_name) + T_name = "treatment" + T = pd.Series(T, name=T_name) + X_name = ["feature"] + X = pd.DataFrame(X, columns=X_name) + W_name = ["control1", "control2", "control3", "control4"] + W = pd.DataFrame(W, columns=W_name) + est = LinearDRLearner().dowhy.fit(Y, T, X, W) + np.testing.assert_array_equal(est._common_causes, X_name + W_name) + np.testing.assert_array_equal(est._effect_modifiers, X_name) + np.testing.assert_array_equal(est._treatment, [T_name]) + np.testing.assert_array_equal(est._outcome, [Y_name]) From 5b069f0c3d7d47927a9d8ec4d7cb14249e291521 Mon Sep 17 00:00:00 2001 From: Maggie Hei Date: Fri, 19 Mar 2021 18:01:44 -0400 Subject: [PATCH 3/3] remove deprecated args --- README.md | 10 ++--- econml/_ortho_learner.py | 18 +------- econml/dml/_rlearner.py | 4 +- econml/dml/causal_forest.py | 19 -------- econml/dml/dml.py | 16 ------- econml/dr/_drlearner.py | 87 ------------------------------------- econml/iv/dml/_dml.py | 12 ----- econml/iv/dr/_dr.py | 8 ---- 8 files changed, 7 insertions(+), 167 deletions(-) diff --git a/README.md b/README.md index d229e57f3..4d34bfcb7 100644 --- a/README.md +++ b/README.md @@ -396,19 +396,19 @@ See the References section for more details. reg = lambda: RandomForestRegressor(min_samples_leaf=20) clf = lambda: RandomForestClassifier(min_samples_leaf=20) models = [('ldml', LinearDML(model_y=reg(), model_t=clf(), discrete_treatment=True, - linear_first_stages=False, n_splits=3)), + linear_first_stages=False, cv=3)), ('xlearner', XLearner(models=reg(), cate_models=reg(), propensity_model=clf())), ('dalearner', DomainAdaptationLearner(models=reg(), final_models=reg(), propensity_model=clf())), ('slearner', SLearner(overall_model=reg())), ('drlearner', DRLearner(model_propensity=clf(), model_regression=reg(), - model_final=reg(), n_splits=3)), + model_final=reg(), cv=3)), ('rlearner', NonParamDML(model_y=reg(), model_t=clf(), model_final=reg(), - discrete_treatment=True, n_splits=3)), + discrete_treatment=True, cv=3)), ('dml3dlasso', DML(model_y=reg(), model_t=clf(), model_final=LassoCV(cv=3, fit_intercept=False), discrete_treatment=True, featurizer=PolynomialFeatures(degree=3), - linear_first_stages=False, n_splits=3)) + linear_first_stages=False, cv=3)) ] # fit cate models on train data @@ -416,7 +416,7 @@ See the References section for more details. # score cate models on validation data scorer = RScorer(model_y=reg(), model_t=clf(), - discrete_treatment=True, n_splits=3, mc_iters=2, mc_agg='median') + discrete_treatment=True, cv=3, mc_iters=2, mc_agg='median') scorer.fit(Y_val, T_val, X=X_val) rscore = [scorer.score(mdl) for _, mdl in models] # select the best model diff --git a/econml/_ortho_learner.py b/econml/_ortho_learner.py index dfa56e1e4..19bd85049 100644 --- a/econml/_ortho_learner.py +++ b/econml/_ortho_learner.py @@ -428,9 +428,8 @@ def _gen_ortho_learner_model_final(self): def __init__(self, *, discrete_treatment, discrete_instrument, categories, cv, random_state, - n_splits='raise', mc_iters=None, mc_agg='mean'): + mc_iters=None, mc_agg='mean'): self.cv = cv - self.n_splits = n_splits self.discrete_treatment = discrete_treatment self.discrete_instrument = discrete_instrument self.random_state = random_state @@ -855,18 +854,3 @@ def models_nuisance_(self): if not hasattr(self, '_models_nuisance'): raise AttributeError("Model is not fitted!") return self._models_nuisance - - ####################################################### - # These should be removed once `n_splits` is deprecated - ####################################################### - - @property - def n_splits(self): - return self.cv - - @n_splits.setter - def n_splits(self, value): - if value != 'raise': - warn("Parameter `n_splits` has been deprecated and will be removed in the next version. " - "Use parameter `cv` instead.") - self.cv = value diff --git a/econml/dml/_rlearner.py b/econml/dml/_rlearner.py index 831b65faa..eaadb7143 100644 --- a/econml/dml/_rlearner.py +++ b/econml/dml/_rlearner.py @@ -261,13 +261,11 @@ def _gen_rlearner_model_final(self): is multidimensional, then the average of the MSEs for each dimension of Y is returned. """ - def __init__(self, *, discrete_treatment, categories, cv, random_state, - n_splits='raise', mc_iters=None, mc_agg='mean'): + def __init__(self, *, discrete_treatment, categories, cv, random_state, mc_iters=None, mc_agg='mean'): super().__init__(discrete_treatment=discrete_treatment, discrete_instrument=False, # no instrument, so doesn't matter categories=categories, cv=cv, - n_splits=n_splits, random_state=random_state, mc_iters=mc_iters, mc_agg=mc_agg) diff --git a/econml/dml/causal_forest.py b/econml/dml/causal_forest.py index 49879470d..4d4ac5eb8 100644 --- a/econml/dml/causal_forest.py +++ b/econml/dml/causal_forest.py @@ -490,7 +490,6 @@ def __init__(self, *, discrete_treatment=False, categories='auto', cv=2, - n_crossfit_splits='raise', mc_iters=None, mc_agg='mean', drate=True, @@ -541,13 +540,9 @@ def __init__(self, *, self.subforest_size = subforest_size self.n_jobs = n_jobs self.verbose = verbose - self.n_crossfit_splits = n_crossfit_splits - if self.n_crossfit_splits != 'raise': - cv = self.n_crossfit_splits super().__init__(discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_crossfit_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -971,17 +966,3 @@ def __getitem__(self, index): def __iter__(self): """Return iterator over estimators in the ensemble.""" return self.model_cate.__iter__() - - ####################################################### - # These should be removed once `n_splits` is deprecated - ####################################################### - - @property - def n_crossfit_splits(self): - return self.cv - - @n_crossfit_splits.setter - def n_crossfit_splits(self, value): - if value != 'raise': - warn("Deprecated by parameter `n_crossfit_splits` and will be removed in next version.") - self.cv = value diff --git a/econml/dml/dml.py b/econml/dml/dml.py index 45ad2d012..c5b4e1ecf 100644 --- a/econml/dml/dml.py +++ b/econml/dml/dml.py @@ -420,7 +420,6 @@ def __init__(self, *, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -435,7 +434,6 @@ def __init__(self, *, super().__init__(discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -596,7 +594,6 @@ def __init__(self, *, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -609,7 +606,6 @@ def __init__(self, *, discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state,) @@ -790,7 +786,6 @@ def __init__(self, *, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -810,7 +805,6 @@ def __init__(self, *, discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -974,7 +968,6 @@ def __init__(self, model_y='auto', model_t='auto', dim=20, bw=1.0, cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): self.dim = dim @@ -987,7 +980,6 @@ def __init__(self, model_y='auto', model_t='auto', discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -1087,7 +1079,6 @@ def __init__(self, *, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -1101,7 +1092,6 @@ def __init__(self, *, super().__init__(discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -1190,7 +1180,6 @@ def ForestDML(model_y, model_t, discrete_treatment=False, categories='auto', cv=2, - n_crossfit_splits='raise', mc_iters=None, mc_agg='mean', n_estimators=100, @@ -1245,10 +1234,6 @@ def ForestDML(model_y, model_t, Unless an iterable is used, we call `split(concat[W, X], T)` to generate the splits. If all W, X are None, then we call `split(ones((T.shape[0], 1)), T)`. - n_crossfit_splits: int or 'raise', optional (default='raise') - Deprecated by parameter `cv` and will be removed in next version. Can be used - interchangeably with `cv`. - mc_iters: int, optional (default=None) The number of times to rerun the first stage models to reduce the variance of the nuisances. @@ -1375,7 +1360,6 @@ def ForestDML(model_y, model_t, discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_crossfit_splits=n_crossfit_splits, mc_iters=mc_iters, mc_agg=mc_agg, n_estimators=n_estimators, diff --git a/econml/dr/_drlearner.py b/econml/dr/_drlearner.py index 73637a2cc..f94a85222 100644 --- a/econml/dr/_drlearner.py +++ b/econml/dr/_drlearner.py @@ -397,7 +397,6 @@ def __init__(self, *, min_propensity=1e-6, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -408,7 +407,6 @@ def __init__(self, *, self.featurizer = clone(featurizer, safe=False) self.min_propensity = min_propensity super().__init__(cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, discrete_treatment=True, @@ -810,7 +808,6 @@ def __init__(self, *, min_propensity=1e-6, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -823,7 +820,6 @@ def __init__(self, *, min_propensity=min_propensity, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -1087,7 +1083,6 @@ def __init__(self, *, min_propensity=1e-6, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -1107,7 +1102,6 @@ def __init__(self, *, min_propensity=min_propensity, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -1239,10 +1233,6 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner): Unless an iterable is used, we call `split(concat[W, X], T)` to generate the splits. If all W, X are None, then we call `split(ones((T.shape[0], 1)), T)`. - n_crossfit_splits: int or 'raise', optional (default='raise') - Deprecated by parameter `cv` and will be removed in next version. Can be used - interchangeably with `cv`. - mc_iters: int, optional (default=None) The number of times to rerun the first stage models to reduce the variance of the nuisances. @@ -1255,12 +1245,6 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner): forest of sqrt(n_estimators) sub-forests, where each sub-forest contains sqrt(n_estimators) trees. - criterion : string, optional (default="mse") - The function to measure the quality of a split. Supported criteria - are "mse" for the mean squared error, which is equal to variance - reduction as feature selection criterion, and "mae" for the mean - absolute error. - max_depth : integer or None, optional (default=None) The maximum depth of the tree. If None, then nodes are expanded until all leaves are pure or until all leaves contain less than @@ -1312,11 +1296,6 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner): valid partition of the node samples is found, even if it requires to effectively inspect more than ``max_features`` features. - max_leaf_nodes : int or None, optional (default=None) - Grow trees with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - min_impurity_decrease : float, optional (default=0.) A node will be split if this split induces a decrease of the impurity greater than or equal to this value. @@ -1333,16 +1312,6 @@ class ForestDRLearner(ForestModelFinalCateEstimatorDiscreteMixin, DRLearner): ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum, if ``sample_weight`` is passed. - subsample_fr : float or 'auto', optional (default='auto') - The fraction of the half-samples that are used on each tree. Each tree - will be built on subsample_fr * n_samples/2. - - If 'auto', then the subsampling fraction is set to:: - - (n_samples/2)**(1-1/(2*n_features+2))/(n_samples/2) - - which is sufficient to guarantee asympotitcally valid inference. - honest : boolean, optional (default=True) Whether to use honest trees, i.e. half of the samples are used for creating the tree structure and the other half for the estimation at @@ -1371,19 +1340,15 @@ def __init__(self, *, min_propensity=1e-6, categories='auto', cv=2, - n_crossfit_splits='raise', mc_iters=None, mc_agg='mean', n_estimators=1000, - criterion='deprecated', max_depth=None, min_samples_split=5, min_samples_leaf=5, min_weight_fraction_leaf=0., max_features="auto", - max_leaf_nodes='deprecated', min_impurity_decrease=0., - subsample_fr='deprecated', max_samples=.45, min_balancedness_tol=.45, honest=True, @@ -1404,12 +1369,6 @@ def __init__(self, *, self.subforest_size = subforest_size self.n_jobs = n_jobs self.verbose = verbose - self.n_crossfit_splits = n_crossfit_splits - if self.n_crossfit_splits != 'raise': - cv = self.n_crossfit_splits - self.subsample_fr = subsample_fr - self.max_leaf_nodes = max_leaf_nodes - self.criterion = criterion super().__init__(model_regression=model_regression, model_propensity=model_propensity, model_final=None, @@ -1418,7 +1377,6 @@ def __init__(self, *, min_propensity=min_propensity, categories=categories, cv=cv, - n_splits='raise', mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -1506,48 +1464,3 @@ def model_final(self): def model_final(self, model): if model is not None: raise ValueError("Parameter `model_final` cannot be altered for this estimator!") - - #################################################################### - # Everything below should be removed once parameters are deprecated - #################################################################### - - @property - def n_crossfit_splits(self): - return self.cv - - @n_crossfit_splits.setter - def n_crossfit_splits(self, value): - if value != 'raise': - warn("Deprecated by parameter `n_splits` and will be removed in next version.") - self.cv = value - - @property - def criterion(self): - return "mse" - - @criterion.setter - def criterion(self, value): - if value != 'deprecated': - warn("The parameter 'criterion' has been deprecated and will be removed in the next version. " - "Only the 'mse' criterion is supported.") - - @property - def max_leaf_nodes(self): - return None - - @max_leaf_nodes.setter - def max_leaf_nodes(self, value): - if value != 'deprecated': - warn("The parameter 'max_leaf_nodes' has been deprecated and will be removed in the next version.") - - @property - def subsample_fr(self): - return 2 * self.max_samples - - @subsample_fr.setter - def subsample_fr(self, value): - if value != 'deprecated': - warn("The parameter 'subsample_fr' has been deprecated and will be removed in the next version. " - "Use 'max_samples' instead, with the convention that " - "'subsample_fr=x' is equivalent to 'max_samples=x/2'.") - self.max_samples = .45 if value == 'auto' else value / 2 diff --git a/econml/iv/dml/_dml.py b/econml/iv/dml/_dml.py index b8eb5829d..87d8458b9 100644 --- a/econml/iv/dml/_dml.py +++ b/econml/iv/dml/_dml.py @@ -69,7 +69,6 @@ def __init__(self, discrete_instrument=False, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -77,7 +76,6 @@ def __init__(self, discrete_instrument=False, discrete_instrument=discrete_instrument, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -221,7 +219,6 @@ def __init__(self, *, discrete_instrument=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -232,7 +229,6 @@ def __init__(self, *, discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -298,7 +294,6 @@ def __init__(self, *, discrete_instrument=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -309,7 +304,6 @@ def __init__(self, *, discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -485,14 +479,12 @@ class _BaseDMLIV(_OrthoLearner): def __init__(self, discrete_instrument=False, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): super().__init__(discrete_treatment=discrete_treatment, discrete_instrument=discrete_instrument, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -765,7 +757,6 @@ def __init__(self, *, featurizer=None, fit_cate_intercept=True, cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', discrete_instrument=False, discrete_treatment=False, @@ -777,7 +768,6 @@ def __init__(self, *, self.featurizer = clone(featurizer, safe=False) self.fit_cate_intercept = fit_cate_intercept super().__init__(cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, discrete_instrument=discrete_instrument, @@ -895,7 +885,6 @@ def __init__(self, *, model_final, featurizer=None, cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', discrete_instrument=False, @@ -908,7 +897,6 @@ def __init__(self, *, self.model_final = clone(model_final, safe=False) self.featurizer = clone(featurizer, safe=False) super().__init__(cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, discrete_instrument=discrete_instrument, diff --git a/econml/iv/dr/_dr.py b/econml/iv/dr/_dr.py index e87e31cc0..75e4b6c4f 100644 --- a/econml/iv/dr/_dr.py +++ b/econml/iv/dr/_dr.py @@ -211,7 +211,6 @@ def __init__(self, *, discrete_treatment=False, categories='auto', cv=2, - n_splits='raise', mc_iters=None, mc_agg='mean', random_state=None): @@ -224,7 +223,6 @@ def __init__(self, *, discrete_treatment=discrete_treatment, categories=categories, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, random_state=random_state) @@ -425,7 +423,6 @@ def __init__(self, *, fit_cate_intercept=True, cov_clip=.1, cv=3, - n_splits='raise', mc_iters=None, mc_agg='mean', opt_reweighted=False, @@ -442,7 +439,6 @@ def __init__(self, *, fit_cate_intercept=fit_cate_intercept, cov_clip=cov_clip, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, discrete_instrument=True, @@ -560,7 +556,6 @@ def __init__(self, *, fit_cate_intercept=True, cov_clip=.1, cv=3, - n_splits='raise', mc_iters=None, mc_agg='mean', opt_reweighted=False, @@ -575,7 +570,6 @@ def __init__(self, *, fit_cate_intercept=fit_cate_intercept, cov_clip=cov_clip, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, opt_reweighted=opt_reweighted, @@ -696,7 +690,6 @@ def __init__(self, *, fit_cate_intercept=True, cov_clip=.1, cv=3, - n_splits='raise', mc_iters=None, mc_agg='mean', categories='auto', @@ -709,7 +702,6 @@ def __init__(self, *, model_final=None, cov_clip=cov_clip, cv=cv, - n_splits=n_splits, mc_iters=mc_iters, mc_agg=mc_agg, opt_reweighted=False,