Spaces:
Running
on
A10G
Running
on
A10G
Jiaming Han
commited on
Commit
·
22ff2b2
1
Parent(s):
3c55139
update
Browse files- app.py +3 -2
- t2i_inference.py +1 -1
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from torchvision.transforms.functional import to_tensor
|
4 |
-
from huggingface_hub import hf_hub_download, login
|
5 |
|
6 |
from t2i_inference import T2IConfig, TextToImageInference
|
7 |
|
@@ -28,7 +28,8 @@ def generate_text(self, image: str, prompt: str) -> str:
|
|
28 |
|
29 |
login(token=os.getenv('HF_TOKEN'))
|
30 |
config = T2IConfig()
|
31 |
-
config.
|
|
|
32 |
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
|
33 |
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
|
34 |
inference = TextToImageInference(config)
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
from torchvision.transforms.functional import to_tensor
|
4 |
+
from huggingface_hub import hf_hub_download, snapshot_download, login
|
5 |
|
6 |
from t2i_inference import T2IConfig, TextToImageInference
|
7 |
|
|
|
28 |
|
29 |
login(token=os.getenv('HF_TOKEN'))
|
30 |
config = T2IConfig()
|
31 |
+
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
|
32 |
+
config.ar_path = hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth")
|
33 |
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
|
34 |
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
|
35 |
inference = TextToImageInference(config)
|
t2i_inference.py
CHANGED
@@ -13,7 +13,7 @@ from tok.mm_autoencoder import MMAutoEncoder
|
|
13 |
class T2IConfig:
|
14 |
model_path: str = "csuhan/Tar-1.5B"
|
15 |
# visual tokenizer config
|
16 |
-
ar_path: str = '
|
17 |
encoder_path: str = 'ta_tok.pth'
|
18 |
decoder_path: str = 'vq_ds16_t2i.pt'
|
19 |
|
|
|
13 |
class T2IConfig:
|
14 |
model_path: str = "csuhan/Tar-1.5B"
|
15 |
# visual tokenizer config
|
16 |
+
ar_path: str = 'ar_dtok_lp_256px.pth'
|
17 |
encoder_path: str = 'ta_tok.pth'
|
18 |
decoder_path: str = 'vq_ds16_t2i.pt'
|
19 |
|