ManojINaik's picture
Update main.py
556899d verified
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv
# Load environment variables from .env file, which HF Spaces uses for secrets
load_dotenv()
app = FastAPI()
# Get the token from the environment (Hugging Face secrets)
# Use a fallback for local testing if you want
token = os.getenv("HF_TOKEN")
# Initialize the client with your token
client = InferenceClient(
"google/gemma-3n-E4B-it",
token=token
)
class Item(BaseModel):
prompt: str
history: list = []
system_prompt: str = "You are a helpful AI assistant."
temperature: float = 0.7
max_new_tokens: int = 1024
top_p: float = 0.95
repetition_penalty: float = 1.0
# CORRECT Llama-3 prompt formatting function
def format_prompt(message, history, system_prompt):
messages = [
{"role": "system", "content": system_prompt}
]
for user_prompt, bot_response in history:
messages.append({"role": "user", "content": user_prompt})
messages.append({"role": "assistant", "content": bot_response})
messages.append({"role": "user", "content": message})
# This is the official Llama-3 chat template
return client.chat_completion(messages, stream=False, max_tokens=1).prompt # HACK: Use the client to build the prompt
def generate(item: Item):
temperature = item.temperature
if temperature < 1e-2:
temperature = 1e-2
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=item.top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
)
# Use the apply_chat_template method to get a correctly formatted string
formatted_prompt = client.apply_chat_template(
[
{"role": "system", "content": item.system_prompt},
*sum([
[{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]
for user, assistant in item.history
], []),
{"role": "user", "content": item.prompt}
],
tokenize=False
)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return output
@app.post("/generate/")
async def generate_text(item: Item):
response_text = generate(item)
return {"response": response_text}
# Optional: Add a root endpoint to show the app is alive
@app.get("/")
def read_root():
return {"Status": "API is running. Use the /generate/ endpoint to get a response."}