What is Learning Rate?
Ever got confused with the working of learning rate and it’s significance. Let’s try to take that off of your plate today!!!
Whenever you read about learning rate, the first image you come across in the internet is the graphical representation on the nature of the learning rate with respect to your model performance. Let’s explore the Why’s and How’s of this graph.
What is learning rate?
Learning rate is a hyper-parameter that controls how much we are adjusting the weights of our network with respect to the loss gradient. The learning parameter changes as
New weight = Existing weight — Learning rate * Gradient
Let’s look deep on the terms in the above equation by answering few questions.
What do you want your model to do while training?
The main aim of the Neural network training is to attain the global minimum of the cost function using an optimizer. In each step of your training, you find the change/slope/gradient needed to reach the local minimum ultimately to attain the global minimum. A term called “learning rate” is multiplied to the change/gradient; so you reach the global minimum at the rate of the speed that the multiplicative factor (learning rate) produces.
If learning rate is too low — Reaches global minimum very late resulting in higher training time and lower speed.
If learning rate is too high — Will never reach the global minimum as you keep jumping due to taking larger steps every time.
Seems easy with the theory, but how do you know whether a learning rate is good or bad before completely training the model.
General ways on setting learning rates
You could estimate a good learning rate by training the model initially with a very low learning rate and increasing it (either linearly or exponentially) at each iteration. If we record the learning at each iteration and plot the learning rate against loss; we will see that as the learning rate increases, there will be a point where the loss stops decreasing and starts to increase. In practice, our learning rate should ideally be somewhere to the left to the lowest point of the graph. Refer the below image to know the optimal learning rate.
Thumb Rule : The best learning rate is associated with the steepest drop in loss
Now, we know how to set a learning rate based on a model and the data. Mostly we incorporate the use of Transfer learning — using a pre-trained model for your application. A pre-trained model would have been trained on a massive data with it’s own learning rate and optimization technique. Retraining that model for your data might manipulates the learning of that model.
How to set a right learning rate for a pre-trained model?
To know which is a right learning rate, Lets see what is the impact of learning rate with respect to learnt weights. High learning rates increase the risk of losing previous knowledge because higher the learning rate more changes with the learnt weights (pretrained) and so it’s smart to use a small learning rate. Assuming that the pre-trained model has been well trained, keeping a small learning rate will ensure that you don’t distort the CNN weights too soon and too much.
Another way of using learning rate is differential learning- It is a method where you set different learning rates to different layers in the network during training. How do you do that efficiently without distorting the pretrained weights much
Initial layers- small learning rate (as initial layers must have learnt high level features and please don’t disturb that)
Middle layers — relatively large learning rates (making your model more certain to your data capturing the detailed features of the data )
Latter layers — high learning rates (obviously, you want your prediction with high accuracy)
Now I hope you can play with the learning rate in the right way. Be it your own model or making use of a pre-trained model.
Coming to the main aim of this blog- Did you figure out what exactly the graph we saw initially tell us?
If yes, Great !!!
If not, Let me list out them; All you can see is four different behaviours of the learning rate.
Very high learning rate-Why? The loss is touching the sky which means you will never reach the global minimum.
High learning rate-Why? The loss is decreasing but saturates after a while which means you are reaching towards the global minimum in the faster pace and then you couldn’t reach it because you need to slow down there.
Low learning rate- Why? The loss is decreasing but not at slower pace.
Good learning rate- Why? The loss is an exponentially decreasing with respect to epochs and yes that’s what you want.
Hope this helps!!!. In the next blog, let’s take a deep dive with the complex learning rate behavior
Spoiler Alert: What if the loss is saturated after a while and decreasing nor increasing the learning rate helps?
Keep Thinking and don’t forget to clap!!!
References:
- Goodfellow, I., Bengio, Y., Courville, A., & Bengio, Y. (2016). Deep learning (Vol. 1, №2). Cambridge: MIT press.
- https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0
- https://towardsdatascience.com/transfer-learning-using-differential-learning-rates-638455797f00