from fastapi import FastAPI, Request from pydantic import BaseModel import torch from fastapi.middleware.cors import CORSMiddleware from ROBERTAmodel import * from BERTmodel import * from DISTILLBERTmodel import * import os import zipfile import shutil VISUALIZER_CLASSES = { "BERT": BERTVisualizer, "RoBERTa": RoBERTaVisualizer, "DistilBERT": DistilBERTVisualizer, } VISUALIZER_CACHE = {} app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_MAP = { "BERT": "bert-base-uncased", "RoBERTa": "roberta-base", "DistilBERT": "distilbert-base-uncased", } class LoadModelRequest(BaseModel): model: str sentence: str task:str hypothesis:str class GradAttnModelRequest(BaseModel): model: str task: str sentence: str hypothesis:str maskID: int | None = None class PredModelRequest(BaseModel): model: str sentence: str task:str hypothesis:str maskID: int | None = None @app.get("/ping") def ping(): return {"message": "pong"} def zip_if_needed(src_dir, zip_path): if os.path.exists(zip_path): return # already zipped print(f"šŸ“¦ Zipping {src_dir} → {zip_path}") with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: for root, _, files in os.walk(src_dir): for file in files: full_path = os.path.join(root, file) rel_path = os.path.relpath(full_path, src_dir) zf.write(full_path, arcname=rel_path) print(f"āœ… Created zip: {zip_path}") def extract_zip_if_needed(zip_path, dest_dir): if os.path.exists(dest_dir): print(f"āœ… Already exists: {dest_dir}") return print(f"šŸ”“ Extracting {zip_path} → {dest_dir}") with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(dest_dir) print(f"āœ… Extracted to: {dest_dir}") def print_directory_tree(path): printstr = '' for root, dirs, files in os.walk(path): indent = ' ' * (root.count(os.sep) - path.count(os.sep)) printstr += indent printstr += os.path.basename(root) printstr += '\n' for f in files: printstr += indent printstr += f return printstr def copy_zip_extract_and_report(): src_base = "./hf_cache" dst_base = "/data/hf_cache" for category in ["models", "tokenizers"]: src_dir = os.path.join(src_base, category) dst_dir = os.path.join(dst_base, category) if not os.path.exists(src_dir): continue os.makedirs(dst_dir, exist_ok=True) for name in os.listdir(src_dir): full_path = os.path.join(src_dir, name) if not os.path.isdir(full_path): continue zip_name = f"{name}.zip" local_zip = os.path.join(src_dir, zip_name) dst_zip = os.path.join(dst_dir, zip_name) extract_path = os.path.join(dst_dir, name) zip_if_needed(full_path, local_zip) # Copy zip to /data if not os.path.exists(dst_zip): shutil.copy(local_zip, dst_zip) print(f"šŸ“¤ Copied zip to: {dst_zip}") extract_zip_if_needed(dst_zip, extract_path) print("\nšŸ“¦ Local hf_cache structure:") printstr1 = print_directory_tree("./hf_cache") print("\nšŸ’¾ Persistent /data/hf_cache structure:") printstr2 = print_directory_tree("/data/hf_cache") return printstr1 + '\n\n' + printstr2 @app.get("/copy_and_extract") def copy_and_extract(): import io from contextlib import redirect_stdout printstr = copy_zip_extract_and_report() return {"message": "done", "log": printstr} @app.post("/load_model") def load_model(req: LoadModelRequest): print(f"\n--- /load_model request received ---") print(f"Model: {req.model}") print(f"Sentence: {req.sentence}") print(f"Task: {req.task}") print(f"hypothesis: {req.hypothesis}") if req.model in VISUALIZER_CACHE: del VISUALIZER_CACHE[req.model] torch.cuda.empty_cache() vis_class = VISUALIZER_CLASSES.get(req.model) if vis_class is None: return {"error": f"Unknown model: {req.model}"} print("instantiating visualizer") try: vis = vis_class(task=req.task.lower()) print(vis) VISUALIZER_CACHE[req.model] = vis print("Visualizer instantiated") except Exception as e: print("Visualizer init failed:", e) return {"error": f"Instantiation failed: {str(e)}"} print('tokenizing') try: if req.task.lower() == 'mnli': token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis) else: token_output = vis.tokenize(req.sentence) print("0 Tokenization successful:", token_output["tokens"]) except Exception as e: print("Tokenization failed:", e) return {"error": f"Tokenization failed: {str(e)}"} print('response ready') response = { "model": req.model, "tokens": token_output['tokens'], "num_layers": vis.num_attention_layers, } print("load model successful") print(response) return response @app.post("/predict_model") def predict_model(req: PredModelRequest): print(f"\n--- /predict_model request received ---") print(f"predict: Model: {req.model}") print(f"predict: Task: {req.task}") print(f"predict: sentence: {req.sentence}") print(f"predict: hypothesis: {req.hypothesis}") print(f"predict: maskID: {req.maskID}") print('predict: instantiating') try: vis_class = VISUALIZER_CLASSES.get(req.model) if vis_class is None: return {"error": f"Unknown model: {req.model}"} #if any(p.device.type == 'meta' for p in vis.model.parameters()): # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) vis = vis_class(task=req.task.lower()) VISUALIZER_CACHE[req.model] = vis print("Model reloaded and cached.") except Exception as e: return {"error": f"Failed to reload model: {str(e)}"} print('predict: meta stuff') print('predict: Run prediction') try: if req.task.lower() == 'mnli': decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis) elif req.task.lower() == 'mlm': decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID) else: decoded, top_probs = vis.predict(req.task.lower(), req.sentence) except Exception as e: decoded, top_probs = "error", e print(e) print('predict: response ready') response = { "decoded": decoded, "top_probs": top_probs.tolist(), } print("predict: predict model successful") if len(decoded) > 5: print([(k,v[:5]) for k,v in response.items()]) else: print(response) return response @app.post("/get_grad_attn_matrix") def get_grad_attn_matrix(req: GradAttnModelRequest): try: print(f"\n--- /get_grad_matrix request received ---") print(f"grad:Model: {req.model}") print(f"grad:Task: {req.task}") print(f"grad:sentence: {req.sentence}") print(f"grad: hypothesis: {req.hypothesis}") print(f"predict: maskID: {req.maskID}") try: vis_class = VISUALIZER_CLASSES.get(req.model) if vis_class is None: return {"error": f"Unknown model: {req.model}"} #if any(p.device.type == 'meta' for p in vis.model.parameters()): # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) vis = vis_class(task=req.task.lower()) VISUALIZER_CACHE[req.model] = vis print("Model reloaded and cached.") except Exception as e: return {"error": f"Failed to reload model: {str(e)}"} print("run function") try: if req.task.lower()=='mnli': grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis) elif req.task.lower()=='mlm': grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID) else: grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence) except Exception as e: print("Exception during grad/attn computation:", e) grad_matrix, attn_matrix = e,e response = { "grad_matrix": grad_matrix, "attn_matrix": attn_matrix, } print('grad attn successful') return response except Exception as e: print("SERVER EXCEPTION:", e) return {"error": str(e)}