Spaces:
Sleeping
Sleeping
File size: 2,705 Bytes
5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 556899d 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 5fa76ab 89b33b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import os
from dotenv import load_dotenv
# Load environment variables from .env file, which HF Spaces uses for secrets
load_dotenv()
app = FastAPI()
# Get the token from the environment (Hugging Face secrets)
# Use a fallback for local testing if you want
token = os.getenv("HF_TOKEN")
# Initialize the client with your token
client = InferenceClient(
"google/gemma-3n-E4B-it",
token=token
)
class Item(BaseModel):
prompt: str
history: list = []
system_prompt: str = "You are a helpful AI assistant."
temperature: float = 0.7
max_new_tokens: int = 1024
top_p: float = 0.95
repetition_penalty: float = 1.0
# CORRECT Llama-3 prompt formatting function
def format_prompt(message, history, system_prompt):
messages = [
{"role": "system", "content": system_prompt}
]
for user_prompt, bot_response in history:
messages.append({"role": "user", "content": user_prompt})
messages.append({"role": "assistant", "content": bot_response})
messages.append({"role": "user", "content": message})
# This is the official Llama-3 chat template
return client.chat_completion(messages, stream=False, max_tokens=1).prompt # HACK: Use the client to build the prompt
def generate(item: Item):
temperature = item.temperature
if temperature < 1e-2:
temperature = 1e-2
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=item.top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
)
# Use the apply_chat_template method to get a correctly formatted string
formatted_prompt = client.apply_chat_template(
[
{"role": "system", "content": item.system_prompt},
*sum([
[{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]
for user, assistant in item.history
], []),
{"role": "user", "content": item.prompt}
],
tokenize=False
)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return output
@app.post("/generate/")
async def generate_text(item: Item):
response_text = generate(item)
return {"response": response_text}
# Optional: Add a root endpoint to show the app is alive
@app.get("/")
def read_root():
return {"Status": "API is running. Use the /generate/ endpoint to get a response."} |