Spaces:
Running on Zero

Ruurd commited on
Commit
44296bc
·
verified ·
1 Parent(s): ec35c53

Change back to normal model

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -25,51 +25,51 @@ with open("token_probabilities.json") as f:
25
  token_probs_dict = json.load(f)
26
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
27
 
28
- def load_model():
29
- ckpt_path = hf_hub_download(
30
- repo_id="ruurd/tini_bi_m",
31
- filename="diffusion-model.pth",
32
- token=os.getenv("HF_TOKEN")
33
- )
34
-
35
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
- model = torch.load(ckpt_path, map_location=device)
37
- model = disable_dropout(model)
38
- model.to(device)
39
- model.eval()
40
- return model
41
-
42
  # def load_model():
43
  # ckpt_path = hf_hub_download(
44
- # repo_id="ruurd/tini_bi",
45
  # filename="diffusion-model.pth",
46
- # token=os.getenv("HF_TOKEN"),
47
- # revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
48
  # )
49
 
50
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
51
 
52
- # # Step 1: Create model from scratch
53
- # model = CustomTransformerModel(CustomTransformerConfig())
 
 
 
 
 
54
 
55
- # # Step 2: Load state_dict from full checkpoint
56
- # full_model = torch.load(ckpt_path, map_location=device)
57
 
58
- # # This handles both full model or just state_dict
59
- # try:
60
- # state_dict = full_model.state_dict()
61
- # except AttributeError:
62
- # state_dict = full_model # already a state_dict
63
 
64
- # # Step 3: Load weights (might print mismatches)
65
- # missing, unexpected = model.load_state_dict(state_dict, strict=False)
66
- # print("Missing keys:", missing)
67
- # print("Unexpected keys:", unexpected)
68
 
69
- # model = disable_dropout(model)
70
- # model.to(device)
71
- # model.eval()
72
- # return model
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
 
 
25
  token_probs_dict = json.load(f)
26
  token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # def load_model():
29
  # ckpt_path = hf_hub_download(
30
+ # repo_id="ruurd/tini_bi_m",
31
  # filename="diffusion-model.pth",
32
+ # token=os.getenv("HF_TOKEN")
 
33
  # )
34
 
35
  # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # model = torch.load(ckpt_path, map_location=device)
37
+ # model = disable_dropout(model)
38
+ # model.to(device)
39
+ # model.eval()
40
+ # return model
41
 
42
+ def load_model():
43
+ ckpt_path = hf_hub_download(
44
+ repo_id="ruurd/tini_bi",
45
+ filename="diffusion-model.pth",
46
+ token=os.getenv("HF_TOKEN"),
47
+ revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
48
+ )
49
 
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
51
 
52
+ # Step 1: Create model from scratch
53
+ model = CustomTransformerModel(CustomTransformerConfig())
 
 
 
54
 
55
+ # Step 2: Load state_dict from full checkpoint
56
+ full_model = torch.load(ckpt_path, map_location=device)
 
 
57
 
58
+ # This handles both full model or just state_dict
59
+ try:
60
+ state_dict = full_model.state_dict()
61
+ except AttributeError:
62
+ state_dict = full_model # already a state_dict
63
+
64
+ # Step 3: Load weights (might print mismatches)
65
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
66
+ print("Missing keys:", missing)
67
+ print("Unexpected keys:", unexpected)
68
+
69
+ model = disable_dropout(model)
70
+ model.to(device)
71
+ model.eval()
72
+ return model
73
 
74
 
75