CCA / app.py
Lyon28's picture
Create app.py
61b2823 verified
import os
import json
import logging
from typing import Dict, List, Optional, Any
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModel,
pipeline,
T5ForConditionalGeneration,
T5Tokenizer
)
import gradio as gr
from flask import Flask, request, jsonify
import threading
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultiModelAPI:
def __init__(self):
self.models = {}
self.tokenizers = {}
self.pipelines = {}
self.model_configs = {
'Lyon28/Tinny-Llama': 'causal-lm',
'Lyon28/Pythia': 'causal-lm',
'Lyon28/Bert-Tinny': 'feature-extraction',
'Lyon28/Albert-Base-V2': 'feature-extraction',
'Lyon28/T5-Small': 'text2text-generation',
'Lyon28/GPT-2': 'causal-lm',
'Lyon28/GPT-Neo': 'causal-lm',
'Lyon28/Distilbert-Base-Uncased': 'feature-extraction',
'Lyon28/Distil_GPT-2': 'causal-lm',
'Lyon28/GPT-2-Tinny': 'causal-lm',
'Lyon28/Electra-Small': 'feature-extraction'
}
def load_model(self, model_name: str):
"""Load a specific model"""
try:
logger.info(f"Loading model: {model_name}")
if model_name in self.models:
logger.info(f"Model {model_name} already loaded")
return True
model_type = self.model_configs.get(model_name, 'causal-lm')
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir="/app/cache"
)
# Add pad token if not exists
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model based on type
if model_type == 'causal-lm':
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
cache_dir="/app/cache"
)
# Create pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
elif model_type == 'text2text-generation':
model = T5ForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
cache_dir="/app/cache"
)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
else: # feature-extraction or other BERT-like models
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
cache_dir="/app/cache"
)
pipe = pipeline(
"feature-extraction",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
self.models[model_name] = model
self.tokenizers[model_name] = tokenizer
self.pipelines[model_name] = pipe
logger.info(f"Successfully loaded model: {model_name}")
return True
except Exception as e:
logger.error(f"Error loading model {model_name}: {str(e)}")
return False
def generate_text(self, model_name: str, prompt: str, **kwargs):
"""Generate text using specified model"""
try:
if model_name not in self.pipelines:
if not self.load_model(model_name):
return {"error": f"Failed to load model {model_name}"}
pipe = self.pipelines[model_name]
model_type = self.model_configs.get(model_name, 'causal-lm')
# Set default parameters
max_length = kwargs.get('max_length', 100)
temperature = kwargs.get('temperature', 0.7)
top_p = kwargs.get('top_p', 0.9)
do_sample = kwargs.get('do_sample', True)
if model_type == 'causal-lm':
result = pipe(
prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=pipe.tokenizer.eos_token_id
)
return {"generated_text": result[0]['generated_text']}
elif model_type == 'text2text-generation':
result = pipe(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=do_sample
)
return {"generated_text": result[0]['generated_text']}
else: # feature extraction
result = pipe(prompt)
return {"embeddings": result}
except Exception as e:
logger.error(f"Error generating text with {model_name}: {str(e)}")
return {"error": str(e)}
def get_model_info(self):
"""Get information about loaded models"""
return {
"available_models": list(self.model_configs.keys()),
"loaded_models": list(self.models.keys()),
"model_types": self.model_configs
}
# Initialize API
api = MultiModelAPI()
# Flask API
app = Flask(__name__)
@app.route('/api/models', methods=['GET'])
def get_models():
"""Get available models"""
return jsonify(api.get_model_info())
@app.route('/api/load_model', methods=['POST'])
def load_model():
"""Load a specific model"""
data = request.json
model_name = data.get('model_name')
if not model_name:
return jsonify({"error": "model_name is required"}), 400
success = api.load_model(model_name)
if success:
return jsonify({"message": f"Model {model_name} loaded successfully"})
else:
return jsonify({"error": f"Failed to load model {model_name}"}), 500
@app.route('/api/generate', methods=['POST'])
def generate():
"""Generate text using specified model"""
data = request.json
model_name = data.get('model_name')
prompt = data.get('prompt')
if not model_name or not prompt:
return jsonify({"error": "model_name and prompt are required"}), 400
# Extract generation parameters
params = {
'max_length': data.get('max_length', 100),
'temperature': data.get('temperature', 0.7),
'top_p': data.get('top_p', 0.9),
'do_sample': data.get('do_sample', True)
}
result = api.generate_text(model_name, prompt, **params)
return jsonify(result)
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy", "loaded_models": len(api.models)})
# Gradio Interface
def gradio_interface():
def generate_text_ui(model_name, prompt, max_length, temperature, top_p):
if not model_name or not prompt:
return "Please select a model and enter a prompt"
params = {
'max_length': int(max_length),
'temperature': float(temperature),
'top_p': float(top_p),
'do_sample': True
}
result = api.generate_text(model_name, prompt, **params)
if 'error' in result:
return f"Error: {result['error']}"
return result.get('generated_text', str(result))
def load_model_ui(model_name):
if not model_name:
return "Please select a model"
success = api.load_model(model_name)
if success:
return f"✅ Model {model_name} loaded successfully"
else:
return f"❌ Failed to load model {model_name}"
with gr.Blocks(title="Multi-Model API Interface") as interface:
gr.Markdown("# Multi-Model API Interface")
gr.Markdown("Load and interact with multiple Hugging Face models")
with gr.Tab("Model Management"):
model_dropdown = gr.Dropdown(
choices=list(api.model_configs.keys()),
label="Select Model",
value=None
)
load_btn = gr.Button("Load Model")
load_status = gr.Textbox(label="Status", interactive=False)
load_btn.click(
load_model_ui,
inputs=[model_dropdown],
outputs=[load_status]
)
with gr.Tab("Text Generation"):
with gr.Row():
with gr.Column():
gen_model = gr.Dropdown(
choices=list(api.model_configs.keys()),
label="Model",
value=None
)
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3
)
with gr.Row():
max_length = gr.Slider(10, 500, value=100, label="Max Length")
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P")
generate_btn = gr.Button("Generate")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
interactive=False
)
generate_btn.click(
generate_text_ui,
inputs=[gen_model, prompt_input, max_length, temperature, top_p],
outputs=[output_text]
)
with gr.Tab("API Documentation"):
gr.Markdown("""
## API Endpoints
### GET /api/models
Get list of available and loaded models
### POST /api/load_model
Load a specific model
```json
{
"model_name": "Lyon28/GPT-2"
}
```
### POST /api/generate
Generate text using a model
```json
{
"model_name": "Lyon28/GPT-2",
"prompt": "Hello world",
"max_length": 100,
"temperature": 0.7,
"top_p": 0.9,
"do_sample": true
}
```
### GET /health
Health check endpoint
""")
return interface
def run_flask():
"""Run Flask API server"""
app.run(host="0.0.0.0", port=5000, debug=False)
def main():
"""Main function to run both Flask and Gradio"""
# Start Flask in a separate thread
flask_thread = threading.Thread(target=run_flask, daemon=True)
flask_thread.start()
# Give Flask time to start
time.sleep(2)
# Create and launch Gradio interface
interface = gradio_interface()
# Launch Gradio on port 7860 (HF Spaces default)
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
if __name__ == "__main__":
main()