fix: complete sklearn 1.9 support — drop self.alphas overwrite (closes #1032)#1042
Open
immu4989 wants to merge 1 commit into
Open
fix: complete sklearn 1.9 support — drop self.alphas overwrite (closes #1032)#1042immu4989 wants to merge 1 commit into
immu4989 wants to merge 1 commit into
Conversation
…y#1032) py-why#1032 reports econml errors on import with scikit-learn 1.9. e546416 ("Fix scikit-learn 1.7+ FutureWarnings ...") added a >=1.7 dispatch in WeightedLassoCV / WeightedMultiTaskLassoCV that translates n_alphas=<int> into alphas=<int> on the super().__init__() call, which fixes the import-time TypeError. But it then ran self.alphas = alphas at the end of the dispatch, overwriting the value sklearn's __init__ had correctly recorded back to the constructor's alphas kwarg (None by default). On sklearn 1.7-1.8 the loose param-validation tolerated the resulting self.alphas = None. sklearn 1.9's stricter _param_validation rejects it, so SparseLinearDML.fit (via _DebiasedLasso.fit -> WeightedLassoCV) and DebiasedLasso.fit raise InvalidParameterError. Drop the overwrite. super().__init__(...) already records self.alphas from the translated value. self.n_alphas is still set so callers can introspect the original wrapper kwarg. Also bump scikit-learn pin from < 1.9 to < 1.10 so 1.9 installs. Adds test_default_alphas_fits_on_strict_sklearn covering both WeightedLassoCV and WeightedMultiTaskLassoCV; verifies on sklearn 1.9 locally that 48 tests pass across test_linear_model, test_dml, and test_treatment_featurization (no regressions). Signed-off-by: Imran Ahamed <immu4989@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #1032.
What's happening
#1032 (filed by @jakevdp) reports that
import econml.orferrors on scikit-learn 1.9 becausen_alphaswas removed fromLassoCV.__init__.The recent commit
e546416added a>= 1.7version dispatch inWeightedLassoCV.__init__/WeightedMultiTaskLassoCV.__init__that translatesn_alphas=<int>intoalphas=<int>on thesuper().__init__()call. That fixes the import-timeTypeErrorjakevdp reported.But e546416 also added
self.alphas = alphasat the end of the dispatch, which clobbers the value sklearn's__init__had correctly recorded back to the constructor'salphaskwarg (Noneby default):On sklearn 1.7–1.8 the loose param-validation tolerated the resulting
self.alphas = None. On sklearn 1.9 the stricter_param_validationrejects it, soSparseLinearDML.fit(via_DebiasedLasso.fit→WeightedLassoCV) andDebiasedLasso.fitraise:Fix
Drop the overwrite.
super().__init__(...)already recordsself.alphasfrom the translated value.self.n_alphasis still set so callers can introspect the original wrapper kwarg.Same edit applied to both
WeightedLassoCV.__init__andWeightedMultiTaskLassoCV.__init__.Also bumps
pyproject.tomlscikit-learn >= 1.0, < 1.9→< 1.10so users can actually install sklearn 1.9.Verification (locally, sklearn 1.9.0 + narwhals)
import econml.orf✓ (jakevdp's exact repro)LinearDML,SparseLinearDML,CausalForestDML,DMLOrthoForest,LinearDRLearner,SparseLinearDRLearnertest_default_alphas_fits_on_strict_sklearnregression test inTestLassoExtensions:WeightedLassoCV().alphas is not NoneandWeightedMultiTaskLassoCV().alphas is not None(the dispatched value must survive).fit(...)on each, verifying sklearn 1.9 strict validation accepts the resultAssertionError: WeightedLassoCV.alphas was clobbered to None by __init__ (#1032)test_linear_model.py+test_dml.py+test_treatment_featurization.pyon sklearn 1.9.