BERTGradGraph / server.py
yifan0sun's picture
Update server.py
0397def
raw
history blame
9.64 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
from fastapi.middleware.cors import CORSMiddleware
from ROBERTAmodel import *
from BERTmodel import *
from DISTILLBERTmodel import *
import os
import zipfile
import shutil
VISUALIZER_CLASSES = {
"BERT": BERTVisualizer,
"RoBERTa": RoBERTaVisualizer,
"DistilBERT": DistilBERTVisualizer,
}
VISUALIZER_CACHE = {}
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_MAP = {
"BERT": "bert-base-uncased",
"RoBERTa": "roberta-base",
"DistilBERT": "distilbert-base-uncased",
}
class LoadModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
class GradAttnModelRequest(BaseModel):
model: str
task: str
sentence: str
hypothesis:str
maskID: int | None = None
class PredModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
maskID: int | None = None
@app.get("/ping")
def ping():
return {"message": "pong"}
def extract_zip_if_needed(zip_path, dest_dir):
if os.path.exists(dest_dir):
print(f"βœ… Already exists: {dest_dir}")
return
print(f"πŸ”“ Extracting {zip_path} β†’ {dest_dir}")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(dest_dir)
print(f"βœ… Extracted to: {dest_dir}")
def print_directory_tree(path):
printstr = ''
for root, dirs, files in os.walk(path):
indent = ' ' * (root.count(os.sep) - path.count(os.sep))
printstr += indent
printstr += os.path.basename(root)
printstr += '\n'
for f in files:
printstr += indent
printstr += f
return printstr
def copy_extract_and_report():
src_base = "./hf_cache"
dst_base = "/data/hf_cache"
for category in ["models", "tokenizers"]:
src_dir = os.path.join(src_base, category)
dst_dir = os.path.join(dst_base, category)
if not os.path.exists(src_dir):
continue
os.makedirs(dst_dir, exist_ok=True)
for name in os.listdir(src_dir):
if not name.endswith(".zip"):
continue
src_zip = os.path.join(src_dir, name)
dst_zip = os.path.join(dst_dir, name)
# Copy zip to /data if not already present
if not os.path.exists(dst_zip):
shutil.copy(src_zip, dst_zip)
print(f"πŸ“€ Copied zip to: {dst_zip}")
# Determine the extract folder name (strip ".zip" from the filename)
extract_folder = os.path.splitext(name)[0]
extract_path = os.path.join(dst_dir, extract_folder)
# Extract if not already extracted
if not os.path.exists(extract_path):
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(dst_zip, 'r') as zip_ref:
zip_ref.extractall(extract_path)
print(f"πŸ“¦ Extracted zip to: {extract_path}")
print("\nπŸ“¦ Local hf_cache structure:")
printstr1 = print_directory_tree("./hf_cache")
print("\nπŸ’Ύ Persistent /data/hf_cache structure:")
printstr2 = print_directory_tree("/data/hf_cache")
return printstr1 + '\n\n' + printstr2
@app.get("/copy_and_extract")
def copy_and_extract():
printstr = copy_extract_and_report()
return {"message": "done", "log": printstr}
@app.get("/data_check")
def data_check():
with open("/data/marker.txt", "w") as f:
f.write("hello from server.py\n")
files = os.listdir("/data")
return {
"message": "done",
"contents": files
}
@app.post("/load_model")
def load_model(req: LoadModelRequest):
print(f"\n--- /load_model request received ---")
print(f"Model: {req.model}")
print(f"Sentence: {req.sentence}")
print(f"Task: {req.task}")
print(f"hypothesis: {req.hypothesis}")
if req.model in VISUALIZER_CACHE:
del VISUALIZER_CACHE[req.model]
torch.cuda.empty_cache()
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
print("instantiating visualizer")
try:
vis = vis_class(task=req.task.lower())
print(vis)
VISUALIZER_CACHE[req.model] = vis
print("Visualizer instantiated")
except Exception as e:
print("Visualizer init failed:", e)
return {"error": f"Instantiation failed: {str(e)}"}
print('tokenizing')
try:
if req.task.lower() == 'mnli':
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
else:
token_output = vis.tokenize(req.sentence)
print("0 Tokenization successful:", token_output["tokens"])
except Exception as e:
print("Tokenization failed:", e)
return {"error": f"Tokenization failed: {str(e)}"}
print('response ready')
response = {
"model": req.model,
"tokens": token_output['tokens'],
"num_layers": vis.num_attention_layers,
}
print("load model successful")
print(response)
return response
@app.post("/predict_model")
def predict_model(req: PredModelRequest):
print(f"\n--- /predict_model request received ---")
print(f"predict: Model: {req.model}")
print(f"predict: Task: {req.task}")
print(f"predict: sentence: {req.sentence}")
print(f"predict: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
print('predict: instantiating')
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print('predict: meta stuff')
print('predict: Run prediction')
try:
if req.task.lower() == 'mnli':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
elif req.task.lower() == 'mlm':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
else:
decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
except Exception as e:
decoded, top_probs = "error", e
print(e)
print('predict: response ready')
response = {
"decoded": decoded,
"top_probs": top_probs.tolist(),
}
print("predict: predict model successful")
if len(decoded) > 5:
print([(k,v[:5]) for k,v in response.items()])
else:
print(response)
return response
@app.post("/get_grad_attn_matrix")
def get_grad_attn_matrix(req: GradAttnModelRequest):
try:
print(f"\n--- /get_grad_matrix request received ---")
print(f"grad:Model: {req.model}")
print(f"grad:Task: {req.task}")
print(f"grad:sentence: {req.sentence}")
print(f"grad: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print("run function")
try:
if req.task.lower()=='mnli':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
elif req.task.lower()=='mlm':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
else:
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
except Exception as e:
print("Exception during grad/attn computation:", e)
grad_matrix, attn_matrix = e,e
response = {
"grad_matrix": grad_matrix,
"attn_matrix": attn_matrix,
}
print('grad attn successful')
return response
except Exception as e:
print("SERVER EXCEPTION:", e)
return {"error": str(e)}
@app.post("/load_all_files")
def load_all_files():
print('load BERTmlm ')
BERTVisualizer('mlm')
print('load BERTmnli ')
BERTVisualizer('mnli')
print('load BERTsst ')
BERTVisualizer('sst')
print('load roBERTmlm ')
RoBERTaVisualizer('mlm')
print('load roBERTmnli')
RoBERTaVisualizer('mnli')
print('load roBERTsst')
RoBERTaVisualizer('sst')
print('load distillBERTmlm ')
DistilBERTVisualizer('mlm')
print('load distillBERTmmli ')
DistilBERTVisualizer('mnli')
print('load distillBERTsst ')
DistilBERTVisualizer('sst')