from fastapi import FastAPI, Request from pydantic import BaseModel from pathlib import Path import torch from fastapi.middleware.cors import CORSMiddleware from ROBERTAmodel import * from BERTmodel import * from DISTILLBERTmodel import * import os import zipfile import shutil VISUALIZER_CLASSES = { "BERT": BERTVisualizer, "RoBERTa": RoBERTaVisualizer, "DistilBERT": DistilBERTVisualizer, } VISUALIZER_CACHE = {} app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_MAP = { "BERT": "bert-base-uncased", "RoBERTa": "roberta-base", "DistilBERT": "distilbert-base-uncased", } class LoadModelRequest(BaseModel): model: str sentence: str task:str hypothesis:str class GradAttnModelRequest(BaseModel): model: str task: str sentence: str hypothesis:str maskID: int | None = None class PredModelRequest(BaseModel): model: str sentence: str task:str hypothesis:str maskID: int | None = None @app.get("/ping") def ping(): return {"message": "pong"} @app.post("/migrate_cache") def migrate_hf_cache(): src_root = Path("/hf_cache") dst_root = Path("/data/hf_cache") if not src_root.exists(): return {"status": "error", "message": "Source directory does not exist"} migrated_files = [] for src_path in src_root.rglob("*"): if src_path.is_file(): relative_path = src_path.relative_to(src_root) dst_path = dst_root / relative_path # Create destination directory if needed dst_path.parent.mkdir(parents=True, exist_ok=True) # Copy file shutil.copy2(src_path, dst_path) migrated_files.append(str(relative_path)) return { "status": "done", "files_migrated": migrated_files, "total": len(migrated_files) } @app.get("/copy_and_extract") def copy_and_extract(): printstr = copy_extract_and_report() return {"message": "done", "log": printstr} @app.get("/data_check") def data_check(): with open("/data/marker.txt", "w") as f: f.write("hello from server.py\n") files = os.listdir("/data") return { "message": "done", "contents": files } @app.get("/list_data") def list_data(): base_path = Path("/data") all_items = [] for path in base_path.rglob("*"): # recursive glob all_items.append({ "path": str(path.relative_to(base_path)), "type": "dir" if path.is_dir() else "file", "size": path.stat().st_size if path.is_file() else None }) return {"items": all_items} @app.post("/purge_data_123456789") def purge_data(): base_path = Path("/data") if not base_path.exists(): return {"status": "error", "message": "/data does not exist"} deleted = [] for child in base_path.iterdir(): try: if child.is_file() or child.is_symlink(): child.unlink() elif child.is_dir(): shutil.rmtree(child) deleted.append(str(child.name)) except Exception as e: deleted.append(f"FAILED: {child.name} ({e})") return { "status": "done", "deleted": deleted, "total": len(deleted) } ############################################################## @app.post("/load_model") def load_model(req: LoadModelRequest): print(f"\n--- /load_model request received ---") print(f"Model: {req.model}") print(f"Sentence: {req.sentence}") print(f"Task: {req.task}") print(f"hypothesis: {req.hypothesis}") if req.model in VISUALIZER_CACHE: del VISUALIZER_CACHE[req.model] torch.cuda.empty_cache() vis_class = VISUALIZER_CLASSES.get(req.model) if vis_class is None: return {"error": f"Unknown model: {req.model}"} print("instantiating visualizer") try: vis = vis_class(task=req.task.lower()) print(vis) VISUALIZER_CACHE[req.model] = vis print("Visualizer instantiated") except Exception as e: print("Visualizer init failed:", e) return {"error": f"Instantiation failed: {str(e)}"} print('tokenizing') try: if req.task.lower() == 'mnli': token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis) else: token_output = vis.tokenize(req.sentence) print("0 Tokenization successful:", token_output["tokens"]) except Exception as e: print("Tokenization failed:", e) return {"error": f"Tokenization failed: {str(e)}"} print('response ready') response = { "model": req.model, "tokens": token_output['tokens'], "num_layers": vis.num_attention_layers, } print("load model successful") print(response) return response @app.post("/predict_model") def predict_model(req: PredModelRequest): print(f"\n--- /predict_model request received ---") print(f"predict: Model: {req.model}") print(f"predict: Task: {req.task}") print(f"predict: sentence: {req.sentence}") print(f"predict: hypothesis: {req.hypothesis}") print(f"predict: maskID: {req.maskID}") print('predict: instantiating') try: vis_class = VISUALIZER_CLASSES.get(req.model) if vis_class is None: return {"error": f"Unknown model: {req.model}"} #if any(p.device.type == 'meta' for p in vis.model.parameters()): # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) vis = vis_class(task=req.task.lower()) VISUALIZER_CACHE[req.model] = vis print("Model reloaded and cached.") except Exception as e: return {"error": f"Failed to reload model: {str(e)}"} print('predict: meta stuff') print('predict: Run prediction') try: if req.task.lower() == 'mnli': decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis) elif req.task.lower() == 'mlm': decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID) else: decoded, top_probs = vis.predict(req.task.lower(), req.sentence) except Exception as e: decoded, top_probs = "error", e print(e) print('predict: response ready') response = { "decoded": decoded, "top_probs": top_probs.tolist(), } print("predict: predict model successful") if len(decoded) > 5: print([(k,v[:5]) for k,v in response.items()]) else: print(response) return response @app.post("/get_grad_attn_matrix") def get_grad_attn_matrix(req: GradAttnModelRequest): try: print(f"\n--- /get_grad_matrix request received ---") print(f"grad:Model: {req.model}") print(f"grad:Task: {req.task}") print(f"grad:sentence: {req.sentence}") print(f"grad: hypothesis: {req.hypothesis}") print(f"predict: maskID: {req.maskID}") try: vis_class = VISUALIZER_CLASSES.get(req.model) if vis_class is None: return {"error": f"Unknown model: {req.model}"} #if any(p.device.type == 'meta' for p in vis.model.parameters()): # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) vis = vis_class(task=req.task.lower()) VISUALIZER_CACHE[req.model] = vis print("Model reloaded and cached.") except Exception as e: return {"error": f"Failed to reload model: {str(e)}"} print("run function") try: if req.task.lower()=='mnli': grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis) elif req.task.lower()=='mlm': grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID) else: grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence) except Exception as e: print("Exception during grad/attn computation:", e) grad_matrix, attn_matrix = e,e response = { "grad_matrix": grad_matrix, "attn_matrix": attn_matrix, } print('grad attn successful') return response except Exception as e: print("SERVER EXCEPTION:", e) return {"error": str(e)} @app.post("/load_all_files") def load_all_files(): print('load BERTmlm ') BERTVisualizer('mlm') print('load BERTmnli ') BERTVisualizer('mnli') print('load BERTsst ') BERTVisualizer('sst') print('load roBERTmlm ') RoBERTaVisualizer('mlm') print('load roBERTmnli') RoBERTaVisualizer('mnli') print('load roBERTsst') RoBERTaVisualizer('sst') print('load distillBERTmlm ') DistilBERTVisualizer('mlm') print('load distillBERTmmli ') DistilBERTVisualizer('mnli') print('load distillBERTsst ') DistilBERTVisualizer('sst')