Skip to content

Commit

Permalink
Fixed np.diag bugs (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
Miruna Oprescu authored Dec 6, 2019
1 parent 3c36288 commit 7c13bf5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 4 additions & 5 deletions econml/sklearn_extensions/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,8 @@ def _fit_weighted_linear_model(self, X, y, sample_weight, check_input=None):
# Weight inputs
normalized_weights = X.shape[0] * sample_weight / np.sum(sample_weight)
sqrt_weights = np.sqrt(normalized_weights)
weight_mat = np.diag(sqrt_weights)
X_weighted = np.matmul(weight_mat, X)
y_weighted = np.matmul(weight_mat, y)
X_weighted = sqrt_weights.reshape(-1, 1) * X
y_weighted = sqrt_weights.reshape(-1, 1) * y if y.ndim > 1 else sqrt_weights * y
fit_params['X'] = X_weighted
fit_params['y'] = y_weighted
if self.fit_intercept:
Expand Down Expand Up @@ -842,8 +841,8 @@ def _get_theta_hat(self, X, sample_weight):

def _get_unscaled_coef_var(self, X, theta_hat, sample_weight):
if sample_weight is not None:
weights_mat = np.diag(sample_weight / np.sum(sample_weight))
sigma = X.T @ weights_mat @ X
norm_weights = sample_weight / np.sum(sample_weight)
sigma = X.T @ (norm_weights.reshape(-1, 1) * X)
else:
sigma = np.matmul(X.T, X) / X.shape[0]
_unscaled_coef_var = np.matmul(
Expand Down
6 changes: 4 additions & 2 deletions econml/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,12 @@ def predict(self, X):
return self.model_instance.predict(X)

def _weighted_inputs(self, X, y, sample_weight):
X, y = check_X_y(X, y, y_numeric=True, multi_output=True)
normalized_weights = sample_weight * X.shape[0] / np.sum(sample_weight)
sqrt_weights = np.sqrt(normalized_weights)
weight_mat = np.diag(sqrt_weights)
return np.matmul(weight_mat, X), np.matmul(weight_mat, y)
weighted_X = sqrt_weights.reshape(-1, 1) * X
weighted_y = sqrt_weights.reshape(-1, 1) * y if y.ndim > 1 else sqrt_weights * y
return weighted_X, weighted_y

def _sampled_inputs(self, X, y, sample_weight):
# Normalize weights
Expand Down

0 comments on commit 7c13bf5

Please sign in to comment.