jukofyork commited on
Commit
cbbd0ce
Β·
verified Β·
1 Parent(s): 063b1a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -64
app.py CHANGED
@@ -20,7 +20,7 @@ def load_lora_state(lora_model_name):
20
  """Download and load LoRA adapter weights"""
21
  temp_lora_dir = "/tmp/lora_adapter"
22
  os.makedirs(temp_lora_dir, exist_ok=True)
23
-
24
  # Download adapter config
25
  config_path = hf_hub_download(
26
  repo_id=lora_model_name,
@@ -28,12 +28,12 @@ def load_lora_state(lora_model_name):
28
  local_dir=temp_lora_dir,
29
  local_dir_use_symlinks=False
30
  )
31
-
32
  with open(config_path, 'r') as f:
33
  lora_config = json.load(f)
34
-
35
  scale = lora_config['lora_alpha'] / lora_config['r']
36
-
37
  # Download adapter weights
38
  try:
39
  adapter_path = hf_hub_download(
@@ -51,50 +51,50 @@ def load_lora_state(lora_model_name):
51
  local_dir_use_symlinks=False
52
  )
53
  lora_state = torch.load(adapter_path, map_location='cpu')
54
-
55
  return lora_state, scale, temp_lora_dir
56
 
57
  def find_lora_weights(lora_state, key):
58
  """Find corresponding LoRA A and B weights for a given key"""
59
  lora_A = None
60
  lora_B = None
61
-
62
  # Remove .weight suffix and handle potential prefixes
63
  clean_key = key.replace('.weight', '')
64
-
65
  for lora_key, lora_weight in lora_state.items():
66
  if clean_key in lora_key or clean_key.replace('language_model.', '') in lora_key:
67
  if 'lora_A' in lora_key:
68
  lora_A = lora_weight
69
  elif 'lora_B' in lora_key:
70
  lora_B = lora_weight
71
-
72
  # Both should be None or both should have values
73
  if (lora_A is None) != (lora_B is None):
74
  return None, None
75
-
76
  return lora_A, lora_B
77
 
78
  def download_and_upload_non_model_files(base_model_name, output_repo_name):
79
  """Download and upload non-model files (config, tokenizer, etc.)"""
80
  temp_config_dir = "/tmp/config_files"
81
  os.makedirs(temp_config_dir, exist_ok=True)
82
-
83
  try:
84
  # List all files in the repository
85
  files = list_repo_files(repo_id=base_model_name)
86
-
87
  # Filter non-model files
88
  non_model_files = [
89
- f for f in files
90
  if not (f.startswith('model') and f.endswith('.safetensors'))
91
  ]
92
-
93
  # Download and upload each non-model file
94
  for filename in non_model_files:
95
  if filename.endswith(('.gguf', '.bin')) and 'model' in filename:
96
  continue # Skip other model formats
97
-
98
  try:
99
  file_path = hf_hub_download(
100
  repo_id=base_model_name,
@@ -102,7 +102,7 @@ def download_and_upload_non_model_files(base_model_name, output_repo_name):
102
  local_dir=temp_config_dir,
103
  local_dir_use_symlinks=False
104
  )
105
-
106
  # Upload to output repo
107
  api.upload_file(
108
  path_or_fileobj=file_path,
@@ -110,70 +110,70 @@ def download_and_upload_non_model_files(base_model_name, output_repo_name):
110
  repo_id=output_repo_name,
111
  repo_type="model"
112
  )
113
-
114
  except Exception as e:
115
  info_fn(f"Skipping {filename}: {e}")
116
-
117
  finally:
118
  shutil.rmtree(temp_config_dir, ignore_errors=True)
119
 
120
- def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
121
  lora_scale, lm_head_scale, multiplicative_lora, progress=gr.Progress()):
122
  temp_lora_dir = None
123
  try:
124
  login(hf_token)
125
-
126
  progress(0.1, desc="Loading LoRA adapter...")
127
  info_fn("Loading LoRA adapter...")
128
-
129
  # Load LoRA state (this downloads the adapter)
130
  lora_state, base_scale, temp_lora_dir = load_lora_state(lora_model_name)
131
-
132
  # Apply LoRA scale multiplier
