Whoever is out there, working with Machine Learning models, overfitting must be a challenge for us. We can overcome this challenge by going out and collecting more data. But this can be costly, time-consuming, or sometimes even impossible for individuals! So, what do we do? We can follow some regularization techniques. Weight Decay is one of the widely used regularization techniques. Let’s start by recapping the overfitting problem.
Overfitting (Hight Variance)
Let’s say we have some data points like on the top-left of the figure above. Now, we want to develop a model which will give us the best fit line. The top-right figure indicates that we can not fit the data points well with a straight line. We can say a straight line to be a 1st degree polynomial, i.e. \(y = ax + b\). So, now we try to fit the data with a 2nd degree polynomial i.e. \(y = ax^2 + bx + c\), and it can fit the data points quite well. Now, if we use more higher degree polynomial i.e. for \( n > 2\), \(y = a_1x^n + a_2x^{n-1} + … + a_{n-1}x + a_n\), the model becomes too complex and tends to overfit like the bottom-right figure.
This indicates that we can limit the complexity of our model in order to prevent overfitting. Indeed, this is a popular technique. However, simply reducing the complexity might not be an ideal scenario, especially in deep learning. This might reduce the capability of a model to solve complex problems. So, how do we fight the overfitting problem? Weight decay can be the weapon we are looking for!
The Weight Decay
Weight decay is commonly known as L2 regularization and is based on the L2 norm. As we recall, the norm of a vector is the indication of how big the vector is. We can also view it as a measure of distance. If we remember Euclidian distance, it is called the L2 norm. We can get L2 norm of a vector \( \mathbf{x} \) as: \[ || \mathbf{x} ||_2 = \sqrt{\sum_{i=1}^{n}x_i^2} \]
Very often we omit the subscript 2. Thus \( || \mathbf{x} ||_2 \) is analogous to \( || \mathbf{x} || \).
Motivation behind weight decay
This technique is motivated by a basic intuition stated as:
Among all functions \( f \), the function \( f = 0 \) is the simplest one in some sences. So, we can measure the complexity of a function by its distance from zero.
But there is no precise way to measure the distance of a function and zero. For simplicity, we can measure the complexity of a linear function \( f(\mathbf{x}) = \mathbf{w}^T\mathbf{x} \) by some norm of its weight vector. The most common method to ensure a small weight vector is to add its norm as a penalty to minimize the loss. Thus now we replace our original objective, minimizing the prediction loss to minimizing the sum of prediction loss and penalty term. Thus, if the weight vectors grow too large, our model may focus to minimize the weight norm rather than minimizing the training error and that’s exactly what we want.
Landing into Weight Decay
Let’s consider we have a model with weights \( \mathbf{w} \), bias \( b \) and loss function: \[ L(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^{n}\frac{1}{2}(\hat{y}_i – y_i)^2 \] Here \( \hat{y}_i \) is the predicted output, calculated as \( \hat{y}_i = \mathbf{w}^T\mathbf{x}_i + b \). Now, we can add the L2 square of the weight vector, \( \mathbf{w} \) to the loss function as a penalty term. Thus the new loss function will be: \[ L(\mathbf{w}, b) = \frac{1}{n} \sum_{i=1}^{n}\frac{1}{2}(\hat{y}_i – y_i)^2 + \frac{\lambda}{2}||\mathbf{w}||^2 \]
Here \( \lambda \) is the regularization constant which lets us make a tread off between the standard loss and this modified loss. Now, two questions might pop up in our minds. 1) Why do we use the L2 norm? and 2) why do we square it? Well, a reason to work with the L2 norm is that it places an outsize penalty on large components of the weight vector. This biases our learning algorithm to distribute that weight evenly across a large number of features. We square the norm for computational convenience. This removes the square root and makes computing the derivative easy. Finally, we divide the output by 2 by convention to remove the extra 2 we get after the derivation.
Now, we can update the new weights with the following equation: \[ \mathbf{w} = ( 1 – \eta \lambda )\mathbf{w} – \frac{\eta }{| \beta |} \sum_{i\in \beta }\mathbf{x}_i (\hat{y}_i – y_i) \] Here \( \eta, \beta, \) and \( | \beta | \) are learning rate, batch and batch size gradually. Now we not only update the weights by the difference of our prediction and target, but also shrink the size of \( \mathbf{w} \) towards zero. That is why this method is called weight decay.
Implementations
As we have gathered enough basic and mathematical intuitions, now let’s move towards implementation. At first, we will try to implement it from scratch, then we will find the impact of weight decay. At the very beginning, let’s create the sample dataset and get the model.
Now, we define the train function. Here we pass the lambda as a parameter and add the penalty term to the loss.
Now, at first, we will train the model with no weight decay. We can achieve this by passing lambda = 0. We have also reported the L2 norm of the weight at the end of training.
Finally, we will train the model with weight decay. For this case, we have passed lambda = 2.
So, weight decay can successfully keep the weight from growing large and prevent overfitting. Weight decay provides a continual process for adjusting a function’s complexity. Smaller values of \( \lambda \) correlate to a less constrained \( \mathbf{w} \), while larger values of \( \lambda \) constrain \( \mathbf{w} \) more strongly.
References
During learning and writing this post, the following resources provided me a great help:
- Dive into Deep Learning
- This thing called Weight Decay
That’s all for this post on Weight Decay. I am a learner who is learning new things and trying to share with others. Let me know your thoughts on this post. Get more machine learning-related posts here.