How to Decorate Your Scikit-learn Models Like a Christmas Tree

Christmas Tree © 2017 Fleur Duivis

Most machine learning applications have the same high-level architecture, with components for:

  • Ingesting data from any sources
  • Cleaning that data and computing new features
  • Applying one or more models on the refined data
  • Serve aggregates and model prediction results to end users

My experience from consultancy projects taught me that when applications like these are growing, new models are added that require new or different features, which in turn require new or different ways of cleaning and preprocessing. Time pressure and limited software engineering skills would often result in the cleaning - and feature computation component to become an ever-growing monolith of unstructured code, leading to several symptoms:

  • The team loses track of which feature is used by which model, often leading to redundancy and clutter (e.g. having multiple features/columns with the same or nearly-the-same data),
  • Obscure features that are only used by a single model are precomputed and stored for all incoming data instead of the subset of data that will be used by the actual model,
  • Different models require different (and sometimes conflicting) strategies for e.g. scaling or null handling of input data, leading to unnecessarily complex decision logic,
  • When moving models from analysis - to production environments, it becomes hard to make sure that all data preprocessing that was done in the analysis environment will be done exactly the same way on the production system;

At the moment, I'm creating an MVP for validating a startup idea based on analysis of sports activity data, so let's explore the issues above in some more detail using examples from that domain.

Global Features vs. Model-specific Features

My targeted users are (amateur) athletes that use GPS-based devices to track their running or cycling activities. Their devices generate streams of second-by-second measurements of signals such as location, heart rate, time, accumulated distance, etc. Our feature computation component is processing these streams and aggregating the stream data into activity-level features such as total distance or average heart rate.

It's not hard to imagine that an activity's distance is a major feature globally: It can be used by many models to make predictions and its distribution tells us interesting things about the user population. Therefore, it makes sense to compute this feature globally and store it for future reuse.

To the contrary, imagine an activity feature such as group size, i.e. how many other athletes participated in the same activity. For cycling activities, riding in a group can have a significant advantage due to greatly reduced aerodynamic drag. But this effect is not linear to the group size: the difference between riding alone or with 3 others is far greater than the difference between riding in a pack of 40 or 80 people. For this reason, it makes sense to apply a log transformation on the group size when it is used in certain models related to cycling.

For running activities, however, due to the lower speeds there is no significant advantage of being in a group, and as such no need at all for group size as a feature in models. In this case, it makes no sense to compute and store the log of the group size as a global feature.

The main challenge then is how to encapsulate all model-specific logic and reduce the clutter in the global feature computation component. It turns out that Scikit-learn's API design can greatly help in this task.

Decorating Scikit-learn Estimators

While using Scikit-learn for some time, it never occurred to me that implementing e.g. BaseEstimator and ClassifierMixin could have another reason than the implementation of some brand new ML algorithm. Until during some refactoring session, I suddenly made the connection with the decorator design pattern. Based on this pattern, we can also "decorate" Scikit-learn estimators. The base component would be one of the existing sklearn estimators, e.g. LogisticRegression. The decorator then both inherits from this base component (or implements the mixins), and has a reference to that estimator as an instance variable. A (hypothetical) example would be a classifier that detects which cycling activities are mountain bike rides:

import numpy as np

from sklearn.base import ClassifierMixin, BaseEstimator, TransformerMixin
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

# The decorator
class MTBDetection(BaseEstimator, ClassifierMixin):  # 1
    def __init__(self, estimator=LogisticRegression(), null_strategy='median'):  # 2
        self.estimator = estimator
        self.null_strategy = null_strategy
        self.extra_transformations = {  # 3
            'group_size': np.log2

    def _handle_nulls(self, X):
        # null handling logic, using self.null_strategy

    def _transform_columns(self, X):  # 3
        return X.assign(**{column: transform_function(X[column])
                           for column, transform_function in self.extra_transformations.iteritems()})

    def fit(self, X, y):
        # Optional: insert logic to configure our pipeline based on properties of X
        self.fitted_estimator_ = make_pipeline(  # 4
                'average_speed', 'distance', 'group_size'  # etc.
            FunctionTransformer(self._handle_nulls, validate=False),
            FunctionTransformer(self._transform_columns, validate=False),
        ).fit(X, y)

        return self

    def predict(self, X):
        return self.fitted_estimator_.predict(X)

# A stateful transformer 
class ColumnSelector(TransformerMixin):  # 5
    def __init__(self, columns):
        self.columns = columns

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        return X.loc[:, self.columns]
  1. By inheriting from BaseEstimator and ClassifierMixin and implementing the fit() and predict() methods, our decorator class implements the same "interface" (Python doesn't have real interfaces) as the other sklearn estimators. This means it can be used in a higher-level Pipeline, or in a GridSearchCV.
  2. As in the other sklearn estimators, all parameters used for fitting a model should be passed to the constructor with a default value. In this example, we have a parameter to indicate some hypothetical null handling strategy. Passing the "inner" estimator to our constructor allows us to use this decorator for different types of classifiers.
  3. This is just a trivial example to show how we can use the decorator to store potentially complex configuration for our model fitting and prediction logic. Note that Dataframe.assign() can only compute (new) columns in isolated steps. For chains of column transformations that depend on each other, we can consider using Dataframe.pipe()
  4. All preprocessing and model fitting operations are chained in a Pipeline. This ensures that when the model is trained, all the preprocessing operations (and potential state they might have) are stored together with the fitted model. This makes us less vulnerable to mistakenly using different preprocessing steps when applying the model at another time. In addition, parameters to be used in the pipeline could be dynamically decided upon at this moment, instead of making the caller of fit() responsible for that. This enhances loose coupling and makes the code easier to maintain and test.
  5. In such a pipeline, we can use anything implementing TransformerMixin when it's necessary to capture the state of the transformer (e.g. with a scaler that needs to remember the mean and standard deviation of the original data it was fitted with). The example above uses a trivial transformer for selecting columns. Alternatively, we can use a FunctionTransformer for stateless transformations.


Some issues I ran into using the above approach (not necessarily related to the decorator pattern):

  • When the preprocessing pipeline becomes more complex than in the example above, there is a chance that one or more transformers are changing the order of rows in the input data, e.g. due to grouping by or joining operations. While the order is not important when fitting the model, the predicted labels returned by predict() will not be in the same order as the observations. So always make sure that the order of observations is maintained within the pipeline.
  • When working in a Jupyter notebook to develop a model, the code for that model would probably be edited in an IDE and %autoreloaded into the notebook. After updating the model code and retraining the model, we'd then pickle it to ship it to the production system. It turns out that pickle has difficulties (and crashes with meaningless error messages) when serializing objects that had updated class definitions from (auto)reloads.


Applying the decorator pattern for encapsulating model-specific preprocessing greatly enhances correctness, maintainability, and testability of our code, as well as easy transportation of models between analysis - and production environments through serialization mechanisms such as pickle.

Stay up to date on the latest insights and best-practices by registering for the GoDataDriven newsletter.