Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
import torch | |
from fastapi.middleware.cors import CORSMiddleware | |
from ROBERTAmodel import * | |
from BERTmodel import * | |
from DISTILLBERTmodel import * | |
import os | |
import zipfile | |
import shutil | |
VISUALIZER_CLASSES = { | |
"BERT": BERTVisualizer, | |
"RoBERTa": RoBERTaVisualizer, | |
"DistilBERT": DistilBERTVisualizer, | |
} | |
VISUALIZER_CACHE = {} | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
MODEL_MAP = { | |
"BERT": "bert-base-uncased", | |
"RoBERTa": "roberta-base", | |
"DistilBERT": "distilbert-base-uncased", | |
} | |
class LoadModelRequest(BaseModel): | |
model: str | |
sentence: str | |
task:str | |
hypothesis:str | |
class GradAttnModelRequest(BaseModel): | |
model: str | |
task: str | |
sentence: str | |
hypothesis:str | |
maskID: int | None = None | |
class PredModelRequest(BaseModel): | |
model: str | |
sentence: str | |
task:str | |
hypothesis:str | |
maskID: int | None = None | |
def ping(): | |
return {"message": "pong"} | |
def zip_if_needed(src_dir, zip_path): | |
if os.path.exists(zip_path): | |
return # already zipped | |
print(f"π¦ Zipping {src_dir} β {zip_path}") | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf: | |
for root, _, files in os.walk(src_dir): | |
for file in files: | |
full_path = os.path.join(root, file) | |
rel_path = os.path.relpath(full_path, src_dir) | |
zf.write(full_path, arcname=rel_path) | |
print(f"β Created zip: {zip_path}") | |
def extract_zip_if_needed(zip_path, dest_dir): | |
if os.path.exists(dest_dir): | |
print(f"β Already exists: {dest_dir}") | |
return | |
print(f"π Extracting {zip_path} β {dest_dir}") | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(dest_dir) | |
print(f"β Extracted to: {dest_dir}") | |
def print_directory_tree(path): | |
for root, dirs, files in os.walk(path): | |
indent = ' ' * (root.count(os.sep) - path.count(os.sep)) | |
print(f"{indent}π {os.path.basename(root)}/") | |
for f in files: | |
print(f"{indent} π {f}") | |
def copy_zip_extract_and_report(): | |
src_base = "./hf_cache" | |
dst_base = "/data/hf_cache" | |
for category in ["models", "tokenizers"]: | |
src_dir = os.path.join(src_base, category) | |
dst_dir = os.path.join(dst_base, category) | |
if not os.path.exists(src_dir): | |
continue | |
os.makedirs(dst_dir, exist_ok=True) | |
for name in os.listdir(src_dir): | |
full_path = os.path.join(src_dir, name) | |
if not os.path.isdir(full_path): | |
continue | |
zip_name = f"{name}.zip" | |
local_zip = os.path.join(src_dir, zip_name) | |
dst_zip = os.path.join(dst_dir, zip_name) | |
extract_path = os.path.join(dst_dir, name) | |
zip_if_needed(full_path, local_zip) | |
# Copy zip to /data | |
if not os.path.exists(dst_zip): | |
shutil.copy(local_zip, dst_zip) | |
print(f"π€ Copied zip to: {dst_zip}") | |
extract_zip_if_needed(dst_zip, extract_path) | |
print("\nπ¦ Local hf_cache structure:") | |
print_directory_tree("./hf_cache") | |
print("\nπΎ Persistent /data/hf_cache structure:") | |
print_directory_tree("/data/hf_cache") | |
def load_model(req: LoadModelRequest): | |
print(f"\n--- /load_model request received ---") | |
print(f"Model: {req.model}") | |
print(f"Sentence: {req.sentence}") | |
print(f"Task: {req.task}") | |
print(f"hypothesis: {req.hypothesis}") | |
if req.model in VISUALIZER_CACHE: | |
del VISUALIZER_CACHE[req.model] | |
torch.cuda.empty_cache() | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
print("instantiating visualizer") | |
try: | |
vis = vis_class(task=req.task.lower()) | |
print(vis) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Visualizer instantiated") | |
except Exception as e: | |
print("Visualizer init failed:", e) | |
return {"error": f"Instantiation failed: {str(e)}"} | |
print('tokenizing') | |
try: | |
if req.task.lower() == 'mnli': | |
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis) | |
else: | |
token_output = vis.tokenize(req.sentence) | |
print("0 Tokenization successful:", token_output["tokens"]) | |
except Exception as e: | |
print("Tokenization failed:", e) | |
return {"error": f"Tokenization failed: {str(e)}"} | |
print('response ready') | |
response = { | |
"model": req.model, | |
"tokens": token_output['tokens'], | |
"num_layers": vis.num_attention_layers, | |
} | |
print("load model successful") | |
print(response) | |
return response | |
def predict_model(req: PredModelRequest): | |
print(f"\n--- /predict_model request received ---") | |
print(f"predict: Model: {req.model}") | |
print(f"predict: Task: {req.task}") | |
print(f"predict: sentence: {req.sentence}") | |
print(f"predict: hypothesis: {req.hypothesis}") | |
print(f"predict: maskID: {req.maskID}") | |
print('predict: instantiating') | |
try: | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
#if any(p.device.type == 'meta' for p in vis.model.parameters()): | |
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) | |
vis = vis_class(task=req.task.lower()) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Model reloaded and cached.") | |
except Exception as e: | |
return {"error": f"Failed to reload model: {str(e)}"} | |
print('predict: meta stuff') | |
print('predict: Run prediction') | |
try: | |
if req.task.lower() == 'mnli': | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis) | |
elif req.task.lower() == 'mlm': | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID) | |
else: | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence) | |
except Exception as e: | |
decoded, top_probs = "error", e | |
print(e) | |
print('predict: response ready') | |
response = { | |
"decoded": decoded, | |
"top_probs": top_probs.tolist(), | |
} | |
print("predict: predict model successful") | |
if len(decoded) > 5: | |
print([(k,v[:5]) for k,v in response.items()]) | |
else: | |
print(response) | |
return response | |
def get_grad_attn_matrix(req: GradAttnModelRequest): | |
try: | |
print(f"\n--- /get_grad_matrix request received ---") | |
print(f"grad:Model: {req.model}") | |
print(f"grad:Task: {req.task}") | |
print(f"grad:sentence: {req.sentence}") | |
print(f"grad: hypothesis: {req.hypothesis}") | |
print(f"predict: maskID: {req.maskID}") | |
try: | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
#if any(p.device.type == 'meta' for p in vis.model.parameters()): | |
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) | |
vis = vis_class(task=req.task.lower()) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Model reloaded and cached.") | |
except Exception as e: | |
return {"error": f"Failed to reload model: {str(e)}"} | |
print("run function") | |
try: | |
if req.task.lower()=='mnli': | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis) | |
elif req.task.lower()=='mlm': | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID) | |
else: | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence) | |
except Exception as e: | |
print("Exception during grad/attn computation:", e) | |
grad_matrix, attn_matrix = e,e | |
response = { | |
"grad_matrix": grad_matrix, | |
"attn_matrix": attn_matrix, | |
} | |
print('grad attn successful') | |
return response | |
except Exception as e: | |
print("SERVER EXCEPTION:", e) | |
return {"error": str(e)} | |