import argparse import os import random import torch import numpy as np from diffusers import DiffusionPipeline, AutoencoderKL from PIL import Image from urllib.parse import urlparse, unquote def download_lora(url, lora_dir): """ 从给定的 URL 下载 LoRA 模型到指定目录,并显示进度条。 如果文件已存在,则跳过下载。 返回: LoRA 文件的本地路径,如果失败则返回 None。 """ try: import requests from tqdm import tqdm except ImportError: print("❌ 'requests' 和 'tqdm' 库是使用 --lora 功能所必需的。") print(" 请运行: pip install requests tqdm") return None os.makedirs(lora_dir, exist_ok=True) # 从 URL 解析文件名 parsed_url = urlparse(url) filename = os.path.basename(unquote(parsed_url.path)) if not filename: print(f"⚠️ 无法从 URL 确定文件名,将使用 'downloaded.safetensors'。 URL: {url}") filename = "downloaded.safetensors" lora_path = os.path.join(lora_dir, filename) if not os.path.exists(lora_path): print(f"📥 LoRA 不存在,正在从 URL 下载...") try: response = requests.get(url, stream=True) response.raise_for_status() # 确保请求成功 total_size = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 KB with open(lora_path, 'wb') as file, tqdm( desc=f"Downloading {filename}", total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(block_size): bar.update(len(data)) file.write(data) if total_size != 0 and bar.n != total_size: print(f"❌ 下载 LoRA 时出错:文件大小不匹配。删除了不完整的文件。") os.remove(lora_path) return None print(f"✅ LoRA 下载完成: {lora_path}") return lora_path except Exception as e: print(f"❌ 下载 LoRA 失败: {e}") if os.path.exists(lora_path): os.remove(lora_path) return None else: print(f"✅ LoRA 文件已存在: {lora_path}") return lora_path def generate_image(pipe, prompt, seed=42, randomize_seed=False, width=768, height=768, guidance_scale=4.5, num_inference_steps=20): """ 生成图像;如果 randomize_seed=True,会随机生成一个种子并返回实际使用的种子。 返回: (PIL.Image, used_seed) """ MAX_SEED = np.iinfo(np.int32).max if randomize_seed: used_seed = random.randint(0, MAX_SEED) else: used_seed = int(seed) device = pipe.device if hasattr(pipe, "device") else ("cuda" if torch.cuda.is_available() else "cpu") generator = torch.Generator(device=device).manual_seed(used_seed) print(f"ℹ️ 使用种子: {used_seed} (randomize={randomize_seed})") print("🚀 开始生成图像...") image = pipe( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, output_type="pil" ).images[0] return image, used_seed def main(): parser = argparse.ArgumentParser(description="使用 FLUX.1-Krea-dev 模型从文本提示生成图像。") parser.add_argument("--prompt", type=str, required=True, help="用于图像生成的文本提示。") parser.add_argument("--lora", type=str, default=None, help="[可选] LoRA 模型的 URL (.safetensors 格式的直链)。") parser.add_argument("--seed", type=int, default=42, help="随机种子。设置为 0 表示随机(每次不同)。默认 42。") parser.add_argument("--steps", type=int, default=20, help="推理步数。") parser.add_argument("--width", type=int, default=768, help="图像宽度。") parser.add_argument("--height", type=int, default=768, help="图像高度。") parser.add_argument("--guidance", type=float, default=4.5, help="指导比例 (Guidance Scale)。") args = parser.parse_args() print("⏳ 正在加载模型,请稍候...") dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" good_vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", subfolder="vae", torch_dtype=dtype) pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=dtype, vae=good_vae).to(device) if args.lora: print("-" * 20) print("🔧 正在处理 LoRA...") lora_file_path = download_lora(args.lora, lora_dir="loras") if lora_file_path: try: print(f"🔗 正在加载 LoRA 权重从: {lora_file_path}") pipe.load_lora_weights(lora_file_path) print("✅ LoRA 加载成功!") except Exception as e: print(f"❌ 加载 LoRA 权重时出错: {e}") print(" 将不使用 LoRA 继续生成。") else: print("⚠️ 未能获取 LoRA 文件,将不使用 LoRA 继续生成。") print("-" * 20) if device == "cuda": torch.cuda.empty_cache() print(f"✅ 模型加载完成,使用设备: {device}") print(f"🎨 开始为提示生成图像: '{args.prompt}'") randomize = (args.seed == 0) seed_value = args.seed if not randomize else 42 generated_image, used_seed = generate_image( pipe=pipe, prompt=args.prompt, seed=seed_value, randomize_seed=randomize, width=args.width, height=args.height, num_inference_steps=args.steps, guidance_scale=args.guidance ) # --- 修改部分:确保输出路径固定 --- output_dir = "output" os.makedirs(output_dir, exist_ok=True) output_path = os.path.abspath(os.path.join(output_dir, "output.png")) # --- 修改结束 --- print(f"💾 正在保存图像到: {output_path}") generated_image.save(output_path) print(f"🎉 完成!文件保存在: {output_path}") print(f"🔢 使用的种子: {used_seed} (seed param was {'0(random)' if args.seed==0 else args.seed})") print(f"🖼️ 大小: {args.width}x{args.height}, steps: {args.steps}, guidance: {args.guidance}") if args.lora and lora_file_path: print(f"🎨 使用的 LoRA: {os.path.basename(lora_file_path)}") if __name__ == "__main__": main()