|
import os |
|
import json |
|
import logging |
|
from typing import Dict, List, Optional, Any |
|
import torch |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
AutoModelForSequenceClassification, |
|
AutoModelForTokenClassification, |
|
AutoModel, |
|
pipeline, |
|
T5ForConditionalGeneration, |
|
T5Tokenizer |
|
) |
|
import gradio as gr |
|
from flask import Flask, request, jsonify |
|
import threading |
|
import time |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class MultiModelAPI: |
|
def __init__(self): |
|
self.models = {} |
|
self.tokenizers = {} |
|
self.pipelines = {} |
|
self.model_configs = { |
|
'Lyon28/Tinny-Llama': 'causal-lm', |
|
'Lyon28/Pythia': 'causal-lm', |
|
'Lyon28/Bert-Tinny': 'feature-extraction', |
|
'Lyon28/Albert-Base-V2': 'feature-extraction', |
|
'Lyon28/T5-Small': 'text2text-generation', |
|
'Lyon28/GPT-2': 'causal-lm', |
|
'Lyon28/GPT-Neo': 'causal-lm', |
|
'Lyon28/Distilbert-Base-Uncased': 'feature-extraction', |
|
'Lyon28/Distil_GPT-2': 'causal-lm', |
|
'Lyon28/GPT-2-Tinny': 'causal-lm', |
|
'Lyon28/Electra-Small': 'feature-extraction' |
|
} |
|
|
|
def load_model(self, model_name: str): |
|
"""Load a specific model""" |
|
try: |
|
logger.info(f"Loading model: {model_name}") |
|
|
|
if model_name in self.models: |
|
logger.info(f"Model {model_name} already loaded") |
|
return True |
|
|
|
model_type = self.model_configs.get(model_name, 'causal-lm') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
cache_dir="/app/cache" |
|
) |
|
|
|
|
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
if model_type == 'causal-lm': |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
cache_dir="/app/cache" |
|
) |
|
|
|
pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
elif model_type == 'text2text-generation': |
|
model = T5ForConditionalGeneration.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
cache_dir="/app/cache" |
|
) |
|
pipe = pipeline( |
|
"text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
else: |
|
model = AutoModel.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
cache_dir="/app/cache" |
|
) |
|
pipe = pipeline( |
|
"feature-extraction", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
|
|
self.models[model_name] = model |
|
self.tokenizers[model_name] = tokenizer |
|
self.pipelines[model_name] = pipe |
|
|
|
logger.info(f"Successfully loaded model: {model_name}") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model {model_name}: {str(e)}") |
|
return False |
|
|
|
def generate_text(self, model_name: str, prompt: str, **kwargs): |
|
"""Generate text using specified model""" |
|
try: |
|
if model_name not in self.pipelines: |
|
if not self.load_model(model_name): |
|
return {"error": f"Failed to load model {model_name}"} |
|
|
|
pipe = self.pipelines[model_name] |
|
model_type = self.model_configs.get(model_name, 'causal-lm') |
|
|
|
|
|
max_length = kwargs.get('max_length', 100) |
|
temperature = kwargs.get('temperature', 0.7) |
|
top_p = kwargs.get('top_p', 0.9) |
|
do_sample = kwargs.get('do_sample', True) |
|
|
|
if model_type == 'causal-lm': |
|
result = pipe( |
|
prompt, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=do_sample, |
|
pad_token_id=pipe.tokenizer.eos_token_id |
|
) |
|
return {"generated_text": result[0]['generated_text']} |
|
|
|
elif model_type == 'text2text-generation': |
|
result = pipe( |
|
prompt, |
|
max_length=max_length, |
|
temperature=temperature, |
|
do_sample=do_sample |
|
) |
|
return {"generated_text": result[0]['generated_text']} |
|
|
|
else: |
|
result = pipe(prompt) |
|
return {"embeddings": result} |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating text with {model_name}: {str(e)}") |
|
return {"error": str(e)} |
|
|
|
def get_model_info(self): |
|
"""Get information about loaded models""" |
|
return { |
|
"available_models": list(self.model_configs.keys()), |
|
"loaded_models": list(self.models.keys()), |
|
"model_types": self.model_configs |
|
} |
|
|
|
|
|
api = MultiModelAPI() |
|
|
|
|
|
app = Flask(__name__) |
|
|
|
@app.route('/api/models', methods=['GET']) |
|
def get_models(): |
|
"""Get available models""" |
|
return jsonify(api.get_model_info()) |
|
|
|
@app.route('/api/load_model', methods=['POST']) |
|
def load_model(): |
|
"""Load a specific model""" |
|
data = request.json |
|
model_name = data.get('model_name') |
|
|
|
if not model_name: |
|
return jsonify({"error": "model_name is required"}), 400 |
|
|
|
success = api.load_model(model_name) |
|
if success: |
|
return jsonify({"message": f"Model {model_name} loaded successfully"}) |
|
else: |
|
return jsonify({"error": f"Failed to load model {model_name}"}), 500 |
|
|
|
@app.route('/api/generate', methods=['POST']) |
|
def generate(): |
|
"""Generate text using specified model""" |
|
data = request.json |
|
model_name = data.get('model_name') |
|
prompt = data.get('prompt') |
|
|
|
if not model_name or not prompt: |
|
return jsonify({"error": "model_name and prompt are required"}), 400 |
|
|
|
|
|
params = { |
|
'max_length': data.get('max_length', 100), |
|
'temperature': data.get('temperature', 0.7), |
|
'top_p': data.get('top_p', 0.9), |
|
'do_sample': data.get('do_sample', True) |
|
} |
|
|
|
result = api.generate_text(model_name, prompt, **params) |
|
return jsonify(result) |
|
|
|
@app.route('/health', methods=['GET']) |
|
def health_check(): |
|
"""Health check endpoint""" |
|
return jsonify({"status": "healthy", "loaded_models": len(api.models)}) |
|
|
|
|
|
def gradio_interface(): |
|
def generate_text_ui(model_name, prompt, max_length, temperature, top_p): |
|
if not model_name or not prompt: |
|
return "Please select a model and enter a prompt" |
|
|
|
params = { |
|
'max_length': int(max_length), |
|
'temperature': float(temperature), |
|
'top_p': float(top_p), |
|
'do_sample': True |
|
} |
|
|
|
result = api.generate_text(model_name, prompt, **params) |
|
|
|
if 'error' in result: |
|
return f"Error: {result['error']}" |
|
|
|
return result.get('generated_text', str(result)) |
|
|
|
def load_model_ui(model_name): |
|
if not model_name: |
|
return "Please select a model" |
|
|
|
success = api.load_model(model_name) |
|
if success: |
|
return f"✅ Model {model_name} loaded successfully" |
|
else: |
|
return f"❌ Failed to load model {model_name}" |
|
|
|
with gr.Blocks(title="Multi-Model API Interface") as interface: |
|
gr.Markdown("# Multi-Model API Interface") |
|
gr.Markdown("Load and interact with multiple Hugging Face models") |
|
|
|
with gr.Tab("Model Management"): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(api.model_configs.keys()), |
|
label="Select Model", |
|
value=None |
|
) |
|
load_btn = gr.Button("Load Model") |
|
load_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
load_btn.click( |
|
load_model_ui, |
|
inputs=[model_dropdown], |
|
outputs=[load_status] |
|
) |
|
|
|
with gr.Tab("Text Generation"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gen_model = gr.Dropdown( |
|
choices=list(api.model_configs.keys()), |
|
label="Model", |
|
value=None |
|
) |
|
prompt_input = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Enter your prompt here...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
max_length = gr.Slider(10, 500, value=100, label="Max Length") |
|
temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature") |
|
top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P") |
|
|
|
generate_btn = gr.Button("Generate") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox( |
|
label="Generated Text", |
|
lines=10, |
|
interactive=False |
|
) |
|
|
|
generate_btn.click( |
|
generate_text_ui, |
|
inputs=[gen_model, prompt_input, max_length, temperature, top_p], |
|
outputs=[output_text] |
|
) |
|
|
|
with gr.Tab("API Documentation"): |
|
gr.Markdown(""" |
|
## API Endpoints |
|
|
|
### GET /api/models |
|
Get list of available and loaded models |
|
|
|
### POST /api/load_model |
|
Load a specific model |
|
```json |
|
{ |
|
"model_name": "Lyon28/GPT-2" |
|
} |
|
``` |
|
|
|
### POST /api/generate |
|
Generate text using a model |
|
```json |
|
{ |
|
"model_name": "Lyon28/GPT-2", |
|
"prompt": "Hello world", |
|
"max_length": 100, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"do_sample": true |
|
} |
|
``` |
|
|
|
### GET /health |
|
Health check endpoint |
|
""") |
|
|
|
return interface |
|
|
|
def run_flask(): |
|
"""Run Flask API server""" |
|
app.run(host="0.0.0.0", port=5000, debug=False) |
|
|
|
def main(): |
|
"""Main function to run both Flask and Gradio""" |
|
|
|
flask_thread = threading.Thread(target=run_flask, daemon=True) |
|
flask_thread.start() |
|
|
|
|
|
time.sleep(2) |
|
|
|
|
|
interface = gradio_interface() |
|
|
|
|
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |