Akash Garg commited on
Commit
76ac4e2
·
1 Parent(s): 6cf650b

updating for zerogpu

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -19,6 +19,7 @@ from pathlib import Path
19
  import uuid
20
  import shutil
21
  from huggingface_hub import snapshot_download
 
22
 
23
 
24
  GLOBAL_STATE = {}
@@ -91,6 +92,25 @@ def build_interface():
91
 
92
  return interface
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if __name__=="__main__":
95
 
96
  parser = argparse.ArgumentParser()
@@ -123,19 +143,4 @@ if __name__=="__main__":
123
  repo_id="Roblox/cube3d-v0.1",
124
  local_dir="./model_weights"
125
  )
126
- config_path = args.config_path
127
- gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
128
- shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
129
- engine_fast = EngineFast(
130
- config_path,
131
- gpt_ckpt_path,
132
- shape_ckpt_path,
133
- device=torch.device("cuda"),
134
- )
135
- GLOBAL_STATE["engine_fast"] = engine_fast
136
- GLOBAL_STATE["SAVE_DIR"] = args.save_dir
137
- os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
138
-
139
- demo = build_interface()
140
- demo.queue(default_concurrency_limit=1)
141
- demo.launch()
 
19
  import uuid
20
  import shutil
21
  from huggingface_hub import snapshot_download
22
+ import spaces
23
 
24
 
25
  GLOBAL_STATE = {}
 
92
 
93
  return interface
94
 
95
+ @spaces.GPU
96
+ def generate(args):
97
+ config_path = args.config_path
98
+ gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
99
+ shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
100
+ engine_fast = EngineFast(
101
+ config_path,
102
+ gpt_ckpt_path,
103
+ shape_ckpt_path,
104
+ device=torch.device("cuda"),
105
+ )
106
+ GLOBAL_STATE["engine_fast"] = engine_fast
107
+ GLOBAL_STATE["SAVE_DIR"] = args.save_dir
108
+ os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
109
+
110
+ demo = build_interface()
111
+ demo.queue(default_concurrency_limit=1)
112
+ demo.launch()
113
+
114
  if __name__=="__main__":
115
 
116
  parser = argparse.ArgumentParser()
 
143
  repo_id="Roblox/cube3d-v0.1",
144
  local_dir="./model_weights"
145
  )
146
+ generate(args)