import os import json import logging from typing import Dict, List, Optional, Any import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModel, pipeline, T5ForConditionalGeneration, T5Tokenizer ) import gradio as gr from flask import Flask, request, jsonify import threading import time # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MultiModelAPI: def __init__(self): self.models = {} self.tokenizers = {} self.pipelines = {} self.model_configs = { 'Lyon28/Tinny-Llama': 'causal-lm', 'Lyon28/Pythia': 'causal-lm', 'Lyon28/Bert-Tinny': 'feature-extraction', 'Lyon28/Albert-Base-V2': 'feature-extraction', 'Lyon28/T5-Small': 'text2text-generation', 'Lyon28/GPT-2': 'causal-lm', 'Lyon28/GPT-Neo': 'causal-lm', 'Lyon28/Distilbert-Base-Uncased': 'feature-extraction', 'Lyon28/Distil_GPT-2': 'causal-lm', 'Lyon28/GPT-2-Tinny': 'causal-lm', 'Lyon28/Electra-Small': 'feature-extraction' } def load_model(self, model_name: str): """Load a specific model""" try: logger.info(f"Loading model: {model_name}") if model_name in self.models: logger.info(f"Model {model_name} already loaded") return True model_type = self.model_configs.get(model_name, 'causal-lm') # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, cache_dir="/app/cache" ) # Add pad token if not exists if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load model based on type if model_type == 'causal-lm': model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, cache_dir="/app/cache" ) # Create pipeline pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1 ) elif model_type == 'text2text-generation': model = T5ForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, cache_dir="/app/cache" ) pipe = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1 ) else: # feature-extraction or other BERT-like models model = AutoModel.from_pretrained( model_name, trust_remote_code=True, cache_dir="/app/cache" ) pipe = pipeline( "feature-extraction", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1 ) self.models[model_name] = model self.tokenizers[model_name] = tokenizer self.pipelines[model_name] = pipe logger.info(f"Successfully loaded model: {model_name}") return True except Exception as e: logger.error(f"Error loading model {model_name}: {str(e)}") return False def generate_text(self, model_name: str, prompt: str, **kwargs): """Generate text using specified model""" try: if model_name not in self.pipelines: if not self.load_model(model_name): return {"error": f"Failed to load model {model_name}"} pipe = self.pipelines[model_name] model_type = self.model_configs.get(model_name, 'causal-lm') # Set default parameters max_length = kwargs.get('max_length', 100) temperature = kwargs.get('temperature', 0.7) top_p = kwargs.get('top_p', 0.9) do_sample = kwargs.get('do_sample', True) if model_type == 'causal-lm': result = pipe( prompt, max_length=max_length, temperature=temperature, top_p=top_p, do_sample=do_sample, pad_token_id=pipe.tokenizer.eos_token_id ) return {"generated_text": result[0]['generated_text']} elif model_type == 'text2text-generation': result = pipe( prompt, max_length=max_length, temperature=temperature, do_sample=do_sample ) return {"generated_text": result[0]['generated_text']} else: # feature extraction result = pipe(prompt) return {"embeddings": result} except Exception as e: logger.error(f"Error generating text with {model_name}: {str(e)}") return {"error": str(e)} def get_model_info(self): """Get information about loaded models""" return { "available_models": list(self.model_configs.keys()), "loaded_models": list(self.models.keys()), "model_types": self.model_configs } # Initialize API api = MultiModelAPI() # Flask API app = Flask(__name__) @app.route('/api/models', methods=['GET']) def get_models(): """Get available models""" return jsonify(api.get_model_info()) @app.route('/api/load_model', methods=['POST']) def load_model(): """Load a specific model""" data = request.json model_name = data.get('model_name') if not model_name: return jsonify({"error": "model_name is required"}), 400 success = api.load_model(model_name) if success: return jsonify({"message": f"Model {model_name} loaded successfully"}) else: return jsonify({"error": f"Failed to load model {model_name}"}), 500 @app.route('/api/generate', methods=['POST']) def generate(): """Generate text using specified model""" data = request.json model_name = data.get('model_name') prompt = data.get('prompt') if not model_name or not prompt: return jsonify({"error": "model_name and prompt are required"}), 400 # Extract generation parameters params = { 'max_length': data.get('max_length', 100), 'temperature': data.get('temperature', 0.7), 'top_p': data.get('top_p', 0.9), 'do_sample': data.get('do_sample', True) } result = api.generate_text(model_name, prompt, **params) return jsonify(result) @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({"status": "healthy", "loaded_models": len(api.models)}) # Gradio Interface def gradio_interface(): def generate_text_ui(model_name, prompt, max_length, temperature, top_p): if not model_name or not prompt: return "Please select a model and enter a prompt" params = { 'max_length': int(max_length), 'temperature': float(temperature), 'top_p': float(top_p), 'do_sample': True } result = api.generate_text(model_name, prompt, **params) if 'error' in result: return f"Error: {result['error']}" return result.get('generated_text', str(result)) def load_model_ui(model_name): if not model_name: return "Please select a model" success = api.load_model(model_name) if success: return f"✅ Model {model_name} loaded successfully" else: return f"❌ Failed to load model {model_name}" with gr.Blocks(title="Multi-Model API Interface") as interface: gr.Markdown("# Multi-Model API Interface") gr.Markdown("Load and interact with multiple Hugging Face models") with gr.Tab("Model Management"): model_dropdown = gr.Dropdown( choices=list(api.model_configs.keys()), label="Select Model", value=None ) load_btn = gr.Button("Load Model") load_status = gr.Textbox(label="Status", interactive=False) load_btn.click( load_model_ui, inputs=[model_dropdown], outputs=[load_status] ) with gr.Tab("Text Generation"): with gr.Row(): with gr.Column(): gen_model = gr.Dropdown( choices=list(api.model_configs.keys()), label="Model", value=None ) prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", lines=3 ) with gr.Row(): max_length = gr.Slider(10, 500, value=100, label="Max Length") temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P") generate_btn = gr.Button("Generate") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=10, interactive=False ) generate_btn.click( generate_text_ui, inputs=[gen_model, prompt_input, max_length, temperature, top_p], outputs=[output_text] ) with gr.Tab("API Documentation"): gr.Markdown(""" ## API Endpoints ### GET /api/models Get list of available and loaded models ### POST /api/load_model Load a specific model ```json { "model_name": "Lyon28/GPT-2" } ``` ### POST /api/generate Generate text using a model ```json { "model_name": "Lyon28/GPT-2", "prompt": "Hello world", "max_length": 100, "temperature": 0.7, "top_p": 0.9, "do_sample": true } ``` ### GET /health Health check endpoint """) return interface def run_flask(): """Run Flask API server""" app.run(host="0.0.0.0", port=5000, debug=False) def main(): """Main function to run both Flask and Gradio""" # Start Flask in a separate thread flask_thread = threading.Thread(target=run_flask, daemon=True) flask_thread.start() # Give Flask time to start time.sleep(2) # Create and launch Gradio interface interface = gradio_interface() # Launch Gradio on port 7860 (HF Spaces default) interface.launch( server_name="0.0.0.0", server_port=7860, share=False ) if __name__ == "__main__": main()