File size: 3,055 Bytes
696ae63
09fb2db
 
 
 
 
 
 
 
 
c269ab0
 
09fb2db
 
 
696ae63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
'''
from diffusers import StableDiffusionPipeline
import torch

# Load the fine-tuned DreamBooth model
pipe = StableDiffusionPipeline.from_pretrained(
    "./nyc-ad-model",
    torch_dtype=torch.float16,
).to("cuda")  # use "cpu" if no GPU

prompt = "brand name: xyc, fried chicken advertisement poster: a fried chicken in brooklyn street"
image = pipe(prompt, num_inference_steps=500, guidance_scale=7.5).images[0]

# Display or save the image
image.save("output_nyc_ad.png")
image.show()
'''
'''
import torch, faiss, json
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline

texts=json.load(open("prompt.txt"))
index=faiss.read_index("prompt.index")
emb=SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")

def rag_prompt(query,k=3):
    q=emb.encode(query,normalize_embeddings=True).astype("float32")
    _,I=index.search(q.reshape(1,-1),k)
    retrieved=" ".join(texts[i] for i in I[0])
    return f"{retrieved}. {query}"

prompt=rag_prompt("fried chicken advertisement poster")
img=pipe(prompt,num_inference_steps=30,guidance_scale=7.5).images[0]
img.save("rag_output.png")
'''

import torch, faiss, json
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load RAG index
texts = json.load(open("prompt.txt"))
index = faiss.read_index("prompt.index")
emb = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

# Load image generation pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    "./nyc-ad-model",
    torch_dtype=torch.float16
).to("cuda")

# Load your own fine-tuned SFT model
text_model_path = "./sft-model"  # Path to your SFT-finetuned model
tokenizer = AutoTokenizer.from_pretrained(text_model_path)
text_model = AutoModelForCausalLM.from_pretrained(
    text_model_path,
    torch_dtype=torch.float16,
    device_map="auto"
)

# Build retrieval-augmented prompt
def rag_prompt(query, k=3):
    q = emb.encode(query, normalize_embeddings=True).astype("float32")
    _, I = index.search(q.reshape(1, -1), k)
    retrieved = " ".join(texts[i] for i in I[0])
    return f"{retrieved}. {query}"

# Prompt for generation
user_prompt = "fried chicken advertisement poster"
full_prompt = rag_prompt(user_prompt)

# Generate image
image = pipe(full_prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("rag_output.png")

# Construct input prompt compatible with SFT format
copy_prompt = f"""### Instruction:
Generate a catchy advertisement slogan for: {user_prompt}

### Response:"""

inputs = tokenizer(copy_prompt, return_tensors="pt").to("cuda")
output_ids = text_model.generate(
    **inputs,
    max_new_tokens=30,
    do_sample=True,
    top_p=0.95
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

# Output result
print("🖼️ Image saved to rag_output.png")
print("📝 Generated slogan:")
print(response.strip())