Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,10 @@ def __init__(self, *,
model_y='auto', model_t='auto',
alpha='auto',
n_alphas=100,
alphas=None,
alpha_cov='auto',
n_alphas_cov=10,
alphas_cov=None,
max_iter=1000,
tol=1e-4,
n_jobs=None,
Expand All @@ -1153,10 +1155,19 @@ def __init__(self, *,
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
import warnings
if n_alphas != 100:
warnings.warn("The n_alphas parameter is deprecated and will be removed in a future release. "
"Use the alphas parameter instead.", FutureWarning)
if n_alphas_cov != 10:
warnings.warn("The n_alphas_cov parameter is deprecated and will be removed in a future release. "
"Use the alphas_cov parameter instead.", FutureWarning)
self.alpha = alpha
self.n_alphas = n_alphas
self.alphas = alphas
self.alpha_cov = alpha_cov
self.n_alphas_cov = n_alphas_cov
self.alphas_cov = alphas_cov
self.max_iter = max_iter
self.tol = tol
self.n_jobs = n_jobs
Expand Down Expand Up @@ -1185,8 +1196,10 @@ def _gen_allowed_missing_vars(self):
def _gen_model_final(self):
return MultiOutputDebiasedLasso(alpha=self.alpha,
n_alphas=self.n_alphas,
alphas=self.alphas,
alpha_cov=self.alpha_cov,
n_alphas_cov=self.n_alphas_cov,
alphas_cov=self.alphas_cov,
fit_intercept=False,
max_iter=self.max_iter,
tol=self.tol,
Expand Down
13 changes: 13 additions & 0 deletions econml/dr/_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,8 +1616,10 @@ def __init__(self, *,
discrete_outcome=False,
alpha='auto',
n_alphas=100,
alphas=None,
alpha_cov='auto',
n_alphas_cov=10,
alphas_cov=None,
max_iter=1000,
tol=1e-4,
n_jobs=None,
Expand All @@ -1632,11 +1634,20 @@ def __init__(self, *,
use_ray=False,
ray_remote_func_options=None):

import warnings
if n_alphas != 100:
warnings.warn("The n_alphas parameter is deprecated and will be removed in a future release. "
"Use the alphas parameter instead.", FutureWarning)
if n_alphas_cov != 10:
warnings.warn("The n_alphas_cov parameter is deprecated and will be removed in a future release. "
"Use the alphas_cov parameter instead.", FutureWarning)
self.fit_cate_intercept = fit_cate_intercept
self.alpha = alpha
self.n_alphas = n_alphas
self.alphas = alphas
self.alpha_cov = alpha_cov
self.n_alphas_cov = n_alphas_cov
self.alphas_cov = alphas_cov
self.max_iter = max_iter
self.tol = tol
self.n_jobs = n_jobs
Expand All @@ -1663,8 +1674,10 @@ def _gen_allowed_missing_vars(self):
def _gen_model_final(self):
return DebiasedLasso(alpha=self.alpha,
n_alphas=self.n_alphas,
alphas=self.alphas,
alpha_cov=self.alpha_cov,
n_alphas_cov=self.n_alphas_cov,
alphas_cov=self.alphas_cov,
fit_intercept=self.fit_cate_intercept,
max_iter=self.max_iter,
tol=self.tol,
Expand Down
13 changes: 13 additions & 0 deletions econml/iv/dr/_dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1790,8 +1790,10 @@ def __init__(self, *,
fit_cate_intercept=True,
alpha='auto',
n_alphas=100,
alphas=None,
alpha_cov='auto',
n_alphas_cov=10,
alphas_cov=None,
max_iter=1000,
tol=1e-4,
n_jobs=None,
Expand All @@ -1809,10 +1811,19 @@ def __init__(self, *,
allow_missing=False,
use_ray=False,
ray_remote_func_options=None):
import warnings
if n_alphas != 100:
warnings.warn("The n_alphas parameter is deprecated and will be removed in a future release. "
"Use the alphas parameter instead.", FutureWarning)
if n_alphas_cov != 10:
warnings.warn("The n_alphas_cov parameter is deprecated and will be removed in a future release. "
"Use the alphas_cov parameter instead.", FutureWarning)
self.alpha = alpha
self.n_alphas = n_alphas
self.alphas = alphas
self.alpha_cov = alpha_cov
self.n_alphas_cov = n_alphas_cov
self.alphas_cov = alphas_cov
self.max_iter = max_iter
self.tol = tol
self.n_jobs = n_jobs
Expand Down Expand Up @@ -1849,8 +1860,10 @@ def __init__(self, *,
def _gen_model_final(self):
return DebiasedLasso(alpha=self.alpha,
n_alphas=self.n_alphas,
alphas=self.alphas,
alpha_cov=self.alpha_cov,
n_alphas_cov=self.n_alphas_cov,
alphas_cov=self.alphas_cov,
fit_intercept=False,
max_iter=self.max_iter,
tol=self.tol,
Expand Down
26 changes: 19 additions & 7 deletions econml/sklearn_extensions/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,13 +594,14 @@ def fit(self, X, y, sample_weight=None):
return self


def _get_theta_coefs_and_tau_sq(i, X, sample_weight, alpha_cov, n_alphas_cov, max_iter, tol, random_state):
def _get_theta_coefs_and_tau_sq(i, X, sample_weight, alpha_cov, n_alphas_cov, max_iter, tol, random_state,
alphas_cov=None):
n_samples, n_features = X.shape
y = X[:, i]
X_reduced = X[:, list(range(i)) + list(range(i + 1, n_features))]
# Call weighted lasso on reduced design matrix
if alpha_cov == 'auto':
local_wlasso = WeightedLassoCV(cv=3, n_alphas=n_alphas_cov,
local_wlasso = WeightedLassoCV(cv=3, n_alphas=n_alphas_cov, alphas=alphas_cov,
fit_intercept=False,
max_iter=max_iter,
tol=tol, n_jobs=1,
Expand Down Expand Up @@ -726,11 +727,20 @@ class DebiasedLasso(WeightedLasso):
def __init__(self, alpha='auto', n_alphas=100, alpha_cov='auto', n_alphas_cov=10,
fit_intercept=True, precompute=False, copy_X=True, max_iter=1000,
tol=1e-4, warm_start=False,
random_state=None, selection='cyclic', n_jobs=None):
random_state=None, selection='cyclic', n_jobs=None, *, alphas=None, alphas_cov=None):
import warnings
if n_alphas != 100:
warnings.warn("The n_alphas parameter is deprecated and will be removed in a future release. "
"Use the alphas parameter instead.", FutureWarning)
if n_alphas_cov != 10:
warnings.warn("The n_alphas_cov parameter is deprecated and will be removed in a future release. "
"Use the alphas_cov parameter instead.", FutureWarning)
self.n_jobs = n_jobs
self.n_alphas = n_alphas
self.alphas = alphas
self.alpha_cov = alpha_cov
self.n_alphas_cov = n_alphas_cov
self.alphas_cov = alphas_cov
super().__init__(
alpha=alpha, fit_intercept=fit_intercept,
precompute=precompute, copy_X=copy_X,
Expand Down Expand Up @@ -929,7 +939,8 @@ def _get_coef_correction(self, X, y, y_pred, sample_weight, theta_hat):

def _get_optimal_alpha(self, X, y, sample_weight):
# To be done once per target. Assumes y can be flattened.
cv_estimator = WeightedLassoCV(cv=5, n_alphas=self.n_alphas, fit_intercept=self.fit_intercept,
cv_estimator = WeightedLassoCV(cv=5, n_alphas=self.n_alphas, alphas=self.alphas,
fit_intercept=self.fit_intercept,
precompute=self.precompute, copy_X=True,
max_iter=self.max_iter, tol=self.tol,
random_state=self.random_state,
Expand All @@ -950,7 +961,7 @@ def _get_theta_hat(self, X, sample_weight):
results = Parallel(n_jobs=self.n_jobs)(
delayed(_get_theta_coefs_and_tau_sq)(i, X, sample_weight,
self.alpha_cov, self.n_alphas_cov,
self.max_iter, self.tol, self.random_state)
self.max_iter, self.tol, self.random_state, self.alphas_cov)
for i in range(n_features))
coefs, tausq = zip(*results)
coefs = np.array(coefs)
Expand Down Expand Up @@ -1072,8 +1083,9 @@ def __init__(self, alpha='auto', n_alphas=100, alpha_cov='auto', n_alphas_cov=10
fit_intercept=True,
precompute=False, copy_X=True, max_iter=1000,
tol=1e-4, warm_start=False,
random_state=None, selection='cyclic', n_jobs=None):
self.estimator = DebiasedLasso(alpha=alpha, n_alphas=n_alphas, alpha_cov=alpha_cov, n_alphas_cov=n_alphas_cov,
random_state=None, selection='cyclic', n_jobs=None, *, alphas=None, alphas_cov=None):
self.estimator = DebiasedLasso(alpha=alpha, n_alphas=n_alphas, alphas=alphas, alpha_cov=alpha_cov,
n_alphas_cov=n_alphas_cov, alphas_cov=alphas_cov,
fit_intercept=fit_intercept,
precompute=precompute, copy_X=copy_X, max_iter=max_iter,
tol=tol, warm_start=warm_start,
Expand Down