import argparse import os import torch from PIL import Image from diffusers import QwenImageEditPipeline import requests from io import BytesIO def load_image(path_or_url: str) -> Image.Image: """支持本地文件和 URL 的图片读取""" if path_or_url.startswith("http://") or path_or_url.startswith("https://"): response = requests.get(path_or_url, stream=True) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") else: return Image.open(path_or_url).convert("RGB") def main(): parser = argparse.ArgumentParser() parser.add_argument("--prompt", type=str, required=True, help="输入提示词") parser.add_argument("--input_image", type=str, default=None, help="输入图片路径或URL (可选)") parser.add_argument("--seed", type=int, default=42, help="随机种子 (0 表示随机)") parser.add_argument("--width", type=int, default=1664, help="图像宽度 (默认 1664)") parser.add_argument("--height", type=int, default=928, help="图像高度 (默认 928)") parser.add_argument("--steps", type=int, default=50, help="推理步数 (默认 50)") args = parser.parse_args() model_name = "Qwen/Qwen-Image-Edit" # Load pipeline pipe = QwenImageEditPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16) device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) # Generator (0 = 随机) generator = None if args.seed == 0 else torch.Generator(device=device).manual_seed(args.seed) # 在这里注入了固定的负面提示词 negative_prompt = "political figures, Chinese political figures, porn, nsfw" # 构造输入 inputs = { "prompt": args.prompt, "generator": generator, "true_cfg_scale": 4.0, "negative_prompt": negative_prompt, "num_inference_steps": args.steps, "width": args.width, "height": args.height, } # 如果有输入图片 → 编辑 if args.input_image: image = load_image(args.input_image) inputs["image"] = image # 推理 with torch.inference_mode(): output = pipe(**inputs) output_image = output.images[0] # 输出目录 output_dir = os.path.abspath("output") os.makedirs(output_dir, exist_ok=True) # 固定文件名 output_path = os.path.join(output_dir, "output.png") output_image.save(output_path) seed_info = "random" if args.seed == 0 else args.seed mode = "Image Edit" if args.input_image else "Text-to-Image" print(f"✅ Image saved at: {output_path} (mode={mode}, seed={seed_info}, steps={args.steps}, size={args.width}x{args.height})") if __name__ == "__main__": main()