Skip to content

Commit 1a5245c

Browse files
adrinjalalithomasjpfan
authored andcommitted
FIX clone to handle dict and GridSearchCV changing original params (scikit-learn#26786)
Co-authored-by: Thomas J. Fan <[email protected]>
1 parent 0b359d7 commit 1a5245c

File tree

6 files changed

+64
-14
lines changed

6 files changed

+64
-14
lines changed

doc/whats_new/v1.4.rst

+12
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ Changelog
5151
passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin
5252
Jalali`_.
5353

54+
- |Enhancement| :func:`base.clone` now supports `dict` as input and creates a
55+
copy. :pr:`26786` by `Adrin Jalali`_.
56+
5457
:mod:`sklearn.decomposition`
5558
............................
5659

@@ -85,6 +88,15 @@ Changelog
8588
to :ref:`metadata routing user guide <metadata_routing>`. :pr:`26789` by
8689
`Adrin Jalali`_.
8790

91+
:mod:`sklearn.model_selection`
92+
..............................
93+
94+
- |Fix| :class:`model_selection.GridSearchCV`,
95+
:class:`model_selection.RandomizedSearchCV`, and
96+
:class:`model_selection.HalvingGridSearchCV` now don't change the given
97+
object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin
98+
Jalali`_.
99+
88100
:mod:`sklearn.tree`
89101
...................
90102

sklearn/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def _clone_parametrized(estimator, *, safe=True):
8080
"""Default implementation of clone. See :func:`sklearn.base.clone` for details."""
8181

8282
estimator_type = type(estimator)
83-
# XXX: not handling dictionaries
84-
if estimator_type in (list, tuple, set, frozenset):
83+
if estimator_type is dict:
84+
return {k: clone(v, safe=safe) for k, v in estimator.items()}
85+
elif estimator_type in (list, tuple, set, frozenset):
8586
return estimator_type([clone(e, safe=safe) for e in estimator])
8687
elif not hasattr(estimator, "get_params") or isinstance(estimator, type):
8788
if not safe:

sklearn/model_selection/_search.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -923,11 +923,14 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
923923
self.best_params_ = results["params"][self.best_index_]
924924

925925
if self.refit:
926-
# we clone again after setting params in case some
927-
# of the params are estimators as well.
928-
self.best_estimator_ = clone(
929-
clone(base_estimator).set_params(**self.best_params_)
926+
# here we clone the estimator as well as the parameters, since
927+
# sometimes the parameters themselves might be estimators, e.g.
928+
# when we search over different estimators in a pipeline.
929+
# ref: https://github.com/scikit-learn/scikit-learn/pull/26786
930+
self.best_estimator_ = clone(base_estimator).set_params(
931+
**clone(self.best_params_, safe=False)
930932
)
933+
931934
refit_start_time = time.time()
932935
if y is not None:
933936
self.best_estimator_.fit(X, y, **fit_params)

sklearn/model_selection/_validation.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -726,14 +726,11 @@ def _fit_and_score(
726726
fit_params = _check_fit_params(X, fit_params, train)
727727

728728
if parameters is not None:
729-
# clone after setting parameters in case any parameters
730-
# are estimators (like pipeline steps)
731-
# because pipeline doesn't clone steps in fit
732-
cloned_parameters = {}
733-
for k, v in parameters.items():
734-
cloned_parameters[k] = clone(v, safe=False)
735-
736-
estimator = estimator.set_params(**cloned_parameters)
729+
# here we clone the parameters, since sometimes the parameters
730+
# themselves might be estimators, e.g. when we search over different
731+
# estimators in a pipeline.
732+
# ref: https://github.com/scikit-learn/scikit-learn/pull/26786
733+
estimator = estimator.set_params(**clone(parameters, safe=False))
737734

738735
start_time = time.time()
739736

sklearn/model_selection/tests/test_search.py

+30
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
make_multilabel_classification,
2323
)
2424
from sklearn.ensemble import HistGradientBoostingClassifier
25+
from sklearn.experimental import enable_halving_search_cv # noqa
2526
from sklearn.impute import SimpleImputer
2627
from sklearn.linear_model import LinearRegression, Ridge, SGDClassifier
2728
from sklearn.metrics import (
@@ -38,6 +39,7 @@
3839
GridSearchCV,
3940
GroupKFold,
4041
GroupShuffleSplit,
42+
HalvingGridSearchCV,
4143
KFold,
4244
LeaveOneGroupOut,
4345
LeavePGroupsOut,
@@ -2420,3 +2422,31 @@ def test_search_cv_verbose_3(capsys, return_train_score):
24202422
else:
24212423
match = re.findall(r"score=[\d\.]+", captured)
24222424
assert len(match) == 3
2425+
2426+
2427+
@pytest.mark.parametrize(
2428+
"SearchCV, param_search",
2429+
[
2430+
(GridSearchCV, "param_grid"),
2431+
(RandomizedSearchCV, "param_distributions"),
2432+
(HalvingGridSearchCV, "param_grid"),
2433+
],
2434+
)
2435+
def test_search_estimator_param(SearchCV, param_search):
2436+
# test that SearchCV object doesn't change the object given in the parameter grid
2437+
X, y = make_classification(random_state=42)
2438+
2439+
params = {"clf": [LinearSVC(dual="auto")], "clf__C": [0.01]}
2440+
orig_C = params["clf"][0].C
2441+
2442+
pipe = Pipeline([("trs", MinimalTransformer()), ("clf", None)])
2443+
2444+
param_grid_search = {param_search: params}
2445+
gs = SearchCV(pipe, refit=True, cv=2, scoring="accuracy", **param_grid_search).fit(
2446+
X, y
2447+
)
2448+
2449+
# testing that the original object in params is not changed
2450+
assert params["clf"][0].C == orig_C
2451+
# testing that the GS is setting the parameter of the step correctly
2452+
assert gs.best_estimator_.named_steps["clf"].C == 0.01

sklearn/tests/test_base.py

+7
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ def test_clone_nan():
184184
assert clf.empty is clf2.empty
185185

186186

187+
def test_clone_dict():
188+
# test that clone creates a clone of a dict
189+
orig = {"a": MyEstimator()}
190+
cloned = clone(orig)
191+
assert orig["a"] is not cloned["a"]
192+
193+
187194
def test_clone_sparse_matrices():
188195
sparse_matrix_classes = [
189196
cls

0 commit comments

Comments
 (0)