133
  scale = base_scale * lora_scale
134
  info_fn(f"Using LoRA scale: {scale} (base: {base_scale}, multiplier: {lora_scale})")
135
-
136
  progress(0.2, desc="Creating output repository...")
137
-
138
  # Create repository
139
  try:
140
  repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True)
141
  info_fn(f"Repository created/updated: {repo_url}")
142
  except Exception as e:
143
  warning_fn(f"Repository might already exist: {e}")
144
-
145
  progress(0.3, desc="Uploading configuration files...")
146
  info_fn("Uploading configuration files...")
147
-
148
  # Download and upload non-model files
149
  download_and_upload_non_model_files(base_model_name, output_repo_name)
150
-
151
  progress(0.4, desc="Finding model shards...")
152
  info_fn("Finding model shards...")
153
-
154
  # Get list of all safetensors files
155
  all_files = list_repo_files(repo_id=base_model_name)
156
  shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')]
157
-
158
  if not shard_files:
159
  raise FileNotFoundError("No model safetensors files found in the repository")
160
-
161
  info_fn(f"Found {len(shard_files)} model shards to process")
162
-
163
  merged_tensors = 0
164
  scaled_lm_heads = 0
165
  total_shards = len(shard_files)
166
-
167
  # Process each shard individually
168
  for i, shard_filename in enumerate(shard_files):
169
- progress(0.4 + (i / total_shards) * 0.5,
170
  desc=f"Processing {shard_filename} ({i+1}/{total_shards})")
171
  info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}")
172
-
173
  # Create temporary directory for this shard only
174
  temp_shard_dir = f"/tmp/shard_{i}"
175
  os.makedirs(temp_shard_dir, exist_ok=True)
176
-
177
  try:
178
  # Download the current shard
179
  shard_path = hf_hub_download(
@@ -182,19 +182,19 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
182
  local_dir=temp_shard_dir,
183
  local_dir_use_symlinks=False
184
  )
185
-
186
  # Process the shard
187
  tensors = {}
188
  shard_merged_count = 0
189
  shard_lm_head_count = 0
190
-
191
  with safe_open(shard_path, framework='pt', device='cpu') as f:
192
  # Get metadata if available
193
  metadata = f.metadata() if hasattr(f, 'metadata') else {}
194
-
195
  for key in f.keys():
196
  tensor = f.get_tensor(key)
197
-
198
  # Apply lm_head scaling if applicable
199
  if key.endswith('lm_head.weight') and lm_head_scale != 1.0:
200
  info_fn(f"Scaling {key} by {lm_head_scale}")
@@ -204,45 +204,45 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
204
  tensor = tensor.to(original_dtype)
205
  shard_lm_head_count += 1
206
  scaled_lm_heads += 1
207
-
208
  # Try to find corresponding LoRA weights
209
  lora_A, lora_B = find_lora_weights(lora_state, key)
210
-
211
  if lora_A is not None and lora_B is not None:
212
  lora_type = "Multiplicative" if multiplicative_lora else "Additive"
213
  info_fn(f"Merging {lora_type} LoRA weights for {key}")
214
  shard_merged_count += 1
215
  merged_tensors += 1
216
-
217
  # Convert to float32 for computation
218
  original_dtype = tensor.dtype
219
  tensor_f32 = tensor.to(torch.float32)
220
  lora_A_f32 = lora_A.to(torch.float32)
221
  lora_B_f32 = lora_B.to(torch.float32)
222
-
223
  if multiplicative_lora:
224
  # Apply Multiplicative-LoRA: W = W + scale * B @ A @ W
225
  tensor_f32 += scale * lora_B_f32 @ lora_A_f32 @ tensor_f32
226
  else:
227
  # Apply standard LoRA: W = W + scale * B @ A
228
  tensor_f32 += scale * lora_B_f32 @ lora_A_f32
229
-
230
  # Convert back to original dtype
231
  tensor = tensor_f32.to(original_dtype)
232
-
233
  # Clean up intermediate tensors
234
  del tensor_f32, lora_A_f32, lora_B_f32
235
  if torch.cuda.is_available():
236
  torch.cuda.empty_cache()
