BERTGradGraph / server.py
yifan0sun's picture
Update server.py
10ec503
raw
history blame
9.09 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
from fastapi.middleware.cors import CORSMiddleware
from ROBERTAmodel import *
from BERTmodel import *
from DISTILLBERTmodel import *
import os
import zipfile
import shutil
VISUALIZER_CLASSES = {
"BERT": BERTVisualizer,
"RoBERTa": RoBERTaVisualizer,
"DistilBERT": DistilBERTVisualizer,
}
VISUALIZER_CACHE = {}
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_MAP = {
"BERT": "bert-base-uncased",
"RoBERTa": "roberta-base",
"DistilBERT": "distilbert-base-uncased",
}
class LoadModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
class GradAttnModelRequest(BaseModel):
model: str
task: str
sentence: str
hypothesis:str
maskID: int | None = None
class PredModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
maskID: int | None = None
@app.get("/ping")
def ping():
return {"message": "pong"}
def zip_if_needed(src_dir, zip_path):
if os.path.exists(zip_path):
return # already zipped
print(f"πŸ“¦ Zipping {src_dir} β†’ {zip_path}")
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(src_dir):
for file in files:
full_path = os.path.join(root, file)
rel_path = os.path.relpath(full_path, src_dir)
zf.write(full_path, arcname=rel_path)
print(f"βœ… Created zip: {zip_path}")
def extract_zip_if_needed(zip_path, dest_dir):
if os.path.exists(dest_dir):
print(f"βœ… Already exists: {dest_dir}")
return
print(f"πŸ”“ Extracting {zip_path} β†’ {dest_dir}")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(dest_dir)
print(f"βœ… Extracted to: {dest_dir}")
def print_directory_tree(path):
printstr = ''
for root, dirs, files in os.walk(path):
indent = ' ' * (root.count(os.sep) - path.count(os.sep))
printstr += indent
printstr += os.path.basename(root)
printstr += '\n'
for f in files:
printstr += indent
printstr += f
return printstr
def copy_zip_extract_and_report():
src_base = "./hf_cache"
dst_base = "/data/hf_cache"
for category in ["models", "tokenizers"]:
src_dir = os.path.join(src_base, category)
dst_dir = os.path.join(dst_base, category)
if not os.path.exists(src_dir):
continue
os.makedirs(dst_dir, exist_ok=True)
for name in os.listdir(src_dir):
full_path = os.path.join(src_dir, name)
if not os.path.isdir(full_path):
continue
zip_name = f"{name}.zip"
local_zip = os.path.join(src_dir, zip_name)
dst_zip = os.path.join(dst_dir, zip_name)
extract_path = os.path.join(dst_dir, name)
zip_if_needed(full_path, local_zip)
# Copy zip to /data
if not os.path.exists(dst_zip):
shutil.copy(local_zip, dst_zip)
print(f"πŸ“€ Copied zip to: {dst_zip}")
extract_zip_if_needed(dst_zip, extract_path)
print("\nπŸ“¦ Local hf_cache structure:")
printstr1 = print_directory_tree("./hf_cache")
print("\nπŸ’Ύ Persistent /data/hf_cache structure:")
printstr2 = print_directory_tree("/data/hf_cache")
return printstr1 + '\n\n' + printstr2
@app.get("/copy_and_extract")
def copy_and_extract():
import io
from contextlib import redirect_stdout
printstr = copy_zip_extract_and_report()
return {"message": "done", "log": printstr}
@app.post("/load_model")
def load_model(req: LoadModelRequest):
print(f"\n--- /load_model request received ---")
print(f"Model: {req.model}")
print(f"Sentence: {req.sentence}")
print(f"Task: {req.task}")
print(f"hypothesis: {req.hypothesis}")
if req.model in VISUALIZER_CACHE:
del VISUALIZER_CACHE[req.model]
torch.cuda.empty_cache()
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
print("instantiating visualizer")
try:
vis = vis_class(task=req.task.lower())
print(vis)
VISUALIZER_CACHE[req.model] = vis
print("Visualizer instantiated")
except Exception as e:
print("Visualizer init failed:", e)
return {"error": f"Instantiation failed: {str(e)}"}
print('tokenizing')
try:
if req.task.lower() == 'mnli':
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
else:
token_output = vis.tokenize(req.sentence)
print("0 Tokenization successful:", token_output["tokens"])
except Exception as e:
print("Tokenization failed:", e)
return {"error": f"Tokenization failed: {str(e)}"}
print('response ready')
response = {
"model": req.model,
"tokens": token_output['tokens'],
"num_layers": vis.num_attention_layers,
}
print("load model successful")
print(response)
return response
@app.post("/predict_model")
def predict_model(req: PredModelRequest):
print(f"\n--- /predict_model request received ---")
print(f"predict: Model: {req.model}")
print(f"predict: Task: {req.task}")
print(f"predict: sentence: {req.sentence}")
print(f"predict: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
print('predict: instantiating')
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print('predict: meta stuff')
print('predict: Run prediction')
try:
if req.task.lower() == 'mnli':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
elif req.task.lower() == 'mlm':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
else:
decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
except Exception as e:
decoded, top_probs = "error", e
print(e)
print('predict: response ready')
response = {
"decoded": decoded,
"top_probs": top_probs.tolist(),
}
print("predict: predict model successful")
if len(decoded) > 5:
print([(k,v[:5]) for k,v in response.items()])
else:
print(response)
return response
@app.post("/get_grad_attn_matrix")
def get_grad_attn_matrix(req: GradAttnModelRequest):
try:
print(f"\n--- /get_grad_matrix request received ---")
print(f"grad:Model: {req.model}")
print(f"grad:Task: {req.task}")
print(f"grad:sentence: {req.sentence}")
print(f"grad: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print("run function")
try:
if req.task.lower()=='mnli':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
elif req.task.lower()=='mlm':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
else:
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
except Exception as e:
print("Exception during grad/attn computation:", e)
grad_matrix, attn_matrix = e,e
response = {
"grad_matrix": grad_matrix,
"attn_matrix": attn_matrix,
}
print('grad attn successful')
return response
except Exception as e:
print("SERVER EXCEPTION:", e)
return {"error": str(e)}