Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient | |
import os | |
from dotenv import load_dotenv | |
# Load environment variables from .env file, which HF Spaces uses for secrets | |
load_dotenv() | |
app = FastAPI() | |
# Get the token from the environment (Hugging Face secrets) | |
# Use a fallback for local testing if you want | |
token = os.getenv("HF_TOKEN") | |
# Initialize the client with your token | |
client = InferenceClient( | |
"google/gemma-3n-E4B-it", | |
token=token | |
) | |
class Item(BaseModel): | |
prompt: str | |
history: list = [] | |
system_prompt: str = "You are a helpful AI assistant." | |
temperature: float = 0.7 | |
max_new_tokens: int = 1024 | |
top_p: float = 0.95 | |
repetition_penalty: float = 1.0 | |
# CORRECT Llama-3 prompt formatting function | |
def format_prompt(message, history, system_prompt): | |
messages = [ | |
{"role": "system", "content": system_prompt} | |
] | |
for user_prompt, bot_response in history: | |
messages.append({"role": "user", "content": user_prompt}) | |
messages.append({"role": "assistant", "content": bot_response}) | |
messages.append({"role": "user", "content": message}) | |
# This is the official Llama-3 chat template | |
return client.chat_completion(messages, stream=False, max_tokens=1).prompt # HACK: Use the client to build the prompt | |
def generate(item: Item): | |
temperature = item.temperature | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=item.max_new_tokens, | |
top_p=item.top_p, | |
repetition_penalty=item.repetition_penalty, | |
do_sample=True, | |
) | |
# Use the apply_chat_template method to get a correctly formatted string | |
formatted_prompt = client.apply_chat_template( | |
[ | |
{"role": "system", "content": item.system_prompt}, | |
*sum([ | |
[{"role": "user", "content": user}, {"role": "assistant", "content": assistant}] | |
for user, assistant in item.history | |
], []), | |
{"role": "user", "content": item.prompt} | |
], | |
tokenize=False | |
) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
return output | |
async def generate_text(item: Item): | |
response_text = generate(item) | |
return {"response": response_text} | |
# Optional: Add a root endpoint to show the app is alive | |
def read_root(): | |
return {"Status": "API is running. Use the /generate/ endpoint to get a response."} |