237
-
238
  tensors[key] = tensor
239
-
240
  # Save processed shard to temporary file
241
  output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}")
242
  save_file(tensors, output_shard_path, metadata=metadata)
243
-
244
  info_fn(f"Shard {shard_filename}:\n- Merged {shard_merged_count} tensors\n- Scaled {shard_lm_head_count} lm_head tensors")
245
-
246
  # Upload the processed shard
247
  api.upload_file(
248
  path_or_fileobj=output_shard_path,
@@ -250,27 +250,27 @@ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo
250
  repo_id=output_repo_name,
251
  repo_type="model"
252
  )
253
-
254
  # Clean up this shard's data
255
  del tensors
256
  gc.collect()
257
-
258
  finally:
259
  # Always clean up the temporary shard directory
260
  shutil.rmtree(temp_shard_dir, ignore_errors=True)
261
-
262
  progress(1.0, desc="Upload completed!")
263
-
264
  success_msg = f"βœ“ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights\nScaled {scaled_lm_heads} lm_head layers"
265
  info_fn("Merge completed successfully!")
266
-
267
  return success_msg
268
-
269
  except Exception as e:
270
  error_msg = f"βœ— Error during merge: {str(e)}"
271
  warning_fn(error_msg)
272
  return error_msg
273
-
274
  finally:
275
  # Cleanup LoRA directory
276
  if temp_lora_dir and os.path.exists(temp_lora_dir):
@@ -284,11 +284,13 @@ This tool merges LoRA (Low-Rank Adaptation) adapters with base models using a me
284
 
285
  ### Key Features
286
  - **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model
287
- - **Streaming Processing**: Downloads β†’ Processes β†’ Uploads β†’ Deletes each shard sequentially
288
  - **Automatic Cleanup**: Temporary files are automatically removed after processing
289
  - **Progress Tracking**: Real-time status updates throughout the merge process
290
- - **Advanced Options**: Configurable LoRA scaling, LM head scaling, and multiplicative LoRA support
 
291
 
 
292
  ### How It Works
293
  LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool applies the LoRA transformation with configurable scaling:
294
 
@@ -299,7 +301,7 @@ Additionally, the model's default temperature behavior can be adjusted by scalin
299
 
300
  - **Up-scaling**: Makes the model's outputs more peaked, requiring lower temperature settings for the same output distribution
301
  - **Down-scaling**: Makes the model's outputs flatter, requiring higher temperature settings for the same output distribution
302
- - **Examples**:
303
  - Scaling `lm_head.weight` by `1.25` makes the new model with `temperature = 1.0` act like the old model with `temperature = 0.8`
304
  - Scaling `lm_head.weight` by `0.667` makes the new model with `temperature = 1.0` act like the old model with `temperature = 1.5`
305
 
@@ -309,8 +311,8 @@ Additionally, the model's default temperature behavior can be adjusted by scalin
309
  - **Result**: Enables merging of much larger models on limited hardware
310
 
311
  ### Example Usage
312
- - **Base Model:** `microsoft/DialoGPT-medium`
313
- - **LoRA Adapter:** `username/my-trained-lora`
314
  - **Output Name:** `username/dialogpt-merged`
315
 
316
  ### Attribution
@@ -338,7 +340,7 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
338
  info="The original model to merge LoRA into"
339
  )
340
  lora_model_name = gr.Textbox(
341
- label="LoRA Adapter Repository",
342
  placeholder="username/my-lora-adapter",
343
  info="Repository containing adapter_model.safetensors"
344
  )
@@ -347,7 +349,7 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
347
  placeholder="username/my-merged-model",
348
  info="Name for the new merged model repository"
349
  )
350
-
351
  gr.Markdown("### Advanced Options")
352
  lora_scale = gr.Number(
353
  label="LoRA Scale",
@@ -382,13 +384,15 @@ with gr.Blocks(title="Memory-Efficient LoRA Merge", theme=gr.themes.Soft()) as d
382
 
383
  with gr.Row():
384
  submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg")
385
-
386
  submit_btn.click(
387
  fn=merge_lora_efficient,
388
- inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
389
  lora_scale, lm_head_scale, multiplicative_lora],
390
  outputs=output_text
391
  )
