Machine learning basics: the cost function

Machine learning is ultimately a way to make a program perform a task and to get that task done better over time. Cost functions define how good or bad a program is at performing such task, pretty much every problem consist on getting the value of the cost function to be as small as possible.

For our example, we will use a very simple dataset which consist on two variables: Car speed and distance to stop, our ultimate goal will be to, given a speed we have never seen before, predict what will be the distance to stop.

Lets define some common vocabulary:

  • \(X\) : These will be the observations, in our case it will represent the car speed.
  • \(y\): The correct answers to our observations, in this case the distance.
  • \(\hat{y}\) : Our own predictions given an \(X\)

Notice that all of the values above are actually vectors, or if you prefer, lists (possibly a more friendly term for a developer), this means that each of them can be accessed by indexes, such as


This takes us to define another element

  • \(n\): The total number of observations, in this case this means how many elements we have in \(X\) and \(y\)

Now, this is the data we are going to work with

Speed (X)
Distance (y)
4 2
7 4
8 16
9 10
10 18
11 17
12 14
13 26
14 26
15 20
16 32
17 32
18 42
19 36
20 32
22 66
23 54
24 70
25 85

We have a total of 19 observations here, now lets plot them

The distance required to stop a car depending on its speed.

With all this in our hands, we can start defining our cost function. A good intuition would be to say that our cost function is simply the difference between our predictions and the actual value.

For example at speed \(15\)km/h we need \(20\) meters to stop. Imagine that we have

  • \(ModelA\) that predicts that \(25 \) meters are needed, the error would be \(25 – 20 = 5 \).
  • \(ModelB \) that predicts that we need \(21\) meters to stop then the error would be \(22 – 20 = 2\) which is already a smaller error than the previous.
  • \(ModelC \) that predicts that we need \(19\) meters to stop? then the error would be \(19 – 20 = -1 \) this is a bit weird, as we want to make our error be close to 0, not to be negative. The solution for that would be to use a squared error instead, so that way the result would always be positive, lets recalculate the errors using squares.
  • \(ModelA = (25 – 20)^2 = 25 \)
  • \(ModelB = (22-20) ^ 2 = 4 \)
  • \(ModelC = (19-20) ^ 2 = 1 \)

More generally we can simply say \( Error = (\hat{y_i} – y_i)^2 \)

With this we can quickly conclude that the best model is the one with the smallest value for the cost function, in this case, that would be \(ModelC\).

The next step is to apply this to every point in the problem, our model should be able to predict what is the distance required to stop for any given speed, and we should be able to calculate the error of such prediction. The solution? apply exactly the same logic but to the whole set of data. As we mentioned before, distance and speed are both vectors, so we can simply do

\( Error =  (\hat{y_1} – y_1)^2 + (\hat{y_2} – y_2)^2 … + (\hat{y_n} – y_n)^2 \)

Or if we want to use a better mathematical term

Error = \sum_{n=1}^n (\hat{y}_n – y_n)^2

Do not let the math intimidate you, the term \( \sum_{n=1}^n\) is just a loop over the elements of vectors.

We cannot simply keep adding the terms, think about it, this means that if we have a dataset with a lot of observations, our error will grow as we have more observations, the solution for that is to use the mean error instead, so lets add that to our formula.

Error = \frac{1}{n} \sum_{n=1}^n (\hat{y}_n – y_n)^2

So now, what we have is the mean of all the squared errors, this function is surprisingly called “Mean Squared Error” or simply \(MAE\) and it will be an important concept for the rest of this post

MSE = \frac{1}{n} \sum_{n=1}^n (\hat{y}_n – y_n)^2

Making predictions

Now, we have been talking a lot about \(\hat{y}\) but how can calculate it? in linear regression this is done by applying a simple formula

\( \hat{y_i} = wX_i + b \)

if we want to generalize it we can simply say

\( \hat{y} = wX + b \)

This introduces 2 new values \(w\) and \(b\)

  • \(w\) : Represents the weight that we need to calculate, this is the value by which we will multiply \(X\).
  • \(b\): Represents the bias, we will simply add this term, and we will NOT relate it to \(X\)

    An example will make this more clear, lets say \(w=-1, b=10\)

This shows a terrible prediction, our red line (that is, our model) does not align at all with our actual observations. The interesting part here is to quantify how bad it is, in order to do so, lets just have a look at the first \(5\) datapoints so we can do all calculations by hand.

Speed Distance predict
4 2 6
7 4 3
8 16 2
9 10 1
10 18 0

We will take the data points where \(i=1\) that is, the first row. So

\(X_1=4; y_1=2; \hat{y_1}=6\) so \(Error_1 = (\hat{y_1} – y_1)^2 = 16\)

If we apply \(MSE = \frac{1}{n} \sum_{n=1}^n (\hat{y}_n – y_n)^2\) we get \(MSE = 1/5 + (6-2)^2 + (3-4) ^2 + (2-16)^2 + (1-10) + (0-18)^2 = 123.6\)

Now, lets consider another model where \(w=3; b=-12\), then we get this

It is already obvious that this model is much better at predicting the distance, however the question is how much better? again the answer lies in \(MAE\), the values are

