| import json, torch | |
| from CondRefAR.pipeline import CondRefARPipeline | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| gpt_cfg = json.load(open("configs/gpt_config.json")) | |
| vq_cfg = json.load(open("configs/vq_config.json")) | |
| pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype) | |
| tok = AutoTokenizer.from_pretrained("google/flan-t5-xl") | |
| enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype).to(device).eval() | |
| prompt = "Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building." | |
| control = "assets/examples/example2.jpg" | |
| from PIL import Image, ImageOps | |
| control_img = Image.open(control).convert("RGB") | |
| inputs = tok([prompt], return_tensors="pt", padding="max_length", truncation=True, max_length=120) | |
| with torch.no_grad(): | |
| emb = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device)).last_hidden_state | |
| imgs = pipe(emb, control_img, cfg_scale=4, temperature=1.0, top_k=2000, top_p=1.0) | |
| imgs[0].save("sample.png") |