Spaces:
Running
Running
File size: 3,055 Bytes
696ae63 09fb2db c269ab0 09fb2db 696ae63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
'''
from diffusers import StableDiffusionPipeline
import torch
# Load the fine-tuned DreamBooth model
pipe = StableDiffusionPipeline.from_pretrained(
"./nyc-ad-model",
torch_dtype=torch.float16,
).to("cuda") # use "cpu" if no GPU
prompt = "brand name: xyc, fried chicken advertisement poster: a fried chicken in brooklyn street"
image = pipe(prompt, num_inference_steps=500, guidance_scale=7.5).images[0]
# Display or save the image
image.save("output_nyc_ad.png")
image.show()
'''
'''
import torch, faiss, json
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
texts=json.load(open("prompt.txt"))
index=faiss.read_index("prompt.index")
emb=SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
def rag_prompt(query,k=3):
q=emb.encode(query,normalize_embeddings=True).astype("float32")
_,I=index.search(q.reshape(1,-1),k)
retrieved=" ".join(texts[i] for i in I[0])
return f"{retrieved}. {query}"
prompt=rag_prompt("fried chicken advertisement poster")
img=pipe(prompt,num_inference_steps=30,guidance_scale=7.5).images[0]
img.save("rag_output.png")
'''
import torch, faiss, json
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load RAG index
texts = json.load(open("prompt.txt"))
index = faiss.read_index("prompt.index")
emb = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# Load image generation pipeline
pipe = StableDiffusionPipeline.from_pretrained(
"./nyc-ad-model",
torch_dtype=torch.float16
).to("cuda")
# Load your own fine-tuned SFT model
text_model_path = "./sft-model" # Path to your SFT-finetuned model
tokenizer = AutoTokenizer.from_pretrained(text_model_path)
text_model = AutoModelForCausalLM.from_pretrained(
text_model_path,
torch_dtype=torch.float16,
device_map="auto"
)
# Build retrieval-augmented prompt
def rag_prompt(query, k=3):
q = emb.encode(query, normalize_embeddings=True).astype("float32")
_, I = index.search(q.reshape(1, -1), k)
retrieved = " ".join(texts[i] for i in I[0])
return f"{retrieved}. {query}"
# Prompt for generation
user_prompt = "fried chicken advertisement poster"
full_prompt = rag_prompt(user_prompt)
# Generate image
image = pipe(full_prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("rag_output.png")
# Construct input prompt compatible with SFT format
copy_prompt = f"""### Instruction:
Generate a catchy advertisement slogan for: {user_prompt}
### Response:"""
inputs = tokenizer(copy_prompt, return_tensors="pt").to("cuda")
output_ids = text_model.generate(
**inputs,
max_new_tokens=30,
do_sample=True,
top_p=0.95
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Output result
print("🖼️ Image saved to rag_output.png")
print("📝 Generated slogan:")
print(response.strip()) |