Spaces:
Running
Running
WEBVTT | |
00:00.000 --> 00:15.280 | |
Hi, I am Mohamed Pezeshki from Mila and today I am going to talk about creating starvation. | |
00:15.280 --> 00:22.280 | |
This is a joint work with Omar Kaba, Joshua Bengio, Aaron Korvel, Doina Prikop, and Guillaume | |
00:22.280 --> 00:23.280 | |
Lajra. | |
00:23.280 --> 00:25.480 | |
Let me start with a story. | |
00:25.480 --> 00:32.000 | |
Back in 1904, there was a horse called Hans and people believed that he could do arithmetic. | |
00:32.000 --> 00:36.440 | |
Here is an article from New York Times published in 1904. | |
00:36.440 --> 00:39.720 | |
The article says that Hans is an expert in numbers. | |
00:39.720 --> 00:45.960 | |
For example, when two numbers of 5 and 9 are written on a blackboard, Hans replies by tapping | |
00:45.960 --> 00:49.000 | |
on the ground 14 times. | |
00:49.000 --> 00:54.960 | |
Seven years later, in an article, Oscar Feinst unveiled that the so-called clever Hans was | |
00:54.960 --> 01:01.320 | |
not actually capable of doing any arithmetic and instead reading subtle hints in his trainer's | |
01:01.320 --> 01:05.720 | |
behavior indicating when to stop tapping. | |
01:05.720 --> 01:12.900 | |
As the article says, even the trainer was not aware of providing these shortcut signals. | |
01:12.900 --> 01:16.800 | |
So Hans was clever but probably not in doing arithmetic. | |
01:16.800 --> 01:21.120 | |
Its cleverness was in reading his trainer's clues. | |
01:21.120 --> 01:26.480 | |
A similar phenomenon has been observed in many applications of machine learning. | |
01:26.480 --> 01:32.200 | |
Essentially, the situations where the model seemingly has a very good performance but | |
01:32.200 --> 01:38.960 | |
in fact it hasn't learned true underlying relationships between the input and the target. | |
01:38.960 --> 01:47.040 | |
In this paper by Robert Gares and co-authors, they list several instances of what they call | |
01:47.040 --> 01:48.520 | |
shortcut learning. | |
01:48.520 --> 01:55.080 | |
For example, in a task of image captioning, the model predicts grazing sheep only by seeing | |
01:55.080 --> 01:57.840 | |
the green hillside. | |
01:57.840 --> 02:03.480 | |
In another instance, the network hallucinates a teapot with high confidence in an image | |
02:03.480 --> 02:06.400 | |
of pure noise. | |
02:06.400 --> 02:11.960 | |
This is another and indeed dangerous example of the task of pneumonia detection from x-ray | |
02:11.960 --> 02:13.160 | |
images. | |
02:13.160 --> 02:17.280 | |
The model appears to have a very good performance even on the test set. | |
02:17.280 --> 02:23.600 | |
However, the heat maps reveal that the network is not looking at the long section at all | |
02:23.600 --> 02:28.440 | |
and just latching on some features in the corner of the image. | |
02:28.440 --> 02:33.880 | |
The intuition behind this phenomenon is a folk knowledge in one form or another. | |
02:33.880 --> 02:39.520 | |
Given a strongly correlated and fast to learn features in training data, grading the sense | |
02:39.520 --> 02:42.560 | |
is biased towards learning them first. | |
02:42.560 --> 02:48.960 | |
However, this intuition is a bit abstract and hand-wavy, so let's look at a more concrete | |
02:48.960 --> 02:51.240 | |
example. | |
02:51.240 --> 02:57.240 | |
Consider a 2D classification task with red and blue data points as shown. | |
02:57.240 --> 03:03.240 | |
If you train in raw network and this data, here is the decision boundary that we learn. | |
03:03.240 --> 03:08.680 | |
Now consider slightly different arrangements of the data points such that the blue data | |
03:08.680 --> 03:14.400 | |
points are slightly shifted to the left and the red data points are shifted to the right, | |
03:14.400 --> 03:17.640 | |
making the data linearly separable. | |
03:17.640 --> 03:23.480 | |
Now if we train in neural network on this, we get an almost linear decision boundary. | |
03:23.480 --> 03:30.600 | |
Note that the network is only making its predictions based on the feature along the x-axis. | |
03:30.600 --> 03:35.200 | |
Indicated in the red circle here, you can see that the decision boundary is very close | |
03:35.200 --> 03:36.520 | |
to the data points. | |
03:36.520 --> 03:42.400 | |
However, the network is super confident on its predictions and the training loss is indeed | |
03:42.400 --> 03:43.880 | |
zero. | |
03:43.880 --> 03:49.560 | |
So you can see that the slightly perturbing data point can get the network to predict | |
03:49.560 --> 03:52.600 | |
an incorrect label with high confidence. | |
03:52.600 --> 03:59.720 | |
This problem will be even more visible when testing the model on OOD, meaning out of distribution | |
03:59.720 --> 04:03.040 | |
test data. | |
04:03.040 --> 04:07.440 | |
An online interactive demo of this work is available on a blog post we wrote. | |
04:07.440 --> 04:12.440 | |
If you wish to play with it a bit, please visit the link provided here. | |
04:12.440 --> 04:18.240 | |
So we hypothesize that what is happening here is gradient starvation. | |
04:18.240 --> 04:24.880 | |
Gradient starvation is a phenomenon in which a neural network captures statistically dominant | |
04:24.880 --> 04:31.160 | |
features while remaining invariant to the rest. | |
04:31.160 --> 04:37.000 | |
Here gradient descent leads to parameter updates, predominantly in directions that only capture | |
04:37.000 --> 04:43.320 | |
these dominant features, thus starving the gradient from other potentially informative | |
04:43.320 --> 04:44.320 | |
features. | |
04:44.320 --> 04:50.280 | |
Here, the notions of feature and dominancy of a feature is rather vague. | |
04:50.280 --> 04:55.520 | |
To define them more formally, we need to look into the learning dynamics. | |
04:55.520 --> 05:00.720 | |
In the interest of time, I will be covering only the general intuition of our results | |
05:00.720 --> 05:07.360 | |
and encourage interested audiences to take a look at the full paper for detailed treatment. | |
05:07.360 --> 05:13.160 | |
So the two main theorems of the paper can be summarized into these two plots that I | |
05:13.160 --> 05:14.160 | |
now explain. | |
05:14.160 --> 05:19.720 | |
Let's first start with gradient starvation itself on the left. | |
05:19.720 --> 05:23.800 | |
We train a model with common binary cross entropy loss. | |
05:23.800 --> 05:29.000 | |
On the x-axis we have training iterations or epochs, and on the y-axis we monitor two | |
05:29.000 --> 05:32.120 | |
features z1 and z2. | |
05:32.120 --> 05:37.480 | |
Their dynamics depend on several factors, including their strength, meaning how easy | |
05:37.480 --> 05:43.280 | |
or how hard it is for the network to learn those features, and their correlation with | |
05:43.280 --> 05:44.280 | |
the target. | |
05:44.280 --> 05:51.600 | |
Here, z1 has a larger correlation and hence converges to a value around 6, and z2 with | |
05:51.600 --> 05:55.800 | |
a smaller correlation converges to a value around 2. | |
05:55.800 --> 06:01.440 | |
However, the strength is equal, i.e. kappa is set to be 1. | |
06:01.440 --> 06:09.800 | |
Again, it means that both of these features are as easy for the network to learn. | |
06:09.800 --> 06:20.280 | |
Now let's keep their correlation fixed but increase the strength of z1. | |
06:20.280 --> 06:25.400 | |
A kappa equal to 2 means that z1 is learned easier than z2. | |
06:25.400 --> 06:31.640 | |
We can immediately see that although their correlation is still the same as before, z1 | |
06:31.640 --> 06:36.560 | |
is overestimated while z2 is underestimated. | |
06:36.560 --> 06:44.000 | |
If we make kappa to be 4 or 8, it becomes more evident that simply because z1 is easier | |
06:44.000 --> 06:51.400 | |
to learn, it is being overestimated, while z2 is being starved. | |
06:51.400 --> 06:58.520 | |
Our theory shows that an increase in the strength of feature z1 has a detrimental effect on | |
06:58.520 --> 07:01.760 | |
the learning of feature z2. | |
07:01.760 --> 07:08.840 | |
Now our second theory shows that adding this term, indicated in the red rectangle, to the | |
07:08.840 --> 07:11.800 | |
loss decouples the features. | |
07:11.800 --> 07:17.640 | |
As you can see, a spectral decoupling decouples the features at the converged solution. | |
07:17.640 --> 07:25.680 | |
Regardless of the value of kappa, all of the experiments on z1 and z2 converge to the same | |
07:25.680 --> 07:26.680 | |
place. | |
07:26.680 --> 07:33.640 | |
Again, we refer interested audience to the paper for more theory as well as more intuition. | |
07:33.640 --> 07:36.720 | |
Now let's look at some experiments. | |
07:36.720 --> 07:39.080 | |
Recall the task that we studied earlier. | |
07:39.080 --> 07:44.880 | |
When the data is not linearly separable, we learn the curve decision boundary. | |
07:44.880 --> 07:49.840 | |
On the right, we see how z1 and z2 evolve. | |
07:49.840 --> 07:55.080 | |
When the data is linearly separable with a small margin, a linear decision boundary is | |
07:55.080 --> 07:56.080 | |
learned. | |
07:56.080 --> 08:02.920 | |
We observe that z1 is overestimated, while z2 is heavily underestimated. | |
08:02.920 --> 08:07.880 | |
Now let's see what happens if we add spectral decoupling. | |
08:07.880 --> 08:14.320 | |
Spectral decoupling suppresses z1 and as a result allows z2 to grow. | |
08:14.320 --> 08:20.480 | |
It also appears that other regularization methods do not succeed at learning a curve | |
08:20.480 --> 08:23.240 | |
decision boundary. | |
08:23.240 --> 08:30.860 | |
So we observed that spectral decoupling leads to a decision boundary with a larger margin. | |
08:30.860 --> 08:33.520 | |
What happens in real-world tasks? | |
08:33.520 --> 08:38.640 | |
The distance to the decision boundary is not trivial to compute when working with nonlinear | |
08:38.640 --> 08:39.640 | |
models. | |
08:39.640 --> 08:42.040 | |
However, we can use a proxy. | |
08:42.040 --> 08:48.200 | |
The amount of perturbation required to fool the network is a proxy to the margin. | |
08:48.200 --> 08:50.400 | |
Look at the plot on the right. | |
08:50.400 --> 08:56.040 | |
On the x-axis, we have the amount of perturbation and on the y-axis, we have how many of the | |
08:56.040 --> 08:59.760 | |
examples are misclassified. | |
08:59.760 --> 09:07.420 | |
You can see that with a fixed amount of perturbation, a model with vanilla binary cross entropy | |
09:07.420 --> 09:14.320 | |
is much more vulnerable compared to a model trained with spectral decoupling. | |
09:14.320 --> 09:20.320 | |
In another experiment, we studied colored MNIST, a well-known task of OOD generalization | |
09:20.320 --> 09:27.960 | |
where the color is spuriously correlated with the labels. | |
09:27.960 --> 09:33.740 | |
Also another task of OOD generalization is a classification task on the CILIB8 dataset | |
09:33.740 --> 09:44.100 | |
where the training data is again biased with respect to the color of the hair and the gender | |
09:44.100 --> 09:50.080 | |
such that most of male images have black hair while the majority of females have blonde | |
09:50.080 --> 09:51.080 | |
hair. | |
09:51.080 --> 09:54.760 | |
Here, we skip the details in the interest of time. | |
09:54.760 --> 10:00.920 | |
However, let me just draw your attention to the superiority of spectral decoupling in | |
10:00.920 --> 10:03.840 | |
these both tasks. | |
10:03.840 --> 10:09.080 | |
Finally to conclude, we talked about the clever hands effect. | |
10:09.080 --> 10:15.360 | |
We showed that a similar phenomenon can happen in neural networks and we called that gradient | |
10:15.360 --> 10:16.360 | |
starvation. | |
10:16.360 --> 10:21.600 | |
To understand gradient starvation, we looked into the learning dynamics. | |
10:21.600 --> 10:29.080 | |
We showed that the presence of a strongly correlated feature could result in a starvation | |
10:29.080 --> 10:30.960 | |
of other features. | |
10:30.960 --> 10:36.560 | |
We also showed that spectral decoupling provides some degree of control over what features | |
10:36.560 --> 10:44.040 | |
to learn and decouples essentially the features. | |
10:44.040 --> 10:45.720 | |
Thanks for your attention. | |
10:45.720 --> 10:50.880 | |
If you're interested to chat more, please visit our poster this afternoon. | |
10:50.880 --> 11:01.760 | |
Thank you very much. | |