ytseg_demo / demo_data /nips-2021 /25970 /transcript_whisper_large-v2.vtt
retkowski's picture
Add demo
cb71ef5
WEBVTT
00:00.000 --> 00:14.520
Hi, my name is Maxwell Nye, and today I'll be talking about improving coherence and consistency
00:14.520 --> 00:19.620
in neural sequence models with dual system neurosymbolic reasoning.
00:19.620 --> 00:23.800
So I first want to give a little bit of a demo, which is to ask this question.
00:23.800 --> 00:26.920
A bat and a ball cost $1.10 in total.
00:26.920 --> 00:29.300
The bat costs $1 more than the ball.
00:29.300 --> 00:31.720
How much does the ball cost?
00:31.720 --> 00:34.920
So I'll let you think a little bit for this.
00:34.920 --> 00:39.200
So one answer that sort of might jump out at you is $0.10, but this is actually incorrect
00:39.200 --> 00:43.920
because the sum of the two objects should be $1.10.
00:43.920 --> 00:46.880
So the correct answer is actually $0.05.
00:46.880 --> 00:54.240
And this is an example from a cognitive reflection test, and these are questions designed to
00:54.240 --> 01:00.140
have a particular answer which comes to mind quite quickly, which is in fact wrong.
01:00.140 --> 01:06.640
And something that's interesting is that large-scale language models such as GPT-3 predict the
01:06.640 --> 01:08.320
wrong answers as well.
01:08.320 --> 01:11.300
And this is true not just for the sort of the classic cognitive reflection test, but
01:11.300 --> 01:15.160
also for variants with different numbers.
01:15.160 --> 01:19.680
So this is sort of an interesting thing.
01:19.680 --> 01:27.400
It talks about how neural language models often have issues with consistency and coherence.
01:27.400 --> 01:30.720
So another place that we can see this a little more concretely is the clutter data set.
01:30.720 --> 01:36.680
In the clutter data set, models are trained to...
01:36.680 --> 01:42.080
There are sentences about people and their family relationships and stories about those
01:42.080 --> 01:43.840
people.
01:43.840 --> 01:48.800
And this was originally devised as a question-answering data set where you ask what the relations
01:48.800 --> 01:49.800
are.
01:49.800 --> 01:58.080
One thing you can do is ask models to be trained on this data set and then generate new stories.
01:58.080 --> 02:02.880
And when you do that, you'll see that often the generated stories have inconsistency.
02:02.880 --> 02:06.560
So if we look at the bottom of the screen here, we can see an example of this.
02:06.560 --> 02:10.080
Robert and his brother Antonio played harmonicas together.
02:10.080 --> 02:13.440
Robert's daughter, Elsie, asked him to play with her.
02:13.440 --> 02:17.280
Elsie doesn't like having to babysit her younger brother, Antonio.
02:17.280 --> 02:21.240
And so we can see that this is a common sense error because Elsie is not the younger brother
02:21.240 --> 02:22.240
of Antonio.
02:22.240 --> 02:27.720
Or Elsie's younger brother is not Antonio.
02:27.720 --> 02:35.760
So what we've done is we've built a dual system model using large-scale neural networks and
02:35.760 --> 02:42.800
symbolic deliberative logic in order to try to help with these consistency issues.
02:42.800 --> 02:44.400
So the model is as follows.
02:44.400 --> 02:52.680
You use neural generation to generate sentences in a particular story.
02:52.680 --> 02:59.360
You might generate the next sentence using a model such as GPT-3 or BART.
02:59.360 --> 03:10.320
What you can then do is parse that sentence into the semantic meaning with respect to
03:10.320 --> 03:15.520
the family relationships and check whether or not it matches the current state of the
03:15.520 --> 03:20.960
family relationships that's been described so far, and only accept the candidate sentence
03:20.960 --> 03:25.800
generations that are actually consistent.
03:25.800 --> 03:27.600
So this has a few components.
03:27.600 --> 03:30.380
One of the components here is a symbolic world model.
03:30.380 --> 03:35.160
In the case of this clutter domain, the symbolic world model that we built encodes people and
03:35.160 --> 03:36.160
their family relationships.
03:36.160 --> 03:42.840
So in other words, you could take a sentence and encode what the underlying family relationship
03:42.840 --> 03:43.840
is.
03:43.840 --> 03:50.680
And what you can do is you can use SMT solvers such as the Z3 solver to check consistency.
03:50.680 --> 03:57.240
So given a new sentence, you can check that it doesn't disobey the rules of ancestry that
03:57.240 --> 03:58.240
we've defined here.
03:58.240 --> 04:04.120
And so some of those are, for example, what is the relationship between children and grandchildren?
04:04.120 --> 04:10.000
And then another is what are the rules about whether ancestry, can you be your own ancestor,
04:10.000 --> 04:12.180
et cetera.
04:12.180 --> 04:15.040
So one question is how is this semantic parsing done?
04:15.040 --> 04:19.560
And it turns out we can actually do this quite cheaply using GPT-3.
04:19.560 --> 04:26.920
So what we can see here in the dotted box is an actual example of a few-shot prompt
04:26.920 --> 04:34.440
we can use to parse each new sentence, each new candidate sentence from the system one
04:34.440 --> 04:42.360
generation model and parse it into the semantic form that we can then give to the world model
04:42.360 --> 04:46.280
solver.
04:46.280 --> 04:52.120
So the results here show that models that use this dual system neurosymbolic stories
04:52.120 --> 05:02.160
show improved coherence over just sentences that were constructed by a neural model.
05:02.160 --> 05:10.160
So the example here is that what we've done is we've used human judgments on which of
05:10.160 --> 05:14.800
the following sentences make more sense given the prior context of the story.
05:14.800 --> 05:25.280
And we see that if we use a symbolic world model and the parsing scheme described above,
05:25.280 --> 05:32.520
humans prefer the judgments given by this model.
05:32.520 --> 05:36.360
We can also apply the same sort of reasoning to a completely different task.
05:36.360 --> 05:42.080
Here we can discuss the grounded instruction following task, the grounded instruction following
05:42.080 --> 05:44.020
domain called gscan.
05:44.020 --> 05:49.360
In this domain, the goal is to have an agent, which is shown by this pink triangle, follow
05:49.360 --> 05:53.240
a command to perform some simple action in this grid world.
05:53.240 --> 06:00.520
So you can see here, walk to a small yellow cylinder might be an example of a command.
06:00.520 --> 06:06.800
Prior work has shown that one thing you can do is encode the initial state, encode the
06:06.800 --> 06:14.280
instruction and then train a neural model to predict the action sequences.
06:14.280 --> 06:19.600
Other work has also shown that one thing you can do is train a model to predict a distribution
06:19.600 --> 06:25.200
over the correct target location as part of the neural model.
06:25.200 --> 06:29.600
That will also increase the performance of the model.
06:29.600 --> 06:38.400
What we do here is show that if you do both of these things, you predict both an action
06:38.400 --> 06:43.800
sequence and a target location, like what is the location you should end up in, and
06:43.800 --> 06:48.600
then check whether or not when you execute the set of instructions, you will end up in
06:48.600 --> 06:50.720
the predicted target location.
06:50.720 --> 06:57.800
You can sort of check consistency between these two different predictions and only accept
06:57.800 --> 07:06.560
those instruction sequences which match the target location prediction.
07:06.560 --> 07:14.700
And this leads to also higher accuracy, especially in a low data regime.
07:14.700 --> 07:18.320
We have more details about the results of the paper.
07:18.320 --> 07:21.160
So that's a little bit of an overview of our paper.
07:21.160 --> 07:24.520
Our takeaways are that you can build systems with combined neural methods and explicit
07:24.520 --> 07:25.560
world knowledge.
07:25.560 --> 07:28.880
And if you add just a little bit of world knowledge, you can really help increase coherence
07:28.880 --> 07:34.880
and consistency for these large sequence models.
07:34.880 --> 07:38.520
There are some challenges here about parsing in larger scale domains and also what it would
07:38.520 --> 07:41.360
mean to automatically build a more complete world model.
07:41.360 --> 08:01.360
Thank you very much.