Speed Distance predict
4 2 0
7 4 9
8 16 12
9 10 15
10 18 18

So we can again calculate \( MSE = 1/5 + (0-2)^2 + (9-4) ^2 + (12-16)^2 + (15-12) + (18-18)^2 = 10.8 \)

This gives us a critical information, not only we can figure out which model is better, we can also quantify how much better the model is, and that becomes very important, imagine for example how relevant this could be for autonomous driving.

Cost functions for other problems.

\(MSE\) is a good cost function, but it only helps us for regression problems, that means problems where our ouput is a number, for example predicting how warm a day would be based on some variables or predicting what will be the value of a security in the stock market.

However there are many problems where we want to classify values, an example would be to know whether or not a car can stop completely or if it would have an accident, in this scenario \(MSE\) does not help us, we need another cost function.

The logistic function.

For binary classification problems where our output can only take two possible values, we want to use this little function \(logistic = \frac{1}{1+e^{-\hat{y}}}\) It does not look very intuitive, but if we actually plot it, we get.

Logistic sigmoid function

The interesting thing about this function is that it takes values between 0 and 1, so we can apply a similar measure to the error by simply comparing \(y\) with our \(\hat{y}\) which will take values from 0 to 1 while \(y\) will either be 0 or 1.


Cost functions are at the core of understanding machine learning as they ultimately provide a measure of success for a given model, they are also at the center of fundamental algorithms such as gradient descent.

It really helped me in the early day to calculate some of the functions by hand to fully understand their meaning.

There are many other cost functions that one needs to be aware, but these two are the core ones to start with. I strongly recommend going through a couple of examples with \(MSE\).

Happy coding.

Why developers need to understand basic math.

These days coding has become easier than ever, some languages such as python or ruby are great to dive into coding, there are even programming languages for kids like scratch. Such easiness made some people think that one can go into coding without really understanding math… I disagree, at least for those who want to code at a professional level.

Before you jump on me, keep reading.

Back in 2001 I was in my first university year back in Madrid, and we had a professor who insisted in one idea

Developers do not need to be expert on calculus, but they do need a high level of basic math.

To illustrate this, I will add a couple of examples.

How can you code an algorithm that counts the numbers from 1 to a 1000?

I still remember that all of us said, “well just use a for loop and a counter”, and for the record, that is what I would expect if I ask this question to anyone. However it turns out there is much, much better way to do this.

What’s 1000+1? and then, what’s 999+2? and 998+3? Easy right? Each of the sums is 1001, and there are a total of 500 pairs, so the answer is 500500.

It turns out that the first person who realized this was Gauss, apparently back in the day, whenever the teachers wanted to be left alone, they will assign this kind of work to the kids, turns out Gauss figure this out while he was still a kid.

Gauss figured out an effective way to add the numbers 1 to N

Now, do not get me wrong, I do not expect that any developer would be a genius as Gauss was, but I do expect any developer to appreciate the incredible advantage of this method, essentially it allows to go from O(n) to O(1) running time.

Say you have 23 people in one room, what do you think it the chance than two of them have the same birthday?

This is an interesting problem, here I am not particularly interested in the math behind this (which has to do with stats), I am more interested on how our intuition can be terrible some times. What would be your guess? mine was around 10%. Well it turns out that with 23 people the chance of two of them having the same birthday is actually 50%

The probability of at least two people having the same birthday goes up much quicker than expected

Surprised? Well that’s actually the whole point, this is called the Birthday Paradox and it is used to show how bad human beings are at guessing probabilities.

So what’s the deal with this and programming? Essentially this can be applied to show how easy it is to have hash collisions, and that normally is related to security, so yes, this IS very relevant. This is called the Birthday attack, unsurprisingly, I have known other developers who experienced the huge pain of being bitten by this (essentially two users end up with the same login hash and they could see other’s users data).

Imagine you have a really large sheet of paper of standard width (0.1 mm), you fold it 50 times, what is the new width?

Lets do the math manually, after 1 fold we have a width of 0.2 mm, then 0.4, 0.8, 0.16, 0.32… We are dealing with an exponential function here, the base would be 2 tenths of mm, and the exponent would be 50, so we have 2^50 which is 1125899906842624 tenth of millimeter, we divide by 10 and we have millimiters, we divide then by 1000 and we will have meters, we divide by 1000 again and we will have kilometers and the answer is 112589990 kilometers, now think about one thing: the distance to the moon is 384400 Kilometers.

This mean that folding that hypothetically long enough sheet of paper will cover the distance from the Earth to the Moon 292 times!!

Your intuition (and mine) tells you that such a thing is simply not possible, do the math and prove yourself wrong. This particular example is fetched from a great talk from Eduardo Saenz de Cabezon “Las matematicas son para siempre (math is forever)” which is available in Ted.

This example shows how the exponential function can get really crazy really fast, and that is important for developers too.


None of the examples I provided requires advanced math or calculus, all of them can be explained by using the simplest operations, yet they yield important results and relevant lessons to think about: programming requires thinking about how things scale and math provides the tools to perform such thinking.

So in short: Do not be afraid of math, embrace it for it is the best companion of a good developer.