392
 
 
 
393
  demo.queue()
394
  demo.launch(show_error=True)
 
20
  """Download and load LoRA adapter weights"""
21
  temp_lora_dir = "/tmp/lora_adapter"
22
  os.makedirs(temp_lora_dir, exist_ok=True)
23
+
24
  # Download adapter config
25
  config_path = hf_hub_download(
26
  repo_id=lora_model_name,
 
28
  local_dir=temp_lora_dir,
29
  local_dir_use_symlinks=False
30
  )
31
+
32
  with open(config_path, 'r') as f:
33
  lora_config = json.load(f)
34
+
35
  scale = lora_config['lora_alpha'] / lora_config['r']
36
+
37
  # Download adapter weights
38
  try:
39
  adapter_path = hf_hub_download(
 
51
  local_dir_use_symlinks=False
52
  )
53
  lora_state = torch.load(adapter_path, map_location='cpu')
54
+
55
  return lora_state, scale, temp_lora_dir
56
 
57
  def find_lora_weights(lora_state, key):
58
  """Find corresponding LoRA A and B weights for a given key"""
59
  lora_A = None
60
  lora_B = None
61
+
62
  # Remove .weight suffix and handle potential prefixes
63
  clean_key = key.replace('.weight', '')
64
+
65
  for lora_key, lora_weight in lora_state.items():
66
  if clean_key in lora_key or clean_key.replace('language_model.', '') in lora_key:
67
  if 'lora_A' in lora_key:
68
  lora_A = lora_weight
69
  elif 'lora_B' in lora_key:
70
  lora_B = lora_weight
71
+
72
  # Both should be None or both should have values
73
  if (lora_A is None) != (lora_B is None):
74
  return None, None
75
+
76
  return lora_A, lora_B
77
 
78
  def download_and_upload_non_model_files(base_model_name, output_repo_name):
79
  """Download and upload non-model files (config, tokenizer, etc.)"""
80
  temp_config_dir = "/tmp/config_files"
81
  os.makedirs(temp_config_dir, exist_ok=True)
82
+
83
  try:
84
  # List all files in the repository
85
  files = list_repo_files(repo_id=base_model_name)
86
+
87
  # Filter non-model files
88
  non_model_files = [
89
+ f for f in files
90
  if not (f.startswith('model') and f.endswith('.safetensors'))
91
  ]
92
+
93
  # Download and upload each non-model file
94
  for filename in non_model_files:
95
  if filename.endswith(('.gguf', '.bin')) and 'model' in filename:
96
  continue # Skip other model formats
97
+
98
  try:
99
  file_path = hf_hub_download(
100
  repo_id=base_model_name,
 
102
  local_dir=temp_config_dir,
103
  local_dir_use_symlinks=False
104
  )
105
+
106
  # Upload to output repo
107
  api.upload_file(
108
  path_or_fileobj=file_path,
 
110
  repo_id=output_repo_name,
111
  repo_type="model"
112
  )
113
+
114
  except Exception as e:
115
  info_fn(f"Skipping {filename}: {e}")
116
+
117
  finally:
118
  shutil.rmtree(temp_config_dir, ignore_errors=True)
119
 
120
+ def merge_lora_efficient(hf_token, base_model_name, lora_model_name, output_repo_name,
121
  lora_scale, lm_head_scale, multiplicative_lora, progress=gr.Progress()):
122
  temp_lora_dir = None
123
  try:
124
  login(hf_token)
125
+
126
  progress(0.1, desc="Loading LoRA adapter...")
127
  info_fn("Loading LoRA adapter...")
128
+
129
  # Load LoRA state (this downloads the adapter)
130
  lora_state, base_scale, temp_lora_dir = load_lora_state(lora_model_name)
131
+
132
  # Apply LoRA scale multiplier
133
  scale = base_scale * lora_scale
134
  info_fn(f"Using LoRA scale: {scale} (base: {base_scale}, multiplier: {lora_scale})")
135
+
136
  progress(0.2, desc="Creating output repository...")
137
+
138
  # Create repository
139
  try:
140
  repo_url = api.create_repo(repo_id=output_repo_name, exist_ok=True)
