Linear Regression
Starting to open the black box of machine learning
Non-Technical
Technical

Linear regression is a machine learning algorithm used to predict an output value based on a set of input values. For example, if we’re trying to predict housing prices during the pandemic, we might look at values of various houses for which we have data for.

Some variables of interest might be square footage, number of bedrooms, number of bathrooms, the inflation in the area, etc. In this case, the price of each house would be the dependent variable, or the variable we’re trying to predict. The remaining variables are known as the independent variables. These variables will help us try to predict the dependent variable.

This is what some sample data might look like.

Square footage# of bedrooms# of bathroomsDistance to nearest schoolCrime rate (per 100k population)Price of house
3,200 sqft321.5 mi98$400,000
1,800 sqft221.8 mi100$325,000
2,250 sqft210.8 mi67$350,000
1,200 sqft112.2 mi12$120,000
1,500 sqft222.4 mi45???

How can we best predict the price of the final house given the data for the preceding houses? is the question we're trying to solve.

Naturally, instead of a human diving into the data and trying to figure out the numbers, we’re delegating it to the machine that can do it much more efficiently and deterministically.

One might ask, how exactly does a computer achieve this? Believe it or not, the answer is quite intuitive, even without looking at any math.

Let’s take a simpler, two-dimensional example that we can graph. Say we have the following points that we’re trying to fit a line to.

Data points

To measure the best line to fit through all these points, we need some sort of numeric metric that’s objective and allows us to compare different lines. We are looking for a line that doesn’t necessarily go through all the points, or any for that matter. We want a line that, holistically, is the best fit for all the points, not just some.

Let's clarify what "best fit" means. If each data point is regarded as the ground truth, then the y-coordinate of the line is our predicted value for that specific x-value.

The line is a function that when given an x-value, will output a y-value. We want to make sure our predictions are as accurate to the entire data set as possible.

Which of the following lines is the better fit, i.e., predicts the data more accurately? How did you know?

Data points

Data points

Given that this is quite a stylized example, it’s quite easy and objective to determine which one is the better line.

What about this one?

Data points

Data points

This example is much more subjective. Thus, it’s imperative for us to have a metric to compare lines.

One common metric is called Ordinary Least Squares. One draws a line from each point (the original data point) straight up or down to the line (the predicted value). This distance represents the error of the model, i.e., the line.

Data points

Data points

For each point, we just take the predicted value (the y-coordinate of the red circles on the line) and subtract it from the actual value (the y-coordinate of the blue points).

Before we sum everything up to measure our error, we have to remember that some of the lengths will be negative when our predicted value is less than the actual value, i.e., the point is above the line.

A negative error doesn't make sense. If we add up all the lengths, we would be cancelling out some of our error, resulting in an inaccurate metric. We square each error before summing up across all points to accommodate for this.

Data points

Data points

Each square represents the error of its associated point. The larger the square, the further away the point is from the line.

Thus, the line with smaller squares overall is the line that tends to better predict the values from the original dataset; This is the line with less error, holistically.

After summing up the areas, we get that the one on the left has a combined area of 2.01, while the one on the right is only 1.75! We prefer the latter.

The next question you might be asking yourself is: Hey, neither of those above lines seem like the most optimal one. How can we calculate that?

This is where the machine learning algorithm called Gradient Descent comes in. An animation of the process is inserted below. Its second half is sped up because the algorithm makes smaller and smaller changes as the line gets closer and closer to optimal.

The line begins with a slope and y-intercept of zero.

This algorithm also extends to lines of higher degrees (again, sped up).

What's happening here is, given a starting line, we tell the computer to move the line in a direction that minimizes the sum of the errors (a.k.a. the combined area of the squares).

This step is done by defining a cost function that takes in the parameters defining the line and using gradients (yay calculus) to minimize the total error iteratively. This means after each iteration, the line is a little bit more accurate to the dataset. Over time, the algorithm should converge to the optimal weights.

Click the technical link above to learn more! Beware, there will be math involved.