BERTGradGraph / server.py
yifan0sun's picture
files
993b547
raw
history blame
6.32 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
import torch
from fastapi.middleware.cors import CORSMiddleware
from ROBERTAmodel import *
from BERTmodel import *
from DISTILLBERTmodel import *
VISUALIZER_CLASSES = {
"BERT": BERTVisualizer,
"RoBERTa": RoBERTaVisualizer,
"DistilBERT": DistilBERTVisualizer,
}
VISUALIZER_CACHE = {}
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # or restrict to ["http://localhost:3000"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_MAP = {
"BERT": "bert-base-uncased",
"RoBERTa": "roberta-base",
"DistilBERT": "distilbert-base-uncased",
}
class LoadModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
class GradAttnModelRequest(BaseModel):
model: str
task: str
sentence: str
hypothesis:str
maskID: int | None = None
class PredModelRequest(BaseModel):
model: str
sentence: str
task:str
hypothesis:str
maskID: int | None = None
@app.post("/load_model")
def load_model(req: LoadModelRequest):
print(f"\n--- /load_model request received ---")
print(f"Model: {req.model}")
print(f"Sentence: {req.sentence}")
print(f"Task: {req.task}")
print(f"hypothesis: {req.hypothesis}")
if req.model in VISUALIZER_CACHE:
del VISUALIZER_CACHE[req.model]
torch.cuda.empty_cache()
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
print("instantiating visualizer")
try:
vis = vis_class(task=req.task.lower())
print(vis)
VISUALIZER_CACHE[req.model] = vis
print("Visualizer instantiated")
except Exception as e:
print("Visualizer init failed:", e)
return {"error": f"Instantiation failed: {str(e)}"}
print('tokenizing')
try:
if req.task.lower() == 'mnli':
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
else:
token_output = vis.tokenize(req.sentence)
print("0 Tokenization successful:", token_output["tokens"])
except Exception as e:
print("Tokenization failed:", e)
return {"error": f"Tokenization failed: {str(e)}"}
print('response ready')
response = {
"model": req.model,
"tokens": token_output['tokens'],
"num_layers": vis.num_attention_layers,
}
print("load model successful")
print(response)
return response
@app.post("/predict_model")
def predict_model(req: PredModelRequest):
print(f"\n--- /predict_model request received ---")
print(f"predict: Model: {req.model}")
print(f"predict: Task: {req.task}")
print(f"predict: sentence: {req.sentence}")
print(f"predict: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
print('predict: instantiating')
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print('predict: meta stuff')
print('predict: Run prediction')
try:
if req.task.lower() == 'mnli':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
elif req.task.lower() == 'mlm':
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
else:
decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
except Exception as e:
decoded, top_probs = "error", e
print(e)
print('predict: response ready')
response = {
"decoded": decoded,
"top_probs": top_probs.tolist(),
}
print("predict: predict model successful")
if len(decoded) > 5:
print([(k,v[:5]) for k,v in response.items()])
else:
print(response)
return response
@app.post("/get_grad_attn_matrix")
def get_grad_attn_matrix(req: GradAttnModelRequest):
try:
print(f"\n--- /get_grad_matrix request received ---")
print(f"grad:Model: {req.model}")
print(f"grad:Task: {req.task}")
print(f"grad:sentence: {req.sentence}")
print(f"grad: hypothesis: {req.hypothesis}")
print(f"predict: maskID: {req.maskID}")
try:
vis_class = VISUALIZER_CLASSES.get(req.model)
if vis_class is None:
return {"error": f"Unknown model: {req.model}"}
#if any(p.device.type == 'meta' for p in vis.model.parameters()):
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
vis = vis_class(task=req.task.lower())
VISUALIZER_CACHE[req.model] = vis
print("Model reloaded and cached.")
except Exception as e:
return {"error": f"Failed to reload model: {str(e)}"}
print("run function")
try:
if req.task.lower()=='mnli':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
elif req.task.lower()=='mlm':
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
else:
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
except Exception as e:
print("Exception during grad/attn computation:", e)
grad_matrix, attn_matrix = e,e
response = {
"grad_matrix": grad_matrix,
"attn_matrix": attn_matrix,
}
print('grad attn successful')
return response
except Exception as e:
print("SERVER EXCEPTION:", e)
return {"error": str(e)}