141
  info_fn(f"Repository created/updated: {repo_url}")
142
  except Exception as e:
143
  warning_fn(f"Repository might already exist: {e}")
144
+
145
  progress(0.3, desc="Uploading configuration files...")
146
  info_fn("Uploading configuration files...")
147
+
148
  # Download and upload non-model files
149
  download_and_upload_non_model_files(base_model_name, output_repo_name)
150
+
151
  progress(0.4, desc="Finding model shards...")
152
  info_fn("Finding model shards...")
153
+
154
  # Get list of all safetensors files
155
  all_files = list_repo_files(repo_id=base_model_name)
156
  shard_files = [f for f in all_files if f.startswith('model') and f.endswith('.safetensors')]
157
+
158
  if not shard_files:
159
  raise FileNotFoundError("No model safetensors files found in the repository")
160
+
161
  info_fn(f"Found {len(shard_files)} model shards to process")
162
+
163
  merged_tensors = 0
164
  scaled_lm_heads = 0
165
  total_shards = len(shard_files)
166
+
167
  # Process each shard individually
168
  for i, shard_filename in enumerate(shard_files):
169
+ progress(0.4 + (i / total_shards) * 0.5,
170
  desc=f"Processing {shard_filename} ({i+1}/{total_shards})")
171
  info_fn(f"Processing shard {i+1}/{total_shards}: {shard_filename}")
172
+
173
  # Create temporary directory for this shard only
174
  temp_shard_dir = f"/tmp/shard_{i}"
175
  os.makedirs(temp_shard_dir, exist_ok=True)
176
+
177
  try:
178
  # Download the current shard
179
  shard_path = hf_hub_download(
 
182
  local_dir=temp_shard_dir,
183
  local_dir_use_symlinks=False
184
  )
185
+
186
  # Process the shard
187
  tensors = {}
188
  shard_merged_count = 0
189
  shard_lm_head_count = 0
190
+
191
  with safe_open(shard_path, framework='pt', device='cpu') as f:
192
  # Get metadata if available
193
  metadata = f.metadata() if hasattr(f, 'metadata') else {}
194
+
195
  for key in f.keys():
196
  tensor = f.get_tensor(key)
197
+
198
  # Apply lm_head scaling if applicable
199
  if key.endswith('lm_head.weight') and lm_head_scale != 1.0:
200
  info_fn(f"Scaling {key} by {lm_head_scale}")
 
204
  tensor = tensor.to(original_dtype)
205
  shard_lm_head_count += 1
206
  scaled_lm_heads += 1
207
+
208
  # Try to find corresponding LoRA weights
209
  lora_A, lora_B = find_lora_weights(lora_state, key)
210
+
211
  if lora_A is not None and lora_B is not None:
212
  lora_type = "Multiplicative" if multiplicative_lora else "Additive"
213
  info_fn(f"Merging {lora_type} LoRA weights for {key}")
214
  shard_merged_count += 1
215
  merged_tensors += 1
216
+
217
  # Convert to float32 for computation
218
  original_dtype = tensor.dtype
219
  tensor_f32 = tensor.to(torch.float32)
220
  lora_A_f32 = lora_A.to(torch.float32)
221
  lora_B_f32 = lora_B.to(torch.float32)
222
+
223
  if multiplicative_lora:
224
  # Apply Multiplicative-LoRA: W = W + scale * B @ A @ W
225
  tensor_f32 += scale * lora_B_f32 @ lora_A_f32 @ tensor_f32
226
  else:
227
  # Apply standard LoRA: W = W + scale * B @ A
228
  tensor_f32 += scale * lora_B_f32 @ lora_A_f32
229
+
230
  # Convert back to original dtype
231
  tensor = tensor_f32.to(original_dtype)
232
+
233
  # Clean up intermediate tensors
234
  del tensor_f32, lora_A_f32, lora_B_f32
235
  if torch.cuda.is_available():
236
  torch.cuda.empty_cache()
237
+
238
  tensors[key] = tensor
239
+
240
  # Save processed shard to temporary file
241
  output_shard_path = os.path.join(temp_shard_dir, f"processed_{shard_filename}")
242
  save_file(tensors, output_shard_path, metadata=metadata)
