Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deprecated args and dowhy wrapper #434

Merged
merged 4 commits into from
Mar 22, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions econml/dowhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ def _get_params(self):
# to represent
init_signature = inspect.signature(init)
parameters = init_signature.parameters.values()
params = []
for p in parameters:
if p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD:
raise RuntimeError("cate estimators should always specify their parameters in the signature "
"of their __init__ (no varargs, no varkwargs). "
f"{self._cate_estimator} with constructor {init_signature} doesn't "
"follow this convention.")
# if the argument is deprecated, ignore it
if p.default != "deprecated":
params.append(p.name)
# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])
return sorted(params)

def fit(self, Y, T, X=None, W=None, Z=None, *, outcome_names=None, treatment_names=None, feature_names=None,
confounder_names=None, instrument_names=None, graph=None, estimand_type="nonparametric-ate",
Expand Down Expand Up @@ -106,30 +110,41 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, outcome_names=None, treatment_nam
-------
self
"""

Y, T, X, W, Z = check_input_arrays(Y, T, X, W, Z)
heimengqi marked this conversation as resolved.
Show resolved Hide resolved

# create dataframe
n_obs = Y.shape[0]
Y, T, X, W, Z = reshape_arrays_2dim(n_obs, Y, T, X, W, Z)

# currently dowhy only support single outcome and single treatment
assert Y.shape[1] == 1, "Can only accept single dimensional outcome."
assert T.shape[1] == 1, "Can only accept single dimensional treatment."

# column names
if outcome_names is None:
outcome_names = get_input_columns(Y, prefix="Y")
if treatment_names is None:
treatment_names = get_input_columns(T, prefix="T")
if feature_names is None:
feature_names = get_input_columns(X, prefix="X")
if X is not None:
feature_names = get_input_columns(X, prefix="X")
else:
feature_names = []
if confounder_names is None:
confounder_names = get_input_columns(W, prefix="W")
if W is not None:
confounder_names = get_input_columns(W, prefix="W")
else:
confounder_names = []
if instrument_names is None:
instrument_names = get_input_columns(Z, prefix="Z")
if Z is not None:
instrument_names = get_input_columns(Z, prefix="Z")
else:
instrument_names = []
column_names = outcome_names + treatment_names + feature_names + confounder_names + instrument_names

# transfer input to numpy arrays
Y, T, X, W, Z = check_input_arrays(Y, T, X, W, Z)
# transfer input to 2d arrays
n_obs = Y.shape[0]
Y, T, X, W, Z = reshape_arrays_2dim(n_obs, Y, T, X, W, Z)
# create dataframe
df = pd.DataFrame(np.hstack((Y, T, X, W, Z)), columns=column_names)

# currently dowhy only support single outcome and single treatment
assert Y.shape[1] == 1, "Can only accept single dimensional outcome."
assert T.shape[1] == 1, "Can only accept single dimensional treatment."

# call dowhy
self.dowhy_ = CausalModel(
data=df,
treatment=treatment_names,
Expand Down
4 changes: 2 additions & 2 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ def n_crossfit_splits(self, value):

@property
def criterion(self):
return self.criterion
return "mse"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a bit weird that we allow est.criterion = 'deprecated' but not est.criterion = 'mse', and yet we return "mse" here...

If this is identical to sklearn's approach then I guess we can stick with it, but seems unintuitive to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly the same approach, they have a decorator for deprecated args and the deprecated args are not defined in init. Or can I just remove these deprecated args? since it should be removed for next release anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kbattocchi @vsyrgkanis Any final conclusion on this? Then I could make it ready to merge.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it be to mimic sklearn's approach? I think it would be best to either do that or remove the deprecated args as you suggested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all the deprecated args, also for other classes, especially clean up the n_splits/n_crossfit_splits. It's not so clear to me how to mirror sklearn, they seem just add a warning, might worth create an issue to explore how to handle that systematically.


@criterion.setter
def criterion(self, value):
Expand All @@ -1533,7 +1533,7 @@ def criterion(self, value):

@property
def max_leaf_nodes(self):
return self.max_leaf_nodes
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parallel comment here.


@max_leaf_nodes.setter
def max_leaf_nodes(self, value):
Expand Down
3 changes: 2 additions & 1 deletion econml/tests/test_dowhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import unittest
from econml.dml import LinearDML, CausalForestDML
from econml.orf import DROrthoForest
from econml.dr import DRLearner
from econml.dr import DRLearner, ForestDRLearner
from econml.metalearners import XLearner
from econml.iv.dml import DMLATEIV
from sklearn.linear_model import LinearRegression, LogisticRegression, Lasso
Expand Down Expand Up @@ -33,6 +33,7 @@ def clf():
linear_first_stages=False),
"dr": DRLearner(model_propensity=clf(), model_regression=reg(),
model_final=reg()),
"forestdr": ForestDRLearner(model_propensity=clf(), model_regression=reg()),
"xlearner": XLearner(models=reg(), cate_models=reg(), propensity_model=clf()),
"cfdml": CausalForestDML(model_y=reg(), model_t=clf(), discrete_treatment=True),
"orf": DROrthoForest(n_trees=10, propensity_model=clf(), model_Y=reg()),
Expand Down