# For CXR import random import cv2 import numpy as np import torch from PIL import Image from torchvision import transforms from transformers import BatchEncoding, PreTrainedTokenizer """ Mixin for all modalities, each mixin has: - preprocess function that takes in path or data and returns tensor - construct_input function that takes in tensor and returns dict with batch dimension for model input - key string for model input dict """ class ECHO_Mixin: LOWER_YELLOW: list[int] = [20, 50, 50] UPPER_YELLOW: list[int] = [100, 255, 255] IMAGE_SIZE: tuple[int, int] = (224, 224) NORM_MEAN: tuple[float, float, float] = (0.48145466, 0.4578275, 0.40821073) NORM_STD: tuple[float, float, float] = (0.26862954, 0.26130258, 0.27577711) ECHO_TRANSFORMS = transforms.Compose( [ transforms.ToTensor(), # Scaling into [0, 1] transforms.Resize(IMAGE_SIZE), transforms.Normalize( mean=NORM_MEAN, std=NORM_STD, ), ] ) ECHO_KEY: str = "echo" def grabimage(self, split: str, data: dict[str, np.ndarray]) -> np.ndarray: """""" if split == "train": caseofinterest = random.choice(list(data.keys())) imageindice = random.choice(list(range(data[caseofinterest].shape[0]))) else: caseofinterest = random.choice(list(data.keys())) # listofcases[0] imageindice = 0 video = data[caseofinterest] return self.extract_echoframe(imageindice, video) def extract_echoframe(self, imageindice: int, video: np.ndarray) -> np.ndarray: image = video[imageindice] hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) lower_yellow = np.array(self.LOWER_YELLOW) # Lower bound of yellow hue upper_yellow = np.array(self.UPPER_YELLOW) # Upper bound of yellow hue mask = cv2.inRange(hsv_image, lower_yellow, upper_yellow) image[mask > 0] = [0, 0, 0] image = np.array(image, dtype=np.float32) image -= image.min() image /= image.max() image *= 255 image = image image = image[:, :, :] image = image.astype(np.uint8) return image def preprocess_echoseries( self, video_dict: dict[str, np.ndarray], split: str = "valid" ) -> torch.Tensor: """assumes inference mode""" image = self.grabimage(split, video_dict) if not isinstance(image, np.ndarray): raise TypeError("Expected image to be a numpy ndarray") pil_image = Image.fromarray(image) transformed = self.ECHO_TRANSFORMS(pil_image) if not isinstance(transformed, torch.Tensor): transformed = transforms.ToTensor()(pil_image) return transformed def preprocess_single_echo(self, avi_path: str) -> torch.Tensor: """assumes inference mode, opens AVI file and processes first frame Output: image: torch.Tensor of shape (C, H, W) """ cap = cv2.VideoCapture(avi_path) success, frame = cap.read() cap.release() if not success or frame is None: raise ValueError(f"Could not read frame from AVI file: {avi_path}") image = self.extract_echoframe(0, np.array([frame])) # process first frame image = self.ECHO_TRANSFORMS(Image.fromarray(image)) if not isinstance(image, torch.Tensor): image = torch.from_numpy(image) return image # CXR class CXR_Mixin: RESIZE: tuple[int, int] = (256, 256) IMAGE_SIZE: tuple[int, int] = (224, 224) NORM_MEAN: list[float] = [0.5862785803043838] NORM_STD: list[float] = [0.27950088968644304] VISION_KEY: str = "vision" CXR_TRANSFORMS = transforms.Compose( [ transforms.ToTensor(), # Scaling into [0, 1] transforms.Resize(RESIZE), transforms.CenterCrop(IMAGE_SIZE), transforms.Normalize( mean=NORM_MEAN, std=NORM_STD, ), ] ) @staticmethod def remove_border(pixel_array: np.ndarray) -> np.ndarray: # Find where the image is not just background (0s) coords = np.column_stack(np.where(pixel_array > 0)) x_min, y_min = coords.min(axis=0) x_max, y_max = coords.max(axis=0) # Crop the image cropped_image = pixel_array[x_min:x_max, y_min:y_max] return cropped_image def preprocess_loaded_cxr(self, img: np.array) -> torch.Tensor: cxr = self.remove_border(img) # Convert grayscale image to 3-channel RGB cxr = np.repeat(cxr[..., np.newaxis], 3, axis=-1) cxr = Image.fromarray(cxr) transformed = self.CXR_TRANSFORMS(cxr) if not isinstance(transformed, torch.Tensor): transformed = transforms.ToTensor()(cxr) return transformed def preprocess_single_cxr(self, image_path: str) -> torch.Tensor: """assumes inference mode""" with open(image_path, "rb") as fopen: image = Image.open(fopen).convert("RGB") image = np.array(image)[:, :, 0] # convert to grayscale cxr = self.preprocess_loaded_cxr(image) return cxr class ECG_Mixin: LENGTH: int = 1000 FREQUENCY: int = 100 # we assume 100Hz sampling rate CHANNELS: int = 12 NORM_MEAN: float = 0.02547506 NORM_SCALE: float = 0.16486814 NORM_VAR: float = 0.0271815 ECG_KEY: str = "ecg" def manual_standardize(self, x: np.ndarray) -> torch.Tensor: """ Apply manual standardization to ECG or other data. Equivalent to sklearn's StandardScaler with given constants. Args: x (np.ndarray): Input array of shape (12, 1000) Returns: torch.Tensor: Scaled array of the same shape """ return torch.from_numpy((x - self.NORM_MEAN) / self.NORM_SCALE).float() def check_ecg(self, ecg: np.ndarray) -> np.ndarray: # Find where the image is not just background (0s) if np.isnan(ecg).any() or np.isinf(ecg).any(): raise ValueError("ECG contains NaN or Inf values") return ecg[:, : self.LENGTH] # Truncate to first 1000 length (10 seconds at 100Hz) def preprocess_single_ecg(self, ecg_path: str) -> torch.Tensor: """assumes inference mode""" # ecg is a np array path, assumes 12 channels ecg = np.load(ecg_path) if ecg.ndim == 2 and ecg.shape[0] != self.CHANNELS: raise ValueError(f"Expected ECG with {self.CHANNELS} channels, got {ecg.shape[0]}") ecg = self.check_ecg(ecg) transformed = self.manual_standardize(ecg) return transformed class Text_Mixin: MODALITY_LIST: dict[str, str] = {"echo": "echocardiogram", "ecg": "ecg", "vision": "cxr"} MAX_LENGTH: int = 120 # longer length to accomodate longer reports TEXT_LENGTH: int = 100 # 100 words def get_first_n_words(self, text: str, n: int = 100) -> str: """97.5 percentile of text is less than 35 words""" words = text.split() # Split the text into words return " ".join(words[:n]) # Join the first n words back into a string def createCaption(self, caption: str, modality: str = "") -> str: assert modality in set(self.MODALITY_LIST.keys()) or modality == "", ( f"modality should be in {self.MODALITY_LIST} or empty" ) return f"text : {caption}, {modality} looks like : " def createTokenizedCaption(self, caption: str, tokenizer: PreTrainedTokenizer) -> BatchEncoding: encoding = tokenizer( caption, padding="max_length", truncation=True, max_length=self.MAX_LENGTH, return_tensors="pt", ) return encoding def construct_caption( self, caption: str, tokenizer: PreTrainedTokenizer, modality: str = "" ) -> BatchEncoding: """given caption string, return tokenized caption dict for model input Output: dict with keys 'input_ids' and 'attention_mask', each of shape (1, L) """ caption_str = self.createCaption(caption, modality) tokenized = self.createTokenizedCaption(caption_str, tokenizer) return tokenized