File size: 2,705 Bytes
5fa76ab
 
 
89b33b7
 
5fa76ab
89b33b7
 
5fa76ab
 
 
89b33b7
 
 
 
 
 
556899d
89b33b7
 
5fa76ab
 
 
89b33b7
 
 
 
 
5fa76ab
 
89b33b7
 
 
 
 
5fa76ab
89b33b7
 
 
 
 
 
5fa76ab
 
89b33b7
5fa76ab
 
 
 
 
 
89b33b7
5fa76ab
 
 
 
89b33b7
 
 
 
 
 
 
 
 
 
 
 
 
5fa76ab
 
 
 
 
89b33b7
5fa76ab
 
 
 
89b33b7
 
5fa76ab
89b33b7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv

# Load environment variables from .env file, which HF Spaces uses for secrets
load_dotenv()

app = FastAPI()

# Get the token from the environment (Hugging Face secrets)
# Use a fallback for local testing if you want
token = os.getenv("HF_TOKEN")

# Initialize the client with your token
client = InferenceClient(
    "google/gemma-3n-E4B-it",
    token=token
)

class Item(BaseModel):
    prompt: str
    history: list = []
    system_prompt: str = "You are a helpful AI assistant."
    temperature: float = 0.7
    max_new_tokens: int = 1024
    top_p: float = 0.95
    repetition_penalty: float = 1.0

# CORRECT Llama-3 prompt formatting function
def format_prompt(message, history, system_prompt):
    messages = [
        {"role": "system", "content": system_prompt}
    ]
    for user_prompt, bot_response in history:
        messages.append({"role": "user", "content": user_prompt})
        messages.append({"role": "assistant", "content": bot_response})
    messages.append({"role": "user", "content": message})
    
    # This is the official Llama-3 chat template
    return client.chat_completion(messages, stream=False, max_tokens=1).prompt # HACK: Use the client to build the prompt

def generate(item: Item):
    temperature = item.temperature
    if temperature < 1e-2:
        temperature = 1e-2

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=item.max_new_tokens,
        top_p=item.top_p,
        repetition_penalty=item.repetition_penalty,
        do_sample=True,
    )

    # Use the apply_chat_template method to get a correctly formatted string
    formatted_prompt = client.apply_chat_template(
        [
            {"role": "system", "content": item.system_prompt},
            *sum([
                [{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]
                for user, assistant in item.history
            ], []),
            {"role": "user", "content": item.prompt}
        ],
        tokenize=False
    )
    
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        
    return output

@app.post("/generate/")
async def generate_text(item: Item):
    response_text = generate(item)
    return {"response": response_text}

# Optional: Add a root endpoint to show the app is alive
@app.get("/")
def read_root():
    return {"Status": "API is running. Use the /generate/ endpoint to get a response."}