Spaces:
Running on Zero

Ruurd commited on
Commit
ec35c53
·
verified ·
1 Parent(s): a2125f4

Change to mask model

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -25,51 +25,51 @@ with open("token_probabilities.json") as f:
25
  token_probs_dict = json.load(f)
26
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
27
 
28
- # def load_model():
29
- # ckpt_path = hf_hub_download(
30
- # repo_id="ruurd/tini_bi_m",
31
- # filename="diffusion-model.pth",
32
- # token=os.getenv("HF_TOKEN")
33
- # )
34
-
35
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- # model = torch.load(ckpt_path, map_location=device)
37
- # model = disable_dropout(model)
38
- # model.to(device)
39
- # model.eval()
40
- # return model
41
-
42
  def load_model():
43
  ckpt_path = hf_hub_download(
44
- repo_id="ruurd/tini_bi",
45
  filename="diffusion-model.pth",
46
- token=os.getenv("HF_TOKEN"),
47
- revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
48
  )
49
 
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
51
 
52
- # Step 1: Create model from scratch
53
- model = CustomTransformerModel(CustomTransformerConfig())
 
 
 
 
 
54
 
55
- # Step 2: Load state_dict from full checkpoint
56
- full_model = torch.load(ckpt_path, map_location=device)
57
 
58
- # This handles both full model or just state_dict
59
- try:
60
- state_dict = full_model.state_dict()
61
- except AttributeError:
62
- state_dict = full_model # already a state_dict
63
 
64
- # Step 3: Load weights (might print mismatches)
65
- missing, unexpected = model.load_state_dict(state_dict, strict=False)
66
- print("Missing keys:", missing)
67
- print("Unexpected keys:", unexpected)
68
 
69
- model = disable_dropout(model)
70
- model.to(device)
71
- model.eval()
72
- return model
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
 
 
25
  token_probs_dict = json.load(f)
26
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def load_model():
29
  ckpt_path = hf_hub_download(
30
+ repo_id="ruurd/tini_bi_m",
31
  filename="diffusion-model.pth",
32
+ token=os.getenv("HF_TOKEN")
 
33
  )
34
 
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ model = torch.load(ckpt_path, map_location=device)
37
+ model = disable_dropout(model)
38
+ model.to(device)
39
+ model.eval()
40
+ return model
41
 
42
+ # def load_model():
43
+ # ckpt_path = hf_hub_download(
44
+ # repo_id="ruurd/tini_bi",
45
+ # filename="diffusion-model.pth",
46
+ # token=os.getenv("HF_TOKEN"),
47
+ # revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
48
+ # )
49
 
50
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
51
 
52
+ # # Step 1: Create model from scratch
53
+ # model = CustomTransformerModel(CustomTransformerConfig())
 
 
 
54
 
55
+ # # Step 2: Load state_dict from full checkpoint
56
+ # full_model = torch.load(ckpt_path, map_location=device)
 
 
57
 
58
+ # # This handles both full model or just state_dict
59
+ # try:
60
+ # state_dict = full_model.state_dict()
61
+ # except AttributeError:
62
+ # state_dict = full_model # already a state_dict
63
+
64
+ # # Step 3: Load weights (might print mismatches)
65
+ # missing, unexpected = model.load_state_dict(state_dict, strict=False)
66
+ # print("Missing keys:", missing)
67
+ # print("Unexpected keys:", unexpected)
68
+
69
+ # model = disable_dropout(model)
70
+ # model.to(device)
71
+ # model.eval()
72
+ # return model
73
 
74
 
75