Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import torch | |
import gdown | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
from PIL import Image | |
# --- Download from Google Drive --- | |
def download_from_gdrive(file_id, dest_path): | |
if not os.path.exists(dest_path): | |
print(f"π₯ Downloading {dest_path}...") | |
gdown.download(f"https://drive.google.com/uc?id={file_id}", dest_path, quiet=False) | |
print(f"β Downloaded {dest_path} from Google Drive") | |
else: | |
print(f"β {dest_path} already exists") | |
# Download base model and LoRA | |
print("π Starting BitKun LoRA Generator...") | |
ckpt_id = "1OdP2SDB6MyR6JDK6_ekyqFHz91rkT2ZK" # Your base model ID | |
download_from_gdrive(ckpt_id, "AnyLoRA_noVae_fp16-pruned.ckpt") | |
# βοΈ Configuration | |
project_name = "bitkun" | |
epoch_number = 10 | |
lora_path = f"{project_name}-{epoch_number:02d}.safetensors" | |
base_model_path = "AnyLoRA_noVae_fp16-pruned.ckpt" | |
output_folder = "generated_images" | |
os.makedirs(output_folder, exist_ok=True) | |
# Auto-detect device and set appropriate dtype | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
print(f"π§ Using device: {device} with dtype: {dtype}") | |
# π Load Base Model with optimizations | |
print("π¦ Loading base model...") | |
try: | |
pipe = StableDiffusionPipeline.from_single_file( | |
base_model_path, | |
torch_dtype=dtype, | |
safety_checker=None, | |
requires_safety_checker=False, | |
use_safetensors=True | |
).to(device) | |
# β© Use faster scheduler | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
pipe.scheduler.config, | |
use_karras_sigmas=True, | |
algorithm_type="dpmsolver++" | |
) | |
# Enable memory efficient attention if available | |
if hasattr(pipe, "enable_attention_slicing"): | |
pipe.enable_attention_slicing() | |
# Enable memory efficient attention | |
if hasattr(pipe, "enable_xformers_memory_efficient_attention"): | |
try: | |
pipe.enable_xformers_memory_efficient_attention() | |
print("β XFormers memory efficient attention enabled") | |
except: | |
print("β οΈ XFormers not available, using standard attention") | |
print("β Base model loaded successfully!") | |
except Exception as e: | |
print(f"β Error loading base model: {e}") | |
raise e | |
# π Load LoRA | |
lora_loaded = False | |
try: | |
if os.path.exists(lora_path): | |
pipe.load_lora_weights(lora_path, adapter_name="default") | |
pipe.set_adapters(["default"], adapter_weights=[0.8]) | |
lora_loaded = True | |
print("β LoRA loaded and pipeline ready!") | |
else: | |
print(f"β οΈ LoRA file not found: {lora_path}") | |
print("Pipeline will run with base model only.") | |
except Exception as e: | |
print(f"β οΈ Could not load LoRA weights: {e}") | |
print("Pipeline will run with base model only.") | |
# π¨ Optimized Generation Function | |
def generate_bitkun(prompt, negative_prompt, num_images, steps, guidance_scale, width, height): | |
if not prompt.strip(): | |
return [], "β οΈ Please enter a prompt!" | |
# Add bitkun to prompt if not present | |
if "bitkun" not in prompt.lower(): | |
prompt = f"bitkun, {prompt}" | |
seed = 42 | |
images = [] | |
for i in range(num_images): | |
try: | |
print(f"π¨ Generating image {i + 1}/{num_images}...") | |
generator = torch.Generator(device=device).manual_seed(seed + i) | |
# Use autocast only for CUDA | |
if device == "cuda": | |
with torch.autocast("cuda"): | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
generator=generator | |
) | |
else: | |
result = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
generator=generator | |
) | |
image = result.images[0] | |
# Save image | |
filename = f"{project_name}_custom_{i + 1}.png" | |
filepath = os.path.join(output_folder, filename) | |
image.save(filepath) | |
images.append(image) | |
except Exception as e: | |
error_msg = f"β Error generating image {i + 1}: {str(e)}" | |
print(error_msg) | |
continue | |
if not images: | |
final_status = "β Failed to generate any images. Please try again with different settings." | |
else: | |
lora_status = "with LoRA" if lora_loaded else "without LoRA" | |
final_status = f"π Successfully generated {len(images)}/{num_images} image(s) {lora_status}!" | |
return images, final_status | |
# π Gradio UI with Speed Presets | |
with gr.Blocks(title="BitKun LoRA Generator", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# πΆ BitKun LoRA Generator π¨ (FAST VERSION) | |
### θͺη±γͺγγγ³γγε―ΎεΏ / Custom Prompt Support | |
""") | |
# Show LoRA status | |
lora_status_text = "β LoRA loaded successfully!" if lora_loaded else "β οΈ Running with base model only (LoRA not found)" | |
gr.Markdown(f"**Status:** {lora_status_text}") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### π― Generation Settings") | |
custom_prompt = gr.Textbox( | |
label="π γγγ³γγ / Prompt", | |
placeholder="δΎ: happy, smiling, cartoon style, colorful background", | |
lines=3, | |
info="'bitkun' will be automatically added if not included" | |
) | |
negative_prompt = gr.Textbox( | |
label="π« γγ¬γγ£γγγγ³γγ / Negative Prompt", | |
value="realistic, human skin, photo, blurry, distorted, extra limbs, bad anatomy", | |
lines=2 | |
) | |
# Speed Presets | |
gr.Markdown("### β‘ Speed Presets") | |
with gr.Row(): | |
speed_preset = gr.Radio( | |
choices=[ | |
("π Ultra Fast (10 steps, 256x256)", "ultra_fast"), | |
("β‘ Fast (15 steps, 384x384)", "fast"), | |
("π― Balanced (20 steps, 512x512)", "balanced"), | |
("π¨ Quality (25 steps, 512x512)", "quality"), | |
("π§ Custom", "custom") | |
], | |
value="fast", | |
label="Choose Speed vs Quality" | |
) | |
with gr.Row(): | |
num_images = gr.Slider( | |
label="πΌοΈ Number of Images", | |
minimum=1, | |
maximum=3, | |
value=1, | |
step=1, | |
info="More images = longer processing time" | |
) | |
# Advanced settings (hidden by default) | |
with gr.Accordion("π§ Advanced Settings", open=False): | |
steps = gr.Slider( | |
label="π Inference Steps", | |
minimum=5, | |
maximum=50, | |
value=15, | |
step=1, | |
info="More steps = higher quality but slower" | |
) | |
guidance_scale = gr.Slider( | |
label="ποΈ Guidance Scale", | |
minimum=1.0, | |
maximum=15.0, | |
value=7.5, | |
step=0.5, | |
info="Higher = more prompt adherence" | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label="π Width", | |
minimum=256, | |
maximum=768, | |
value=384, | |
step=64 | |
) | |
height = gr.Slider( | |
label="π Height", | |
minimum=256, | |
maximum=768, | |
value=384, | |
step=64 | |
) | |
generate_btn = gr.Button( | |
"π¨ η»εγηζ / Generate Images", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### πΌοΈ Generated Images") | |
gallery = gr.Gallery( | |
label="ηζη»ε / Generated Images", | |
columns=2, | |
rows=2, | |
height=400, | |
show_label=False | |
) | |
status_text = gr.Textbox( | |
label="π Generation Status", | |
interactive=False, | |
lines=2 | |
) | |
# Speed preset change handler | |
def update_settings(preset): | |
if preset == "ultra_fast": | |
return 10, 7.0, 256, 256 | |
elif preset == "fast": | |
return 15, 7.5, 384, 384 | |
elif preset == "balanced": | |
return 20, 7.5, 512, 512 | |
elif preset == "quality": | |
return 25, 8.0, 512, 512 | |
else: # custom | |
return 15, 7.5, 384, 384 | |
speed_preset.change( | |
fn=update_settings, | |
inputs=[speed_preset], | |
outputs=[steps, guidance_scale, width, height] | |
) | |
# Example prompts section | |
gr.Markdown(""" | |
### π Quick Examples / γ―γ€γγ―δΎ: | |
**β‘ For fastest results, try these short prompts:** | |
- `happy, smiling` | |
- `sad, crying` | |
- `angry, red face` | |
- `surprised, shocked` | |
- `sleepy, tired` | |
- `superhero, cape` | |
""") | |
# Button click event | |
generate_btn.click( | |
fn=generate_bitkun, | |
inputs=[custom_prompt, negative_prompt, num_images, steps, guidance_scale, width, height], | |
outputs=[gallery, status_text], | |
show_progress=True | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |