{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/venom/miniforge3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/home/venom/miniforge3/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.\n", " @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)\n" ] } ], "source": [ "import os\n", "import pandas as pd\n", "from PIL import Image\n", "import torch\n", "from torch.utils.data import Dataset, DataLoader, random_split\n", "from torchvision import transforms\n", "import lightning as L\n", "import kornia as K\n", "import numpy as np\n", "import random\n", "import sys\n", "\n", "PROJECT_ROOT = os.path.abspath(os.path.normpath(\"/home/venom/repo/xray-exp/\"))\n", "sys.path.append(PROJECT_ROOT)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from models.model_loader import create_model\n", "from scripts.trainer import XrayReg" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class XrayInferenceDataset(Dataset):\n", "\n", " def __init__(self, root_dir, transform=None):\n", " self.root_dir = root_dir\n", " self.file_names = os.listdir(root_dir)\n", " self.transform = transform\n", "\n", " def __len__(self):\n", " return len(self.file_names)\n", "\n", " def __getitem__(self, idx):\n", " file_name = self.file_names[idx]\n", " img_path = os.path.join(self.root_dir, file_name)\n", " img = Image.open(img_path)\n", "\n", " img = img.convert(\"L\")\n", "\n", " if self.transform:\n", " img = self.transform(img)\n", "\n", " return img, file_name\n", "\n", "\n", "class XrayDataInference(L.LightningDataModule):\n", " common_seed = 42\n", "\n", " @staticmethod\n", " def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)\n", "\n", " def __init__(self, root_dir, batch_size=32):\n", " super().__init__()\n", " self.root_dir = root_dir\n", " self.batch_size = batch_size\n", "\n", " torch.manual_seed(self.common_seed)\n", " torch.cuda.manual_seed_all(self.common_seed)\n", " torch.backends.cudnn.deterministic = True\n", "\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.Resize((224, 224)),\n", " transforms.ToTensor(),\n", " ]\n", " )\n", " self.inference_dataset = XrayInferenceDataset(\n", " self.root_dir, transform=self.transform\n", " )\n", "\n", " def inference_dataloader(self):\n", " return DataLoader(\n", " self.inference_dataset,\n", " batch_size=self.batch_size,\n", " shuffle=False,\n", " num_workers=4,\n", " worker_init_fn=self.seed_worker,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/venom/miniforge3/lib/python3.10/site-packages/timm/models/_factory.py:117: UserWarning: Mapping deprecated model name vit_large_patch16_224_in21k to current vit_large_patch16_224.augreg_in21k.\n", " model = create_fn(\n" ] } ], "source": [ "trainer_config = XrayReg.load_from_checkpoint(\n", " \"/home/venom/repo/xray-exp/xray_regression_noaug/912yp4l6/checkpoints/epoch=99-step=5900.ckpt\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "infer_ds = XrayDataInference(\n", " \"/home/venom/Downloads/CXR AI PNG- FINAL 13-12/\", batch_size=16\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# run inference against the infer_ds and log to a file (file name run_name)\n", "\n", "model = trainer_config.model\n", "\n", "model.eval()\n", "model = model.cuda()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# run inference against the infer_ds and log to a file (file name run_ID)\n", "RUN_ID = \"912yp4l6\"\n", "\n", "with open(f\"/home/venom/repo/xray-exp/inference_results/{RUN_ID}.csv\", \"w\") as f:\n", " f.write(\"file_name,predicted\\n\")\n", " for img, file_name in infer_ds.inference_dataloader():\n", " img = img.cuda()\n", " with torch.no_grad():\n", " pred = model(img)\n", " pred = pred.cpu().numpy()\n", " for i in range(len(pred)):\n", " f.write(f\"{file_name[i]},{pred[i][0]}\\n\")" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }