Spaces:
Running
Running
import requests | |
import os | |
import gradio as gr | |
from huggingface_hub import update_repo_visibility, upload_folder, create_repo, upload_file | |
from slugify import slugify | |
import re | |
import uuid | |
from typing import Optional, Dict, Any, List | |
import json | |
import shutil # For cleaning up local folders | |
import traceback # For debugging | |
TRUSTED_UPLOADERS = [ | |
"KappaNeuro", "CiroN2022", "multimodalart", "Norod78", "joachimsallstrom", | |
"blink7630", "e-n-v-y", "DoctorDiffusion", "RalFinger", "artificialguybr" | |
] | |
# --- Helper Functions (CivitAI API, Data Extraction, File Handling) --- | |
def get_json_data(url: str) -> Optional[Dict[str, Any]]: | |
url_split = url.split('/') | |
if len(url_split) < 5 or not url_split[4].isdigit(): # Check if model ID is present and numeric | |
print(f"Error: Invalid CivitAI URL format or missing model ID: {url}") | |
# Try to extract model ID if it's just the ID | |
if url.isdigit(): | |
model_id = url | |
else: | |
# Check if it's a slugified URL without /models/ part | |
match = re.search(r'(\d+)(?:/[^/]+)?$', url) | |
if match: | |
model_id = match.group(1) | |
else: | |
return None | |
else: | |
model_id = url_split[4] | |
api_url = f"https://civitai.com/api/v1/models/{model_id}" | |
try: | |
response = requests.get(api_url, timeout=15) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching JSON data from {api_url}: {e}") | |
return None | |
def check_nsfw(json_data: Dict[str, Any], profile: Optional[gr.OAuthProfile]) -> bool: | |
if json_data.get("nsfw", False): | |
print(f"Model {json_data.get('id', 'Unknown')} flagged as NSFW at model level.") | |
return False | |
if profile and profile.username in TRUSTED_UPLOADERS: | |
print(f"Trusted uploader {profile.username}, bypassing strict image NSFW check for model {json_data.get('id', 'Unknown')}.") | |
return True | |
for model_version in json_data.get("modelVersions", []): | |
for image_media in model_version.get("images", []): # 'images' can contain videos | |
if image_media.get("nsfwLevel", 0) > 5: # Allow 0-5 (None, Soft, Moderate, Mature, X) | |
print(f"Model {json_data.get('id', 'Unknown')} version {model_version.get('id')} has media with nsfwLevel > 5.") | |
return False | |
return True | |
def get_prompts_from_image(image_id: int) -> (str, str): | |
url = f'https://civitai.com/api/trpc/image.getGenerationData?input={{"json":{{"id":{image_id}}}}}' | |
prompt = "" | |
negative_prompt = "" | |
try: | |
response = requests.get(url, timeout=10) | |
if response.status_code == 200: | |
data = response.json() | |
result = data.get('result', {}).get('data', {}).get('json', {}) | |
if result and result.get('meta') is not None: | |
prompt = result['meta'].get('prompt', "") | |
negative_prompt = result['meta'].get('negativePrompt', "") | |
# else: | |
# print(f"Prompt fetch for {image_id}: Status {response.status_code}") | |
except requests.exceptions.RequestException as e: | |
print(f"Error fetching prompt data for image_id {image_id}: {e}") | |
return prompt, negative_prompt | |
def extract_info(json_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: | |
if json_data.get("type") != "LORA": | |
return None | |
model_mapping = { | |
"SDXL 1.0": "stabilityai/stable-diffusion-xl-base-1.0", "SDXL 0.9": "stabilityai/stable-diffusion-xl-base-1.0", | |
"SD 1.5": "runwayml/stable-diffusion-v1-5", "SD 1.4": "CompVis/stable-diffusion-v1-4", | |
"SD 2.1": "stabilityai/stable-diffusion-2-1-base", "SD 2.0": "stabilityai/stable-diffusion-2-base", | |
"SD 2.1 768": "stabilityai/stable-diffusion-2-1", "SD 2.0 768": "stabilityai/stable-diffusion-2", | |
"SD 3": "stabilityai/stable-diffusion-3-medium-diffusers", | |
"SD 3.5": "stabilityai/stable-diffusion-3-medium", | |
"SD 3.5 Large": "stabilityai/stable-diffusion-3-medium", | |
"SD 3.5 Medium": "stabilityai/stable-diffusion-3-medium", | |
"SD 3.5 Large Turbo": "stabilityai/stable-diffusion-3-medium-turbo", | |
"Flux.1 D": "black-forest-labs/FLUX.1-dev", "Flux.1 S": "black-forest-labs/FLUX.1-schnell", | |
"LTXV": "Lightricks/LTX-Video-0.9.7-dev", | |
"Hunyuan Video": "hunyuanvideo-community/HunyuanVideo", | |
"Wan Video 1.3B t2v": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", | |
"Wan Video 14B t2v": "Wan-AI/Wan2.1-T2V-14B-Diffusers", | |
"Wan Video 14B i2v 480p": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", | |
"Wan Video 14B i2v 720p": "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", | |
"Pony": "SG161222/RealVisXL_V4.0", | |
"Illustrious": "artificialguybr/LogoRedmond", # Example, could be "stabilityai/stable-diffusion-xl-base-1.0" | |
} | |
for model_version in json_data.get("modelVersions", []): | |
civic_base_model_name = model_version.get("baseModel") | |
if civic_base_model_name in model_mapping: | |
base_model_hf_name = model_mapping[civic_base_model_name] | |
urls_to_download: List[Dict[str, Any]] = [] | |
primary_file_found = False | |
for file_data in model_version.get("files", []): | |
if file_data.get("primary") and file_data.get("type") == "Model": | |
urls_to_download.append({ | |
"url": file_data["downloadUrl"], | |
"filename": os.path.basename(file_data["name"]), | |
"type": "weightName", "is_video": False | |
}) | |
primary_file_found = True | |
break | |
if not primary_file_found: continue | |
for media_data in model_version.get("images", []): | |
if media_data.get("nsfwLevel", 0) > 5: continue | |
media_url_parts = media_data.get("url","").split("/") # Add default "" for url | |
if not media_url_parts or not media_url_parts[-1]: continue # Ensure URL and filename part exist | |
filename_part = media_url_parts[-1] | |
id_candidate = filename_part.split(".")[0].split("?")[0] | |
prompt, negative_prompt = "", "" | |
if media_data.get("hasMeta", False) and media_data.get("type") == "image": | |
if id_candidate.isdigit(): | |
try: | |
prompt, negative_prompt = get_prompts_from_image(int(id_candidate)) | |
except ValueError: | |
print(f"Warning: Non-integer ID '{id_candidate}' for prompt fetching.") | |
except Exception as e: | |
print(f"Warning: Prompt fetch failed for ID {id_candidate}: {e}") | |
is_video_file = media_data.get("type") == "video" | |
media_type_key = "videoName" if is_video_file else "imageName" | |
urls_to_download.append({ | |
"url": media_data["url"], "filename": os.path.basename(filename_part), | |
"type": media_type_key, "prompt": prompt, "negative_prompt": negative_prompt, | |
"is_video": is_video_file | |
}) | |
allow_commercial_use_raw = json_data.get("allowCommercialUse", "Sell") | |
if isinstance(allow_commercial_use_raw, list): | |
allow_commercial_use_processed = allow_commercial_use_raw[0] if allow_commercial_use_raw else "Sell" | |
elif isinstance(allow_commercial_use_raw, bool): | |
allow_commercial_use_processed = "Sell" if allow_commercial_use_raw else "None" | |
elif isinstance(allow_commercial_use_raw, str): | |
allow_commercial_use_processed = allow_commercial_use_raw | |
else: # Fallback for unexpected types | |
allow_commercial_use_processed = "Sell" | |
info_dict = { | |
"urls_to_download": urls_to_download, "id": model_version.get("id"), | |
"baseModel": base_model_hf_name, "modelId": model_version.get("modelId", json_data.get("id")), | |
"name": json_data.get("name", "Untitled LoRA"), | |
"description": json_data.get("description", "No description provided."), | |
"trainedWords": model_version.get("trainedWords", []), | |
"creator": json_data.get("creator", {}).get("username", "Unknown Creator"), | |
"tags": json_data.get("tags", []), | |
"allowNoCredit": json_data.get("allowNoCredit", True), | |
"allowCommercialUse": allow_commercial_use_processed, | |
"allowDerivatives": json_data.get("allowDerivatives", True), | |
"allowDifferentLicense": json_data.get("allowDifferentLicense", True) | |
} | |
return info_dict | |
return None | |
def download_file_from_url(url: str, filename: str, folder: str = "."): | |
headers = {} | |
local_filepath = os.path.join(folder, filename) | |
try: | |
headers['User-Agent'] = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
civitai_token = os.environ.get("CIVITAI_API_TOKEN") | |
if civitai_token: | |
headers['Authorization'] = f'Bearer {civitai_token}' | |
response = requests.get(url, headers=headers, stream=True, timeout=120) | |
response.raise_for_status() | |
with open(local_filepath, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
except requests.exceptions.HTTPError as e_http: | |
if e_http.response.status_code in [401, 403] and not headers.get('Authorization') and not civitai_token: | |
print(f"Authorization error (401/403) downloading {url}. Consider setting CIVITAI_API_TOKEN for restricted files.") | |
raise gr.Error(f"HTTP Error downloading {filename}: {e_http.response.status_code} {e_http.response.reason}. URL: {url}") | |
except requests.exceptions.RequestException as e_req: | |
raise gr.Error(f"Request Error downloading {filename}: {e_req}. URL: {url}") | |
def download_files(info: Dict[str, Any], folder: str = ".") -> Dict[str, List[Any]]: | |
downloaded_media_items: List[Dict[str, Any]] = [] | |
downloaded_weights: List[str] = [] | |
for item in info["urls_to_download"]: | |
filename_to_save_raw = item["filename"] | |
filename_to_save = re.sub(r'[<>:"/\\|?*]', '_', filename_to_save_raw) | |
if not filename_to_save: | |
base, ext = os.path.splitext(item["url"]) | |
filename_to_save = f"downloaded_file_{uuid.uuid4().hex[:8]}{ext if ext else '.bin'}" | |
gr.Info(f"Downloading {filename_to_save}...") | |
download_file_from_url(item["url"], filename_to_save, folder) | |
if item["type"] == "weightName": | |
downloaded_weights.append(filename_to_save) | |
elif item["type"] in ["imageName", "videoName"]: | |
prompt_clean = re.sub(r'<.*?>', '', item.get("prompt", "")) | |
negative_prompt_clean = re.sub(r'<.*?>', '', item.get("negative_prompt", "")) | |
downloaded_media_items.append({ | |
"filename": filename_to_save, "prompt": prompt_clean, | |
"negative_prompt": negative_prompt_clean, "is_video": item.get("is_video", False) | |
}) | |
return {"media_items": downloaded_media_items, "weightName": downloaded_weights} | |
def process_url(url: str, profile: Optional[gr.OAuthProfile], do_download: bool = True, folder: str = ".") -> (Optional[Dict[str, Any]], Optional[Dict[str, List[Any]]]): | |
json_data = get_json_data(url) | |
if json_data: | |
if check_nsfw(json_data, profile): | |
info = extract_info(json_data) | |
if info: | |
downloaded_files_dict = None | |
if do_download: | |
downloaded_files_dict = download_files(info, folder) | |
return info, downloaded_files_dict | |
else: | |
model_type = json_data.get("type", "Unknown type") | |
base_models_in_json = [mv.get("baseModel", "Unknown base") for mv in json_data.get("modelVersions", [])] | |
error_message = f"This LoRA is not supported. Details:\n" | |
error_message += f"- Model Type: {model_type} (expected LORA)\n" | |
if base_models_in_json: | |
error_message += f"- Detected Base Models in CivitAI: {', '.join(list(set(base_models_in_json)))}\n" | |
error_message += "Ensure it's a LORA for a supported base (SD, SDXL, Pony, Flux, LTXV, Hunyuan, Wan) and has primary files." | |
raise gr.Error(error_message) | |
else: | |
raise gr.Error("This model is flagged as NSFW by CivitAI or its media exceeds the allowed NSFW level (max level 5).") | |
else: | |
raise gr.Error("Could not fetch CivitAI API data. Check URL or model ID. Example: https://civitai.com/models/12345 or just 12345") | |
# --- README Creation --- | |
def create_readme(info: Dict[str, Any], downloaded_files: Dict[str, List[Any]], user_repo_id: str, link_civit: bool = False, is_author: bool = True, folder: str = "."): | |
original_url = f"https://civitai.com/models/{info['modelId']}" | |
link_civit_disclaimer = f'([CivitAI]({original_url}))' | |
non_author_disclaimer = f'This model was originally uploaded on [CivitAI]({original_url}), by [{info["creator"]}](https://civitai.com/user/{info["creator"]}/models). The information below was provided by the author on CivitAI:' | |
is_video_model = False | |
video_base_models_hf = [ | |
"Lightricks/LTX-Video-0.9.7-dev", "hunyuanvideo-community/HunyuanVideo", | |
"hunyuanvideo-community/HunyuanVideo-I2V", "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", | |
"Wan-AI/Wan2.1-T2V-14B-Diffusers", "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers", | |
"Wan-AI/Wan2.1-I2V-14B-720P-Diffusers" | |
] | |
if info["baseModel"] in video_base_models_hf: is_video_model = True | |
is_i2v_model = "i2v" in info["baseModel"].lower() | |
default_tags = ["lora", "diffusers", "migrated"] | |
if is_video_model: | |
default_tags.append("video") | |
default_tags.append("image-to-video" if is_i2v_model else "text-to-video") | |
default_tags.append("template:video-lora") | |
else: | |
default_tags.extend(["text-to-image", "stable-diffusion", "template:sd-lora"]) | |
civit_tags_raw = info.get("tags", []) | |
civit_tags_processed = [] | |
if isinstance(civit_tags_raw, list): | |
civit_tags_processed = [str(t).replace(":", "").strip() for t in civit_tags_raw if str(t).replace(":", "").strip() and str(t).replace(":", "").strip() not in default_tags] | |
tags = default_tags + civit_tags_processed | |
unpacked_tags = "\n- ".join(sorted(list(set(tags)))) | |
trained_words = [word for word in info.get('trainedWords', []) if word] | |
formatted_words = ', '.join(f'`{word}`' for word in trained_words) | |
trigger_words_section = f"## Trigger words\nYou should use {formatted_words} to trigger the generation." if formatted_words else "" | |
widget_content = "" | |
media_items_for_widget = downloaded_files.get("media_items", []) | |
if not media_items_for_widget: | |
widget_content = "# No example media available for widget.\n" | |
else: | |
for media_item in media_items_for_widget[:5]: | |
prompt_text = media_item["prompt"] | |
negative_prompt_text = media_item["negative_prompt"] | |
filename = media_item["filename"] | |
escaped_prompt = prompt_text.replace("'", "''").replace("\n", " ") | |
negative_prompt_cleaned_and_escaped = "" | |
if negative_prompt_text: | |
negative_prompt_cleaned_and_escaped = negative_prompt_text.replace("'", "''").replace("\n", " ") # Correct | |
negative_prompt_widget_entry = "" | |
if negative_prompt_cleaned_and_escaped: # Only add if non-empty | |
negative_prompt_widget_entry = f"""parameters: | |
negative_prompt: '{negative_prompt_cleaned_and_escaped}'""" | |
widget_content += f"""- text: '{escaped_prompt if escaped_prompt else ' ' }' | |
{negative_prompt_widget_entry} | |
output: | |
url: >- | |
{filename} | |
""" | |
flux_models_bf16 = ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"] | |
dtype = "torch.bfloat16" if info["baseModel"] in flux_models_bf16 else "torch.float16" | |
pipeline_import = "AutoPipelineForText2Image" | |
example_prompt_for_pipeline = formatted_words if formatted_words else 'Your custom prompt' | |
if media_items_for_widget and media_items_for_widget[0]["prompt"]: | |
example_prompt_for_pipeline = media_items_for_widget[0]["prompt"] | |
cleaned_example_pipeline_prompt = example_prompt_for_pipeline.replace("'", "\\'").replace("\n", " ") | |
pipeline_call_example = f"image = pipeline('{cleaned_example_pipeline_prompt}').images[0]" | |
if is_video_model: | |
pipeline_import = "DiffusionPipeline" | |
video_prompt_example = cleaned_example_pipeline_prompt | |
pipeline_call_example = f"# Example prompt for video generation\nprompt = \"{video_prompt_example}\"\n" | |
pipeline_call_example += "# Adjust parameters like num_frames, num_inference_steps, height, width as needed for the specific pipeline.\n" | |
pipeline_call_example += "# video_frames = pipeline(prompt, num_frames=16, guidance_scale=7.5, num_inference_steps=25).frames # Example parameters" | |
if "LTX-Video" in info["baseModel"]: | |
pipeline_call_example += "\n# LTX-Video uses a specific setup. Check its model card on Hugging Face." | |
elif "HunyuanVideo" in info["baseModel"]: | |
pipeline_call_example += "\n# HunyuanVideo often uses custom pipeline scripts or specific classes (e.g., HunyuanDiTPipeline). Check its HF model card." | |
elif "Wan-AI" in info["baseModel"]: | |
pipeline_call_example += "\n# Wan-AI models (e.g., WanVideoTextToVideoPipeline) require specific pipeline classes. Check model card for usage." | |
weight_name = (downloaded_files["weightName"][0] if downloaded_files.get("weightName") | |
else "your_lora_weights.safetensors") | |
diffusers_code_block = f"""```py | |
from diffusers import {pipeline_import} | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Note: The pipeline class '{pipeline_import}' is a general suggestion. | |
# For specific video models (LTX, Hunyuan, Wan), you will likely need a dedicated pipeline class | |
# (e.g., TextToVideoSDPipeline, HunyuanDiTPipeline, WanVideoTextToVideoPipeline, etc.). | |
# Please refer to the documentation of the base model '{info["baseModel"]}' on Hugging Face for precise usage. | |
pipeline = {pipeline_import}.from_pretrained('{info["baseModel"]}', torch_dtype={dtype}) | |
pipeline.to(device) | |
# Load LoRA weights | |
pipeline.load_lora_weights('{user_repo_id}', weight_name='{weight_name}') | |
# For some pipelines, you might need to fuse LoRA layers before inference | |
# and unfuse them after, or apply scaling. Check model card. | |
# Example: pipeline.fuse_lora() or pipeline.set_adapters(["default"], adapter_weights=[0.8]) | |
# Example generation call (adjust parameters as needed for the specific pipeline) | |
{pipeline_call_example} | |
# If using fused LoRA: | |
# pipeline.unfuse_lora() | |
```""" | |
commercial_use_val = info["allowCommercialUse"] | |
content = f"""--- | |
license: other | |
license_name: bespoke-lora-trained-license | |
license_link: https://multimodal.art/civitai-licenses?allowNoCredit={info["allowNoCredit"]}&allowCommercialUse={commercial_use_val}&allowDerivatives={info["allowDerivatives"]}&allowDifferentLicense={info["allowDifferentLicense"]} | |
tags: | |
- {unpacked_tags} | |
base_model: {info["baseModel"]} | |
instance_prompt: {trained_words[0] if trained_words else ''} | |
widget: | |
{widget_content} | |
--- | |
# {info["name"]} | |
<Gallery /> | |
{non_author_disclaimer if not is_author else ''} | |
{link_civit_disclaimer if link_civit else ''} | |
## Model description | |
{info["description"]} | |
{trigger_words_section} | |
## Download model | |
Weights for this model are available in Safetensors format. | |
[Download](/{user_repo_id}/tree/main/{weight_name}) the LoRA in the Files & versions tab. | |
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) | |
{diffusers_code_block} | |
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters). | |
""" | |
readme_path = os.path.join(folder, "README.md") | |
with open(readme_path, "w", encoding="utf-8") as file: | |
file.write(content) | |
# --- Hugging Face Profile / Authorship --- | |
def get_creator(username: str) -> Dict: | |
if "COOKIE_INFO" not in os.environ or not os.environ["COOKIE_INFO"]: | |
print("Warning: COOKIE_INFO env var not set. Cannot fetch CivitAI creator's HF username.") | |
return {"result": {"data": {"json": {"links": []}}}} | |
url = f"https://civitai.com/api/trpc/user.getCreator?input=%7B%22json%22%3A%7B%22username%22%3A%22{username}%22%2C%22authed%22%3Atrue%7D%7D" | |
headers = { | |
"authority": "civitai.com", "accept": "*/*", "accept-language": "en-US,en;q=0.9", | |
"content-type": "application/json", "cookie": os.environ["COOKIE_INFO"], | |
"referer": f"https://civitai.com/user/{username}/models", | |
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/100.0.0.0 Safari/537.36" | |
} | |
try: | |
response = requests.get(url, headers=headers, timeout=10) | |
response.raise_for_status() | |
return response.json() | |
except requests.RequestException as e: | |
print(f"Error fetching CivitAI creator data for {username}: {e}") | |
return {"result": {"data": {"json": {"links": []}}}} | |
def extract_huggingface_username(civitai_username: str) -> Optional[str]: | |
data = get_creator(civitai_username) | |
try: | |
links = data.get('result', {}).get('data', {}).get('json', {}).get('links', []) | |
if not isinstance(links, list): return None | |
for link in links: | |
if not isinstance(link, dict): continue | |
url = link.get('url', '') | |
if isinstance(url, str) and \ | |
(url.startswith('https://huggingface.co/') or url.startswith('https://www.huggingface.co/')): | |
hf_username = url.split('/')[-1].split('?')[0].split('#')[0] | |
if hf_username: return hf_username | |
except Exception as e: | |
print(f"Error parsing CivitAI creator data for HF username: {e}") | |
return None | |
# --- Gradio UI Logic Functions --- | |
def check_civit_link(profile_state: Optional[gr.OAuthProfile], url_input: str): | |
url_input = url_input.strip() | |
if not url_input: | |
return "", gr.update(interactive=False, visible=False), gr.update(visible=False), gr.update(visible=False) | |
if not profile_state: | |
return "Please log in with Hugging Face first.", gr.update(interactive=False, visible=False), gr.update(visible=False), gr.update(visible=False) | |
try: | |
info, _ = process_url(url_input, profile_state, do_download=False) | |
if not info: # Should be caught by process_url, but as a safeguard | |
return "Could not process this CivitAI URL. Model might be unsupported or invalid.", gr.update(interactive=False, visible=True), gr.update(visible=False), gr.update(visible=False) | |
except gr.Error as e: # Catch errors from process_url (like NSFW, unsupported, API fetch failed) | |
return str(e), gr.update(interactive=False, visible=True), gr.update(visible=False), gr.update(visible=False) | |
except Exception as e: # Catch any other unexpected error during processing check | |
print(f"Unexpected error in check_civit_link during process_url: {e}\n{traceback.format_exc()}") | |
return f"An unexpected error occurred: {str(e)}", gr.update(interactive=False, visible=True), gr.update(visible=False), gr.update(visible=False) | |
# If model is processable, then check authorship | |
civitai_creator_username = info['creator'] | |
hf_username_on_civitai = extract_huggingface_username(civitai_creator_username) | |
if profile_state.username in TRUSTED_UPLOADERS: | |
return f'Welcome, trusted uploader {profile_state.username}! You can upload this model by "{civitai_creator_username}".', gr.update(interactive=True, visible=True), gr.update(visible=False), gr.update(visible=True) | |
if not hf_username_on_civitai: | |
no_username_text = ( | |
f'If you are "{civitai_creator_username}" on CivitAI, hi! Your CivitAI profile does not seem to have a Hugging Face username linked. ' | |
f'Please visit <a href="https://civitai.com/user/account" target="_blank">your CivitAI account settings</a> and add your 🤗 username ({profile_state.username}). ' | |
f'Example: <br/><img width="60%" src="https://i.imgur.com/hCbo9uL.png" alt="CivitAI profile settings example"/><br/>' | |
f'(If you are not "{civitai_creator_username}", you cannot submit their model at this time.)' | |
) | |
return no_username_text, gr.update(interactive=False, visible=False), gr.update(visible=True), gr.update(visible=False) | |
if profile_state.username.lower() != hf_username_on_civitai.lower(): | |
unmatched_username_text = ( | |
f'The Hugging Face username on "{civitai_creator_username}"\'s CivitAI profile ("{hf_username_on_civitai}") ' | |
f'does not match your logged-in Hugging Face account ("{profile_state.username}"). ' | |
f'Please update it on <a href="https://civitai.com/user/account" target="_blank">CivitAI</a> or log in to Hugging Face as "{hf_username_on_civitai}".<br/>' | |
f'<img src="https://i.imgur.com/hCbo9uL.png" alt="CivitAI profile settings example"/>' | |
) | |
return unmatched_username_text, gr.update(interactive=False, visible=False), gr.update(visible=True), gr.update(visible=False) | |
return f'Authorship verified for "{civitai_creator_username}" (🤗 {profile_state.username}). Ready to upload!', gr.update(interactive=True, visible=True), gr.update(visible=False), gr.update(visible=True) | |
def handle_auth_change_and_update_state(profile: Optional[gr.OAuthProfile]): | |
# This function now returns the profile to update the state | |
if profile: # Logged in | |
return profile, gr.update(visible=False), gr.update(visible=True), "", gr.update(value=""), gr.update(interactive=False, visible=False), gr.update(visible=False) | |
else: # Logged out | |
return None, gr.update(visible=True), gr.update(visible=False), "", gr.update(value=""), gr.update(interactive=False, visible=False), gr.update(visible=False) | |
def show_output_area(): | |
return gr.update(visible=True) | |
def list_civit_models(username: str) -> str: | |
if not username.strip(): return "" | |
url = f"https://civitai.com/api/v1/models?username={username}&limit=100&sort=Newest" # Max limit is 100 per page on CivitAI | |
json_models_list = [] | |
page_count, max_pages = 0, 1 # Limit to 1 page (100 models) for now to be quicker, can be increased | |
gr.Info(f"Fetching LoRAs for CivitAI user: {username}...") | |
while url and page_count < max_pages: | |
try: | |
response = requests.get(url, timeout=15) # Increased timeout | |
response.raise_for_status() | |
data = response.json() | |
current_items = data.get('items', []) | |
json_models_list.extend(item for item in current_items if item.get("type") == "LORA" and item.get("name")) | |
metadata = data.get('metadata', {}) | |
url = metadata.get('nextPage', None) | |
page_count += 1 | |
except requests.RequestException as e: | |
gr.Warning(f"Failed to fetch page {page_count + 1} for {username}: {e}") | |
break | |
if not json_models_list: | |
gr.Info(f"No suitable LoRA models found for {username} or failed to fetch.") | |
return "" | |
urls_text = "\n".join( | |
f'https://civitai.com/models/{model["id"]}/{slugify(model["name"])}' | |
for model in json_models_list | |
) | |
gr.Info(f"Found {len(json_models_list)} LoRA models for {username}.") | |
return urls_text.strip() | |
# --- Main Upload Functions --- | |
def upload_civit_to_hf(profile: Optional[gr.OAuthProfile], oauth_token_obj: gr.OAuthToken, url: str, link_civit_checkbox_val: bool): | |
if not profile or not profile.username: | |
raise gr.Error("User profile not available. Please log in.") | |
if not oauth_token_obj or not oauth_token_obj.token: | |
raise gr.Error("Hugging Face token not available. Please log in again.") | |
hf_auth_token = oauth_token_obj.token | |
folder_uuid = str(uuid.uuid4()) | |
base_temp_dir = "temp_uploads" | |
os.makedirs(base_temp_dir, exist_ok=True) | |
folder_path = os.path.join(base_temp_dir, folder_uuid) | |
os.makedirs(folder_path, exist_ok=True) | |
gr.Info(f"Starting processing of model {url}") | |
try: | |
info, downloaded_data = process_url(url, profile, do_download=True, folder=folder_path) | |
if not info or not downloaded_data: | |
# process_url should raise gr.Error, but this is a fallback. | |
raise gr.Error("Failed to process URL or download files after initial checks.") | |
slug_name = slugify(info["name"]) | |
user_repo_id = f"{profile.username}/{slug_name}" | |
is_author = False | |
# Re-verify authorship just before upload, using info from processed model | |
civitai_creator_username_from_model = info.get('creator', 'Unknown Creator') | |
hf_username_on_civitai = extract_huggingface_username(civitai_creator_username_from_model) | |
if profile.username in TRUSTED_UPLOADERS or \ | |
(hf_username_on_civitai and profile.username.lower() == hf_username_on_civitai.lower()): | |
is_author = True | |
create_readme(info, downloaded_data, user_repo_id, link_civit_checkbox_val, is_author=is_author, folder=folder_path) | |
repo_url_huggingface = f"https://huggingface.co/{user_repo_id}" | |
gr.Info(f"Creating/updating repository {user_repo_id} on Hugging Face...") | |
create_repo(repo_id=user_repo_id, private=True, exist_ok=True, token=hf_auth_token) | |
gr.Info(f"Starting upload to {repo_url_huggingface}...") | |
upload_folder( | |
folder_path=folder_path, repo_id=user_repo_id, repo_type="model", | |
token=hf_auth_token, commit_message=f"Upload LoRA: {info['name']} from CivitAI ID {info['modelId']}" | |
) | |
update_repo_visibility(repo_id=user_repo_id, private=False, token=hf_auth_token) | |
gr.Info(f"Model uploaded successfully!") | |
return f'''# Model uploaded to 🤗! | |
## Access it here [{user_repo_id}]({repo_url_huggingface}) ''' | |
except Exception as e: | |
print(f"Error during Hugging Face repo operations for {url}: {e}\n{traceback.format_exc()}") | |
raise gr.Error(f"Upload failed for {url}: {str(e)}. Token might be expired. Try re-logging or check server logs.") | |
finally: | |
try: | |
if os.path.exists(folder_path): | |
shutil.rmtree(folder_path) | |
except Exception as e_clean: | |
print(f"Error cleaning up folder {folder_path}: {e_clean}") | |
def bulk_upload(profile: Optional[gr.OAuthProfile], oauth_token_obj: gr.OAuthToken, urls_text: str, link_civit_checkbox_val: bool): | |
if not profile or not oauth_token_obj or not oauth_token_obj.token: | |
raise gr.Error("Authentication missing for bulk upload. Please log in.") | |
urls = [url.strip() for url in urls_text.splitlines() if url.strip()] | |
if not urls: | |
return "No URLs provided for bulk upload." | |
upload_results = [] | |
total_urls = len(urls) | |
gr.Info(f"Starting bulk upload for {total_urls} models.") | |
for i, url in enumerate(urls): | |
gr.Info(f"Processing model {i+1}/{total_urls}: {url}") | |
try: | |
result_message = upload_civit_to_hf(profile, oauth_token_obj, url, link_civit_checkbox_val) | |
upload_results.append(result_message) | |
gr.Info(f"Successfully processed {url}") | |
except gr.Error as ge: # Catch Gradio specific errors to display them | |
gr.Warning(f"Skipping model {url} due to error: {str(ge)}") | |
upload_results.append(f"Failed to upload {url}: {str(ge)}") | |
except Exception as e: # Catch any other unhandled exception | |
gr.Warning(f"Unhandled error uploading model {url}: {str(e)}") | |
upload_results.append(f"Failed to upload {url}: Unhandled exception - {str(e)}") | |
print(f"Unhandled exception during bulk upload for {url}: {e}\n{traceback.format_exc()}") | |
return "\n\n---\n\n".join(upload_results) if upload_results else "No URLs were processed or all failed." | |
# --- Gradio UI Definition --- | |
css = ''' | |
#login_button_area { margin-bottom: 10px; } | |
#disabled_upload_area { opacity: 0.6; pointer-events: none; } | |
.gr-html ul { list-style-type: disc; margin-left: 20px; } | |
.gr-html ol { list-style-type: decimal; margin-left: 20px; } | |
.gr-html a { color: #007bff; text-decoration: underline; } | |
.gr-html img { max-width: 100%; height: auto; margin-top: 5px; margin-bottom: 5px; border: 1px solid #ddd; } | |
#instructions_area { padding: 10px; border: 1px solid #eee; border-radius: 5px; margin-top: 10px; background-color: #f9f9f9; } | |
''' | |
with gr.Blocks(css=css, title="CivitAI to Hugging Face LoRA Uploader") as demo: | |
auth_profile_state = gr.State() # Stores the gr.OAuthProfile object | |
gr.Markdown('''# Upload your CivitAI LoRA to Hugging Face 🤗 | |
By uploading your LoRAs to Hugging Face you get diffusers compatibility, a free GPU-based Inference Widget, you'll be listed in [LoRA Studio](https://lorastudio.co/models) after a short review, and get the possibility to submit your model to the [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) ✨ | |
''') | |
with gr.Row(elem_id="login_button_area"): | |
# LoginButton updates auth_profile_state via the .then() chain on demo.load | |
login_button = gr.LoginButton() | |
with gr.Column(visible=True, elem_id="disabled_upload_area") as disabled_area: | |
gr.HTML("<h3>Please log in with Hugging Face to enable uploads.</h3>") | |
gr.Textbox( | |
placeholder="e.g., https://civitai.com/models/12345/my-lora or just 12345", | |
label="CivitAI Model URL or ID (Log in to enable)", | |
interactive=False | |
) | |
with gr.Column(visible=False) as enabled_area: | |
gr.HTML("<h3 style='color:green;'>Logged in! You can now upload models.</h3>") | |
with gr.Tabs(): | |
with gr.TabItem("Single Model Upload"): | |
submit_source_civit_enabled = gr.Textbox( | |
placeholder="e.g., https://civitai.com/models/12345/my-lora or just 12345", | |
label="CivitAI Model URL or ID", | |
info="Enter the full URL or just the numeric ID of the CivitAI LoRA model page.", | |
) | |
instructions_html = gr.HTML(elem_id="instructions_area") # For feedback | |
try_again_button = gr.Button("I've updated my CivitAI profile (Re-check Authorship)", visible=False) | |
link_civit_checkbox_single = gr.Checkbox(label="Add a link back to CivitAI in the README?", value=True, visible=True) | |
submit_button_single_model = gr.Button("Upload This Model to Hugging Face", interactive=False, visible=False, variant="primary") | |
with gr.TabItem("Bulk Upload"): | |
civit_username_to_bulk = gr.Textbox( | |
label="Your CivitAI Username (Optional)", | |
info="Enter your CivitAI username to auto-populate the list below with your LoRAs (up to 100 newest)." | |
) | |
submit_bulk_civit_urls = gr.Textbox( | |
label="CivitAI Model URLs or IDs (One per line)", | |
info="Paste multiple CivitAI model page URLs or just IDs here, one on each line.", | |
lines=8, | |
) | |
link_civit_checkbox_bulk = gr.Checkbox(label="Add a link back to CivitAI in READMEs?", value=True) | |
bulk_upload_button = gr.Button("Start Bulk Upload", variant="primary") | |
output_markdown_area = gr.Markdown(label="Upload Progress & Results", visible=False) | |
# --- Event Handlers Wiring --- | |
# This demo.load is triggered by login/logout from gr.LoginButton (which is a client-side component that calls this on auth change) | |
# and also on initial page load (where profile will be None if not logged in via cookies). | |
# The first input to demo.load for LoginButton is the profile. | |
demo.load( | |
fn=handle_auth_change_and_update_state, | |
inputs=gr.Variable(), # This will receive the profile from LoginButton | |
outputs=[auth_profile_state, disabled_area, enabled_area, instructions_html, submit_source_civit_enabled, submit_button_single_model, try_again_button], | |
api_name=False, queue=False | |
) | |
submit_source_civit_enabled.change( | |
fn=check_civit_link, | |
inputs=[auth_profile_state, submit_source_civit_enabled], | |
outputs=[instructions_html, submit_button_single_model, try_again_button, submit_button_single_model], # submit_button_single_model is repeated to control both interactivity and visibility | |
api_name=False | |
) | |
try_again_button.click( | |
fn=check_civit_link, | |
inputs=[auth_profile_state, submit_source_civit_enabled], | |
outputs=[instructions_html, submit_button_single_model, try_again_button, submit_button_single_model], | |
api_name=False | |
) | |
civit_username_to_bulk.submit( | |
fn=list_civit_models, | |
inputs=[civit_username_to_bulk], | |
outputs=[submit_bulk_civit_urls], | |
api_name=False | |
) | |
submit_button_single_model.click( | |
fn=show_output_area, inputs=[], outputs=[output_markdown_area], api_name=False | |
).then( | |
fn=upload_civit_to_hf, | |
inputs=[auth_profile_state, gr.OAuthToken(scopes=["write_repository","read_repository"]), submit_source_civit_enabled, link_civit_checkbox_single], | |
outputs=[output_markdown_area], | |
api_name="upload_single_model" | |
) | |
bulk_upload_button.click( | |
fn=show_output_area, inputs=[], outputs=[output_markdown_area], api_name=False | |
).then( | |
fn=bulk_upload, | |
inputs=[auth_profile_state, gr.OAuthToken(scopes=["write_repository","read_repository"]), submit_bulk_civit_urls, link_civit_checkbox_bulk], | |
outputs=[output_markdown_area], | |
api_name="upload_bulk_models" | |
) | |
demo.queue(default_concurrency_limit=3, max_size=10) | |
if __name__ == "__main__": | |
# For local testing, you might need to set these environment variables: | |
# os.environ["COOKIE_INFO"] = "your_civitai_session_cookie_here" # For creator verification | |
# os.environ["CIVITAI_API_TOKEN"] = "your_civitai_api_key_here" # For potentially restricted downloads | |
# os.environ["GRADIO_SERVER_NAME"] = "0.0.0.0" # To make it accessible on local network | |
# To enable OAuth locally, you might need to set HF_HUB_DISABLE_OAUTH_CHECKMESSAGES="1" | |
# and ensure your HF OAuth app is configured for http://localhost:7860 or http://127.0.0.1:7860 | |
demo.launch(debug=True, share=os.environ.get("GRADIO_SHARE") == "true") |