File size: 4,115 Bytes
2cc756f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c52625e
 
 
 
511e59d
3bf9b00
2cc756f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# main.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
import torch
import spacy

# Initialize FastAPI app
app = FastAPI()

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Pydantic model for request body
class SloganRequest(BaseModel):
    brand: str
    description: str
    industry: str
    tone: Optional[str] = "playful"
    num: Optional[int] = 5
    liked_slogan: Optional[str] = None

# Load models
nlp = spacy.load("en_core_web_sm")
model = GPT2LMHeadModel.from_pretrained("./") #slogan_generator_medium")
tokenizer = GPT2Tokenizer.from_pretrained("./") #slogan_generator_medium")
generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

# Tone presets
TONE_PRESETS = {
    "playful": {"temperature": 0.95, "top_p": 0.95, "repetition_penalty": 1.2},
    "bold": {"temperature": 0.8, "top_p": 0.9, "repetition_penalty": 1.45},
    "minimalist": {"temperature": 0.6, "top_p": 0.8, "repetition_penalty": 1.5},
    "luxury": {"temperature": 0.7, "top_p": 0.85, "repetition_penalty": 1.35},
    "classic": {"temperature": 0.7, "top_p": 0.9, "repetition_penalty": 1.25}
}

def summarize_description(text: str) -> str:
    """Extract key words from description using spaCy"""
    doc = nlp(text)
    keywords = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"]]
    return " ".join(keywords[:12])

@app.get("/")
def read_root():
    return {"message": "Welcome to Slogan Generator API. Use POST / to generate slogans."}

# @app.post("/generate-slogans")
@app.post("/")
async def generate_slogans(request: SloganRequest):
    try:
        # Process description
        processed_desc = summarize_description(request.description)
        
        # Generate prompts based on presence of liked slogan
        if request.liked_slogan:
            prompt1 = (
                f"Create {request.industry} brand slogans similar to: '{request.liked_slogan}'\n"
                f"Brand: {request.brand}\n"
                f"Key Attributes: {processed_desc}\n"
                "Slogan:"
            )
            prompt2 = (
                f"Generate slogans in the style of: '{request.liked_slogan}'\n"
                f"For: {request.brand}\n"
                f"Details: {processed_desc}\n"
                "Slogan:"
            )
        else:
            prompt1 = (
                f"Create a {request.industry} brand slogan that's {request.tone} and unique.\n"
                f"Brand: {request.brand}\n"
                f"Attributes: {processed_desc}\n"
                "Slogan:"
            )
            prompt2 = (
                f"Write {request.tone} marketing slogans for this {request.industry} brand:\n"
                f"Name: {request.brand}\n"
                f"About: {processed_desc}\n"
                "Slogan:"
            )

        # Generation parameters
        gen_params = {
            **TONE_PRESETS[request.tone],
            "max_new_tokens": 25,
            "num_return_sequences": request.num,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id
        }

        # Generate from both prompts
        outputs1 = generator(prompt1, **gen_params)
        outputs2 = generator(prompt2, **gen_params)

        # Process and deduplicate slogans
        slogans = []
        for output_group in [outputs1, outputs2]:
            for o in output_group:
                raw = o['generated_text'].split("Slogan:")[-1].strip()
                clean = raw.split("\n")[0].replace('"', '').replace('(', '').split(".")[0].strip()
                if len(clean) > 4 and clean not in slogans:
                    slogans.append(clean)

        return {"slogans": slogans[:request.num * 2]}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))