The Obi-Wan Skywalker Algorithm
This Week on Gradient Ascent:
If knowledge distillation were Star Wars 💫
Why skip connections are awesome - The doodle edition 🎨
[Resource] High-quality image generation: Explained 🖼️
[Resource] Convert code into equations using code ➕
[Paper] The universal optimizer we've been waiting for 🤖
An Introduction To Knowledge Distillation - The Obi-Wan Skywalker Algorithm:
Close your eyes and imagine for a moment that you're in Tatooine. You're sitting next to a young Luke Skywalker and C3PO, searching for R2D2. Yes, the room inside the vehicle is a bit cramped, but hey, you're in the Star Wars saga so you'll take what you get. The desert air blows across your face as arid peaks and valleys whiz by. You squint into the vast expanse as Luke asks you if you see anything. Finally, you spot R2D2. But before you can celebrate, Tusken Raiders attack, and your crew are hopelessly overpowered. Remember, this is Luke before he knew what the "Force" was. But wait! Out of nowhere, a hooded figure comes to save the day, driving away the pesky Raiders and rescuing you all.
I'll stop with the blow-by-blow account of Star Wars here, but that hooded figure becomes Luke's mentor, Obi-Wan Kenobi. He then teaches Luke the ways of the "Force" and we all know what happened thereafter.
Have you ever wondered if Luke would have discovered the ways of the "Force" if not for Obi-Wan? Imagine the plight of the galaxy if that were the case. The Jedi way of teaching from master to apprentice is potent. For one, it helps the padawan skip over many years of mistakes and false starts and become one with the Force sooner. The student also receives continuous feedback to improve and become a better Jedi.
Haven't you seen this in your own life? When you learn a new skill from someone who's ahead of you, you learn much faster than if you self-study. Wouldn't it be awesome if we could use this principle in machine learning? Turns out, we can.
This technique is called Knowledge distillation.
What is knowledge distillation?
Knowledge distillation is a clever way of transferring knowledge from a computationally expensive teacher model to a smaller, lightweight student model. The large teacher model is really good at the task you want to solve. But it's really good for a reason. Consequently, you can't deploy it in an app or use it on mobile (edge) devices, or in situations where resources are constrained. Instead of giving up, you can use knowledge distillation to transfer some of this hard-earned knowledge from this large model over to a smaller model that you know can work in all of the scenarios above.
Do you know what's crazy though? The smaller model probably wouldn't have succeeded on its own. But when you perform knowledge distillation, you get the best of both worlds in that you get a very lightweight but smart model.
Another way to think about knowledge distillation is to consider it as a variation of compression, except that we compress by transference as opposed to pruning and removing neurons from an existing model.
Components of knowledge distillation
There are three components to knowledge distillation - The teacher, which is this large model that you meticulously trained. Alternatively, this could also be a very good pretrained model you found somewhere. The second component is the student. This is the lightweight model that you plan to use eventually but isn't good enough yet and needs to learn. What does lightweight mean here? Well, it could be a shallower version of the teacher model with fewer layers and neurons per layer, a quantized version of the teacher, a different smaller network, and so on. The third and final component is the method used to transfer the knowledge over to the student model.
Understanding how knowledge distillation works
Let's go back to our Star Wars analogy - If Luke had to learn how to use the Force on his own, he probably would have after a lot of trial and error (after all he's one of the main heroes in the Star Wars saga). But think of how long it would've taken him to really understand the Force and get so good to the point where he could save the day multiple times. Obi-Wan, the teacher, on the other hand, is very experienced and knows the Force in and out. Through his mentorship, Luke's learning was accelerated significantly.
Luke got two sources of feedback. Firstly, he felt a connection with the Force. In machine learning terms, that would be the ground truth that the model doesn't see but that we use to measure its performance. The second source of feedback for Luke is from Obi-Wan telling him whether he's using the Force well or not. This is complimentary feedback to the first that he can use to improve his Jedi skills. In knowledge distillation terms, this is the signal from the teacher model to the student.
Naturally, the question is how does Obi-Wan give Luke feedback? For that, we need to first understand what kind of knowledge is worth transferring over from mentor to mentee.
Types of knowledge
There are broadly two kinds of knowledge that can be used as feedback signals. The first is just the outcome itself. In this case, the teacher and student would both be asked to do the same thing, and we'd compare their results. The student would then learn how different their result was from their teachers. Now, a special case of this is when we need to perform complex tasks. Instead of just learning from the final result, intermediate results could also be used as feedback signals. For example, Obi-Wan could either perform a Jedi mind trick ("You don't need to see our identification") as Luke watches, or explain each part of the process to him.
In knowledge distillation terms, these are called Response-based knowledge (just the outcome) and Feature-based knowledge (the intermediate results too).
The second broad kind captures the relationship between the intermediate steps - How the steps tie together. This is called Relation-based knowledge, and, in knowledge distillation parlance would be the relationships between feature maps of a network.
Modes of knowledge distillation
Now that we know the types of knowledge worth transferring, how can we transfer the knowledge over?
The first way is offline distillation, where you have a teacher model already pre-trained and available for you and only the student learns. The second is called online distillation. This is a rarer form in which you don't have access to a pre-trained model. Therefore, you have to train the teacher yourself. While doing so, you can also set up the student model side by side, and train both. The third is self-distillation, where the same model is used as teacher and student. Here, deeper layers of the model can teach earlier layers, or, a model from earlier epochs can be used to train a model deeper in its learning journey. Very meta, I agree.
Knowledge distillation in the wild
Here are two notable stories where knowledge distillation was applied IRL. The first is DistilBERT where knowledge from a HUGE language model, BERT, was distilled successfully into a smaller model. This resulting model was 40% smaller (66M vs 110M parameters) and faster by 60% than its teacher. What's even more impressive is that this smaller model was 97% as good as the teacher in specific tasks. The second is actually knowledge distillation used in Amazon Alexa, where a teacher-student model combination helped label over 1 million hours of unlabeled speech from just 7000 hours of labeled examples.
So the next time you're in a pinch with compute resources but really want a good model to use, think of knowledge distillation.
Poorly Drawn Machine Learning:
Learn more about skip connections here
Papers & Resources To Consider:
Understanding VQ-GAN
This is a really good video that explains how VQ-GAN (Vector Quantized Generative Adversarial Networks) work. Before diffusion models came in and completely took over image generation, GANs were the best thing since sliced bread. If you're curious about how this GAN works, watch the video below. It's really well explained.
Use code to convert code to equations
Code: https://github.com/google/latexify_py
Latexify is a python package that's really nifty at converting code into equations. This is an excellent resource for two reasons - 1) As practitioners, we sometimes struggle with the reverse problem, i.e., converting equations to code. This is great to use as a learning tool to see how code maps back into the math you see in papers. 2) Also, this is great for quickly converting your prototypes into technical reports or papers without worrying about the right LaTeX syntax. Look at the image below to see how it works:
One optimizer to rule them all
Paper: https://arxiv.org/abs/2211.09760v1
Code: https://velo-code.github.io/
One of the most annoying parts of training a model is figuring out how to get the last drop of juice out of it. This is called hyperparameter tuning in deep learning speak. One of the many knobs that we can tweak as practitioners and see meaningful results is the optimizer. In this paper, the authors have trained an optimizer that outperforms hand-designed ones. See this Twitter thread below which explains the idea (The code and paper are linked above if you want to take this for a spin).