ytseg_demo / demo_data /nips-2021 /25974 /transcript_whisper_large-v2.vtt
retkowski's picture
Add demo
cb71ef5
WEBVTT
00:00.000 --> 00:15.280
Hi, I am Mohamed Pezeshki from Mila and today I am going to talk about creating starvation.
00:15.280 --> 00:22.280
This is a joint work with Omar Kaba, Joshua Bengio, Aaron Korvel, Doina Prikop, and Guillaume
00:22.280 --> 00:23.280
Lajra.
00:23.280 --> 00:25.480
Let me start with a story.
00:25.480 --> 00:32.000
Back in 1904, there was a horse called Hans and people believed that he could do arithmetic.
00:32.000 --> 00:36.440
Here is an article from New York Times published in 1904.
00:36.440 --> 00:39.720
The article says that Hans is an expert in numbers.
00:39.720 --> 00:45.960
For example, when two numbers of 5 and 9 are written on a blackboard, Hans replies by tapping
00:45.960 --> 00:49.000
on the ground 14 times.
00:49.000 --> 00:54.960
Seven years later, in an article, Oscar Feinst unveiled that the so-called clever Hans was
00:54.960 --> 01:01.320
not actually capable of doing any arithmetic and instead reading subtle hints in his trainer's
01:01.320 --> 01:05.720
behavior indicating when to stop tapping.
01:05.720 --> 01:12.900
As the article says, even the trainer was not aware of providing these shortcut signals.
01:12.900 --> 01:16.800
So Hans was clever but probably not in doing arithmetic.
01:16.800 --> 01:21.120
Its cleverness was in reading his trainer's clues.
01:21.120 --> 01:26.480
A similar phenomenon has been observed in many applications of machine learning.
01:26.480 --> 01:32.200
Essentially, the situations where the model seemingly has a very good performance but
01:32.200 --> 01:38.960
in fact it hasn't learned true underlying relationships between the input and the target.
01:38.960 --> 01:47.040
In this paper by Robert Gares and co-authors, they list several instances of what they call
01:47.040 --> 01:48.520
shortcut learning.
01:48.520 --> 01:55.080
For example, in a task of image captioning, the model predicts grazing sheep only by seeing
01:55.080 --> 01:57.840
the green hillside.
01:57.840 --> 02:03.480
In another instance, the network hallucinates a teapot with high confidence in an image
02:03.480 --> 02:06.400
of pure noise.
02:06.400 --> 02:11.960
This is another and indeed dangerous example of the task of pneumonia detection from x-ray
02:11.960 --> 02:13.160
images.
02:13.160 --> 02:17.280
The model appears to have a very good performance even on the test set.
02:17.280 --> 02:23.600
However, the heat maps reveal that the network is not looking at the long section at all
02:23.600 --> 02:28.440
and just latching on some features in the corner of the image.
02:28.440 --> 02:33.880
The intuition behind this phenomenon is a folk knowledge in one form or another.
02:33.880 --> 02:39.520
Given a strongly correlated and fast to learn features in training data, grading the sense
02:39.520 --> 02:42.560
is biased towards learning them first.
02:42.560 --> 02:48.960
However, this intuition is a bit abstract and hand-wavy, so let's look at a more concrete
02:48.960 --> 02:51.240
example.
02:51.240 --> 02:57.240
Consider a 2D classification task with red and blue data points as shown.
02:57.240 --> 03:03.240
If you train in raw network and this data, here is the decision boundary that we learn.
03:03.240 --> 03:08.680
Now consider slightly different arrangements of the data points such that the blue data
03:08.680 --> 03:14.400
points are slightly shifted to the left and the red data points are shifted to the right,
03:14.400 --> 03:17.640
making the data linearly separable.
03:17.640 --> 03:23.480
Now if we train in neural network on this, we get an almost linear decision boundary.
03:23.480 --> 03:30.600
Note that the network is only making its predictions based on the feature along the x-axis.
03:30.600 --> 03:35.200
Indicated in the red circle here, you can see that the decision boundary is very close
03:35.200 --> 03:36.520
to the data points.
03:36.520 --> 03:42.400
However, the network is super confident on its predictions and the training loss is indeed
03:42.400 --> 03:43.880
zero.
03:43.880 --> 03:49.560
So you can see that the slightly perturbing data point can get the network to predict
03:49.560 --> 03:52.600
an incorrect label with high confidence.
03:52.600 --> 03:59.720
This problem will be even more visible when testing the model on OOD, meaning out of distribution
03:59.720 --> 04:03.040
test data.
04:03.040 --> 04:07.440
An online interactive demo of this work is available on a blog post we wrote.
04:07.440 --> 04:12.440
If you wish to play with it a bit, please visit the link provided here.
04:12.440 --> 04:18.240
So we hypothesize that what is happening here is gradient starvation.
04:18.240 --> 04:24.880
Gradient starvation is a phenomenon in which a neural network captures statistically dominant
04:24.880 --> 04:31.160
features while remaining invariant to the rest.
04:31.160 --> 04:37.000
Here gradient descent leads to parameter updates, predominantly in directions that only capture
04:37.000 --> 04:43.320
these dominant features, thus starving the gradient from other potentially informative
04:43.320 --> 04:44.320
features.
04:44.320 --> 04:50.280
Here, the notions of feature and dominancy of a feature is rather vague.
04:50.280 --> 04:55.520
To define them more formally, we need to look into the learning dynamics.
04:55.520 --> 05:00.720
In the interest of time, I will be covering only the general intuition of our results
05:00.720 --> 05:07.360
and encourage interested audiences to take a look at the full paper for detailed treatment.
05:07.360 --> 05:13.160
So the two main theorems of the paper can be summarized into these two plots that I
05:13.160 --> 05:14.160
now explain.
05:14.160 --> 05:19.720
Let's first start with gradient starvation itself on the left.
05:19.720 --> 05:23.800
We train a model with common binary cross entropy loss.
05:23.800 --> 05:29.000
On the x-axis we have training iterations or epochs, and on the y-axis we monitor two
05:29.000 --> 05:32.120
features z1 and z2.
05:32.120 --> 05:37.480
Their dynamics depend on several factors, including their strength, meaning how easy
05:37.480 --> 05:43.280
or how hard it is for the network to learn those features, and their correlation with
05:43.280 --> 05:44.280
the target.
05:44.280 --> 05:51.600
Here, z1 has a larger correlation and hence converges to a value around 6, and z2 with
05:51.600 --> 05:55.800
a smaller correlation converges to a value around 2.
05:55.800 --> 06:01.440
However, the strength is equal, i.e. kappa is set to be 1.
06:01.440 --> 06:09.800
Again, it means that both of these features are as easy for the network to learn.
06:09.800 --> 06:20.280
Now let's keep their correlation fixed but increase the strength of z1.
06:20.280 --> 06:25.400
A kappa equal to 2 means that z1 is learned easier than z2.
06:25.400 --> 06:31.640
We can immediately see that although their correlation is still the same as before, z1
06:31.640 --> 06:36.560
is overestimated while z2 is underestimated.
06:36.560 --> 06:44.000
If we make kappa to be 4 or 8, it becomes more evident that simply because z1 is easier
06:44.000 --> 06:51.400
to learn, it is being overestimated, while z2 is being starved.
06:51.400 --> 06:58.520
Our theory shows that an increase in the strength of feature z1 has a detrimental effect on
06:58.520 --> 07:01.760
the learning of feature z2.
07:01.760 --> 07:08.840
Now our second theory shows that adding this term, indicated in the red rectangle, to the
07:08.840 --> 07:11.800
loss decouples the features.
07:11.800 --> 07:17.640
As you can see, a spectral decoupling decouples the features at the converged solution.
07:17.640 --> 07:25.680
Regardless of the value of kappa, all of the experiments on z1 and z2 converge to the same
07:25.680 --> 07:26.680
place.
07:26.680 --> 07:33.640
Again, we refer interested audience to the paper for more theory as well as more intuition.
07:33.640 --> 07:36.720
Now let's look at some experiments.
07:36.720 --> 07:39.080
Recall the task that we studied earlier.
07:39.080 --> 07:44.880
When the data is not linearly separable, we learn the curve decision boundary.
07:44.880 --> 07:49.840
On the right, we see how z1 and z2 evolve.
07:49.840 --> 07:55.080
When the data is linearly separable with a small margin, a linear decision boundary is
07:55.080 --> 07:56.080
learned.
07:56.080 --> 08:02.920
We observe that z1 is overestimated, while z2 is heavily underestimated.
08:02.920 --> 08:07.880
Now let's see what happens if we add spectral decoupling.
08:07.880 --> 08:14.320
Spectral decoupling suppresses z1 and as a result allows z2 to grow.
08:14.320 --> 08:20.480
It also appears that other regularization methods do not succeed at learning a curve
08:20.480 --> 08:23.240
decision boundary.
08:23.240 --> 08:30.860
So we observed that spectral decoupling leads to a decision boundary with a larger margin.
08:30.860 --> 08:33.520
What happens in real-world tasks?
08:33.520 --> 08:38.640
The distance to the decision boundary is not trivial to compute when working with nonlinear
08:38.640 --> 08:39.640
models.
08:39.640 --> 08:42.040
However, we can use a proxy.
08:42.040 --> 08:48.200
The amount of perturbation required to fool the network is a proxy to the margin.
08:48.200 --> 08:50.400
Look at the plot on the right.
08:50.400 --> 08:56.040
On the x-axis, we have the amount of perturbation and on the y-axis, we have how many of the
08:56.040 --> 08:59.760
examples are misclassified.
08:59.760 --> 09:07.420
You can see that with a fixed amount of perturbation, a model with vanilla binary cross entropy
09:07.420 --> 09:14.320
is much more vulnerable compared to a model trained with spectral decoupling.
09:14.320 --> 09:20.320
In another experiment, we studied colored MNIST, a well-known task of OOD generalization
09:20.320 --> 09:27.960
where the color is spuriously correlated with the labels.
09:27.960 --> 09:33.740
Also another task of OOD generalization is a classification task on the CILIB8 dataset
09:33.740 --> 09:44.100
where the training data is again biased with respect to the color of the hair and the gender
09:44.100 --> 09:50.080
such that most of male images have black hair while the majority of females have blonde
09:50.080 --> 09:51.080
hair.
09:51.080 --> 09:54.760
Here, we skip the details in the interest of time.
09:54.760 --> 10:00.920
However, let me just draw your attention to the superiority of spectral decoupling in
10:00.920 --> 10:03.840
these both tasks.
10:03.840 --> 10:09.080
Finally to conclude, we talked about the clever hands effect.
10:09.080 --> 10:15.360
We showed that a similar phenomenon can happen in neural networks and we called that gradient
10:15.360 --> 10:16.360
starvation.
10:16.360 --> 10:21.600
To understand gradient starvation, we looked into the learning dynamics.
10:21.600 --> 10:29.080
We showed that the presence of a strongly correlated feature could result in a starvation
10:29.080 --> 10:30.960
of other features.
10:30.960 --> 10:36.560
We also showed that spectral decoupling provides some degree of control over what features
10:36.560 --> 10:44.040
to learn and decouples essentially the features.
10:44.040 --> 10:45.720
Thanks for your attention.
10:45.720 --> 10:50.880
If you're interested to chat more, please visit our poster this afternoon.
10:50.880 --> 11:01.760
Thank you very much.