jukofyork commited on
Commit
8bdaa9d
Β·
verified Β·
1 Parent(s): bf7c12f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -0
app.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import gradio as gr
3
+ import torch
4
+ from huggingface_hub import hf_hub_download, HfApi, login, list_repo_files
5
+ from safetensors import safe_open
6
+ from safetensors.torch import save_file, load_file
7
+ import os
8
+ import shutil
9
+ import json
10
+
11
+ api = HfApi()
12
+
13
+ def info_fn(text):
14
+ gr.Info(text)
15
+
16
+ def warning_fn(text):
17
+ gr.Warning(text)
18
+
19
+ def load_lora_state(lora_model_name):
20
+ """Download and load LoRA adapter weights"""
21
+ temp_lora_dir = "/tmp/lora_adapter"
22
+ os.makedirs(temp_lora_dir, exist_ok=True)
23
+
24
+ # Download adapter config
25
+ config_path = hf_hub_download(
26
+ repo_id=lora_model_name,
27
+ filename="adapter_config.json",
28
+ local_dir=temp_lora_dir,
29
+ local_dir_use_symlinks=False
30
+ )
31
+
32
+ with open(config_path, 'r') as f:
33
+ lora_config = json.load(f)
34
+
35
+ scale = lora_config['lora_alpha'] / lora_config['r']
36
+
37
+ # Download adapter weights
38
+ try:
39
+ adapter_path = hf_hub_download(
40
+ repo_id=lora_model_name,
41
+ filename="adapter_model.safetensors",
42
+ local_dir=temp_lora_dir,
43
+ local_dir_use_symlinks=False
44
+ )
45
+ lora_state = load_file(adapter_path, device='cpu')
46
+ except:
47
+ adapter_path = hf_hub_download(
48
+ repo_id=lora_model_name,
49
+ filename="adapter_model.bin",
50
+ local_dir=temp_lora_dir,
51
+ local_dir_use_symlinks=False
52
+ )
53
+ lora_state = torch.load(adapter_path, map_location='cpu')
54
+
55
+ return lora_state, scale, temp_lora_dir
56
+
57
+ def find_lora_weights(lora_state, key):
58
+ """Find corresponding LoRA A and B weights for a given key"""
59
+ lora_A = None
60
+ lora_B = None
61
+
62
+ # Remove .weight suffix and handle potential prefixes
63
+ clean_key = key.replace('.weight', '')
64
+
65
+ for lora_key, lora_weight in lora_state.items():
66
+ if clean_key in lora_key or clean_key.replace('language_model.', '') in lora_key:
67
+ if 'lora_A' in lora_key:
68
+ lora_A = lora_weight
69
+ elif 'lora_B' in lora_key:
70
+ lora_B = lora_weight
71
+
72
+ # Both should be None or both should have values
73
+ if (lora_A is None) != (lora_B is None):
74
+ return None, None
75
+
76
+ return lora_A, lora_B
77
+
78
+ def download_and_upload_non_model_files(base_model_name, output_repo_name):
79
+ """Download and upload non-model files (config, tokenizer, etc.)"""
80
+ temp_config_dir = "/tmp/config_files"
81
+ os.makedirs(temp_config_dir, exist_ok=True)
82
+
83
+ try:
84
+ # List all files in the repository
85
+ files = list_repo_files(repo_id=base_model_name)
86
+
87
+ # Filter non-model files
88
+ non_model_files = [
89
+ f for f in files
90
+ if not (f.startswith('model') and f.endswith('.safetensors'))
91
+ ]
92
+
93
+ # Download and upload each non-model file
94
+ for filename in non_model_files:
95
+ if filename.endswith(('.gguf', '.bin')) and 'model' in filename:
96
+ continue # Skip other model formats
97
+
98
+ try:
99
+ file_path = hf_hub_download(
100
+ repo_id=base_model_name,
101
+ filename=filename,
102
+ local_dir=temp_config_dir,
103
+ local_dir_use_symlinks=False
104
+ )
105
+
106
+ # Upload to output repo
107
+ api.upload_file(
108
+ path_or_fileobj=file_path,
109
+ path_in_repo=filename,
110
+ repo_id=output_repo_name,
111
+ repo_type="model"
112
+ )
113
+
114
+ except Exception as e:
115
+ info_fn(f"Skipping {filename}: {e}")
116
+
117
+ finally:
118
+ shutil.rmtree(temp_config_dir, ignore_errors=True)
119
+
120
+ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
121
+ lora_scale, lm_head_scale, multiplicative_lora, progress=gr.Progress()):
122
+ temp_lora_dir = None
123
+ try:
124
+ login(hf_token)
125
+
126
+ progress(0.1, desc="Loading LoRA adapter...")
127
+ info_fn("Loading LoRA adapter...")
128
+
129
+ # Load LoRA state (this downloads the adapter)
130
+ lora_state, base_scale, temp_lora_dir = load_lora_state(lora_model_name)
131
+
132
+ # Apply LoRA scale multiplier
133
+ scale = base_scale * lora_scale
134
+ info_fn(f"Using LoRA scale: {scale} (base: {base_scale}, multiplier: {lora_scale})")
135
+
136
+ progress(0.2, desc="Creating output repository...")
137
+
138
+ # Create repository
139
+ try:
140
+ repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True)
141
+ info_fn(f"Repository created/updated: {repo_url}")
142
+ except Exception as e:
143
+ warning_fn(f"Repository might already exist: {e}")
144
+
145
+ progress(0.3, desc="Uploading configuration files...")
146
+ info_fn("Uploading configuration files...")
147
+
148
+ # Download and upload non-model files
149
+ download_and_upload_non_model_files(base_model_name, output_repo_name)
150
+
151
+ progress(0.4, desc="Finding model shards...")
152
+ info_fn("Finding model shards...")
153
+
154
+ # Get list of all safetensors files
155
+ all_files = list_repo_files(repo_id=base_model_name)
156
+ shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')]
157
+
158
+ if not shard_files:
159
+ raise FileNotFoundError("No model safetensors files found in the repository")
160
+
161
+ info_fn(f"Found {len(shard_files)} model shards to process")
162
+
163
+ merged_tensors = 0
164
+ scaled_lm_heads = 0
165
+ total_shards = len(shard_files)
166
+
167
+ # Process each shard individually
168
+ for i, shard_filename in enumerate(shard_files):
169
+ progress(0.4 + (i / total_shards) * 0.5,
170
+ desc=f"Processing {shard_filename} ({i+1}/{total_shards})")
171
+ info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}")
172
+
173
+ # Create temporary directory for this shard only
174
+ temp_shard_dir = f"/tmp/shard_{i}"
175
+ os.makedirs(temp_shard_dir, exist_ok=True)
176
+
177
+ try:
178
+ # Download the current shard
179
+ shard_path = hf_hub_download(
180
+ repo_id=base_model_name,
181
+ filename=shard_filename,
182
+ local_dir=temp_shard_dir,
183
+ local_dir_use_symlinks=False
184
+ )
185
+
186
+ # Process the shard
187
+ tensors = {}
188
+ shard_merged_count = 0
189
+ shard_lm_head_count = 0
190
+
191
+ with safe_open(shard_path, framework='pt', device='cpu') as f:
192
+ # Get metadata if available
193
+ metadata = f.metadata() if hasattr(f, 'metadata') else {}
194
+
195
+ for key in f.keys():
196
+ tensor = f.get_tensor(key)
197
+
198
+ # Apply lm_head scaling if applicable
199
+ if key.endswith('lm_head.weight') and lm_head_scale != 1.0:
200
+ info_fn(f"Scaling {key} by {lm_head_scale}")
201
+ original_dtype = tensor.dtype
202
+ tensor = tensor.to(torch.float32)
203
+ tensor = tensor * lm_head_scale
204
+ tensor = tensor.to(original_dtype)
205
+ shard_lm_head_count += 1
206
+ scaled_lm_heads += 1
207
+
208
+ # Try to find corresponding LoRA weights
209
+ lora_A, lora_B = find_lora_weights(lora_state, key)
210
+
211
+ if lora_A is not None and lora_B is not None:
212
+ lora_type = "Multiplicative" if multiplicative_lora else "Additive"
213
+ info_fn(f"Merging {lora_type} LoRA weights for {key}")
214
+ shard_merged_count += 1
215
+ merged_tensors += 1
216
+
217
+ # Convert to float32 for computation
218
+ original_dtype = tensor.dtype
219
+ tensor_f32 = tensor.to(torch.float32)
220
+ lora_A_f32 = lora_A.to(torch.float32)
221
+ lora_B_f32 = lora_B.to(torch.float32)
222
+
223
+ if multiplicative_lora:
224
+ # Apply Multiplicative-LoRA: W = W + scale * B @ A @ W
225
+ tensor_f32 += scale * lora_B_f32 @ lora_A_f32 @ tensor_f32
226
+ else:
227
+ # Apply standard LoRA: W = W + scale * B @ A
228
+ tensor_f32 += scale * lora_B_f32 @ lora_A_f32
229
+
230
+ # Convert back to original dtype
231
+ tensor = tensor_f32.to(original_dtype)
232
+
233
+ # Clean up intermediate tensors
234
+ del tensor_f32, lora_A_f32, lora_B_f32
235
+ if torch.cuda.is_available():
236
+ torch.cuda.empty_cache()
237
+
238
+ tensors[key] = tensor
239
+
240
+ # Save processed shard to temporary file
241
+ output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}")
242
+ save_file(tensors, output_shard_path, metadata=metadata)
243
+
244
+ info_fn(f"Shard {shard_filename}:\n- Merged {shard_merged_count} tensors\n- Scaled {shard_lm_head_count} lm_head tensors")
245
+
246
+ # Upload the processed shard
247
+ api.upload_file(
248
+ path_or_fileobj=output_shard_path,
249
+ path_in_repo=shard_filename,
250
+ repo_id=output_repo_name,
251
+ repo_type="model"
252
+ )
253
+
254
+ # Clean up this shard's data
255
+ del tensors
256
+ gc.collect()
257
+
258
+ finally:
259
+ # Always clean up the temporary shard directory
260
+ shutil.rmtree(temp_shard_dir, ignore_errors=True)
261
+
262
+ progress(1.0, desc="Upload completed!")
263
+
264
+ success_msg = f"βœ“ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights\nScaled {scaled_lm_heads} lm_head layers"
265
+ info_fn("Merge completed successfully!")
266
+
267
+ return success_msg
268
+
269
+ except Exception as e:
270
+ error_msg = f"βœ— Error during merge: {str(e)}"
271
+ warning_fn(error_msg)
272
+ return error_msg
273
+
274
+ finally:
275
+ # Cleanup LoRA directory
276
+ if temp_lora_dir and os.path.exists(temp_lora_dir):
277
+ shutil.rmtree(temp_lora_dir, ignore_errors=True)
278
+ gc.collect()
279
+
280
+ INTRODUCTION_TEXT = """
281
+ ## Memory-Efficient LoRA Merge
282
+
283
+ This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a memory-efficient approach that processes model files individually, significantly reducing memory requirements compared to traditional methods.
284
+
285
+ ### Key Features
286
+ - **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model
287
+ - **Streaming Processing**: Downloads β†’ Processes β†’ Uploads β†’ Deletes each shard sequentially
288
+ - **Automatic Cleanup**: Temporary files are automatically removed after processing
289
+ - **Progress Tracking**: Real-time status updates throughout the merge process
290
+ - **Advanced Options**: Configurable LoRA scaling, LM head scaling, and multiplicative LoRA support
291
+
292
+ ### How It Works
293
+ LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool applies the LoRA transformation with configurable scaling:
294
+
295
+ - **Standard Additive-LoRA**: `W_new = W + scale Γ— B^T @ A`
296
+ - **Multiplicative LoRA**: `W_new = W + scale Γ— B^T @ A @ W`
297
+
298
+ Additionally, the model's default temperature behavior can be adjusted by scaling the `lm_head.weight` tensor:
299
+
300
+ - **Up-scaling**: Makes the model's outputs more peaked, requiring lower temperature settings for the same output distribution
301
+ - **Down-scaling**: Makes the model's outputs flatter, requiring higher temperature settings for the same output distribution
302
+ - **Examples**:
303
+ - Scaling `lm_head.weight` by `1.25` makes the new model with `temperature = 1.0` act like the old model with `temperature = 0.8`
304
+ - Scaling `lm_head.weight` by `0.667` makes the new model with `temperature = 1.0` act like the old model with `temperature = 1.5`
305
+
306
+ ### Memory Efficiency
307
+ - **Traditional approach**: Loads entire model (~15GB+ for 7B parameter models)
308
+ - **This approach**: Peak usage determined by largest shard size, not total model size
309
+ - **Result**: Enables merging of much larger models on limited hardware
310
+
311
+ ### Example Usage
312
+ - **Base Model:** `microsoft/DialoGPT-medium`
313
+ - **LoRA Adapter:** `username/my-trained-lora`
314
+ - **Output Name:** `username/dialogpt-merged`
315
+
316
+ ### Attribution
317
+ This tool builds upon excellent work from the community:
318
+
319
+ - **Base implementation:** [Weyaxi/merge-lora](https://huggingface.co/spaces/Weyaxi/merge-lora)
320
+ - **Memory-efficient method:** [qlora-pipe](https://github.com/tdrussell/qlora-pipe/blob/main/tools/merge_lora.py) by tdrussell
321
+ """
322
+
323
+ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as demo:
324
+ gr.Markdown(INTRODUCTION_TEXT)
325
+
326
+ with gr.Row():
327
+ with gr.Column(scale=1):
328
+ gr.Markdown("### Configuration")
329
+ hf_token = gr.Textbox(
330
+ label="Hugging Face Token",
331
+ placeholder="hf_...",
332
+ type="password",
333
+ info="Token with write access to create repositories"
334
+ )
335
+ base_model_name = gr.Textbox(
336
+ label="Base Model Repository",
337
+ placeholder="microsoft/DialoGPT-medium",
338
+ info="The original model to merge LoRA into"
339
+ )
340
+ lora_model_name = gr.Textbox(
341
+ label="LoRA Adapter Repository",
342
+ placeholder="username/my-lora-adapter",
343
+ info="Repository containing adapter_model.safetensors"
344
+ )
345
+ output_repo_name = gr.Textbox(
346
+ label="Output Repository Name",
347
+ placeholder="username/my-merged-model",
348
+ info="Name for the new merged model repository"
349
+ )
350
+
351
+ gr.Markdown("### Advanced Options")
352
+ lora_scale = gr.Number(
353
+ label="LoRA Scale",
354
+ value=1.0,
355
+ minimum=0.0,
356
+ maximum=10.0,
357
+ step=0.1,
358
+ info="Multiplier for LoRA strength (1.0 = default)"
359
+ )
360
+ lm_head_scale = gr.Number(
361
+ label="LM Head Scale",
362
+ value=1.0,
363
+ minimum=0.1,
364
+ maximum=5.0,
365
+ step=0.05,
366
+ info="Multiplier for lm_head weights (1.0 = default)"
367
+ )
368
+ multiplicative_lora = gr.Checkbox(
369
+ label="Multiplicative LoRA",
370
+ value=False,
371
+ info="Apply a \"multiplicative-LoRA\" instead of a standard \"additive-LoRA\""
372
+ )
373
+
374
+ with gr.Column(scale=1):
375
+ gr.Markdown("### Status")
376
+ output_text = gr.Textbox(
377
+ label="Merge Progress & Results",
378
+ lines=20,
379
+ interactive=False,
380
+ show_copy_button=True
381
+ )
382
+
383
+ with gr.Row():
384
+ submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg")
385
+
386
+ submit_btn.click(
387
+ fn=merge_lora_efficient,
388
+ inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
389
+ lora_scale, lm_head_scale, multiplicative_lora],
390
+ outputs=output_text
391
+ )
392
+
393
+ demo.queue()
394
+ demo.launch(show_error=True)