-
Notifications
You must be signed in to change notification settings - Fork 726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scaling ortholearners using Ray #800
Changes from 7 commits
1128981
d5af162
bb772c2
9b7540d
e1d3aba
b10c804
274e788
b190e8f
eae67fd
f09f2e3
6d81b37
f877ae0
d28199f
3c2eb4b
3289831
b67bdab
d97b132
57b78b1
9fbbe7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,12 +114,12 @@ jobs: | |
kind: [except-customer-scenarios, customer-scenarios] | ||
include: | ||
- kind: "except-customer-scenarios" | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
pattern: "(?!CustomerScenarios)" | ||
install_graphviz: true | ||
version: '3.8' # no supported version of tensorflow for 3.9 | ||
- kind: "customer-scenarios" | ||
extras: "[plt,dowhy]" | ||
extras: "[plt,dowhy,ray]" | ||
pattern: "CustomerScenarios" | ||
version: '3.9' | ||
install_graphviz: false | ||
|
@@ -193,19 +193,19 @@ jobs: | |
include: | ||
- kind: serial | ||
opts: '-m "serial" -n 1' | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
- kind: other | ||
opts: '-m "cate_api" -n auto' | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
- kind: dml | ||
opts: '-m "dml"' | ||
extras: "[tf,plt]" | ||
extras: "[tf,plt,ray]" | ||
- kind: main | ||
opts: '-m "not (notebook or automl or dml or serial or cate_api or treatment_featurization)" -n 2' | ||
extras: "[tf,plt,dowhy]" | ||
extras: "[tf,plt,dowhy,ray]" | ||
- kind: treatment | ||
opts: '-m "treatment_featurization" -n auto' | ||
extras: "[tf,plt]" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in latest commit |
||
extras: "[tf,plt,ray]" | ||
fail-fast: false | ||
runs-on: ${{ matrix.os }} | ||
steps: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -47,7 +47,67 @@ class in this module implements the general logic in a very versatile way | |||||||||
inverse_onehot, jacify_featurizer, ndim, reshape, shape, transpose) | ||||||||||
|
||||||||||
|
||||||||||
def _crossfit(model, folds, *args, **kwargs): | ||||||||||
def _fit_fold(model, train_idxs, test_idxs, calculate_scores, args, kwargs): | ||||||||||
""" | ||||||||||
Fits a single model on the training data and calculates the nuisance value on the test data. | ||||||||||
model: object | ||||||||||
An object that supports fit and predict. Fit must accept all the args | ||||||||||
and the keyword arguments kwargs. Similarly predict must all accept | ||||||||||
all the args as arguments and kwards as keyword arguments. The fit | ||||||||||
function estimates a model of the nuisance function, based on the input | ||||||||||
data to fit. Predict evaluates the fitted nuisance function on the input | ||||||||||
data to predict. | ||||||||||
|
||||||||||
train_idxs (array-like): Indices for the training data. | ||||||||||
test_idxs (array-like): Indices for the test data. | ||||||||||
calculate_scores (bool): Whether to calculate scores after fitting. | ||||||||||
|
||||||||||
args : a sequence of (numpy matrices or None) | ||||||||||
Each matrix is a data variable whose first index corresponds to a sample | ||||||||||
kwargs : a sequence of key-value args, with values being (numpy matrices or None) | ||||||||||
Each keyword argument is of the form Var=x, with x a numpy array. Each | ||||||||||
of these arrays are data variables. The model fit and predict will be | ||||||||||
called with signature: `model.fit(*args, **kwargs)` and | ||||||||||
`model.predict(*args, **kwargs)`. Key-value arguments that have value | ||||||||||
None, are ommitted from the two calls. So all the args and the non None | ||||||||||
kwargs variables must be part of the models signature. | ||||||||||
Returns: | ||||||||||
-------- | ||||||||||
-Tuple containing: | ||||||||||
nuisance_temp (tuple): Predictions or values of interest from the model. | ||||||||||
fitted_model: The fitted model after training. | ||||||||||
test_idxs (array-like): Indices of the test data. | ||||||||||
score_temp (tuple or None): Scores calculated after fitting if `calculate_scores` is True, otherwise None. | ||||||||||
""" | ||||||||||
model = clone(model, safe=False) | ||||||||||
|
||||||||||
if len(np.intersect1d(train_idxs, test_idxs)) > 0: | ||||||||||
raise AttributeError( | ||||||||||
"Invalid crossfitting fold structure. Train and test indices of each fold must be disjoint. {},{}".format( | ||||||||||
train_idxs, test_idxs)) | ||||||||||
|
||||||||||
args_train = tuple(var[train_idxs] if var is not None else None for var in args) | ||||||||||
args_test = tuple(var[test_idxs] if var is not None else None for var in args) | ||||||||||
|
||||||||||
kwargs_train = {key: var[train_idxs] for key, var in kwargs.items()} | ||||||||||
kwargs_test = {key: var[test_idxs] for key, var in kwargs.items()} | ||||||||||
|
||||||||||
model.fit(*args_train, **kwargs_train) | ||||||||||
nuisance_temp = model.predict(*args_test, **kwargs_test) | ||||||||||
|
||||||||||
if not isinstance(nuisance_temp, tuple): | ||||||||||
nuisance_temp = (nuisance_temp,) | ||||||||||
|
||||||||||
if calculate_scores: | ||||||||||
score_temp = model.score(*args_test, **kwargs_test) | ||||||||||
|
||||||||||
if not isinstance(score_temp, tuple): | ||||||||||
score_temp = (score_temp,) | ||||||||||
|
||||||||||
return nuisance_temp, model, test_idxs, (score_temp if calculate_scores else None) | ||||||||||
|
||||||||||
|
||||||||||
def _crossfit(model, use_ray, folds, ray_remote_fun_option, *args, **kwargs): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes more sense for folds to come before the ray arguments (and certainly for the ray arguments to be adjacent), and these changes make the specification match the docstring.
Suggested change
|
||||||||||
""" | ||||||||||
General crossfit based calculation of nuisance parameters. | ||||||||||
|
||||||||||
|
@@ -60,6 +120,10 @@ def _crossfit(model, folds, *args, **kwargs): | |||||||||
function estimates a model of the nuisance function, based on the input | ||||||||||
data to fit. Predict evaluates the fitted nuisance function on the input | ||||||||||
data to predict. | ||||||||||
use_ray: bool, default False (optional) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
having a default implies optional |
||||||||||
Flag to indicate whether to use ray to parallelize the cross-fitting step. | ||||||||||
ray_remote_fun_option: dict, default None (optional) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Having a default implies optional |
||||||||||
Options to pass to the ray.remote decorator. | ||||||||||
folds : list of tuple or None | ||||||||||
The crossfitting fold structure. Every entry in the list is a tuple whose | ||||||||||
first element are the training indices of the args and kwargs data and | ||||||||||
|
@@ -115,7 +179,10 @@ def predict(self, X, y, W=None): | |||||||||
y = X[:, 0] + np.random.normal(size=(5000,)) | ||||||||||
folds = list(KFold(2).split(X, y)) | ||||||||||
model = Lasso(alpha=0.01) | ||||||||||
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model), folds, X, y, W=y, Z=None) | ||||||||||
use_ray = False | ||||||||||
ray_remote_fun_option = {} | ||||||||||
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model),use_ray, folds,ray_remote_fun_option, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
X, y,W=y, Z=None) | ||||||||||
|
||||||||||
>>> nuisance | ||||||||||
(array([-1.105728... , -1.537566..., -2.451827... , ..., 1.106287..., | ||||||||||
|
@@ -127,9 +194,7 @@ def predict(self, X, y, W=None): | |||||||||
|
||||||||||
""" | ||||||||||
model_list = [] | ||||||||||
fitted_inds = [] | ||||||||||
calculate_scores = hasattr(model, 'score') | ||||||||||
|
||||||||||
# remove None arguments | ||||||||||
kwargs = filter_none_kwargs(**kwargs) | ||||||||||
|
||||||||||
|
@@ -150,46 +215,43 @@ def predict(self, X, y, W=None): | |||||||||
first_arr = args[0] if args else kwargs.items()[0][1] | ||||||||||
return nuisances, model_list, np.arange(first_arr.shape[0]), scores | ||||||||||
|
||||||||||
|
||||||||||
folds = list(folds) | ||||||||||
fold_refs = [] | ||||||||||
if use_ray: | ||||||||||
import ray | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Remove this line, and instead use the try:
import ray
except ImportError as exn:
from .utilities import MissingModule
# make any access to ray throw an exception
ray = MissingModule("Ray is not a dependency of the base econml package; install econml[ray] or econml[all] to require it, "
"or install ray separately, to use functionality that depends on ray", exn) This will make it so that any attempted use of the ray module within the file will fail with this message, without needing to put any additional try-except blocks into the various places throughout the file where that might happen |
||||||||||
# Adding the kwargs to ray object store to be used by remote functions for each fold to avoid IO overhead | ||||||||||
ray_args = ray.put(kwargs) | ||||||||||
for idx, (train_idxs, test_idxs) in enumerate(folds): | ||||||||||
fold_refs.append( | ||||||||||
ray.remote(_fit_fold).options(**ray_remote_fun_option).remote(model, train_idxs, test_idxs, | ||||||||||
calculate_scores, args, ray_args)) | ||||||||||
fitted_inds = [] | ||||||||||
for idx, (train_idxs, test_idxs) in enumerate(folds): | ||||||||||
model_list.append(clone(model, safe=False)) | ||||||||||
if len(np.intersect1d(train_idxs, test_idxs)) > 0: | ||||||||||
raise AttributeError("Invalid crossfitting fold structure." + | ||||||||||
"Train and test indices of each fold must be disjoint.") | ||||||||||
if use_ray: | ||||||||||
nuisance_temp, model, test_idxs, score_temp = ray.get(fold_refs[idx]) | ||||||||||
else: | ||||||||||
nuisance_temp, model, test_idxs, score_temp = _fit_fold(model, train_idxs, test_idxs, | ||||||||||
calculate_scores, args, kwargs) | ||||||||||
|
||||||||||
if len(np.intersect1d(fitted_inds, test_idxs)) > 0: | ||||||||||
raise AttributeError("Invalid crossfitting fold structure. The same index appears in two test folds.") | ||||||||||
raise AttributeError( | ||||||||||
"Invalid crossfitting fold structure. The same index appears in two test folds.") | ||||||||||
fitted_inds = np.concatenate((fitted_inds, test_idxs)) | ||||||||||
|
||||||||||
args_train = tuple(var[train_idxs] if var is not None else None for var in args) | ||||||||||
args_test = tuple(var[test_idxs] if var is not None else None for var in args) | ||||||||||
|
||||||||||
kwargs_train = {key: var[train_idxs] for key, var in kwargs.items()} | ||||||||||
kwargs_test = {key: var[test_idxs] for key, var in kwargs.items()} | ||||||||||
|
||||||||||
model_list[idx].fit(*args_train, **kwargs_train) | ||||||||||
|
||||||||||
nuisance_temp = model_list[idx].predict(*args_test, **kwargs_test) | ||||||||||
|
||||||||||
if not isinstance(nuisance_temp, tuple): | ||||||||||
nuisance_temp = (nuisance_temp,) | ||||||||||
|
||||||||||
if idx == 0: | ||||||||||
nuisances = tuple([np.full((args[0].shape[0],) + nuis.shape[1:], np.nan) for nuis in nuisance_temp]) | ||||||||||
|
||||||||||
for it, nuis in enumerate(nuisance_temp): | ||||||||||
nuisances[it][test_idxs] = nuis | ||||||||||
|
||||||||||
if calculate_scores: | ||||||||||
score_temp = model_list[idx].score(*args_test, **kwargs_test) | ||||||||||
|
||||||||||
if not isinstance(score_temp, tuple): | ||||||||||
score_temp = (score_temp,) | ||||||||||
|
||||||||||
if idx == 0: | ||||||||||
scores = tuple([] for _ in score_temp) | ||||||||||
|
||||||||||
for it, score in enumerate(score_temp): | ||||||||||
scores[it].append(score) | ||||||||||
|
||||||||||
model_list.append(model) | ||||||||||
|
||||||||||
return nuisances, model_list, np.sort(fitted_inds.astype(int)), (scores if calculate_scores else None) | ||||||||||
|
||||||||||
|
||||||||||
|
@@ -297,6 +359,10 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator): | |||||||||
mc_agg: {'mean', 'median'}, default 'mean' | ||||||||||
How to aggregate the nuisance value for each sample across the `mc_iters` monte carlo iterations of | ||||||||||
cross-fitting. | ||||||||||
use_ray: bool, default False | ||||||||||
Whether to use ray to parallelize the cross-fitting step. | ||||||||||
ray_remote_func_options: dict, default None | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
Options to pass to the ray.remote decorator. | ||||||||||
|
||||||||||
Examples | ||||||||||
-------- | ||||||||||
|
@@ -346,6 +412,11 @@ def _gen_ortho_learner_model_final(self): | |||||||||
discrete_instrument=False, categories='auto', random_state=None) | ||||||||||
est.fit(y, X[:, 0], W=X[:, 1:]) | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry if my previous comment was unclear: I think including the comment is helpful for understanding why |
||||||||||
OR (for parallelization using ray)) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note that for us to run our doctests, this must be valid python code. Also note that with your change we'll define |
||||||||||
est = OrthoLearner(cv=2, discrete_treatment=False, treatment_featurizer=None, | ||||||||||
discrete_instrument=False, categories='auto', random_state=None,use_ray=True) | ||||||||||
est.fit(y, X[:, 0], W=X[:, 1:]) | ||||||||||
|
||||||||||
>>> est.score_ | ||||||||||
0.00756830... | ||||||||||
>>> est.const_marginal_effect() | ||||||||||
|
@@ -434,7 +505,8 @@ def _gen_ortho_learner_model_final(self): | |||||||||
def __init__(self, *, | ||||||||||
discrete_treatment, treatment_featurizer, | ||||||||||
discrete_instrument, categories, cv, random_state, | ||||||||||
mc_iters=None, mc_agg='mean'): | ||||||||||
mc_iters=None, mc_agg='mean', use_ray=False, **ray_remote_func_options): | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Make the dict explicit, as described in a review comment elsewhere. |
||||||||||
self.actors = [] | ||||||||||
self.cv = cv | ||||||||||
self.discrete_treatment = discrete_treatment | ||||||||||
self.treatment_featurizer = treatment_featurizer | ||||||||||
|
@@ -443,6 +515,8 @@ def __init__(self, *, | |||||||||
self.categories = categories | ||||||||||
self.mc_iters = mc_iters | ||||||||||
self.mc_agg = mc_agg | ||||||||||
self.use_ray = use_ray | ||||||||||
self.ray_remote_func_opt = ray_remote_func_options | ||||||||||
super().__init__() | ||||||||||
|
||||||||||
@abstractmethod | ||||||||||
|
@@ -646,9 +720,30 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N | |||||||||
sample_weight_nuisances = sample_weight | ||||||||||
|
||||||||||
self._models_nuisance = [] | ||||||||||
|
||||||||||
if self.use_ray: | ||||||||||
# Initialize Ray Connection if not already initialized | ||||||||||
try: | ||||||||||
import ray | ||||||||||
except ImportError: | ||||||||||
raise ImportError("To use `use_ray=True`, try installing econml via pip3 install econml['ray'].") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The MissingModule pattern should make this unnecessary. |
||||||||||
if not ray.is_initialized(): | ||||||||||
ray.init(ignore_reinit_error=True) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we want to ignore reinit errors? And would we ever hit them, if we get to this line because ray is not initialized already? |
||||||||||
|
||||||||||
# Define Ray remote function (Ray remote wrapper of the _fit_nuisances function) | ||||||||||
def _fit_nuisances(Y, T, X, W, Z, sample_weight, groups): | ||||||||||
return self._fit_nuisances(Y, T, X, W, Z, sample_weight=sample_weight, groups=groups) | ||||||||||
|
||||||||||
# Create Ray remote jobs for parallel processing | ||||||||||
self.nuiances_ref = [ray.remote(_fit_nuisances).options(**self.ray_remote_func_opt).remote( | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why store these in an attribute when they only appear to be needed again on line 743 and then never again? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And on a minor note, there's a typo if you do want to store it in an attribute:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its because or limitation on ray end , remote needs to be called on method or class. I will take care of the typo |
||||||||||
Y, T, X, W, Z, sample_weight_nuisances, groups) for _ in range(self.mc_iters or 1)] | ||||||||||
|
||||||||||
for idx in range(self.mc_iters or 1): | ||||||||||
nuisances, fitted_models, new_inds, scores = self._fit_nuisances( | ||||||||||
Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups) | ||||||||||
if self.use_ray: | ||||||||||
nuisances, fitted_models, new_inds, scores = ray.get(self.nuiances_ref[idx]) | ||||||||||
else: | ||||||||||
nuisances, fitted_models, new_inds, scores = self._fit_nuisances( | ||||||||||
Y, T, X, W, Z, sample_weight=sample_weight_nuisances, groups=groups) | ||||||||||
all_nuisances.append(nuisances) | ||||||||||
self._models_nuisance.append(fitted_models) | ||||||||||
if scores is None: | ||||||||||
|
@@ -663,6 +758,10 @@ def fit(self, Y, T, *, X=None, W=None, Z=None, sample_weight=None, freq_weight=N | |||||||||
elif not np.array_equal(fitted_inds, new_inds): | ||||||||||
raise AttributeError("Different indices were fit by different folds, so they cannot be aggregated") | ||||||||||
|
||||||||||
# Shutdown Ray Connection | ||||||||||
if self.use_ray and ray.is_initialized: | ||||||||||
ray.shutdown() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this always safe? Would this cause problems, for example, if someone tries to fit an OrthoLearner within a ray remote call? |
||||||||||
|
||||||||||
if self.mc_iters is not None: | ||||||||||
if self.mc_agg == 'mean': | ||||||||||
nuisances = tuple(np.mean(nuisance_mc_variants, axis=0) | ||||||||||
|
@@ -790,9 +889,9 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group | |||||||||
else: | ||||||||||
folds = splitter.split(to_split, strata) | ||||||||||
|
||||||||||
nuisances, fitted_models, fitted_inds, scores = _crossfit(self._ortho_learner_model_nuisance, folds, | ||||||||||
Y, T, X=X, W=W, Z=Z, | ||||||||||
sample_weight=sample_weight, groups=groups) | ||||||||||
nuisances, fitted_models, fitted_inds, scores = _crossfit(self._ortho_learner_model_nuisance, self.use_ray, | ||||||||||
folds, self.ray_remote_func_opt, Y, T, X=X, W=W, | ||||||||||
Z=Z, sample_weight=sample_weight, groups=groups, ) | ||||||||||
return nuisances, fitted_models, fitted_inds, scores | ||||||||||
|
||||||||||
def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, | ||||||||||
|
@@ -817,30 +916,35 @@ def const_marginal_effect(self, X=None): | |||||||||
return self._ortho_learner_model_final.predict() | ||||||||||
else: | ||||||||||
return self._ortho_learner_model_final.predict(X) | ||||||||||
|
||||||||||
const_marginal_effect.__doc__ = LinearCateEstimator.const_marginal_effect.__doc__ | ||||||||||
|
||||||||||
def const_marginal_effect_interval(self, X=None, *, alpha=0.05): | ||||||||||
X, = check_input_arrays(X) | ||||||||||
self._check_fitted_dims(X) | ||||||||||
return super().const_marginal_effect_interval(X, alpha=alpha) | ||||||||||
|
||||||||||
const_marginal_effect_interval.__doc__ = LinearCateEstimator.const_marginal_effect_interval.__doc__ | ||||||||||
|
||||||||||
def const_marginal_effect_inference(self, X=None): | ||||||||||
X, = check_input_arrays(X) | ||||||||||
self._check_fitted_dims(X) | ||||||||||
return super().const_marginal_effect_inference(X) | ||||||||||
|
||||||||||
const_marginal_effect_inference.__doc__ = LinearCateEstimator.const_marginal_effect_inference.__doc__ | ||||||||||
|
||||||||||
def effect_interval(self, X=None, *, T0=0, T1=1, alpha=0.05): | ||||||||||
X, T0, T1 = check_input_arrays(X, T0, T1) | ||||||||||
self._check_fitted_dims(X) | ||||||||||
return super().effect_interval(X, T0=T0, T1=T1, alpha=alpha) | ||||||||||
|
||||||||||
effect_interval.__doc__ = LinearCateEstimator.effect_interval.__doc__ | ||||||||||
|
||||||||||
def effect_inference(self, X=None, *, T0=0, T1=1): | ||||||||||
X, T0, T1 = check_input_arrays(X, T0, T1) | ||||||||||
self._check_fitted_dims(X) | ||||||||||
return super().effect_inference(X, T0=T0, T1=T1) | ||||||||||
|
||||||||||
effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__ | ||||||||||
|
||||||||||
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None): | ||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -110,7 +110,7 @@ def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None | |||||||
effects = self._model_final.predict(X).reshape((-1, Y_res.shape[1], T_res.shape[1])) | ||||||||
Y_res_pred = np.einsum('ijk,ik->ij', effects, T_res).reshape(Y_res.shape) | ||||||||
if sample_weight is not None: | ||||||||
return np.mean(np.average((Y_res - Y_res_pred)**2, weights=sample_weight, axis=0)) | ||||||||
return np.mean(np.average((Y_res - Y_res_pred) ** 2, weights=sample_weight, axis=0)) | ||||||||
else: | ||||||||
return np.mean((Y_res - Y_res_pred) ** 2) | ||||||||
|
||||||||
|
@@ -272,15 +272,17 @@ def _gen_rlearner_model_final(self): | |||||||
""" | ||||||||
|
||||||||
def __init__(self, *, discrete_treatment, treatment_featurizer, categories, | ||||||||
cv, random_state, mc_iters=None, mc_agg='mean'): | ||||||||
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, **ray_remote_func_options): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
super().__init__(discrete_treatment=discrete_treatment, | ||||||||
treatment_featurizer=treatment_featurizer, | ||||||||
discrete_instrument=False, # no instrument, so doesn't matter | ||||||||
categories=categories, | ||||||||
cv=cv, | ||||||||
random_state=random_state, | ||||||||
mc_iters=mc_iters, | ||||||||
mc_agg=mc_agg) | ||||||||
mc_agg=mc_agg, | ||||||||
use_ray=use_ray, | ||||||||
**ray_remote_func_options) | ||||||||
|
||||||||
@abstractmethod | ||||||||
def _gen_model_y(self): | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -466,7 +466,10 @@ def __init__(self, *, | |||||
cv=2, | ||||||
mc_iters=None, | ||||||
mc_agg='mean', | ||||||
random_state=None): | ||||||
random_state=None, | ||||||
use_ray=False, | ||||||
kbattocchi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
**ray_remote_func_options | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I think it would be better to make this an explicit dictionary argument, rather than having it implicitly include any other keyword arguments passed to the DML initializer since in the future we might want similar arguments for other compute backends. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (This also applies all the way up the hierarchy, to the RLearner and OrthoLearner initializer arguments) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in latest commit |
||||||
): | ||||||
# TODO: consider whether we need more care around stateful featurizers, | ||||||
# since we clone it and fit separate copies | ||||||
self.fit_cate_intercept = fit_cate_intercept | ||||||
|
@@ -481,7 +484,10 @@ def __init__(self, *, | |||||
cv=cv, | ||||||
mc_iters=mc_iters, | ||||||
mc_agg=mc_agg, | ||||||
random_state=random_state) | ||||||
random_state=random_state, | ||||||
use_ray=use_ray, | ||||||
**ray_remote_func_options | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) | ||||||
|
||||||
def _gen_featurizer(self): | ||||||
return clone(self.featurizer, safe=False) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless you make any changes to the notebooks to take advantage of the new ray functionality, these changes should not be necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit