Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from pydantic import BaseModel | |
import torch | |
from fastapi.middleware.cors import CORSMiddleware | |
from ROBERTAmodel import * | |
from BERTmodel import * | |
from DISTILLBERTmodel import * | |
import os | |
import zipfile | |
import shutil | |
VISUALIZER_CLASSES = { | |
"BERT": BERTVisualizer, | |
"RoBERTa": RoBERTaVisualizer, | |
"DistilBERT": DistilBERTVisualizer, | |
} | |
VISUALIZER_CACHE = {} | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
MODEL_MAP = { | |
"BERT": "bert-base-uncased", | |
"RoBERTa": "roberta-base", | |
"DistilBERT": "distilbert-base-uncased", | |
} | |
class LoadModelRequest(BaseModel): | |
model: str | |
sentence: str | |
task:str | |
hypothesis:str | |
class GradAttnModelRequest(BaseModel): | |
model: str | |
task: str | |
sentence: str | |
hypothesis:str | |
maskID: int | None = None | |
class PredModelRequest(BaseModel): | |
model: str | |
sentence: str | |
task:str | |
hypothesis:str | |
maskID: int | None = None | |
def ping(): | |
return {"message": "pong"} | |
def extract_zip_if_needed(zip_path, dest_dir): | |
if os.path.exists(dest_dir): | |
print(f"β Already exists: {dest_dir}") | |
return | |
print(f"π Extracting {zip_path} β {dest_dir}") | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(dest_dir) | |
print(f"β Extracted to: {dest_dir}") | |
def print_directory_tree(path): | |
printstr = '' | |
for root, dirs, files in os.walk(path): | |
indent = ' ' * (root.count(os.sep) - path.count(os.sep)) | |
printstr += indent | |
printstr += os.path.basename(root) | |
printstr += '\n' | |
for f in files: | |
printstr += indent | |
printstr += f | |
return printstr | |
def copy_extract_and_report(): | |
src_base = "./hf_cache" | |
dst_base = "/data/hf_cache" | |
for category in ["models", "tokenizers"]: | |
src_dir = os.path.join(src_base, category) | |
dst_dir = os.path.join(dst_base, category) | |
if not os.path.exists(src_dir): | |
continue | |
os.makedirs(dst_dir, exist_ok=True) | |
for name in os.listdir(src_dir): | |
if not name.endswith(".zip"): | |
continue | |
src_zip = os.path.join(src_dir, name) | |
dst_zip = os.path.join(dst_dir, name) | |
# Copy zip to /data if not already present | |
if not os.path.exists(dst_zip): | |
shutil.copy(src_zip, dst_zip) | |
print(f"π€ Copied zip to: {dst_zip}") | |
# Determine the extract folder name (strip ".zip" from the filename) | |
extract_folder = os.path.splitext(name)[0] | |
extract_path = os.path.join(dst_dir, extract_folder) | |
# Extract if not already extracted | |
if not os.path.exists(extract_path): | |
os.makedirs(extract_path, exist_ok=True) | |
with zipfile.ZipFile(dst_zip, 'r') as zip_ref: | |
zip_ref.extractall(extract_path) | |
print(f"π¦ Extracted zip to: {extract_path}") | |
print("\nπ¦ Local hf_cache structure:") | |
printstr1 = print_directory_tree("./hf_cache") | |
print("\nπΎ Persistent /data/hf_cache structure:") | |
printstr2 = print_directory_tree("/data/hf_cache") | |
return printstr1 + '\n\n' + printstr2 | |
def copy_and_extract(): | |
printstr = copy_extract_and_report() | |
return {"message": "done", "log": printstr} | |
def data_check(): | |
with open("/data/marker.txt", "w") as f: | |
f.write("hello from server.py\n") | |
files = os.listdir("/data") | |
return { | |
"message": "done", | |
"contents": files | |
} | |
def load_model(req: LoadModelRequest): | |
print(f"\n--- /load_model request received ---") | |
print(f"Model: {req.model}") | |
print(f"Sentence: {req.sentence}") | |
print(f"Task: {req.task}") | |
print(f"hypothesis: {req.hypothesis}") | |
if req.model in VISUALIZER_CACHE: | |
del VISUALIZER_CACHE[req.model] | |
torch.cuda.empty_cache() | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
print("instantiating visualizer") | |
try: | |
vis = vis_class(task=req.task.lower()) | |
print(vis) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Visualizer instantiated") | |
except Exception as e: | |
print("Visualizer init failed:", e) | |
return {"error": f"Instantiation failed: {str(e)}"} | |
print('tokenizing') | |
try: | |
if req.task.lower() == 'mnli': | |
token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis) | |
else: | |
token_output = vis.tokenize(req.sentence) | |
print("0 Tokenization successful:", token_output["tokens"]) | |
except Exception as e: | |
print("Tokenization failed:", e) | |
return {"error": f"Tokenization failed: {str(e)}"} | |
print('response ready') | |
response = { | |
"model": req.model, | |
"tokens": token_output['tokens'], | |
"num_layers": vis.num_attention_layers, | |
} | |
print("load model successful") | |
print(response) | |
return response | |
def predict_model(req: PredModelRequest): | |
print(f"\n--- /predict_model request received ---") | |
print(f"predict: Model: {req.model}") | |
print(f"predict: Task: {req.task}") | |
print(f"predict: sentence: {req.sentence}") | |
print(f"predict: hypothesis: {req.hypothesis}") | |
print(f"predict: maskID: {req.maskID}") | |
print('predict: instantiating') | |
try: | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
#if any(p.device.type == 'meta' for p in vis.model.parameters()): | |
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) | |
vis = vis_class(task=req.task.lower()) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Model reloaded and cached.") | |
except Exception as e: | |
return {"error": f"Failed to reload model: {str(e)}"} | |
print('predict: meta stuff') | |
print('predict: Run prediction') | |
try: | |
if req.task.lower() == 'mnli': | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis) | |
elif req.task.lower() == 'mlm': | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID) | |
else: | |
decoded, top_probs = vis.predict(req.task.lower(), req.sentence) | |
except Exception as e: | |
decoded, top_probs = "error", e | |
print(e) | |
print('predict: response ready') | |
response = { | |
"decoded": decoded, | |
"top_probs": top_probs.tolist(), | |
} | |
print("predict: predict model successful") | |
if len(decoded) > 5: | |
print([(k,v[:5]) for k,v in response.items()]) | |
else: | |
print(response) | |
return response | |
def get_grad_attn_matrix(req: GradAttnModelRequest): | |
try: | |
print(f"\n--- /get_grad_matrix request received ---") | |
print(f"grad:Model: {req.model}") | |
print(f"grad:Task: {req.task}") | |
print(f"grad:sentence: {req.sentence}") | |
print(f"grad: hypothesis: {req.hypothesis}") | |
print(f"predict: maskID: {req.maskID}") | |
try: | |
vis_class = VISUALIZER_CLASSES.get(req.model) | |
if vis_class is None: | |
return {"error": f"Unknown model: {req.model}"} | |
#if any(p.device.type == 'meta' for p in vis.model.parameters()): | |
# vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu")) | |
vis = vis_class(task=req.task.lower()) | |
VISUALIZER_CACHE[req.model] = vis | |
print("Model reloaded and cached.") | |
except Exception as e: | |
return {"error": f"Failed to reload model: {str(e)}"} | |
print("run function") | |
try: | |
if req.task.lower()=='mnli': | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis) | |
elif req.task.lower()=='mlm': | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID) | |
else: | |
grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence) | |
except Exception as e: | |
print("Exception during grad/attn computation:", e) | |
grad_matrix, attn_matrix = e,e | |
response = { | |
"grad_matrix": grad_matrix, | |
"attn_matrix": attn_matrix, | |
} | |
print('grad attn successful') | |
return response | |
except Exception as e: | |
print("SERVER EXCEPTION:", e) | |
return {"error": str(e)} | |