Spaces:
Running on Zero

Ruurd commited on
Commit
bd9baef
·
1 Parent(s): 332db3a

Add model config files

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. models.py +50 -0
app.py CHANGED
@@ -17,6 +17,8 @@ from infer import (
17
  generate_diffusion_text,
18
  filter_logits
19
  )
 
 
20
 
21
  # Load .env only when running locally
22
  if os.getenv("HF_TOKEN") is None:
 
17
  generate_diffusion_text,
18
  filter_logits
19
  )
20
+ from models import CustomTransformerModel
21
+ from model_config import CustomTransformerConfig
22
 
23
  # Load .env only when running locally
24
  if os.getenv("HF_TOKEN") is None:
models.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.amp import autocast
5
+ from transformers import PreTrainedModel
6
+ from model_config import CustomTransformerConfig
7
+
8
+ class CustomTransformerModel(PreTrainedModel):
9
+ config_class = CustomTransformerConfig
10
+
11
+ def __init__(self, config):
12
+ super().__init__(config)
13
+
14
+ def forward(self, input_ids, labels=None, **kwargs):
15
+ batch_size, seq_len = input_ids.shape
16
+ device = input_ids.device
17
+ masking_type = getattr(self.config, "masking_type", "bidirectional")
18
+
19
+ if masking_type == 'bidirectional':
20
+ base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
21
+ elif masking_type == 'bidirectional_masked':
22
+ base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
23
+ base_mask.fill_diagonal_(False)
24
+ elif masking_type == 'unidirectional':
25
+ base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
26
+ else:
27
+ raise ValueError(f"Unknown masking type: {masking_type}")
28
+
29
+ attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
30
+ attention_mask = attention_mask.to(dtype=torch.float32)
31
+
32
+
33
+ with autocast("mps", dtype=torch.float16):
34
+ outputs = self.llama(
35
+ input_ids,
36
+ attention_mask=attention_mask,
37
+ output_hidden_states=True,
38
+ use_cache=False,
39
+ **kwargs
40
+ )
41
+
42
+ logits = outputs.logits[:, :, :self.config.vocab_size].view(batch_size, seq_len, self.config.vocab_size)
43
+ loss = None
44
+
45
+ if labels is not None:
46
+ assert labels.shape == (batch_size, seq_len)
47
+ loss_fct = nn.CrossEntropyLoss()
48
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
49
+
50
+ return {"loss": loss, "logits": logits} if loss is not None else {"logits": logits}