Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from peft import PeftModel, PeftConfig | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| app = FastAPI() | |
| # Allow CORS for all origins (adjust this in production) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Path to your HF Hub repo with full model + adapter | |
| adapter_path = "imnim/multi-label-email-classifier" | |
| try: | |
| # Load PEFT config to get base model path | |
| peft_config = PeftConfig.from_pretrained(adapter_path, use_auth_token=True) | |
| # Load base model and tokenizer with HF auth token | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| peft_config.base_model_name_or_path, | |
| torch_dtype=torch.float32, | |
| device_map={"": "cpu"}, | |
| use_auth_token=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| peft_config.base_model_name_or_path, | |
| use_auth_token=True | |
| ) | |
| # Load adapter with HF auth token | |
| model = PeftModel.from_pretrained( | |
| base_model, adapter_path, | |
| device_map={"": "cpu"}, | |
| use_auth_token=True | |
| ) | |
| # Setup text-generation pipeline | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1) | |
| except Exception as e: | |
| raise RuntimeError(f"β Failed to load model + adapter: {str(e)}") | |
| # Request schema | |
| class EmailInput(BaseModel): | |
| subject: str | |
| body: str | |
| # POST /classify endpoint | |
| async def classify_email(data: EmailInput): | |
| prompt = f"""### Subject:\n{data.subject}\n\n### Body:\n{data.body}\n\n### Labels:""" | |
| try: | |
| result = pipe(prompt, max_new_tokens=50, do_sample=True, top_k=50, top_p=0.95) | |
| full_text = result[0]["generated_text"] | |
| label_section = full_text.split("### Labels:")[1].strip() | |
| return {"label": label_section} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Model inference failed: {str(e)}") | |