Spaces:
Sleeping
Sleeping
import torch | |
from PIL import Image | |
import numpy as np | |
from io import BytesIO | |
from huggingface_hub import hf_hub_download | |
from pathlib import Path | |
from src.preprocess import read_xray, enhance_exposure, unsharp_masking, apply_clahe, resize_pil_image, increase_brightness | |
from src.network.model import RealESRGAN | |
from src.app.exceptions import InputError, ModelLoadError, PreprocessingError, InferenceError,PostprocessingError | |
class ModelLoadError(Exception): | |
pass | |
class InferencePipeline: | |
def __init__(self, config): | |
""" | |
Initialize the inference pipeline using configuration. | |
Args: | |
config: Configuration dictionary. | |
""" | |
self.config = config | |
preferred_device = config["model"].get("device", "cuda") | |
if preferred_device == "cuda" and not torch.cuda.is_available(): | |
print("[Warning] CUDA requested but not available. Falling back to CPU.") | |
self.device = "cpu" | |
else: | |
self.device = preferred_device | |
self.scale = config["model"].get("scale", 4) | |
model_source = config["model"].get("source", "local") | |
self.model = RealESRGAN(self.device, scale=self.scale) | |
print(f"Using device: {self.device}") | |
try: | |
if model_source == "huggingface": | |
repo_id = config["model"]["repo_id"] | |
filename = config["model"]["filename"] | |
local_path = hf_hub_download(repo_id=repo_id, filename=filename) | |
self.load_weights(local_path) | |
else: | |
local_path = config["model"]["weights"] | |
self.load_weights(local_path) | |
except Exception as e: | |
raise ModelLoadError(f"Failed to load the model: {str(e)}") | |
def load_weights(self, model_weights): | |
""" | |
Load the model weights. | |
Args: | |
model_weights: Path to the model weights file. | |
""" | |
try: | |
self.model.load_weights(model_weights) | |
except FileNotFoundError: | |
raise ModelLoadError(f"Model weights not found at '{model_weights}'.") | |
except Exception as e: | |
raise ModelLoadError(f"Error loading weights: {str(e)}") | |
def preprocess(self, image_path_or_bytes, apply_pre_contrast_adjustment=True, is_dicom=False): | |
""" | |
Preprocess the input image. | |
Args: | |
image_path: Path to the input image file. | |
is_dicom: Boolean indicating if the input is a DICOM file. | |
Returns: | |
PIL Image: Preprocessed image. | |
""" | |
try: | |
if is_dicom: | |
img = read_xray(image_path_or_bytes) | |
else: | |
img = Image.open(image_path_or_bytes) | |
if apply_pre_contrast_adjustment: | |
img = enhance_exposure(np.array(img)) | |
if isinstance(img,np.ndarray): | |
img = Image.fromarray(((img / np.max(img))*255).astype(np.uint8)) | |
if img.mode not in ['RGB']: | |
img = img.convert('RGB') | |
img = unsharp_masking( | |
img, | |
self.config["preprocessing"]["unsharping_mask"].get("kernel_size", 7), | |
self.config["preprocessing"]["unsharping_mask"].get("strength", 2) | |
) | |
img = increase_brightness( | |
img, | |
self.config["preprocessing"]["brightness"].get("factor", 1.2), | |
) | |
if img.mode not in ['RGB']: | |
img = img.convert('RGB') | |
return img, img.size | |
except Exception as e: | |
raise PreprocessingError(f"Error during preprocessing: {str(e)}") | |
def postprocess(self, image_array): | |
""" | |
Postprocess the output from the model. | |
Args: | |
image_array: PIL.Image output from the model. | |
Returns: | |
PIL Image: Postprocessed image. | |
""" | |
try: | |
return apply_clahe( | |
image_array, | |
self.config["postprocessing"]["clahe"].get("clipLimit", 2.0), | |
tuple(self.config["postprocessing"]["clahe"].get("tileGridSize", [16, 16])) | |
) | |
except Exception as e: | |
raise PostprocessingError(f"Error during postprocessing: {str(e)}") | |
def is_dicom(self, file_path_or_bytes): | |
""" | |
Check if the input file is a DICOM file. | |
Args: | |
file_path_or_bytes (str or bytes or BytesIO): Path to the file, byte content, or BytesIO object. | |
Returns: | |
bool: True if the file is a DICOM file, False otherwise. | |
""" | |
try: | |
if isinstance(file_path_or_bytes, str): | |
# Check the file extension | |
file_extension = Path(file_path_or_bytes).suffix.lower() | |
if file_extension in ['.dcm', '.dicom']: | |
return True | |
# Open the file and check the header | |
with open(file_path_or_bytes, 'rb') as file: | |
header = file.read(132) | |
return header[-4:] == b'DICM' | |
elif isinstance(file_path_or_bytes, BytesIO): | |
file_path_or_bytes.seek(0) | |
header = file_path_or_bytes.read(132) | |
file_path_or_bytes.seek(0) # Reset the stream position | |
return header[-4:] == b'DICM' | |
elif isinstance(file_path_or_bytes, bytes): | |
header = file_path_or_bytes[:132] | |
return header[-4:] == b'DICM' | |
except Exception as e: | |
print(f"Error during DICOM validation: {e}") | |
return False | |
return False | |
def validate_input(self, input_data): | |
""" | |
Validate the input data to ensure it is suitable for processing. | |
Args: | |
input_data: Path to the input file, bytes content, or BytesIO object. | |
Returns: | |
bool: True if the input is valid, raises InputError otherwise. | |
""" | |
if isinstance(input_data, str): | |
# Check if the file exists | |
if not Path(input_data).exists(): | |
raise InputError(f"Input file '{input_data}' does not exist.") | |
# Check if the file type is supported | |
file_extension = Path(input_data).suffix.lower() | |
if file_extension not in ['.png', '.jpeg', '.jpg', '.dcm', '.dicom']: | |
raise InputError(f"Unsupported file type '{file_extension}'. Supported types are PNG, JPEG, and DICOM.") | |
elif isinstance(input_data, BytesIO): | |
# Check if BytesIO data is not empty | |
if input_data.getbuffer().nbytes == 0: | |
raise InputError("Input BytesIO data is empty.") | |
else: | |
raise InputError("Unsupported input type. Must be a file path, byte content, or BytesIO object.") | |
return True | |
def infer(self, input_image): | |
""" | |
Perform inference on a single image. | |
Args: | |
input_image: PIL Image to be processed. | |
Returns: | |
PIL Image: Super-resolved image. | |
""" | |
try: | |
# Perform inference | |
input_array = np.array(input_image) | |
sr_array = self.model.predict(input_array) | |
return sr_array | |
except Exception as e: | |
raise InferenceError(f"Error during inference: {str(e)}") | |
def run(self, input_path, apply_pre_contrast_adjustment = True, apply_clahe_postprocess=False, return_original_size = True): | |
""" | |
Process a single image and save the output. | |
Args: | |
input_path: Path to the input image file. | |
is_dicom: Boolean indicating if the input is a DICOM file. | |
apply_clahe_postprocess: Boolean indicating if CLAHE should be applied post-processing. | |
""" | |
# Validate the input | |
self.validate_input(input_path) | |
is_dicom =self.is_dicom(input_path) | |
img, original_size = self.preprocess(input_path, is_dicom=is_dicom, apply_pre_contrast_adjustment = apply_pre_contrast_adjustment) | |
if img is None: | |
raise InputError(f"Invalid Input") | |
sr_image = self.infer(img) | |
if apply_clahe_postprocess: | |
sr_image = self.postprocess(sr_image) | |
if return_original_size: | |
sr_image = resize_pil_image(sr_image, target_shape = original_size) | |
return sr_image |