Scikit-learn compatibility¶
All estimators implemented in gwlearn are designed to provide scikit-learn compatible API. However, that is occasionally tricky due to the need of geometry as an additional argument in fit, predict or score. Luckily, there is a way thanks to scikit-learn’s (experimental) support of metadata routing.
import geopandas as gpd
from geodatasets import get_path
from sklearn import set_config
from sklearn.model_selection import GridSearchCV
from gwlearn.linear_model import GWLinearRegression
The metadata routing is not enabled by default, hence if you’d like to use it, you need to enable the configuration.
set_config(enable_metadata_routing=True)
You can then initialise the estimator class and use .set_* methods to indicate that the geometry argument shall be routed alongside X and y.
gwlr = GWLinearRegression(fixed=False, keep_models=True, bandwidth=25)
gwlr.set_fit_request(geometry=True)
gwlr.set_predict_request(geometry=True)
gwlr.set_score_request(geometry=True)
GWLinearRegression(bandwidth=25, keep_models=True)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
| bandwidth | 25 | |
| fixed | False | |
| kernel | 'bisquare' | |
| include_focal | True | |
| graph | None | |
| n_jobs | -1 | |
| fit_global_model | True | |
| strict | False | |
| keep_models | True | |
| temp_folder | None | |
| batch_size | None | |
| verbose | False |
A model prepared like this can be consumed by sklearn’s Pipeline or GridSearchCV.
gdf = gpd.read_file(get_path("geoda.guerry"))
gs = GridSearchCV(gwlr, {'include_focal': [True, False]}, cv=2)
gs.fit(
gdf[['Crm_prp', 'Litercy', 'Donatns', 'Lottery']],
gdf["Suicids"],
geometry=gdf.representative_point()
)
GridSearchCV(cv=2, estimator=GWLinearRegression(bandwidth=25, keep_models=True),
param_grid={'include_focal': [True, False]})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
GWLinearRegression(bandwidth=25, keep_models=True)
Parameters
| bandwidth | 25 | |
| fixed | False | |
| kernel | 'bisquare' | |
| include_focal | True | |
| graph | None | |
| n_jobs | -1 | |
| fit_global_model | True | |
| strict | False | |
| keep_models | True | |
| temp_folder | None | |
| batch_size | None | |
| verbose | False |
You can then retrieve the outcome.
gs.best_estimator_.include_focal
True
This way, gwlearn estimators can be included in your workflows with ease.
It is recommended to disable metadata routing once you’re done.
set_config(enable_metadata_routing=False)