yifan0sun commited on
Commit
23957f3
Β·
2 Parent(s): a19e5f5 10ec503

Merge branch 'main' of https://huggingface.co/spaces/yifan0sun/BERTGradGraph

Browse files
Files changed (1) hide show
  1. server.py +104 -17
server.py CHANGED
@@ -9,6 +9,11 @@ from ROBERTAmodel import *
9
  from BERTmodel import *
10
  from DISTILLBERTmodel import *
11
 
 
 
 
 
 
12
  VISUALIZER_CLASSES = {
13
  "BERT": BERTVisualizer,
14
  "RoBERTa": RoBERTaVisualizer,
@@ -57,6 +62,105 @@ def ping():
57
  return {"message": "pong"}
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  @app.post("/load_model")
61
  def load_model(req: LoadModelRequest):
62
  print(f"\n--- /load_model request received ---")
@@ -216,20 +320,3 @@ def get_grad_attn_matrix(req: GradAttnModelRequest):
216
  print("SERVER EXCEPTION:", e)
217
  return {"error": str(e)}
218
 
219
-
220
- if __name__ == "__main__":
221
-
222
- print('rim ')
223
- BERTVisualizer('mlm')
224
- BERTVisualizer('mnli')
225
- BERTVisualizer('sst')
226
-
227
-
228
- RoBERTaVisualizer('mlm')
229
- RoBERTaVisualizer('mnli')
230
- RoBERTaVisualizer('sst')
231
-
232
-
233
- DistilBERTVisualizer('mlm')
234
- DistilBERTVisualizer('mnli')
235
- DistilBERTVisualizer('sst')
 
9
  from BERTmodel import *
10
  from DISTILLBERTmodel import *
11
 
12
+ import os
13
+ import zipfile
14
+ import shutil
15
+
16
+
17
  VISUALIZER_CLASSES = {
18
  "BERT": BERTVisualizer,
19
  "RoBERTa": RoBERTaVisualizer,
 
62
  return {"message": "pong"}
63
 
64
 
65
+
66
+
67
+ def zip_if_needed(src_dir, zip_path):
68
+ if os.path.exists(zip_path):
69
+ return # already zipped
70
+ print(f"πŸ“¦ Zipping {src_dir} β†’ {zip_path}")
71
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
72
+ for root, _, files in os.walk(src_dir):
73
+ for file in files:
74
+ full_path = os.path.join(root, file)
75
+ rel_path = os.path.relpath(full_path, src_dir)
76
+ zf.write(full_path, arcname=rel_path)
77
+ print(f"βœ… Created zip: {zip_path}")
78
+
79
+ def extract_zip_if_needed(zip_path, dest_dir):
80
+ if os.path.exists(dest_dir):
81
+ print(f"βœ… Already exists: {dest_dir}")
82
+ return
83
+ print(f"πŸ”“ Extracting {zip_path} β†’ {dest_dir}")
84
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
85
+ zip_ref.extractall(dest_dir)
86
+ print(f"βœ… Extracted to: {dest_dir}")
87
+
88
+ def print_directory_tree(path):
89
+ printstr = ''
90
+ for root, dirs, files in os.walk(path):
91
+ indent = ' ' * (root.count(os.sep) - path.count(os.sep))
92
+ printstr += indent
93
+ printstr += os.path.basename(root)
94
+ printstr += '\n'
95
+ for f in files:
96
+ printstr += indent
97
+ printstr += f
98
+ return printstr
99
+
100
+
101
+ def copy_zip_extract_and_report():
102
+ src_base = "./hf_cache"
103
+ dst_base = "/data/hf_cache"
104
+
105
+ for category in ["models", "tokenizers"]:
106
+ src_dir = os.path.join(src_base, category)
107
+ dst_dir = os.path.join(dst_base, category)
108
+
109
+ if not os.path.exists(src_dir):
110
+ continue
111
+
112
+ os.makedirs(dst_dir, exist_ok=True)
113
+
114
+ for name in os.listdir(src_dir):
115
+ full_path = os.path.join(src_dir, name)
116
+ if not os.path.isdir(full_path):
117
+ continue
118
+
119
+ zip_name = f"{name}.zip"
120
+ local_zip = os.path.join(src_dir, zip_name)
121
+ dst_zip = os.path.join(dst_dir, zip_name)
122
+ extract_path = os.path.join(dst_dir, name)
123
+
124
+ zip_if_needed(full_path, local_zip)
125
+
126
+ # Copy zip to /data
127
+ if not os.path.exists(dst_zip):
128
+ shutil.copy(local_zip, dst_zip)
129
+ print(f"πŸ“€ Copied zip to: {dst_zip}")
130
+
131
+ extract_zip_if_needed(dst_zip, extract_path)
132
+
133
+ print("\nπŸ“¦ Local hf_cache structure:")
134
+ printstr1 = print_directory_tree("./hf_cache")
135
+
136
+ print("\nπŸ’Ύ Persistent /data/hf_cache structure:")
137
+ printstr2 = print_directory_tree("/data/hf_cache")
138
+ return printstr1 + '\n\n' + printstr2
139
+
140
+
141
+
142
+ @app.get("/copy_and_extract")
143
+ def copy_and_extract():
144
+ import io
145
+ from contextlib import redirect_stdout
146
+
147
+ printstr = copy_zip_extract_and_report()
148
+
149
+ return {"message": "done", "log": printstr}
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
  @app.post("/load_model")
165
  def load_model(req: LoadModelRequest):
166
  print(f"\n--- /load_model request received ---")
 
320
  print("SERVER EXCEPTION:", e)
321
  return {"error": str(e)}
322