kibrq commited on
Commit
56bad2a
·
1 Parent(s): ec08061

Update model

Browse files
Files changed (4) hide show
  1. config.json +66 -0
  2. configuration_greedy.py +39 -0
  3. modeling_greedy.py +85 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GreedyModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_greedy.GreedyConfig",
7
+ "AutoModelForCausalLM": "modeling_greedy.GreedyModel"
8
+ },
9
+ "eos_token_id": 8,
10
+ "pad_token_id": 9,
11
+ "reciprocals": [
12
+ [
13
+ 4,
14
+ 3
15
+ ],
16
+ [
17
+ 5,
18
+ 2
19
+ ],
20
+ [
21
+ 6,
22
+ 1
23
+ ]
24
+ ],
25
+ "reducables": [
26
+ [
27
+ [
28
+ 4
29
+ ],
30
+ [
31
+ 3
32
+ ]
33
+ ],
34
+ [
35
+ [
36
+ 5
37
+ ],
38
+ [
39
+ 2
40
+ ]
41
+ ],
42
+ [
43
+ [
44
+ 6
45
+ ],
46
+ [
47
+ 1
48
+ ]
49
+ ],
50
+ [
51
+ [
52
+ 4,
53
+ 5,
54
+ 6
55
+ ],
56
+ [
57
+ 1,
58
+ 2,
59
+ 3
60
+ ]
61
+ ]
62
+ ],
63
+ "torch_dtype": "float32",
64
+ "transformers_version": "4.21.1",
65
+ "vocab_size": 10
66
+ }
configuration_greedy.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, PreTrainedTokenizerBase
2
+ from freegroup import tools
3
+
4
+ class GreedyConfig(PretrainedConfig):
5
+
6
+ @classmethod
7
+ def from_tokenizer(cls, freegroup_dimension, tokenizer: PreTrainedTokenizerBase, **kwargs):
8
+
9
+ freegroup_generators = list(range(1, freegroup_dimension + 1))
10
+
11
+ reciprocals = []
12
+ for x in freegroup_generators:
13
+ a, b = tokenizer.convert_tokens_to_ids([str(x), str(-x)])
14
+ reciprocals.append([a, b])
15
+
16
+ reducables = [[] for _ in range(freegroup_dimension + 1)]
17
+ for reducable, closure_generator in zip(reducables, [[x] for x in freegroup_generators] + [freegroup_generators[::]]):
18
+ reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, closure_generator))))
19
+ reducable.append(tokenizer.convert_tokens_to_ids(list(map(str, tools.reciprocal(closure_generator)))))
20
+
21
+ return cls(
22
+ reciprocals = reciprocals,
23
+ reducables = reducables,
24
+ vocab_size = len(tokenizer),
25
+ eos_token_id = tokenizer.eos_token_id,
26
+ pad_token_id = tokenizer.pad_token_id,
27
+ **kwargs
28
+ )
29
+
30
+ def __init__(self, **kwargs):
31
+ # reciporcals: List[List[int]]: i.e. ['x', 'X'], ...
32
+ self.reciprocals = kwargs.pop('reciprocals', None)
33
+
34
+ # reducables: List[List[List[int]]]: generators for normal closures, i.e [[[x], [X]], [[y], [Y]], ...]
35
+ self.reducables = kwargs.pop('reducables', None)
36
+
37
+ super().__init__(**kwargs)
38
+
39
+
modeling_greedy.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers import PreTrainedModel
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+
6
+ from .configuration_greedy import GreedyConfig
7
+ from freegroup import tools
8
+
9
+ class GreedyModel(PreTrainedModel):
10
+ config_class = GreedyConfig
11
+
12
+ def __init__(self, config: GreedyConfig):
13
+ super().__init__(config)
14
+ self.stub = torch.nn.parameter.Parameter(torch.tensor(0.))
15
+
16
+ def _reduce_step(self, token, stack, reducables):
17
+ stack.append(token.item())
18
+
19
+ for reducable in self.config.reciprocals + reducables:
20
+ n = len(reducable)
21
+ if len(stack) >= len(reducable):
22
+ if tools.occurs(stack[-n:], reducable * 2):
23
+ del stack[-n:]
24
+
25
+ return stack
26
+
27
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
28
+ past = kwargs.pop('past', None)
29
+ return {'input_ids': input_ids, 'past': past}
30
+
31
+ def forward(self, input_ids = None, past = None, **kwargs):
32
+
33
+ assert (input_ids is not None), "Can't be None"
34
+
35
+ batch_size, sequence_length = input_ids.shape
36
+
37
+ if past is None:
38
+ stacks = [[[] for _ in range(len(self.config.reducables))] for _ in range(batch_size)]
39
+ hidden_states = None
40
+ else:
41
+ stacks, hidden_states = past
42
+
43
+ begin_idx = 0 if hidden_states is None else hidden_states.size(0)
44
+
45
+ for t in range(begin_idx, sequence_length):
46
+ last_hidden_states = torch.zeros((batch_size, self.config.vocab_size))
47
+
48
+ for batch_idx, word in enumerate(input_ids):
49
+ for stack, reducables in zip(stacks[batch_idx], self.config.reducables):
50
+
51
+ self._reduce_step(word[t], stack, reducables)
52
+ if not stack: continue
53
+
54
+ last = stack[-1]
55
+
56
+ for r in reducables:
57
+ if not last in r:
58
+ key = r[0]
59
+ last_hidden_states[batch_idx][r[0]] += 1
60
+ if last in r:
61
+ pos = r.index(last)
62
+ key = r[(pos + 1) % len(r)]
63
+ last_hidden_states[batch_idx][key] += 1
64
+ for r in self.config.reciprocals:
65
+ if last in r:
66
+ pos = r.index(last)
67
+ key = r[(pos + 1) % len(r)]
68
+ last_hidden_states[batch_idx][key] += 1
69
+
70
+ for r in self.config.reciprocals:
71
+ if word[t] in r:
72
+ pos = r.index(word[t])
73
+ key = r[(pos + 1) % len(r)]
74
+ last_hidden_states[batch_idx][key] = -torch.inf
75
+
76
+ if all(map(lambda x: len(x) == 0, stacks[batch_idx])):
77
+ last_hidden_states[batch_idx][self.config.eos_token_id] = torch.inf
78
+
79
+ if hidden_states is None: hidden_states = last_hidden_states.clone().unsqueeze(0)
80
+ else: hidden_states = torch.cat((hidden_states, last_hidden_states.unsqueeze(0)))
81
+
82
+ return CausalLMOutputWithPast(
83
+ logits = hidden_states.permute(1, 0, 2),
84
+ past_key_values = (stacks, hidden_states)
85
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8316f6ca2e3d5db92da31339c2ddee2b14adf2d3cbc0668dc5d8960db7668d67
3
+ size 747