Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, Depends | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
from io import BytesIO | |
from diffusers import OnnxStableDiffusionPipeline | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
import os | |
app = FastAPI() | |
# Global variable to hold the loaded pipeline | |
pipeline = None | |
model_id = "clip.opt/model.onnx" # Or any other ONNX compatible Stable Diffusion model | |
repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Directory to store downloaded ONNX models using snapshot_download | |
class ImageRequest(BaseModel): | |
prompt: str | |
num_inference_steps: int = 50 | |
guidance_scale: float = 7.5 | |
async def load_pipeline(): | |
"""Loads the ONNX Stable Diffusion pipeline from Hugging Face Hub using snapshot_download.""" | |
global pipeline | |
if pipeline is None: | |
try: | |
local_model_path = snapshot_download( | |
repo_id=repo_id, | |
local_dir=".", # Specify local_dir to ensure files are placed there | |
allow_patterns=["clip.opt/*.onnx"] # Specify necessary file patterns (adjust as needed) | |
) | |
pipeline = OnnxStableDiffusionPipeline.from_pretrained( | |
local_model_path, # Use the local path from snapshot_download | |
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU | |
) | |
print(f"ONNX Stable Diffusion pipeline loaded successfully from {model_id} (ONNX revision) using snapshot_download from: {local_model_path}") | |
except Exception as e: | |
print(f"Error loading ONNX pipeline using snapshot_download: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}") | |
return pipeline | |
async def get_pipeline(): | |
"""Dependency to ensure pipeline is loaded before endpoint is called.""" | |
return await load_pipeline() | |
async def startup_event(): | |
"""Loads the pipeline on startup.""" | |
await load_pipeline() | |
async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)): | |
""" | |
Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline. | |
""" | |
try: | |
image = pipeline_dep( | |
request.prompt, | |
num_inference_steps=request.num_inference_steps, | |
guidance_scale=request.guidance_scale | |
).images[0] | |
# Convert PIL Image to bytes for streaming response | |
img_byte_arr = BytesIO() | |
image.save(img_byte_arr, format="PNG") | |
img_byte_arr = img_byte_arr.getvalue() | |
return StreamingResponse(content=iter([img_byte_arr]), media_type="image/png") | |
except Exception as e: | |
print(f"Error during image generation: {e}") | |
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) |