api-test / app.py
OjciecTadeusz's picture
Update app.py
8345d88 verified
raw
history blame
6.71 kB
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
app = FastAPI()
client = InferenceClient("Qwen/Qwen2.5-Coder-32B-Instruct")
class Item(BaseModel):
prompt: str
history: list
system_prompt: str
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(item: Item):
temperature = float(item.temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(item.top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=item.max_new_tokens,
top_p=top_p,
repetition_penalty=item.repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
return output
@app.post("/generate/")
async def generate_text(item: Item):
return {"response": generate(item)}
# import gradio as gr
# from fastapi import FastAPI, Request, HTTPException
# from fastapi.responses import JSONResponse
# import datetime
# import requests
# import os
# import logging
# import toml
# # Initialize FastAPI
# app = FastAPI()
# # Configure logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# # Load config
# with open("config.toml") as f:
# config = toml.load(f)
# #API_URL = os.getenv('API_URL')
# #API_TOKEN = os.getenv('API_TOKEN')
# # API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2-5-coder-32-a0ab504.hf.space/v1/chat/completions'
# API_URL = 'https://ojciectadeusz-fastapi-inference-qwen2.5-coder-32b-instruct.hf.space/v1/chat/completions'
# headers = {
# "Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}",
# "Content-Type": "application/json"
# }
# def format_chat_response(response_text, prompt_tokens=0, completion_tokens=0):
# return {
# "id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
# "object": "chat.completion",
# "created": int(datetime.datetime.now().timestamp()),
# "model": "Qwen/Qwen2.5-Coder-32B",
# "choices": [{
# "index": 0,
# "message": {
# "role": "assistant",
# "content": response_text
# },
# "finish_reason": "stop"
# }],
# "usage": {
# "prompt_tokens": prompt_tokens,
# "completion_tokens": completion_tokens,
# "total_tokens": prompt_tokens + completion_tokens
# }
# }
# async def query_model(payload):
# try:
# response = requests.post(API_URL, headers=headers, json=payload)
# response.raise_for_status()
# return response.json()
# except requests.exceptions.RequestException as e:
# logger.error(f"Request failed: {e}")
# raise HTTPException(status_code=500, detail=str(e))
# @app.get("/status")
# async def status():
# try:
# response_text = os.getenv('HF_API_TOKEN') + "it's working"
# return JSONResponse(content=format_chat_response(response_text))
# except Exception as e:
# logger.error(f"Status check failed: {e}")
# raise HTTPException(status_code=500, detail=str(e))
# @app.post("/v1/chat/completions")
# async def chat_completion(request: Request):
# try:
# data = await request.json()
# messages = data.get("messages", [])
# if not messages:
# raise HTTPException(status_code=400, detail="Messages are required")
# payload = {
# "inputs": {
# "messages": messages
# },
# "parameters": {
# "max_new_tokens": data.get("max_tokens", 2048),
# "temperature": data.get("temperature", 0.7),
# "top_p": data.get("top_p", 0.95),
# "do_sample": True
# }
# }
# response = await query_model(payload)
# if isinstance(response, dict) and "error" in response:
# raise HTTPException(status_code=500, detail=response["error"])
# response_text = response[0]["generated_text"]
# return JSONResponse(content=format_chat_response(response_text))
# except HTTPException as e:
# logger.error(f"Chat completion failed: {e.detail}")
# raise e
# except Exception as e:
# logger.error(f"Unexpected error: {e}")
# raise HTTPException(status_code=500, detail=str(e))
# def generate_response(messages):
# payload = {
# "inputs": {
# "messages": messages
# },
# "parameters": {
# "max_new_tokens": 2048,
# "temperature": 0.7,
# "top_p": 0.95,
# "do_sample": True
# }
# }
# try:
# response = requests.post(API_URL, headers=headers, json=payload)
# response.raise_for_status()
# result = response.json()
# if isinstance(result, dict) and "error" in result:
# return f"Error: {result['error']}"
# return result[0]["generated_text"]
# except requests.exceptions.RequestException as e:
# logger.error(f"Request failed: {e}")
# return f"Error: {e}"
# def chat_interface(messages):
# chat_history = []
# for message in messages:
# try:
# response = generate_response([{"role": "user", "content": message}])
# chat_history.append({"role": "user", "content": message})
# chat_history.append({"role": "assistant", "content": response})
# except Exception as e:
# chat_history.append({"role": "user", "content": message})
# chat_history.append({"role": "assistant", "content": f"Error: {str(e)}"})
# return chat_history
# # Create Gradio interface
# def gradio_app():
# return gr.ChatInterface(chat_interface, type="messages")
# # Mount both FastAPI and Gradio
# app = gr.mount_gradio_app(app, gradio_app(), path="/")
# # For running with uvicorn directly
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=7860)