api-test / app.py
OjciecTadeusz's picture
Update app.py
1fb73a8 verified
raw
history blame
4.13 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import datetime
import asyncio
# Initialize FastAPI
app = FastAPI()
# Load model and tokenizer
model_name = "Qwen/Qwen2.5-Coder-32B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Configure model loading with specific parameters
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
def format_chat_response(response_text, prompt_tokens, completion_tokens):
return {
"id": f"chatcmpl-{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}",
"object": "chat.completion",
"created": int(datetime.datetime.now().timestamp()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
}
}
@app.post("/v1/chat/completions")
async def chat_completion(request: Request):
try:
data = await request.json()
messages = data.get("messages", [])
# Convert messages to model input format
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Count prompt tokens
prompt_tokens = len(tokenizer.encode(prompt))
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=data.get("max_tokens", 2048),
temperature=data.get("temperature", 0.7),
top_p=data.get("top_p", 0.95),
do_sample=True
)
response_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
completion_tokens = len(tokenizer.encode(response_text))
return JSONResponse(
content=format_chat_response(response_text, prompt_tokens, completion_tokens)
)
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
# Synchronous function to generate response
def generate_response(messages):
# Convert messages to model input format
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Generate response
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.7,
top_p=0.95,
do_sample=True
)
return tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
# Gradio interface for testing
def chat_interface(message, history):
history = history or []
messages = []
# Convert history to messages format
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
# Add current message
messages.append({"role": "user", "content": message})
# Generate response synchronously
try:
response_text = generate_response(messages)
return response_text
except Exception as e:
return f"Error generating response: {str(e)}"
interface = gr.ChatInterface(
chat_interface,
title="Qwen2.5-Coder-32B Chat",
description="Chat with Qwen2.5-Coder-32B model. This Space also provides a /v1/chat/completions endpoint."
)
# Mount both FastAPI and Gradio
app = gr.mount_gradio_app(app, interface, path="/")