civitai-to-hf / app.py
multimodalart's picture
Update app.py
0ec77e4 verified
raw
history blame
39.6 kB
import requests
import os
import gradio as gr
from huggingface_hub import update_repo_visibility, upload_folder, create_repo, upload_file
from slugify import slugify
import re
import uuid
from typing import Optional, Dict, Any, List
import json
import shutil # For cleaning up local folders
import traceback # For debugging
TRUSTED_UPLOADERS = [
"KappaNeuro", "CiroN2022", "multimodalart", "Norod78", "joachimsallstrom",
"blink7630", "e-n-v-y", "DoctorDiffusion", "RalFinger", "artificialguybr"
]
# --- Helper Functions (CivitAI API, Data Extraction, File Handling) ---
def get_json_data(url: str) -> Optional[Dict[str, Any]]:
url_split = url.split('/')
if len(url_split) < 5 or not url_split[4].isdigit(): # Check if model ID is present and numeric
print(f"Error: Invalid CivitAI URL format or missing model ID: {url}")
# Try to extract model ID if it's just the ID
if url.isdigit():
model_id = url
else:
# Check if it's a slugified URL without /models/ part
match = re.search(r'(\d+)(?:/[^/]+)?$', url)
if match:
model_id = match.group(1)
else:
return None
else:
model_id = url_split[4]
api_url = f"https://civitai.com/api/v1/models/{model_id}"
try:
response = requests.get(api_url, timeout=15)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"Error fetching JSON data from {api_url}: {e}")
return None
def check_nsfw(json_data: Dict[str, Any], profile: Optional[gr.OAuthProfile]) -> bool:
if json_data.get("nsfw", False):
print(f"Model {json_data.get('id', 'Unknown')} flagged as NSFW at model level.")
return False
if profile and profile.username in TRUSTED_UPLOADERS:
print(f"Trusted uploader {profile.username}, bypassing strict image NSFW check for model {json_data.get('id', 'Unknown')}.")
return True
for model_version in json_data.get("modelVersions", []):
for image_media in model_version.get("images", []): # 'images' can contain videos
if image_media.get("nsfwLevel", 0) > 5: # Allow 0-5 (None, Soft, Moderate, Mature, X)
print(f"Model {json_data.get('id', 'Unknown')} version {model_version.get('id')} has media with nsfwLevel > 5.")
return False
return True
def get_prompts_from_image(image_id: int) -> (str, str):
url = f'https://civitai.com/api/trpc/image.getGenerationData?input={{"json":{{"id":{image_id}}}}}'
prompt = ""
negative_prompt = ""
try:
response = requests.get(url, timeout=10)
if response.status_code == 200:
data = response.json()
result = data.get('result', {}).get('data', {}).get('json', {})
if result and result.get('meta') is not None:
prompt = result['meta'].get('prompt', "")
negative_prompt = result['meta'].get('negativePrompt', "")
# else:
# print(f"Prompt fetch for {image_id}: Status {response.status_code}")
except requests.exceptions.RequestException as e:
print(f"Error fetching prompt data for image_id {image_id}: {e}")
return prompt, negative_prompt
def extract_info(json_data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if json_data.get("type") != "LORA":
return None
model_mapping = {
"SDXL 1.0": "stabilityai/stable-diffusion-xl-base-1.0", "SDXL 0.9": "stabilityai/stable-diffusion-xl-base-1.0",
"SD 1.5": "runwayml/stable-diffusion-v1-5", "SD 1.4": "CompVis/stable-diffusion-v1-4",
"SD 2.1": "stabilityai/stable-diffusion-2-1-base", "SD 2.0": "stabilityai/stable-diffusion-2-base",
"SD 2.1 768": "stabilityai/stable-diffusion-2-1", "SD 2.0 768": "stabilityai/stable-diffusion-2",
"SD 3": "stabilityai/stable-diffusion-3-medium-diffusers",
"SD 3.5": "stabilityai/stable-diffusion-3-medium",
"SD 3.5 Large": "stabilityai/stable-diffusion-3-medium",
"SD 3.5 Medium": "stabilityai/stable-diffusion-3-medium",
"SD 3.5 Large Turbo": "stabilityai/stable-diffusion-3-medium-turbo",
"Flux.1 D": "black-forest-labs/FLUX.1-dev", "Flux.1 S": "black-forest-labs/FLUX.1-schnell",
"LTXV": "Lightricks/LTX-Video-0.9.7-dev",
"Hunyuan Video": "hunyuanvideo-community/HunyuanVideo",
"Wan Video 1.3B t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"Wan Video 14B t2v": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"Wan Video 14B i2v 480p": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
"Wan Video 14B i2v 720p": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
"Pony": "SG161222/RealVisXL_V4.0",
"Illustrious": "artificialguybr/LogoRedmond", # Example, could be "stabilityai/stable-diffusion-xl-base-1.0"
}
for model_version in json_data.get("modelVersions", []):
civic_base_model_name = model_version.get("baseModel")
if civic_base_model_name in model_mapping:
base_model_hf_name = model_mapping[civic_base_model_name]
urls_to_download: List[Dict[str, Any]] = []
primary_file_found = False
for file_data in model_version.get("files", []):
if file_data.get("primary") and file_data.get("type") == "Model":
urls_to_download.append({
"url": file_data["downloadUrl"],
"filename": os.path.basename(file_data["name"]),
"type": "weightName", "is_video": False
})
primary_file_found = True
break
if not primary_file_found: continue
for media_data in model_version.get("images", []):
if media_data.get("nsfwLevel", 0) > 5: continue
media_url_parts = media_data.get("url","").split("/") # Add default "" for url
if not media_url_parts or not media_url_parts[-1]: continue # Ensure URL and filename part exist
filename_part = media_url_parts[-1]
id_candidate = filename_part.split(".")[0].split("?")[0]
prompt, negative_prompt = "", ""
if media_data.get("hasMeta", False) and media_data.get("type") == "image":
if id_candidate.isdigit():
try:
prompt, negative_prompt = get_prompts_from_image(int(id_candidate))
except ValueError:
print(f"Warning: Non-integer ID '{id_candidate}' for prompt fetching.")
except Exception as e:
print(f"Warning: Prompt fetch failed for ID {id_candidate}: {e}")
is_video_file = media_data.get("type") == "video"
media_type_key = "videoName" if is_video_file else "imageName"
urls_to_download.append({
"url": media_data["url"], "filename": os.path.basename(filename_part),
"type": media_type_key, "prompt": prompt, "negative_prompt": negative_prompt,
"is_video": is_video_file
})
allow_commercial_use_raw = json_data.get("allowCommercialUse", "Sell")
if isinstance(allow_commercial_use_raw, list):
allow_commercial_use_processed = allow_commercial_use_raw[0] if allow_commercial_use_raw else "Sell"
elif isinstance(allow_commercial_use_raw, bool):
allow_commercial_use_processed = "Sell" if allow_commercial_use_raw else "None"
elif isinstance(allow_commercial_use_raw, str):
allow_commercial_use_processed = allow_commercial_use_raw
else: # Fallback for unexpected types
allow_commercial_use_processed = "Sell"
info_dict = {
"urls_to_download": urls_to_download, "id": model_version.get("id"),
"baseModel": base_model_hf_name, "modelId": model_version.get("modelId", json_data.get("id")),
"name": json_data.get("name", "Untitled LoRA"),
"description": json_data.get("description", "No description provided."),
"trainedWords": model_version.get("trainedWords", []),
"creator": json_data.get("creator", {}).get("username", "Unknown Creator"),
"tags": json_data.get("tags", []),
"allowNoCredit": json_data.get("allowNoCredit", True),
"allowCommercialUse": allow_commercial_use_processed,
"allowDerivatives": json_data.get("allowDerivatives", True),
"allowDifferentLicense": json_data.get("allowDifferentLicense", True)
}
return info_dict
return None
def download_file_from_url(url: str, filename: str, folder: str = "."):
headers = {}
local_filepath = os.path.join(folder, filename)
try:
headers['User-Agent'] = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
civitai_token = os.environ.get("CIVITAI_API_TOKEN")
if civitai_token:
headers['Authorization'] = f'Bearer {civitai_token}'
response = requests.get(url, headers=headers, stream=True, timeout=120)
response.raise_for_status()
with open(local_filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
except requests.exceptions.HTTPError as e_http:
if e_http.response.status_code in [401, 403] and not headers.get('Authorization') and not civitai_token:
print(f"Authorization error (401/403) downloading {url}. Consider setting CIVITAI_API_TOKEN for restricted files.")
raise gr.Error(f"HTTP Error downloading {filename}: {e_http.response.status_code} {e_http.response.reason}. URL: {url}")
except requests.exceptions.RequestException as e_req:
raise gr.Error(f"Request Error downloading {filename}: {e_req}. URL: {url}")
def download_files(info: Dict[str, Any], folder: str = ".") -> Dict[str, List[Any]]:
downloaded_media_items: List[Dict[str, Any]] = []
downloaded_weights: List[str] = []
for item in info["urls_to_download"]:
filename_to_save_raw = item["filename"]
filename_to_save = re.sub(r'[<>:"/\\|?*]', '_', filename_to_save_raw)
if not filename_to_save:
base, ext = os.path.splitext(item["url"])
filename_to_save = f"downloaded_file_{uuid.uuid4().hex[:8]}{ext if ext else '.bin'}"
gr.Info(f"Downloading {filename_to_save}...")
download_file_from_url(item["url"], filename_to_save, folder)
if item["type"] == "weightName":
downloaded_weights.append(filename_to_save)
elif item["type"] in ["imageName", "videoName"]:
prompt_clean = re.sub(r'<.*?>', '', item.get("prompt", ""))
negative_prompt_clean = re.sub(r'<.*?>', '', item.get("negative_prompt", ""))
downloaded_media_items.append({
"filename": filename_to_save, "prompt": prompt_clean,
"negative_prompt": negative_prompt_clean, "is_video": item.get("is_video", False)
})
return {"media_items": downloaded_media_items, "weightName": downloaded_weights}
def process_url(url: str, profile: Optional[gr.OAuthProfile], do_download: bool = True, folder: str = ".") -> (Optional[Dict[str, Any]], Optional[Dict[str, List[Any]]]):
json_data = get_json_data(url)
if json_data:
if check_nsfw(json_data, profile):
info = extract_info(json_data)
if info:
downloaded_files_dict = None
if do_download:
downloaded_files_dict = download_files(info, folder)
return info, downloaded_files_dict
else:
model_type = json_data.get("type", "Unknown type")
base_models_in_json = [mv.get("baseModel", "Unknown base") for mv in json_data.get("modelVersions", [])]
error_message = f"This LoRA is not supported. Details:\n"
error_message += f"- Model Type: {model_type} (expected LORA)\n"
if base_models_in_json:
error_message += f"- Detected Base Models in CivitAI: {', '.join(list(set(base_models_in_json)))}\n"
error_message += "Ensure it's a LORA for a supported base (SD, SDXL, Pony, Flux, LTXV, Hunyuan, Wan) and has primary files."
raise gr.Error(error_message)
else:
raise gr.Error("This model is flagged as NSFW by CivitAI or its media exceeds the allowed NSFW level (max level 5).")
else:
raise gr.Error("Could not fetch CivitAI API data. Check URL or model ID. Example: https://civitai.com/models/12345 or just 12345")
# --- README Creation ---
def create_readme(info: Dict[str, Any], downloaded_files: Dict[str, List[Any]], user_repo_id: str, link_civit: bool = False, is_author: bool = True, folder: str = "."):
original_url = f"https://civitai.com/models/{info['modelId']}"
link_civit_disclaimer = f'([CivitAI]({original_url}))'
non_author_disclaimer = f'This model was originally uploaded on [CivitAI]({original_url}), by [{info["creator"]}](https://civitai.com/user/{info["creator"]}/models). The information below was provided by the author on CivitAI:'
is_video_model = False
video_base_models_hf = [
"Lightricks/LTX-Video-0.9.7-dev", "hunyuanvideo-community/HunyuanVideo",
"hunyuanvideo-community/HunyuanVideo-I2V", "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
"Wan-AI/Wan2.1-T2V-14B-Diffusers", "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
]
if info["baseModel"] in video_base_models_hf: is_video_model = True
is_i2v_model = "i2v" in info["baseModel"].lower()
default_tags = ["lora", "diffusers", "migrated"]
if is_video_model:
default_tags.append("video")
default_tags.append("image-to-video" if is_i2v_model else "text-to-video")
default_tags.append("template:video-lora")
else:
default_tags.extend(["text-to-image", "stable-diffusion", "template:sd-lora"])
civit_tags_raw = info.get("tags", [])
civit_tags_processed = []
if isinstance(civit_tags_raw, list):
civit_tags_processed = [str(t).replace(":", "").strip() for t in civit_tags_raw if str(t).replace(":", "").strip() and str(t).replace(":", "").strip() not in default_tags]
tags = default_tags + civit_tags_processed
unpacked_tags = "\n- ".join(sorted(list(set(tags))))
trained_words = [word for word in info.get('trainedWords', []) if word]
formatted_words = ', '.join(f'`{word}`' for word in trained_words)
trigger_words_section = f"## Trigger words\nYou should use {formatted_words} to trigger the generation." if formatted_words else ""
widget_content = ""
media_items_for_widget = downloaded_files.get("media_items", [])
if not media_items_for_widget:
widget_content = "# No example media available for widget.\n"
else:
for media_item in media_items_for_widget[:5]:
prompt_text = media_item["prompt"]
negative_prompt_text = media_item["negative_prompt"]
filename = media_item["filename"]
escaped_prompt = prompt_text.replace("'", "''").replace("\n", " ")
negative_prompt_cleaned_and_escaped = ""
if negative_prompt_text:
negative_prompt_cleaned_and_escaped = negative_prompt_text.replace("'", "''").replace("\n", " ") # Correct
negative_prompt_widget_entry = ""
if negative_prompt_cleaned_and_escaped: # Only add if non-empty
negative_prompt_widget_entry = f"""parameters:
negative_prompt: '{negative_prompt_cleaned_and_escaped}'"""
widget_content += f"""- text: '{escaped_prompt if escaped_prompt else ' ' }'
{negative_prompt_widget_entry}
output:
url: >-
{filename}
"""
flux_models_bf16 = ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]
dtype = "torch.bfloat16" if info["baseModel"] in flux_models_bf16 else "torch.float16"
pipeline_import = "AutoPipelineForText2Image"
example_prompt_for_pipeline = formatted_words if formatted_words else 'Your custom prompt'
if media_items_for_widget and media_items_for_widget[0]["prompt"]:
example_prompt_for_pipeline = media_items_for_widget[0]["prompt"]
cleaned_example_pipeline_prompt = example_prompt_for_pipeline.replace("'", "\\'").replace("\n", " ")
pipeline_call_example = f"image = pipeline('{cleaned_example_pipeline_prompt}').images[0]"
if is_video_model:
pipeline_import = "DiffusionPipeline"
video_prompt_example = cleaned_example_pipeline_prompt
pipeline_call_example = f"# Example prompt for video generation\nprompt = \"{video_prompt_example}\"\n"
pipeline_call_example += "# Adjust parameters like num_frames, num_inference_steps, height, width as needed for the specific pipeline.\n"
pipeline_call_example += "# video_frames = pipeline(prompt, num_frames=16, guidance_scale=7.5, num_inference_steps=25).frames # Example parameters"
if "LTX-Video" in info["baseModel"]:
pipeline_call_example += "\n# LTX-Video uses a specific setup. Check its model card on Hugging Face."
elif "HunyuanVideo" in info["baseModel"]:
pipeline_call_example += "\n# HunyuanVideo often uses custom pipeline scripts or specific classes (e.g., HunyuanDiTPipeline). Check its HF model card."
elif "Wan-AI" in info["baseModel"]:
pipeline_call_example += "\n# Wan-AI models (e.g., WanVideoTextToVideoPipeline) require specific pipeline classes. Check model card for usage."
weight_name = (downloaded_files["weightName"][0] if downloaded_files.get("weightName")
else "your_lora_weights.safetensors")
diffusers_code_block = f"""```py
from diffusers import {pipeline_import}
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
# Note: The pipeline class '{pipeline_import}' is a general suggestion.
# For specific video models (LTX, Hunyuan, Wan), you will likely need a dedicated pipeline class
# (e.g., TextToVideoSDPipeline, HunyuanDiTPipeline, WanVideoTextToVideoPipeline, etc.).
# Please refer to the documentation of the base model '{info["baseModel"]}' on Hugging Face for precise usage.
pipeline = {pipeline_import}.from_pretrained('{info["baseModel"]}', torch_dtype={dtype})
pipeline.to(device)
# Load LoRA weights
pipeline.load_lora_weights('{user_repo_id}', weight_name='{weight_name}')
# For some pipelines, you might need to fuse LoRA layers before inference
# and unfuse them after, or apply scaling. Check model card.
# Example: pipeline.fuse_lora() or pipeline.set_adapters(["default"], adapter_weights=[0.8])
# Example generation call (adjust parameters as needed for the specific pipeline)
{pipeline_call_example}
# If using fused LoRA:
# pipeline.unfuse_lora()
```"""
commercial_use_val = info["allowCommercialUse"]
content = f"""---
license: other
license_name: bespoke-lora-trained-license
license_link: https://multimodal.art/civitai-licenses?allowNoCredit={info["allowNoCredit"]}&allowCommercialUse={commercial_use_val}&allowDerivatives={info["allowDerivatives"]}&allowDifferentLicense={info["allowDifferentLicense"]}
tags:
- {unpacked_tags}
base_model: {info["baseModel"]}
instance_prompt: {trained_words[0] if trained_words else ''}
widget:
{widget_content}
---
# {info["name"]}
<Gallery />
{non_author_disclaimer if not is_author else ''}
{link_civit_disclaimer if link_civit else ''}
## Model description
{info["description"]}
{trigger_words_section}
## Download model
Weights for this model are available in Safetensors format.
[Download](/{user_repo_id}/tree/main/{weight_name}) the LoRA in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
{diffusers_code_block}
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters).
"""
readme_path = os.path.join(folder, "README.md")
with open(readme_path, "w", encoding="utf-8") as file:
file.write(content)
# --- Hugging Face Profile / Authorship ---
def get_creator(username: str) -> Dict:
if "COOKIE_INFO" not in os.environ or not os.environ["COOKIE_INFO"]:
print("Warning: COOKIE_INFO env var not set. Cannot fetch CivitAI creator's HF username.")
return {"result": {"data": {"json": {"links": []}}}}
url = f"https://civitai.com/api/trpc/user.getCreator?input=%7B%22json%22%3A%7B%22username%22%3A%22{username}%22%2C%22authed%22%3Atrue%7D%7D"
headers = {
"authority": "civitai.com", "accept": "*/*", "accept-language": "en-US,en;q=0.9",
"content-type": "application/json", "cookie": os.environ["COOKIE_INFO"],
"referer": f"https://civitai.com/user/{username}/models",
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.0.0 Safari/537.36"
}
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"Error fetching CivitAI creator data for {username}: {e}")
return {"result": {"data": {"json": {"links": []}}}}
def extract_huggingface_username(civitai_username: str) -> Optional[str]:
data = get_creator(civitai_username)
try:
links = data.get('result', {}).get('data', {}).get('json', {}).get('links', [])
if not isinstance(links, list): return None
for link in links:
if not isinstance(link, dict): continue
url = link.get('url', '')
if isinstance(url, str) and \
(url.startswith('https://huggingface.co/') or url.startswith('https://www.huggingface.co/')):
hf_username = url.split('/')[-1].split('?')[0].split('#')[0]
if hf_username: return hf_username
except Exception as e:
print(f"Error parsing CivitAI creator data for HF username: {e}")
return None
# --- Gradio UI Logic Functions ---
def check_civit_link(profile_state: Optional[gr.OAuthProfile], url_input: str):
url_input = url_input.strip()
if not url_input:
return "", gr.update(interactive=False, visible=False), gr.update(visible=False), gr.update(visible=False)
if not profile_state:
return "Please log in with Hugging Face first.", gr.update(interactive=False, visible=False), gr.update(visible=False), gr.update(visible=False)
try:
info, _ = process_url(url_input, profile_state, do_download=False)
if not info: # Should be caught by process_url, but as a safeguard
return "Could not process this CivitAI URL. Model might be unsupported or invalid.", gr.update(interactive=False, visible=True), gr.update(visible=False), gr.update(visible=False)
except gr.Error as e: # Catch errors from process_url (like NSFW, unsupported, API fetch failed)
return str(e), gr.update(interactive=False, visible=True), gr.update(visible=False), gr.update(visible=False)
except Exception as e: # Catch any other unexpected error during processing check
print(f"Unexpected error in check_civit_link during process_url: {e}\n{traceback.format_exc()}")
return f"An unexpected error occurred: {str(e)}", gr.update(interactive=False, visible=True), gr.update(visible=False), gr.update(visible=False)
# If model is processable, then check authorship
civitai_creator_username = info['creator']
hf_username_on_civitai = extract_huggingface_username(civitai_creator_username)
if profile_state.username in TRUSTED_UPLOADERS:
return f'Welcome, trusted uploader {profile_state.username}! You can upload this model by "{civitai_creator_username}".', gr.update(interactive=True, visible=True), gr.update(visible=False), gr.update(visible=True)
if not hf_username_on_civitai:
no_username_text = (
f'If you are "{civitai_creator_username}" on CivitAI, hi! Your CivitAI profile does not seem to have a Hugging Face username linked. '
f'Please visit <a href="https://civitai.com/user/account" target="_blank">your CivitAI account settings</a> and add your 🤗 username ({profile_state.username}). '
f'Example: <br/><img width="60%" src="https://i.imgur.com/hCbo9uL.png" alt="CivitAI profile settings example"/><br/>'
f'(If you are not "{civitai_creator_username}", you cannot submit their model at this time.)'
)
return no_username_text, gr.update(interactive=False, visible=False), gr.update(visible=True), gr.update(visible=False)
if profile_state.username.lower() != hf_username_on_civitai.lower():
unmatched_username_text = (
f'The Hugging Face username on "{civitai_creator_username}"\'s CivitAI profile ("{hf_username_on_civitai}") '
f'does not match your logged-in Hugging Face account ("{profile_state.username}"). '
f'Please update it on <a href="https://civitai.com/user/account" target="_blank">CivitAI</a> or log in to Hugging Face as "{hf_username_on_civitai}".<br/>'
f'<img src="https://i.imgur.com/hCbo9uL.png" alt="CivitAI profile settings example"/>'
)
return unmatched_username_text, gr.update(interactive=False, visible=False), gr.update(visible=True), gr.update(visible=False)
return f'Authorship verified for "{civitai_creator_username}" (🤗 {profile_state.username}). Ready to upload!', gr.update(interactive=True, visible=True), gr.update(visible=False), gr.update(visible=True)
def handle_auth_change_and_update_state(profile: Optional[gr.OAuthProfile]):
# This function now returns the profile to update the state
if profile: # Logged in
return profile, gr.update(visible=False), gr.update(visible=True), "", gr.update(value=""), gr.update(interactive=False, visible=False), gr.update(visible=False)
else: # Logged out
return None, gr.update(visible=True), gr.update(visible=False), "", gr.update(value=""), gr.update(interactive=False, visible=False), gr.update(visible=False)
def show_output_area():
return gr.update(visible=True)
def list_civit_models(username: str) -> str:
if not username.strip(): return ""
url = f"https://civitai.com/api/v1/models?username={username}&limit=100&sort=Newest" # Max limit is 100 per page on CivitAI
json_models_list = []
page_count, max_pages = 0, 1 # Limit to 1 page (100 models) for now to be quicker, can be increased
gr.Info(f"Fetching LoRAs for CivitAI user: {username}...")
while url and page_count < max_pages:
try:
response = requests.get(url, timeout=15) # Increased timeout
response.raise_for_status()
data = response.json()
current_items = data.get('items', [])
json_models_list.extend(item for item in current_items if item.get("type") == "LORA" and item.get("name"))
metadata = data.get('metadata', {})
url = metadata.get('nextPage', None)
page_count += 1
except requests.RequestException as e:
gr.Warning(f"Failed to fetch page {page_count + 1} for {username}: {e}")
break
if not json_models_list:
gr.Info(f"No suitable LoRA models found for {username} or failed to fetch.")
return ""
urls_text = "\n".join(
f'https://civitai.com/models/{model["id"]}/{slugify(model["name"])}'
for model in json_models_list
)
gr.Info(f"Found {len(json_models_list)} LoRA models for {username}.")
return urls_text.strip()
# --- Main Upload Functions ---
def upload_civit_to_hf(profile: Optional[gr.OAuthProfile], oauth_token_obj: gr.OAuthToken, url: str, link_civit_checkbox_val: bool):
if not profile or not profile.username:
raise gr.Error("User profile not available. Please log in.")
if not oauth_token_obj or not oauth_token_obj.token:
raise gr.Error("Hugging Face token not available. Please log in again.")
hf_auth_token = oauth_token_obj.token
folder_uuid = str(uuid.uuid4())
base_temp_dir = "temp_uploads"
os.makedirs(base_temp_dir, exist_ok=True)
folder_path = os.path.join(base_temp_dir, folder_uuid)
os.makedirs(folder_path, exist_ok=True)
gr.Info(f"Starting processing of model {url}")
try:
info, downloaded_data = process_url(url, profile, do_download=True, folder=folder_path)
if not info or not downloaded_data:
# process_url should raise gr.Error, but this is a fallback.
raise gr.Error("Failed to process URL or download files after initial checks.")
slug_name = slugify(info["name"])
user_repo_id = f"{profile.username}/{slug_name}"
is_author = False
# Re-verify authorship just before upload, using info from processed model
civitai_creator_username_from_model = info.get('creator', 'Unknown Creator')
hf_username_on_civitai = extract_huggingface_username(civitai_creator_username_from_model)
if profile.username in TRUSTED_UPLOADERS or \
(hf_username_on_civitai and profile.username.lower() == hf_username_on_civitai.lower()):
is_author = True
create_readme(info, downloaded_data, user_repo_id, link_civit_checkbox_val, is_author=is_author, folder=folder_path)
repo_url_huggingface = f"https://huggingface.co/{user_repo_id}"
gr.Info(f"Creating/updating repository {user_repo_id} on Hugging Face...")
create_repo(repo_id=user_repo_id, private=True, exist_ok=True, token=hf_auth_token)
gr.Info(f"Starting upload to {repo_url_huggingface}...")
upload_folder(
folder_path=folder_path, repo_id=user_repo_id, repo_type="model",
token=hf_auth_token, commit_message=f"Upload LoRA: {info['name']} from CivitAI ID {info['modelId']}"
)
update_repo_visibility(repo_id=user_repo_id, private=False, token=hf_auth_token)
gr.Info(f"Model uploaded successfully!")
return f'''# Model uploaded to 🤗!
## Access it here [{user_repo_id}]({repo_url_huggingface}) '''
except Exception as e:
print(f"Error during Hugging Face repo operations for {url}: {e}\n{traceback.format_exc()}")
raise gr.Error(f"Upload failed for {url}: {str(e)}. Token might be expired. Try re-logging or check server logs.")
finally:
try:
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
except Exception as e_clean:
print(f"Error cleaning up folder {folder_path}: {e_clean}")
def bulk_upload(profile: Optional[gr.OAuthProfile], oauth_token_obj: gr.OAuthToken, urls_text: str, link_civit_checkbox_val: bool):
if not profile or not oauth_token_obj or not oauth_token_obj.token:
raise gr.Error("Authentication missing for bulk upload. Please log in.")
urls = [url.strip() for url in urls_text.splitlines() if url.strip()]
if not urls:
return "No URLs provided for bulk upload."
upload_results = []
total_urls = len(urls)
gr.Info(f"Starting bulk upload for {total_urls} models.")
for i, url in enumerate(urls):
gr.Info(f"Processing model {i+1}/{total_urls}: {url}")
try:
result_message = upload_civit_to_hf(profile, oauth_token_obj, url, link_civit_checkbox_val)
upload_results.append(result_message)
gr.Info(f"Successfully processed {url}")
except gr.Error as ge: # Catch Gradio specific errors to display them
gr.Warning(f"Skipping model {url} due to error: {str(ge)}")
upload_results.append(f"Failed to upload {url}: {str(ge)}")
except Exception as e: # Catch any other unhandled exception
gr.Warning(f"Unhandled error uploading model {url}: {str(e)}")
upload_results.append(f"Failed to upload {url}: Unhandled exception - {str(e)}")
print(f"Unhandled exception during bulk upload for {url}: {e}\n{traceback.format_exc()}")
return "\n\n---\n\n".join(upload_results) if upload_results else "No URLs were processed or all failed."
# --- Gradio UI Definition ---
css = '''
#login_button_area { margin-bottom: 10px; }
#disabled_upload_area { opacity: 0.6; pointer-events: none; }
.gr-html ul { list-style-type: disc; margin-left: 20px; }
.gr-html ol { list-style-type: decimal; margin-left: 20px; }
.gr-html a { color: #007bff; text-decoration: underline; }
.gr-html img { max-width: 100%; height: auto; margin-top: 5px; margin-bottom: 5px; border: 1px solid #ddd; }
#instructions_area { padding: 10px; border: 1px solid #eee; border-radius: 5px; margin-top: 10px; background-color: #f9f9f9; }
'''
with gr.Blocks(css=css, title="CivitAI to Hugging Face LoRA Uploader") as demo:
auth_profile_state = gr.State() # Stores the gr.OAuthProfile object
gr.Markdown('''# Upload your CivitAI LoRA to Hugging Face 🤗
By uploading your LoRAs to Hugging Face you get diffusers compatibility, a free GPU-based Inference Widget, you'll be listed in [LoRA Studio](https://lorastudio.co/models) after a short review, and get the possibility to submit your model to the [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) ✨
''')
with gr.Row(elem_id="login_button_area"):
# LoginButton updates auth_profile_state via the .then() chain on demo.load
login_button = gr.LoginButton()
with gr.Column(visible=True, elem_id="disabled_upload_area") as disabled_area:
gr.HTML("<h3>Please log in with Hugging Face to enable uploads.</h3>")
gr.Textbox(
placeholder="e.g., https://civitai.com/models/12345/my-lora or just 12345",
label="CivitAI Model URL or ID (Log in to enable)",
interactive=False
)
with gr.Column(visible=False) as enabled_area:
gr.HTML("<h3 style='color:green;'>Logged in! You can now upload models.</h3>")
with gr.Tabs():
with gr.TabItem("Single Model Upload"):
submit_source_civit_enabled = gr.Textbox(
placeholder="e.g., https://civitai.com/models/12345/my-lora or just 12345",
label="CivitAI Model URL or ID",
info="Enter the full URL or just the numeric ID of the CivitAI LoRA model page.",
)
instructions_html = gr.HTML(elem_id="instructions_area") # For feedback
try_again_button = gr.Button("I've updated my CivitAI profile (Re-check Authorship)", visible=False)
link_civit_checkbox_single = gr.Checkbox(label="Add a link back to CivitAI in the README?", value=True, visible=True)
submit_button_single_model = gr.Button("Upload This Model to Hugging Face", interactive=False, visible=False, variant="primary")
with gr.TabItem("Bulk Upload"):
civit_username_to_bulk = gr.Textbox(
label="Your CivitAI Username (Optional)",
info="Enter your CivitAI username to auto-populate the list below with your LoRAs (up to 100 newest)."
)
submit_bulk_civit_urls = gr.Textbox(
label="CivitAI Model URLs or IDs (One per line)",
info="Paste multiple CivitAI model page URLs or just IDs here, one on each line.",
lines=8,
)
link_civit_checkbox_bulk = gr.Checkbox(label="Add a link back to CivitAI in READMEs?", value=True)
bulk_upload_button = gr.Button("Start Bulk Upload", variant="primary")
output_markdown_area = gr.Markdown(label="Upload Progress & Results", visible=False)
# --- Event Handlers Wiring ---
# This demo.load is triggered by login/logout from gr.LoginButton (which is a client-side component that calls this on auth change)
# and also on initial page load (where profile will be None if not logged in via cookies).
# The first input to demo.load for LoginButton is the profile.
demo.load(
fn=handle_auth_change_and_update_state,
inputs=gr.Variable(), # This will receive the profile from LoginButton
outputs=[auth_profile_state, disabled_area, enabled_area, instructions_html, submit_source_civit_enabled, submit_button_single_model, try_again_button],
api_name=False, queue=False
)
submit_source_civit_enabled.change(
fn=check_civit_link,
inputs=[auth_profile_state, submit_source_civit_enabled],
outputs=[instructions_html, submit_button_single_model, try_again_button, submit_button_single_model], # submit_button_single_model is repeated to control both interactivity and visibility
api_name=False
)
try_again_button.click(
fn=check_civit_link,
inputs=[auth_profile_state, submit_source_civit_enabled],
outputs=[instructions_html, submit_button_single_model, try_again_button, submit_button_single_model],
api_name=False
)
civit_username_to_bulk.submit(
fn=list_civit_models,
inputs=[civit_username_to_bulk],
outputs=[submit_bulk_civit_urls],
api_name=False
)
submit_button_single_model.click(
fn=show_output_area, inputs=[], outputs=[output_markdown_area], api_name=False
).then(
fn=upload_civit_to_hf,
inputs=[auth_profile_state, gr.OAuthToken(scopes=["write_repository","read_repository"]), submit_source_civit_enabled, link_civit_checkbox_single],
outputs=[output_markdown_area],
api_name="upload_single_model"
)
bulk_upload_button.click(
fn=show_output_area, inputs=[], outputs=[output_markdown_area], api_name=False
).then(
fn=bulk_upload,
inputs=[auth_profile_state, gr.OAuthToken(scopes=["write_repository","read_repository"]), submit_bulk_civit_urls, link_civit_checkbox_bulk],
outputs=[output_markdown_area],
api_name="upload_bulk_models"
)
demo.queue(default_concurrency_limit=3, max_size=10)
if __name__ == "__main__":
# For local testing, you might need to set these environment variables:
# os.environ["COOKIE_INFO"] = "your_civitai_session_cookie_here" # For creator verification
# os.environ["CIVITAI_API_TOKEN"] = "your_civitai_api_key_here" # For potentially restricted downloads
# os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" # To make it accessible on local network
# To enable OAuth locally, you might need to set HF_HUB_DISABLE_OAUTH_CHECKMESSAGES="1"
# and ensure your HF OAuth app is configured for http://localhost:7860 or http://127.0.0.1:7860
demo.launch(debug=True, share=os.environ.get("GRADIO_SHARE") == "true")