Try Using A Stratified Split

Datascience George
2 min readSep 7, 2020

By George Bennett

Often times data is not completely balanced. This may be with the classes or any other feature. If there are major differences in the data, typically categorical data, than they should be accounted for when splitting the data. Say I was building a machine learning model to help detect heart disease. If my data was twenty-five percent woman and seventy-five percent men, and I were to use a normal train test split, there is a chance the test set will contain very few men!

This is just the target variable (which this happens often to), but this can occur with any variable. The solution is to do what is called stratification. Say a statistician wanted to deploy a survey to customers of a store. ninety percent of the stores sales are in person and ten percent come from online. The statistician would call these two groups “stratum”, and would make sure to give out the surveys proportionately to each group. If online customers got eighty percent of the surveys, then that would introduce a lot of bias. Likewise not splitting you data correctly for machine learning will create bias and underfitting in your models.

To put this into practice, you should use scikit-learn’s “StratifiedShuffleSplit”. I will outline below how to use it. First create an instance of the “StratifiedShuffleSplit” class. Call the “split” method on the instance and take your dataset (as a pandas dataframe) and pass it in as the “X” argument, the target variable in as the “y” argument, and pass in the variable you would like to create stratum from as the “groups” argument. Call the “next” function on that method call and you will recieve two arrays of numbers. These are indices. Simply use these indices on the original dataframe and you are left with your training and testing sets.

import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
splitter = StratifiedShuffleSplit()
train_indices, test_indices = next(splitter.split(df, df.target,
groups=df.gender)
train = df.iloc[train_indices]
test = df.iloc[test_indices]

--

--