243
+
244
  info_fn(f"Shard {shard_filename}:\n- Merged {shard_merged_count} tensors\n- Scaled {shard_lm_head_count} lm_head tensors")
245
+
246
  # Upload the processed shard
247
  api.upload_file(
248
  path_or_fileobj=output_shard_path,
 
250
  repo_id=output_repo_name,
251
  repo_type="model"
252
  )
253
+
254
  # Clean up this shard's data
255
  del tensors
256
  gc.collect()
257
+
258
  finally:
259
  # Always clean up the temporary shard directory
260
  shutil.rmtree(temp_shard_dir, ignore_errors=True)
261
+
262
  progress(1.0, desc="Upload completed!")
263
+
264
  success_msg = f"βœ“ Successfully merged and uploaded model!\nModel URL: https://huggingface.co/{output_repo_name}\nProcessed {total_shards} shards\nMerged {merged_tensors} layers with LoRA weights\nScaled {scaled_lm_heads} lm_head layers"
265
  info_fn("Merge completed successfully!")
266
+
267
  return success_msg
268
+
269
  except Exception as e:
270
  error_msg = f"βœ— Error during merge: {str(e)}"
271
  warning_fn(error_msg)
272
  return error_msg
273
+
274
  finally:
275
  # Cleanup LoRA directory
276
  if temp_lora_dir and os.path.exists(temp_lora_dir):
 
284
 
285
  ### Key Features
286
  - **Minimal Memory Usage**: Processes one model shard at a time instead of loading the entire model
287
+ - **Streaming Processing**: Downloads β†’ Processes β†’ Uploads β†’ Deletes each shard sequentially
288
  - **Automatic Cleanup**: Temporary files are automatically removed after processing
289
  - **Progress Tracking**: Real-time status updates throughout the merge process
290
+ - **Advanced Options**: Configurable LoRA scaling, LM-head scaling, and multiplicative LoRA support
291
+ """
292
 
293
+ DETAILS_TEXT = """
294
  ### How It Works
295
  LoRA enables efficient fine-tuning by adding small adapter weights rather than modifying the entire model. This tool applies the LoRA transformation with configurable scaling:
296
 
 
301
 
302
  - **Up-scaling**: Makes the model's outputs more peaked, requiring lower temperature settings for the same output distribution
303
  - **Down-scaling**: Makes the model's outputs flatter, requiring higher temperature settings for the same output distribution
304
+ - **Examples**:
305
  - Scaling `lm_head.weight` by `1.25` makes the new model with `temperature = 1.0` act like the old model with `temperature = 0.8`
306
  - Scaling `lm_head.weight` by `0.667` makes the new model with `temperature = 1.0` act like the old model with `temperature = 1.5`
307
 
 
311
  - **Result**: Enables merging of much larger models on limited hardware
312
 
313
  ### Example Usage
314
+ - **Base Model:** `microsoft/DialoGPT-medium`
315
+ - **LoRA Adapter:** `username/my-trained-lora`
316
  - **Output Name:** `username/dialogpt-merged`
317
 
318
  ### Attribution
 
340
  info="The original model to merge LoRA into"
341
  )
342
  lora_model_name = gr.Textbox(
343
+ label="LoRA Adapter Repository",
344
  placeholder="username/my-lora-adapter",
345
  info="Repository containing adapter_model.safetensors"
346
  )
 
349
  placeholder="username/my-merged-model",
350
  info="Name for the new merged model repository"
351
  )
352
+
353
  gr.Markdown("### Advanced Options")
354
  lora_scale = gr.Number(
355
  label="LoRA Scale",
 
384
 
385
  with gr.Row():
386
  submit_btn = gr.Button("Start LoRA Merge", variant="primary", size="lg")
387
+
388
  submit_btn.click(
389
  fn=merge_lora_efficient,
390
+ inputs=[hf_token, base_model_name, lora_model_name, output_repo_name,
391
  lora_scale, lm_head_scale, multiplicative_lora],
392
  outputs=output_text
393
  )
394
 
395
+ gr.Markdown(DETAILS_TEXT)
396
+
397
  demo.queue()
398
  demo.launch(show_error=True)