Spaces:
Running
Running
''' | |
from diffusers import StableDiffusionPipeline | |
import torch | |
# Load the fine-tuned DreamBooth model | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"./nyc-ad-model", | |
torch_dtype=torch.float16, | |
).to("cuda") # use "cpu" if no GPU | |
prompt = "brand name: xyc, fried chicken advertisement poster: a fried chicken in brooklyn street" | |
image = pipe(prompt, num_inference_steps=500, guidance_scale=7.5).images[0] | |
# Display or save the image | |
image.save("output_nyc_ad.png") | |
image.show() | |
''' | |
''' | |
import torch, faiss, json | |
from sentence_transformers import SentenceTransformer | |
from diffusers import StableDiffusionPipeline | |
texts=json.load(open("prompt.txt")) | |
index=faiss.read_index("prompt.index") | |
emb=SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda") | |
def rag_prompt(query,k=3): | |
q=emb.encode(query,normalize_embeddings=True).astype("float32") | |
_,I=index.search(q.reshape(1,-1),k) | |
retrieved=" ".join(texts[i] for i in I[0]) | |
return f"{retrieved}. {query}" | |
prompt=rag_prompt("fried chicken advertisement poster") | |
img=pipe(prompt,num_inference_steps=30,guidance_scale=7.5).images[0] | |
img.save("rag_output.png") | |
''' | |
import torch, faiss, json | |
from sentence_transformers import SentenceTransformer | |
from diffusers import StableDiffusionPipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load RAG index | |
texts = json.load(open("prompt.txt")) | |
index = faiss.read_index("prompt.index") | |
emb = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
# Load image generation pipeline | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"./nyc-ad-model", | |
torch_dtype=torch.float16 | |
).to("cuda") | |
# Load your own fine-tuned SFT model | |
text_model_path = "./sft-model" # Path to your SFT-finetuned model | |
tokenizer = AutoTokenizer.from_pretrained(text_model_path) | |
text_model = AutoModelForCausalLM.from_pretrained( | |
text_model_path, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Build retrieval-augmented prompt | |
def rag_prompt(query, k=3): | |
q = emb.encode(query, normalize_embeddings=True).astype("float32") | |
_, I = index.search(q.reshape(1, -1), k) | |
retrieved = " ".join(texts[i] for i in I[0]) | |
return f"{retrieved}. {query}" | |
# Prompt for generation | |
user_prompt = "fried chicken advertisement poster" | |
full_prompt = rag_prompt(user_prompt) | |
# Generate image | |
image = pipe(full_prompt, num_inference_steps=30, guidance_scale=7.5).images[0] | |
image.save("rag_output.png") | |
# Construct input prompt compatible with SFT format | |
copy_prompt = f"""### Instruction: | |
Generate a catchy advertisement slogan for: {user_prompt} | |
### Response:""" | |
inputs = tokenizer(copy_prompt, return_tensors="pt").to("cuda") | |
output_ids = text_model.generate( | |
**inputs, | |
max_new_tokens=30, | |
do_sample=True, | |
top_p=0.95 | |
) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Output result | |
print("๐ผ๏ธ Image saved to rag_output.png") | |
print("๐ Generated slogan:") | |
print(response.strip()) |