Linearity test your data- [Short Post]

Often times the choosing a simpler model makes it easier to debug and interpret, but not always are we looking at a straight forward metric that helps us make that decision. Linear models have a clear outline on how to interpret them, which is something we trade off over a more complex model (not necessarily, but mostly), while this is the main reason that motivated me to write this post, it is also very relevant to understand, it doesn’t make sense to fit a complex non-linear model on a linear dataset anyway. That being said, let’s understand how to go about understanding our distribution of data and making that decision.

Step 1: Fit a linear model (Linear regression etc.)

Just a naïve pass at fitting your data on an out-of-box linear model. Needless to say, follow the usual drill, get your data into train and test, fit over your train data.

                                          model = LinearRegression().fit(X_train, y_train)

Step 2: Predict on test

Self explanatory, predict on your test.

                                          Y_pred = model.predict(X_test)

Step 3: Plot residues

Now it’s as simple as plotting your predictions on one axis to the ground_truth in another.

From the plot above there are a couple of inferences that can be made,

  1. Non-Linear – The data is non-linear if there are visible patterns in your 2d plot. In my graph (shown above), we can see there is a visible connection between the ground_truth and predictions, this simply means there is a correlation that is there to be learnt, but the linear model doesn’t capture those connections, it requires a non-linear model to attempt to solve the same.
  2. Linear – The data is linear if there is no visible pattern. If the residues are sparsely distributed on either side of a line (which we can fit over this distribution), then it’s simply a matter of adding more data.

That’s it. Bye now.

Leave a comment

Create a website or blog at WordPress.com

Up ↑