imagen / app.py
cybergamer0123's picture
Update app.py
a7e8e3a verified
raw
history blame
3.1 kB
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from io import BytesIO
from diffusers import OnnxStableDiffusionPipeline
from huggingface_hub import snapshot_download
from PIL import Image
import os
app = FastAPI()
# Global variable to hold the loaded pipeline
pipeline = None
model_id = "clip.opt/model.onnx" # Or any other ONNX compatible Stable Diffusion model
repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Directory to store downloaded ONNX models using snapshot_download
class ImageRequest(BaseModel):
prompt: str
num_inference_steps: int = 50
guidance_scale: float = 7.5
async def load_pipeline():
"""Loads the ONNX Stable Diffusion pipeline from Hugging Face Hub using snapshot_download."""
global pipeline
if pipeline is None:
try:
local_model_path = snapshot_download(
repo_id=repo_id,
local_dir=".", # Specify local_dir to ensure files are placed there
allow_patterns=["clip.opt/*.onnx"] # Specify necessary file patterns (adjust as needed)
)
pipeline = OnnxStableDiffusionPipeline.from_pretrained(
local_model_path, # Use the local path from snapshot_download
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
)
print(f"ONNX Stable Diffusion pipeline loaded successfully from {model_id} (ONNX revision) using snapshot_download from: {local_model_path}")
except Exception as e:
print(f"Error loading ONNX pipeline using snapshot_download: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
return pipeline
async def get_pipeline():
"""Dependency to ensure pipeline is loaded before endpoint is called."""
return await load_pipeline()
@app.on_event("startup")
async def startup_event():
"""Loads the pipeline on startup."""
await load_pipeline()
@app.post("/generate-image/")
async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
"""
Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline.
"""
try:
image = pipeline_dep(
request.prompt,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale
).images[0]
# Convert PIL Image to bytes for streaming response
img_byte_arr = BytesIO()
image.save(img_byte_arr, format="PNG")
img_byte_arr = img_byte_arr.getvalue()
return StreamingResponse(content=iter([img_byte_arr]), media_type="image/png")
except Exception as e:
print(f"Error during image generation: {e}")
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)