├── data ├── ckpt │ └── .gitkeep ├── test │ └── .gitkeep ├── logger │ └── .gitkeep ├── output │ └── .gitkeep └── demo │ ├── train │ └── 6897fa9de148 │ │ └── 2bfbb7fd2e8b │ │ ├── 1997c99c9d59.dcm │ │ ├── 2c8ed843c858.dcm │ │ ├── 487d9ab5531f.dcm │ │ ├── 4f03aef72206.dcm │ │ ├── 52b6b0b793bb.dcm │ │ ├── 5fe975eb5ce9.dcm │ │ ├── 945e7e32955c.dcm │ │ ├── baedb900c69c.dcm │ │ ├── bc1390934263.dcm │ │ ├── c6f29ac6659b.dcm │ │ ├── d9e3d9934410.dcm │ │ └── db937353ea10.dcm │ ├── README.csv │ ├── demo.csv │ └── resnet18_demo.yaml ├── tests ├── __init__.py ├── test_dataset.py ├── test_dataloader_window.py ├── test_dataloader2d.py └── test_dataloader3d.py ├── pe_models ├── __init__.py ├── preprocess │ ├── __init__.py │ ├── build_hdf5_labels.py │ ├── stanford.py │ ├── lidc.py │ └── rsna.py ├── lightning │ ├── __init__.py │ ├── classification_lightning_model.py │ └── window_classification_lightning_model.py ├── models │ ├── __init__.py │ ├── backbones3d │ │ ├── layers │ │ │ └── penet │ │ │ │ ├── __init__.py │ │ │ │ ├── gap_linear.py │ │ │ │ ├── penet_lateral.py │ │ │ │ ├── penet_encoder.py │ │ │ │ ├── se_block.py │ │ │ │ ├── penet_decoder.py │ │ │ │ ├── penet_asp_pool.py │ │ │ │ └── penet_bottleneck.py │ │ ├── penet_classifier.py │ │ ├── resnet_2d3d.py │ │ ├── resnet3d.py │ │ ├── s3dg.py │ │ └── r21d.py │ ├── models_2d.py │ ├── models_3d.py │ └── models_1d.py ├── datasets │ ├── __init__.py │ ├── data_module.py │ ├── dataset_base.py │ └── dataset_1d.py ├── loss.py ├── constants.py ├── utils.py └── builder.py ├── configs ├── resnet18_n_n_lidc0.1_lr1e-4.yaml ├── lrcn_y_n_lidc0.1_lr1e-3.yaml ├── r2plus1d_y_n_rsna0.1_lr1e-5.yaml └── r2plus1d_y_n_lidc1.0_lr1e-5.yaml ├── .gitignore ├── environment.yml ├── README.md ├── LICENSE └── run.py /data/ckpt/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/test/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/logger/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/output/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pe_models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import builder, utils 2 | -------------------------------------------------------------------------------- /pe_models/preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from . import rsna, lidc, stanford 2 | -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/1997c99c9d59.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/1997c99c9d59.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/2c8ed843c858.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/2c8ed843c858.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/487d9ab5531f.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/487d9ab5531f.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/4f03aef72206.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/4f03aef72206.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/52b6b0b793bb.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/52b6b0b793bb.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/5fe975eb5ce9.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/5fe975eb5ce9.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/945e7e32955c.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/945e7e32955c.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/baedb900c69c.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/baedb900c69c.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/bc1390934263.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/bc1390934263.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/c6f29ac6659b.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/c6f29ac6659b.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/d9e3d9934410.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/d9e3d9934410.dcm -------------------------------------------------------------------------------- /data/demo/train/6897fa9de148/2bfbb7fd2e8b/db937353ea10.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajpurkarlab/chest-ct-pretraining/HEAD/data/demo/train/6897fa9de148/2bfbb7fd2e8b/db937353ea10.dcm -------------------------------------------------------------------------------- /pe_models/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification_lightning_model import PEClassificationLightningModel 2 | from .window_classification_lightning_model import PEWindowClassificationLightningModel -------------------------------------------------------------------------------- /pe_models/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_1d import * 2 | from .models_2d import * 3 | from .models_3d import * 4 | 5 | ALL_MODELS = { 6 | "model_1d": PEModel1D, 7 | "model_2d": PEModel2D, 8 | "model_3d": PEModel3D 9 | } 10 | -------------------------------------------------------------------------------- /data/demo/README.csv: -------------------------------------------------------------------------------- 1 | The imaging study included in this folder has been modified with dummy label and only meant for demo purpose. For the full dataset, please refer to [RSNA PE Dataset](https://www.kaggle.com/c/rsna-str-pulmonary-embolism-detection) 2 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/__init__.py: -------------------------------------------------------------------------------- 1 | from .gap_linear import GAPLinear 2 | from .se_block import SEBlock 3 | from .penet_asp_pool import PENetASPPool 4 | from .penet_bottleneck import PENetBottleneck 5 | from .penet_encoder import PENetEncoder 6 | from .penet_decoder import PENetDecoder 7 | -------------------------------------------------------------------------------- /pe_models/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_module, dataset_1d, dataset_2d, dataset_3d 2 | 3 | ALL_DATASETS = { 4 | "1d": dataset_1d.PEDataset1D, 5 | "1d_stanford": dataset_1d.PEDataset1DStanford, 6 | "2d": dataset_2d.PEDataset2D, 7 | "2d_mae": dataset_2d.PEDataset2DMAE, 8 | "2d_stanford": dataset_2d.PEDataset2DStanford, 9 | "3d": dataset_3d.PEDataset3D, 10 | "window": dataset_3d.PEDatasetWindow, 11 | "window_stanford": dataset_3d.PEDatasetWindowStanford, 12 | "lidc-window": dataset_3d.LIDCDatasetWindow, 13 | "lidc-2d": dataset_2d.LIDCDataset2D, 14 | "lidc-1d": dataset_1d.LIDCDataset1D, 15 | "demo": dataset_2d.DemoDataset2D, 16 | } 17 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/gap_linear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class GAPLinear(nn.Module): 5 | def __init__(self, in_channels, out_channels): 6 | """Global average pooling (3D) followed by a linear layer. 7 | 8 | Args: 9 | in_channels: Number of input channels. 10 | out_channels: Number of output channels 11 | """ 12 | super(GAPLinear, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool3d(1) 14 | self.fc = nn.Linear(in_channels, out_channels) 15 | self.fc.is_output_head = True 16 | 17 | def forward(self, x): 18 | x = self.avg_pool(x) 19 | x = x.view(x.size(0), -1) 20 | x = self.fc(x) 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/penet_lateral.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class PENetLateral(nn.Module): 5 | """Lateral connection layer for PENet.""" 6 | def __init__(self, in_channels, out_channels): 7 | super(PENetLateral, self).__init__() 8 | 9 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False) 10 | self.norm = nn.GroupNorm(out_channels // 16, out_channels) 11 | self.relu = nn.LeakyReLU() 12 | 13 | def forward(self, x, x_skip): 14 | # Reduce number of channels in skip connection 15 | x_skip = self.conv(x_skip) 16 | x_skip = self.norm(x_skip) 17 | x_skip = self.relu(x_skip) 18 | 19 | # Add reduced feature map 20 | x += x_skip 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/penet_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import PENetBottleneck 4 | 5 | 6 | class PENetEncoder(nn.Module): 7 | def __init__(self, in_channels, channels, num_blocks, cardinality, block_idx, total_blocks, stride=1): 8 | super(PENetEncoder, self).__init__() 9 | 10 | # Get PENet blocks 11 | penet_blocks = [PENetBottleneck(in_channels, channels, block_idx, total_blocks, cardinality, stride)] 12 | 13 | for i in range(1, num_blocks): 14 | penet_blocks += [PENetBottleneck(channels * PENetBottleneck.expansion, channels, 15 | block_idx + i, total_blocks, cardinality)] 16 | self.penet_blocks = nn.Sequential(*penet_blocks) 17 | 18 | def forward(self, x): 19 | x = self.penet_blocks(x) 20 | 21 | return x 22 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/se_block.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SEBlock(nn.Module): 5 | """Squeeze-and-Excitation Block. 6 | 7 | Based on the paper: 8 | "Squeeze-and-Excitation Networks" 9 | by Jie Hu, Li Shen, Gang Sun 10 | (https://arxiv.org/abs/1709.01507). 11 | """ 12 | 13 | def __init__(self, num_channels, reduction=16): 14 | super(SEBlock, self).__init__() 15 | self.squeeze = nn.AdaptiveAvgPool3d(1) 16 | self.excite = nn.Sequential(nn.Linear(num_channels, num_channels // reduction), 17 | nn.LeakyReLU(inplace=True), 18 | nn.Linear(num_channels // reduction, num_channels), 19 | nn.Sigmoid()) 20 | 21 | def forward(self, x): 22 | num_channels = x.size(1) 23 | 24 | # Squeeze 25 | z = self.squeeze(x) 26 | z = z.view(-1, num_channels) 27 | 28 | # Excite 29 | s = self.excite(z) 30 | s = s.view(-1, num_channels, 1, 1, 1) 31 | 32 | # Apply gate 33 | x = x * s 34 | 35 | return x 36 | -------------------------------------------------------------------------------- /data/demo/demo.csv: -------------------------------------------------------------------------------- 1 | ,Unnamed: 0,StudyInstanceUID,SeriesInstanceUID,SOPInstanceUID,pe_present_on_image,negative_exam_for_pe,Split,RescaleIntercept,RescaleSlope,InstancePath 2 | 0,0,6897fa9de148,2bfbb7fd2e8b,baedb900c69c,1,1,train,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/baedb900c69c.dcm 3 | 1,1,6897fa9de148,2bfbb7fd2e8b,52b6b0b793bb,1,1,train,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/52b6b0b793bb.dcm 4 | 2,2,6897fa9de148,2bfbb7fd2e8b,1997c99c9d59,0,0,train,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/1997c99c9d59.dcm 5 | 3,3,6897fa9de148,2bfbb7fd2e8b,c6f29ac6659b,0,0,train,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/c6f29ac6659b.dcm 6 | 4,4,6897fa9de148,2bfbb7fd2e8b,487d9ab5531f,1,1,valid,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/487d9ab5531f.dcm 7 | 5,5,6897fa9de148,2bfbb7fd2e8b,5fe975eb5ce9,1,1,valid,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/5fe975eb5ce9.dcm 8 | 6,6,6897fa9de148,2bfbb7fd2e8b,4f03aef72206,0,0,valid,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/4f03aef72206.dcm 9 | 7,7,6897fa9de148,2bfbb7fd2e8b,bc1390934263,0,0,valid,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/bc1390934263.dcm 10 | 8,8,6897fa9de148,2bfbb7fd2e8b,db937353ea10,1,1,test,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/db937353ea10.dcm 11 | 9,9,6897fa9de148,2bfbb7fd2e8b,2c8ed843c858,1,1,test,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/2c8ed843c858.dcm 12 | 10,10,6897fa9de148,2bfbb7fd2e8b,945e7e32955c,0,0,test,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/945e7e32955c.dcm 13 | 11,11,6897fa9de148,2bfbb7fd2e8b,d9e3d9934410,0,0,test,-1024.0,1.0,6897fa9de148/2bfbb7fd2e8b/d9e3d9934410.dcm 14 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/penet_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .penet_lateral import PENetLateral 4 | 5 | 6 | class PENetDecoder(nn.Module): 7 | """Decoder (up-sampling layer) for PENet""" 8 | def __init__(self, skip_channels, in_channels, mid_channels, out_channels, kernel_size=4, stride=2): 9 | super(PENetDecoder, self).__init__() 10 | 11 | if skip_channels > 0: 12 | self.lateral = PENetLateral(skip_channels, in_channels) 13 | 14 | self.conv1 = nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False) 15 | self.norm1 = nn.GroupNorm(mid_channels // 16, mid_channels) 16 | self.relu1 = nn.LeakyReLU() 17 | 18 | self.conv2 = nn.ConvTranspose3d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride, padding=1) 19 | self.norm2 = nn.GroupNorm(mid_channels // 16, mid_channels) 20 | self.relu2 = nn.LeakyReLU() 21 | 22 | self.conv3 = nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False) 23 | self.norm3 = nn.GroupNorm(out_channels // 16, out_channels) 24 | self.relu3 = nn.LeakyReLU() 25 | 26 | def forward(self, x, x_skip=None): 27 | if x_skip is not None: 28 | x = self.lateral(x, x_skip) 29 | 30 | x = self.conv1(x) 31 | x = self.norm1(x) 32 | x = self.relu1(x) 33 | 34 | x = self.conv2(x) 35 | x = self.norm2(x) 36 | x = self.relu2(x) 37 | 38 | x = self.conv3(x) 39 | x = self.norm3(x) 40 | x = self.relu3(x) 41 | 42 | return x 43 | -------------------------------------------------------------------------------- /data/demo/resnet18_demo.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'resnet18_demo' 2 | trial_name: 'lr-4' 3 | phase: 'train' 4 | 5 | lightning: 6 | trainer: 7 | gpus: 1 8 | max_epochs: 10 9 | lr: 1e-4 10 | precision: 16 11 | auto_lr_find: false 12 | benchmark: true 13 | replace_sampler_ddp: false 14 | checkpoint_callback: 15 | monitor: 'val/mean_auroc' 16 | dirpath: './data/ckpt' 17 | save_last: true 18 | mode: 'max' 19 | save_top_k: 10 20 | early_stopping_callback: 21 | monitor: 'val/mean_auroc' 22 | min_delta: 0.00 23 | patience: 5 24 | verbose: False 25 | mode: 'max' 26 | logger: 27 | logger_type: 'WandbLogger' 28 | save_dir: './data/logger/' 29 | name: 'resnet_18' 30 | project: 'pe_models' 31 | 32 | 33 | model: 34 | type: 'model_2d' 35 | model_name: 'resnet_18' 36 | freeze_cnn: false 37 | pretrained: false 38 | 39 | data: 40 | use_hdf5: true 41 | type: 'demo' # 1d, 2d, 3d, window, lidc-window, lidc-2d 42 | targets: 'rsna_pe_target' 43 | channels: 'window' # repeat, neighbor, window 44 | weighted_sample: true 45 | positive_only: false 46 | imsize: 256 47 | 48 | transforms: 49 | type: 'imagenet' 50 | ShiftScaleRotate: 51 | shift_limit: 0.05 52 | scale_limit: 0.05 53 | rotate_limit: 20 54 | p: 0.5 55 | RandomCrop: 56 | height: 224 57 | width: 224 58 | 59 | train: 60 | batch_size: 4 61 | num_workers: 4 62 | weighted_loss: false 63 | loss_fn: 64 | name: 'BCEWithLogitsLoss' 65 | optimizer: 66 | name: 'Adam' 67 | weight_decay: 1e-6 68 | scheduler: 69 | name: 'ReduceLROnPlateau' 70 | monitor: 'val_loss' 71 | interval: 'epoch' 72 | frequency: 3 73 | factor: 0.5 74 | patience: 5 75 | -------------------------------------------------------------------------------- /configs/resnet18_n_n_lidc0.1_lr1e-4.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'lidc_2d_resnet18_novid' 2 | trial_name: 'lr-4' 3 | phase: 'train' 4 | 5 | lightning: 6 | trainer: 7 | gpus: 1 8 | max_epochs: 50 9 | lr: 1e-4 10 | precision: 16 11 | auto_lr_find: false 12 | benchmark: true 13 | replace_sampler_ddp: false 14 | limit_train_batches: 0.1 15 | checkpoint_callback: 16 | monitor: 'val/mean_auroc' 17 | dirpath: './data/ckpt' 18 | save_last: true 19 | mode: 'max' 20 | save_top_k: 10 21 | early_stopping_callback: 22 | monitor: 'val/mean_auroc' 23 | min_delta: 0.00 24 | patience: 5 25 | verbose: False 26 | mode: 'max' 27 | logger: 28 | logger_type: 'WandbLogger' 29 | save_dir: './data/logger/' 30 | name: 'resnet_18' 31 | project: 'pe_models' 32 | 33 | 34 | model: 35 | type: 'model_2d' 36 | model_name: 'resnet_18' 37 | freeze_cnn: false 38 | pretrained: false 39 | 40 | data: 41 | use_hdf5: true 42 | dataset: 'rsna' 43 | type: 'lidc-2d' # 1d, 2d, 3d, window, lidc-window, lidc-2d 44 | targets: 'rsna_pe_target' 45 | channels: 'window' # repeat, neighbor, window 46 | weighted_sample: true 47 | positive_only: false 48 | imsize: 256 49 | 50 | transforms: 51 | type: 'imagenet' 52 | ShiftScaleRotate: 53 | shift_limit: 0.05 54 | scale_limit: 0.05 55 | rotate_limit: 20 56 | p: 0.5 57 | RandomCrop: 58 | height: 224 59 | width: 224 60 | 61 | train: 62 | batch_size: 32 63 | num_workers: 8 64 | weighted_loss: false 65 | loss_fn: 66 | name: 'BCEWithLogitsLoss' 67 | optimizer: 68 | name: 'Adam' 69 | weight_decay: 1e-6 70 | scheduler: 71 | name: 'ReduceLROnPlateau' 72 | monitor: 'val_loss' 73 | interval: 'epoch' 74 | frequency: 3 75 | factor: 0.5 76 | patience: 5 -------------------------------------------------------------------------------- /configs/lrcn_y_n_lidc0.1_lr1e-3.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'lidc_1d_resnext101' 2 | trial_name: 'lr-3' 3 | phase: 'train' 4 | 5 | lightning: 6 | trainer: 7 | gpus: 1 8 | max_epochs: 50 9 | precision: 16 10 | lr: 1e-3 11 | auto_lr_find: false 12 | benchmark: true 13 | profiler: 'simple' 14 | replace_sampler_ddp: false 15 | limit_train_batches: 0.1 16 | checkpoint_callback: 17 | monitor: 'val/mean_auroc' 18 | dirpath: './data/ckpt' 19 | save_last: true 20 | mode: 'max' 21 | save_top_k: 10 22 | early_stopping_callback: 23 | monitor: 'val/mean_auroc' 24 | min_delta: 0.00 25 | patience: 5 26 | verbose: False 27 | mode: 'max' 28 | logger: 29 | logger_type: 'WandbLogger' 30 | save_dir: './data/logger/' 31 | name: 'lrcn_resnext101' 32 | project: 'pe_models' 33 | 34 | model: 35 | type: 'model_1d' 36 | aggregation: 'mean' # mean, max, attention 37 | seq_encoder: 38 | rnn_type: 'GRU' # lstm, gru 39 | hidden_size: 256 40 | bidirectional: True 41 | num_layers: 1 42 | dropout_prob: 0.5 43 | 44 | data: 45 | use_hdf5: true 46 | hdf5_path: '/deep2/u/alexke/pe_models_benchmark/data/output/lidc_2d_resnext101_lr-3_positive_only_weighted_sample/1/2022_03_10_23_49_55/features.hdf5' 47 | feature_size: 2048 48 | dataset: 'rsna' 49 | type: 'lidc-1d' # 1d, 2d, 3d, window 50 | targets: 'rsna_pe_target' 51 | num_slices: 150 52 | sample_strategy: 'random' # fix, random 53 | contextualize_slice: true 54 | weighted_sample: false 55 | 56 | train: 57 | batch_size: 32 58 | num_workers: 8 59 | weighted_loss: false 60 | loss_fn: 61 | name: 'BCEWithLogitsLoss' 62 | optimizer: 63 | name: 'Adam' 64 | weight_decay: 1e-6 65 | scheduler: 66 | name: 'ReduceLROnPlateau' 67 | monitor: 'val/mean_auroc' 68 | interval: 'epoch' 69 | frequency: 3 70 | factor: 0.5 71 | patience: 5 -------------------------------------------------------------------------------- /configs/r2plus1d_y_n_rsna0.1_lr1e-5.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'rsna_window_r2plus1d' 2 | trial_name: 'lr-5' 3 | phase: 'train' 4 | 5 | lightning: 6 | trainer: 7 | gpus: 2 8 | distributed_backend: 'ddp' # if more than 1 gpu 9 | max_epochs: 50 10 | lr: 1e-5 11 | precision: 16 12 | auto_lr_find: false 13 | benchmark: true 14 | replace_sampler_ddp: false 15 | limit_train_batches: 0.1 # for lr search 16 | checkpoint_callback: 17 | monitor: 'val/mean_auroc' 18 | dirpath: './data/ckpt' 19 | save_last: true 20 | mode: 'max' 21 | save_top_k: 10 22 | early_stopping_callback: 23 | monitor: 'val/mean_auroc' 24 | min_delta: 0.00 25 | patience: 5 26 | verbose: False 27 | mode: 'max' 28 | logger: 29 | logger_type: 'WandbLogger' 30 | save_dir: './data/logger/' 31 | name: 'r2plus1d_r50' 32 | project: 'pe_models' 33 | 34 | 35 | model: 36 | type: 'model_3d' 37 | model_name: 'r2plus1d_r50' 38 | freeze_cnn: false 39 | pretrained: true 40 | 41 | data: 42 | use_hdf5: true 43 | dataset: 'rsna' 44 | type: 'window' # 1d, 2d, 3d, window, lidc-window, lidc-2d 45 | num_slices: 32 46 | min_abnormal_slice: 4 47 | min_positive_slices: 24 48 | targets: 'rsna_pe_target' 49 | channels: 'repeat' # repeat, neighbor, window 50 | weighted_sample: true 51 | imsize: 256 52 | 53 | transforms: 54 | type: 'imagenet' 55 | Rotate: 56 | rotate_limit: 20 57 | p: 0.5 58 | RandomCrop: 59 | height: 224 60 | width: 224 61 | 62 | train: 63 | batch_size: 8 # change this with sbatch 64 | num_workers: 8 # change this with sbatch 65 | weighted_loss: false 66 | loss_fn: 67 | name: 'BCEWithLogitsLoss' 68 | optimizer: 69 | name: 'Adam' 70 | weight_decay: 1e-6 71 | scheduler: 72 | name: 'ReduceLROnPlateau' 73 | monitor: 'val_loss' 74 | interval: 'epoch' 75 | frequency: 3 76 | factor: 0.5 77 | patience: 5 -------------------------------------------------------------------------------- /configs/r2plus1d_y_n_lidc1.0_lr1e-5.yaml: -------------------------------------------------------------------------------- 1 | experiment_name: 'lidc_window_r2plus1d_full' 2 | trial_name: 'lr-5' 3 | phase: 'train' 4 | 5 | lightning: 6 | trainer: 7 | gpus: 2 8 | distributed_backend: 'ddp' # if more than 1 gpu 9 | max_epochs: 50 10 | lr: 1e-5 11 | precision: 16 12 | auto_lr_find: false 13 | benchmark: true 14 | replace_sampler_ddp: false 15 | # limit_train_batches: 0.1 # for lr search 16 | checkpoint_callback: 17 | monitor: 'val/mean_auroc' 18 | dirpath: './data/ckpt' 19 | save_last: true 20 | mode: 'max' 21 | save_top_k: 10 22 | early_stopping_callback: 23 | monitor: 'val/mean_auroc' 24 | min_delta: 0.00 25 | patience: 5 26 | verbose: False 27 | mode: 'max' 28 | logger: 29 | logger_type: 'WandbLogger' 30 | save_dir: './data/logger/' 31 | name: 'r2plus1d_r50' 32 | project: 'pe_models' 33 | 34 | 35 | model: 36 | type: 'model_3d' 37 | model_name: 'r2plus1d_r50' 38 | freeze_cnn: false 39 | pretrained: true 40 | 41 | data: 42 | use_hdf5: true 43 | dataset: 'rsna' 44 | type: 'lidc-window' # 1d, 2d, 3d, window, lidc-window, lidc-2d 45 | num_slices: 32 46 | min_abnormal_slice: 4 47 | min_positive_slices: 24 48 | targets: 'rsna_pe_target' 49 | channels: 'repeat' # repeat, neighbor, window 50 | weighted_sample: true 51 | imsize: 256 52 | 53 | transforms: 54 | type: 'imagenet' 55 | Rotate: 56 | rotate_limit: 20 57 | p: 0.5 58 | RandomCrop: 59 | height: 224 60 | width: 224 61 | 62 | train: 63 | batch_size: 8 # change this with sbatch 64 | num_workers: 8 # change this with sbatch 65 | weighted_loss: false 66 | loss_fn: 67 | name: 'BCEWithLogitsLoss' 68 | optimizer: 69 | name: 'Adam' 70 | weight_decay: 1e-6 71 | scheduler: 72 | name: 'ReduceLROnPlateau' 73 | monitor: 'val_loss' 74 | interval: 'epoch' 75 | frequency: 3 76 | factor: 0.5 77 | patience: 5 -------------------------------------------------------------------------------- /pe_models/models/models_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from omegaconf import OmegaConf 4 | from . import vision_backbones 5 | 6 | 7 | class PEModel2D(nn.Module): 8 | def __init__(self, cfg, num_classes=1, **kwargs): 9 | super(PEModel2D, self).__init__() 10 | 11 | # define cnn model 12 | model_function = getattr(vision_backbones, cfg.model.model_name) 13 | self.model, self.feature_dim = model_function(pretrained=cfg.model.pretrained) 14 | self.classifier = nn.Linear(self.feature_dim, num_classes) 15 | self.cfg = cfg 16 | 17 | # freeze cnn 18 | if cfg.model.freeze_cnn: 19 | for param in self.model.parameters(): 20 | param.requires_grad = False 21 | 22 | if not OmegaConf.is_none(cfg.model, 'checkpoint') and OmegaConf.is_none(cfg, 'checkpoint'): 23 | ckpt = torch.load(cfg.model.checkpoint) 24 | 25 | if 'state_dict' in ckpt: 26 | state_dict = ckpt['state_dict'] 27 | elif 'model' in ckpt: 28 | state_dict = ckpt['model'] 29 | else: 30 | raise Exception('ckpt key incorrect') 31 | 32 | state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} 33 | state_dict = {k:v for k,v in state_dict.items() if 'head' not in k} 34 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 35 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 36 | 37 | msg = self.model.load_state_dict(state_dict, strict=False) 38 | print('='*80) 39 | print(msg) 40 | print('='*80) 41 | 42 | def forward(self, x, get_features=False): 43 | x = self.model(x) 44 | pred = self.classifier(x) 45 | if get_features: 46 | return pred, x 47 | else: 48 | return pred 49 | 50 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/penet_asp_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PENetASPPool(nn.Module): 6 | """Atrous Spatial Pyramid Pooling layer. 7 | 8 | Based on the paper: 9 | "Rethinking Atrous Convolution for Semantic Image Segmentation" 10 | by Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam 11 | (https://arxiv.org/abs/1706.05587). 12 | """ 13 | def __init__(self, in_channels, out_channels): 14 | super(PENetASPPool, self).__init__() 15 | 16 | self.mid_channels = out_channels // 4 17 | self.in_conv = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=2, dilation=2), 18 | nn.GroupNorm(out_channels // 16, out_channels), 19 | nn.LeakyReLU(inplace=True)) 20 | 21 | self.conv1 = nn.Conv3d(out_channels, self.mid_channels, kernel_size=1) 22 | self.conv2 = nn.Conv3d(out_channels, self.mid_channels, kernel_size=3, padding=6, dilation=6) 23 | self.conv3 = nn.Conv3d(out_channels, self.mid_channels, kernel_size=3, padding=12, dilation=12) 24 | self.conv4 = nn.Sequential(nn.AdaptiveAvgPool3d(1), 25 | nn.Conv3d(out_channels, self.mid_channels, kernel_size=1)) 26 | self.norm = nn.GroupNorm(out_channels // 16, out_channels) 27 | self.relu = nn.LeakyReLU(inplace=True) 28 | 29 | self.out_conv = nn.Sequential(nn.Conv3d(out_channels, out_channels, kernel_size=1), 30 | nn.GroupNorm(out_channels // 16, out_channels), 31 | nn.LeakyReLU(inplace=True)) 32 | 33 | def forward(self, x): 34 | x = self.in_conv(x) 35 | 36 | # Four parallel paths with different dilation factors 37 | x_1 = self.conv1(x) 38 | x_2 = self.conv2(x) 39 | x_3 = self.conv3(x) 40 | x_4 = self.conv4(x) 41 | x_4 = x_4.expand(-1, -1, x_1.size(2), x_1.size(3), x_1.size(4)) 42 | 43 | # Combine parallel pathways 44 | x = torch.cat((x_1, x_2, x_3, x_4), dim=1) 45 | x = self.norm(x) 46 | x = self.relu(x) 47 | 48 | x = self.out_conv(x) 49 | 50 | return x 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /pe_models/models/models_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from . import vision_backbones 5 | 6 | 7 | class PEModel3D(nn.Module): 8 | def __init__(self, cfg, num_classes=1, **kwargs): 9 | super(PEModel3D, self).__init__() 10 | 11 | # define cnn model 12 | model_function = getattr(vision_backbones, cfg.model.model_name) 13 | self.model, self.feature_dim = model_function( 14 | pretrained=cfg.model.pretrained, num_frames=cfg.data.num_slices, 15 | cfg=cfg 16 | ) 17 | self.classifier = nn.Linear(self.feature_dim, num_classes) 18 | self.cfg = cfg 19 | 20 | # freeze cnn 21 | if cfg.model.freeze_cnn: 22 | for param in self.model.parameters(): 23 | param.requires_grad = False 24 | 25 | def forward(self, x, get_features=False): 26 | x = self.model(x) 27 | pred = self.classifier(x) 28 | 29 | if get_features: 30 | return pred, x 31 | else: 32 | return pred 33 | 34 | def fine_tuning_parameters(self, fine_tuning_boundary, fine_tuning_lr=0.0): 35 | """Get parameters for fine-tuning the model. 36 | Args: 37 | fine_tuning_boundary: Name of first layer after the fine-tuning layers. 38 | fine_tuning_lr: Learning rate to apply to fine-tuning layers (all layers before `boundary_layer`). 39 | Returns: 40 | List of dicts that can be passed to an optimizer. 41 | """ 42 | 43 | def gen_params(boundary_layer_name, fine_tuning): 44 | """Generate parameters, if fine_tuning generate the params before boundary_layer_name. 45 | If unfrozen, generate the params at boundary_layer_name and beyond.""" 46 | saw_boundary_layer = False 47 | for name, param in self.named_parameters(): 48 | if name.startswith(boundary_layer_name): 49 | saw_boundary_layer = True 50 | 51 | if saw_boundary_layer and fine_tuning: 52 | return 53 | elif not saw_boundary_layer and not fine_tuning: 54 | continue 55 | else: 56 | yield param 57 | 58 | # Fine-tune the network's layers from encoder.2 onwards 59 | optimizer_parameters = [{'params': gen_params(fine_tuning_boundary, fine_tuning=True), 'lr': fine_tuning_lr}, 60 | {'params': gen_params(fine_tuning_boundary, fine_tuning=False)}] 61 | 62 | return optimizer_parameters 63 | -------------------------------------------------------------------------------- /pe_models/datasets/data_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchvision.transforms as transforms 3 | 4 | from torch.utils.data import DataLoader 5 | from .. import builder 6 | 7 | 8 | class PEDataModule(pl.LightningDataModule): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | 12 | self.cfg = cfg 13 | self.dataset = builder.build_dataset(cfg) 14 | 15 | def train_dataloader(self): 16 | transform = builder.build_transformation(self.cfg, "train") 17 | dataset = self.dataset(self.cfg, split="train", transform=transform) 18 | 19 | if self.cfg.data.weighted_sample: 20 | sampler = dataset.get_sampler() 21 | return DataLoader( 22 | dataset, 23 | pin_memory=True, 24 | drop_last=True, 25 | shuffle=False, 26 | sampler=sampler, 27 | batch_size=self.cfg.train.batch_size, 28 | num_workers=self.cfg.train.num_workers, 29 | ) 30 | else: 31 | return DataLoader( 32 | dataset, 33 | pin_memory=True, 34 | drop_last=True, 35 | shuffle=True, 36 | batch_size=self.cfg.train.batch_size, 37 | num_workers=self.cfg.train.num_workers, 38 | ) 39 | 40 | def val_dataloader(self): 41 | transform = builder.build_transformation(self.cfg, "val") 42 | dataset = self.dataset(self.cfg, split="valid", transform=transform) 43 | return DataLoader( 44 | dataset, 45 | pin_memory=True, 46 | drop_last=True, 47 | shuffle=False, 48 | batch_size=self.cfg.train.batch_size, 49 | num_workers=self.cfg.train.num_workers, 50 | ) 51 | 52 | def test_dataloader(self): 53 | transform = builder.build_transformation(self.cfg, "test") 54 | dataset = self.dataset(self.cfg, split=self.cfg.test_split, transform=transform) 55 | return DataLoader( 56 | dataset, 57 | pin_memory=True, 58 | shuffle=False, 59 | batch_size=self.cfg.train.batch_size, 60 | num_workers=self.cfg.train.num_workers, 61 | ) 62 | 63 | def all_dataloader(self): 64 | transform = builder.build_transformation(self.cfg, "all") 65 | dataset = self.dataset(self.cfg, split=self.cfg.test_split, transform=transform) 66 | return DataLoader( 67 | dataset, 68 | pin_memory=True, 69 | shuffle=False, 70 | batch_size=self.cfg.train.batch_size, 71 | num_workers=self.cfg.train.num_workers, 72 | ) 73 | -------------------------------------------------------------------------------- /pe_models/preprocess/build_hdf5_labels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import cv2 4 | import h5py 5 | from tqdm import tqdm 6 | 7 | import pylidc as pl 8 | 9 | def process_dicom(dcm): 10 | pixel_array = dcm.pixel_array 11 | 12 | intercept = dcm.RescaleIntercept 13 | slope = dcm.RescaleSlope 14 | pixel_array = pixel_array * slope + intercept 15 | 16 | resize_shape = 256 # self.cfg.data.imsize 17 | pixel_array = cv2.resize( 18 | pixel_array, (resize_shape, resize_shape), interpolation=cv2.INTER_AREA 19 | ) 20 | return pixel_array 21 | 22 | records = [] 23 | hdf5_fh = h5py.File("lidc_study.hdf5", "a") 24 | for scan in tqdm(pl.query(pl.Scan).all()): 25 | dicoms = scan.load_all_dicom_images(verbose=False) 26 | if len(dicoms) == 0: 27 | continue 28 | series = np.stack([process_dicom(dcm) for dcm in dicoms]) 29 | hdf5_fh.create_dataset( 30 | scan.study_instance_uid, 31 | data=series, 32 | dtype="float32", 33 | chunks=True 34 | ) 35 | 36 | if len(scan.annotations) > 0: 37 | start_stop = np.stack([annot.bbox_matrix()[2] for annot in scan.annotations]) 38 | for i, dcm in enumerate(dicoms): 39 | records.append({ 40 | "PatientID": dcm.PatientID, 41 | "StudyInstanceUID": dcm.StudyInstanceUID, 42 | "SeriesInstanceUID": dcm.SeriesInstanceUID, 43 | "SOPInstanceUID": dcm.SOPInstanceUID, 44 | "image_index": i, 45 | "nodule_present_on_image": int( 46 | sum((start_stop[:, 0] <= i) & (i <= start_stop[:, 1])) >= 3 47 | ) if len(scan.annotations) > 0 else 0, 48 | "InstanceNumber": dcm.InstanceNumber, 49 | "ImagePositionPatient_0": dcm.ImagePositionPatient[0], 50 | "ImagePositionPatient_1": dcm.ImagePositionPatient[1], 51 | "ImagePositionPatient_2": dcm.ImagePositionPatient[2], 52 | "ImageOrientationPatient_0": dcm.ImageOrientationPatient[0], 53 | "ImageOrientationPatient_1": dcm.ImageOrientationPatient[1], 54 | "ImageOrientationPatient_2": dcm.ImageOrientationPatient[2], 55 | "ImageOrientationPatient_3": dcm.ImageOrientationPatient[3], 56 | "ImageOrientationPatient_4": dcm.ImageOrientationPatient[4], 57 | "ImageOrientationPatient_5": dcm.ImageOrientationPatient[5], 58 | "PixelSpacing_0": dcm.PixelSpacing[0], 59 | "PixelSpacing_1": dcm.PixelSpacing[1], 60 | "RescaleIntercept": dcm.RescaleIntercept, 61 | "RescaleSlope": dcm.RescaleSlope, 62 | "WindowCenter": dcm.get("WindowCenter"), 63 | "WindowWidth": dcm.get("WindowWidth"), 64 | }) 65 | hdf5_fh.close() 66 | 67 | pd.DataFrame.from_records(records).to_csv("lidc_2d.csv", index=False) -------------------------------------------------------------------------------- /pe_models/preprocess/stanford.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import mdai 3 | import pydicom 4 | import sys 5 | import os 6 | import glob 7 | import tqdm 8 | import numpy as np 9 | import h5py 10 | 11 | sys.path.append(os.getcwd()) 12 | 13 | from collections import defaultdict 14 | from pe_models.constants import * 15 | from pe_models import utils 16 | from sklearn.model_selection import train_test_split 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | 20 | def process_window_df(df:pd.DataFrame, num_slices:int=24, min_abnormal_slice:int=4, stride:int=6): 21 | 22 | window_labels_csv = STANFORD_DATA_DIR / f"stanford_window_{num_slices}_min_abnormal_{min_abnormal_slice}_stride_{stride}.csv" 23 | print(window_labels_csv) 24 | 25 | # count number of windows per slice 26 | #count_num_windows = lambda x: x // num_slices\ 27 | # + (1 if x % num_slices > 0 else 0) 28 | count_num_windows = lambda x: (x - num_slices) // stride 29 | df['NumWindows'] = df['NumSlices'].apply( 30 | count_num_windows) 31 | 32 | # get windows list 33 | df_study = df.groupby(['StudyInstanceUID']).head(1) 34 | window_labels = defaultdict(list) 35 | 36 | # studies 37 | for _, row in tqdm.tqdm(df_study.iterrows(), total=df_study.shape[0]): 38 | study_name = row['StudyInstanceUID'] 39 | split = row['Split'] 40 | 41 | study_df = df[df['StudyInstanceUID'] == study_name] 42 | study_df = study_df.sort_values("ImagePositionPatient_2") 43 | 44 | # windows 45 | for idx in range(row['NumWindows']): 46 | #start_idx = idx * num_slices 47 | #end_idx = (idx+1) * num_slice 48 | start_idx = idx * stride 49 | end_idx = (idx * stride) + num_slices 50 | 51 | window_df = study_df.iloc[start_idx: end_idx] 52 | num_positives_slices = window_df['pe_present_on_image'].sum() 53 | label = 1 if num_positives_slices >= min_abnormal_slice else 0 54 | window_labels['StudyInstanceUID'].append(study_name) 55 | window_labels['index'].append(idx) 56 | window_labels['Label'].append(label) 57 | window_labels['Split'].append(split) 58 | window_labels['InstancePath'].append(window_df['InstancePath'].tolist()) 59 | window_labels["ImagePositionPatient_2"].append(window_df["ImagePositionPatient_2"].tolist()) 60 | window_labels['pe_present_on_image'].append(window_df['pe_present_on_image'].tolist()) 61 | window_labels['SOPInstanceUID'].append(window_df['SOPInstanceUID'].tolist()) 62 | window_labels['SliceOrder'].append(window_df['SliceOrder'].tolist()) 63 | window_labels['num_positive_slices'].append(num_positives_slices) 64 | df = pd.DataFrame.from_dict(window_labels) 65 | df.to_csv(window_labels_csv) 66 | 67 | return df 68 | 69 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/layers/penet/penet_bottleneck.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.nn as nn 3 | 4 | from . import SEBlock 5 | 6 | 7 | class PENetBottleneck(nn.Module): 8 | """PENet bottleneck block, similar to a pre-activation ResNeXt bottleneck block. 9 | 10 | Based on the paper: 11 | "Aggregated Residual Transformations for Deep Nerual Networks" 12 | by Saining Xie, Ross Girshick, Piotr Dollár, Zhuowen Tu, Kaiming He 13 | (https://arxiv.org/abs/1611.05431). 14 | """ 15 | 16 | expansion = 2 17 | 18 | def __init__(self, in_channels, channels, block_idx, total_blocks, cardinality=32, stride=1): 19 | super(PENetBottleneck, self).__init__() 20 | mid_channels = cardinality * int(channels / cardinality) 21 | out_channels = channels * self.expansion 22 | self.survival_prob = self._get_survival_prob(block_idx, total_blocks) 23 | 24 | self.down_sample = None 25 | if stride != 1 or in_channels != channels * PENetBottleneck.expansion: 26 | self.down_sample = nn.Sequential( 27 | nn.Conv3d(in_channels, channels * PENetBottleneck.expansion, kernel_size=1, stride=stride, bias=False), 28 | nn.GroupNorm(channels * PENetBottleneck.expansion // 16, channels * PENetBottleneck.expansion)) 29 | 30 | self.conv1 = nn.Conv3d(in_channels, mid_channels, kernel_size=1, bias=False) 31 | self.norm1 = nn.GroupNorm(mid_channels // 16, mid_channels) 32 | self.relu1 = nn.LeakyReLU(inplace=True) 33 | 34 | self.conv2 = nn.Conv3d(mid_channels, mid_channels, kernel_size=3, 35 | stride=stride, padding=1, groups=cardinality, bias=False) 36 | self.norm2 = nn.GroupNorm(mid_channels // 16, mid_channels) 37 | self.relu2 = nn.LeakyReLU(inplace=True) 38 | 39 | self.conv3 = nn.Conv3d(mid_channels, out_channels, kernel_size=1, bias=False) 40 | self.norm3 = nn.GroupNorm(out_channels // 16, out_channels) 41 | self.norm3.is_last_norm = True 42 | self.relu3 = nn.LeakyReLU(inplace=True) 43 | 44 | self.se_block = SEBlock(out_channels, reduction=16) 45 | 46 | @staticmethod 47 | def _get_survival_prob(block_idx, total_blocks, p_final=0.5): 48 | """Get survival probability for stochastic depth. Uses linearly decreasing 49 | survival probability as described in "Deep Networks with Stochastic Depth". 50 | 51 | Args: 52 | block_idx: Index of residual block within entire network. 53 | total_blocks: Total number of residual blocks in entire network. 54 | p_final: Survival probability of the final layer. 55 | """ 56 | return 1. - block_idx / total_blocks * (1. - p_final) 57 | 58 | def forward(self, x): 59 | x_skip = x if self.down_sample is None else self.down_sample(x) 60 | 61 | # Stochastic depth dropout 62 | if self.training and random.random() > self.survival_prob: 63 | return x_skip 64 | 65 | x = self.conv1(x) 66 | x = self.norm1(x) 67 | x = self.relu1(x) 68 | 69 | x = self.conv2(x) 70 | x = self.norm2(x) 71 | x = self.relu2(x) 72 | 73 | x = self.conv3(x) 74 | x = self.norm3(x) 75 | 76 | x = self.se_block(x) 77 | x += x_skip 78 | 79 | x = self.relu3(x) 80 | 81 | return x 82 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pe_models 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - blas=1.0=mkl 8 | - bottleneck=1.3.2=py39hdd57654_1 9 | - ca-certificates=2021.10.26=h06a4308_2 10 | - certifi=2021.10.8=py39h06a4308_0 11 | - intel-openmp=2021.3.0=h06a4308_3350 12 | - ld_impl_linux-64=2.35.1=h7274673_9 13 | - libffi=3.3=he6710b0_2 14 | - libgcc-ng=9.3.0=h5101ec6_17 15 | - libgomp=9.3.0=h5101ec6_17 16 | - libstdcxx-ng=9.3.0=hd4cf53a_17 17 | - mkl=2021.3.0=h06a4308_520 18 | - mkl-service=2.4.0=py39h7f8727e_0 19 | - mkl_fft=1.3.1=py39hd3c417c_0 20 | - mkl_random=1.2.2=py39h51133e4_0 21 | - ncurses=6.2=he6710b0_1 22 | - numexpr=2.7.3=py39h22e1b3c_1 23 | - numpy=1.21.2=py39h20f2e39_0 24 | - numpy-base=1.21.2=py39h79a1101_0 25 | - openssl=1.1.1l=h7f8727e_0 26 | - pandas=1.3.3=py39h8c16a72_0 27 | - pip=21.2.4=py39h06a4308_0 28 | - python=3.9.0=hdb3f193_2 29 | - python-dateutil=2.8.2=pyhd3eb1b0_0 30 | - pytz=2021.3=pyhd3eb1b0_0 31 | - readline=8.1=h27cfd23_0 32 | - setuptools=58.0.4=py39h06a4308_0 33 | - six=1.16.0=pyhd3eb1b0_0 34 | - sqlite=3.36.0=hc218d9a_0 35 | - tk=8.6.11=h1ccaba5_0 36 | - tzdata=2021e=hda174b7_0 37 | - wheel=0.37.0=pyhd3eb1b0_1 38 | - xz=5.2.5=h7b6447c_0 39 | - zlib=1.2.11=h7b6447c_3 40 | - pip: 41 | - absl-py==0.15.0 42 | - aiohttp==3.7.4.post0 43 | - albumentations==1.1.0 44 | - antlr4-python3-runtime==4.8 45 | - arrow==1.2.1 46 | - async-timeout==3.0.1 47 | - attrs==21.2.0 48 | - cachetools==4.2.4 49 | - chardet==4.0.0 50 | - charset-normalizer==2.0.7 51 | - click==8.0.3 52 | - configparser==5.0.2 53 | - cycler==0.10.0 54 | - dicom2nifti==2.3.0 55 | - docker-pycreds==0.4.0 56 | - fsspec==2021.10.1 57 | - future==0.18.2 58 | - gitdb==4.0.9 59 | - gitpython==3.1.24 60 | - google-auth==2.3.2 61 | - google-auth-oauthlib==0.4.6 62 | - grpcio==1.41.1 63 | - h5py==3.5.0 64 | - idna==3.3 65 | - imageio==2.9.0 66 | - iniconfig==1.1.1 67 | - joblib==1.1.0 68 | - kiwisolver==1.3.2 69 | - markdown==3.3.4 70 | - matplotlib==3.4.3 71 | - mdai==0.8.0 72 | - multidict==5.2.0 73 | - networkx==2.6.3 74 | - nibabel==3.2.1 75 | - oauthlib==3.1.1 76 | - omegaconf==2.1.1 77 | - opencv-python==4.5.4.58 78 | - opencv-python-headless==4.5.4.58 79 | - packaging==21.0 80 | - pathtools==0.1.2 81 | - pillow==8.4.0 82 | - pluggy==1.0.0 83 | - promise==2.3 84 | - protobuf==3.19.0 85 | - psutil==5.8.0 86 | - py==1.10.0 87 | - pyasn1==0.4.8 88 | - pyasn1-modules==0.2.8 89 | - pydeprecate==0.3.1 90 | - pydicom==2.2.2 91 | - pylibjpeg==1.3.0 92 | - pylibjpeg-libjpeg==1.2.0 93 | - pylibjpeg-openjpeg==1.1.1 94 | - pylibjpeg-rle==1.1.0 95 | - pyparsing==3.0.1 96 | - pytest==6.2.5 97 | - pytorch-lightning==1.4.9 98 | - pywavelets==1.1.1 99 | - pyyaml==6.0 100 | - qudida==0.0.4 101 | - requests==2.26.0 102 | - requests-oauthlib==1.3.0 103 | - retrying==1.3.3 104 | - rsa==4.7.2 105 | - scikit-image==0.18.3 106 | - scikit-learn==1.0.1 107 | - scipy==1.7.1 108 | - sentry-sdk==1.4.3 109 | - shortuuid==1.0.1 110 | - smmap==5.0.0 111 | - subprocess32==3.5.4 112 | - tensorboard==2.7.0 113 | - tensorboard-data-server==0.6.1 114 | - tensorboard-plugin-wit==1.8.0 115 | - termcolor==1.1.0 116 | - threadpoolctl==3.0.0 117 | - tifffile==2021.10.12 118 | - timm==0.4.12 119 | - toml==0.10.2 120 | - torch==1.10.0 121 | - torchaudio==0.10.0 122 | - torchmetrics==0.5.1 123 | - torchvision==0.11.1 124 | - tqdm==4.62.3 125 | - typing-extensions==3.10.0.2 126 | - urllib3==1.26.7 127 | - wandb==0.12.5 128 | - werkzeug==2.0.2 129 | - yarl==1.7.0 130 | - yaspin==2.1.0 131 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.getcwd()) 7 | 8 | from pandas.core.algorithms import unique 9 | 10 | from pe_models.constants import * 11 | 12 | 13 | class RSNADatasetTestCase(unittest.TestCase): 14 | def setUp(self): 15 | self.df = pd.read_csv(RSNA_TRAIN_CSV) 16 | 17 | self.df_train = self.df[self.df[RSNA_SPLIT_COL] == "train"] 18 | self.df_val = self.df[self.df[RSNA_SPLIT_COL] == "valid"] 19 | self.df_test = self.df[self.df[RSNA_SPLIT_COL] == "test"] 20 | 21 | self.df_ins_train = self.df[self.df[RSNA_INSTITUTION_SPLIT_COL] == "train"] 22 | self.df_ins_val = self.df[self.df[RSNA_INSTITUTION_SPLIT_COL] == "valid"] 23 | self.df_ins_test = self.df[self.df[RSNA_INSTITUTION_SPLIT_COL] == "test"] 24 | self.df_ins_stanford = self.df[ 25 | self.df[RSNA_INSTITUTION_SPLIT_COL] == "stanford_test" 26 | ] 27 | 28 | def test_num_study_in_split_gt_zero(self): 29 | 30 | num_train_studies = self.df_train[RSNA_STUDY_COL].nunique() 31 | num_val_studies = self.df_val[RSNA_STUDY_COL].nunique() 32 | num_test_studies = self.df_test[RSNA_STUDY_COL].nunique() 33 | 34 | num_ins_train_studies = self.df_ins_train[RSNA_STUDY_COL].nunique() 35 | num_ins_val_studies = self.df_ins_val[RSNA_STUDY_COL].nunique() 36 | num_ins_test_studies = self.df_ins_test[RSNA_STUDY_COL].nunique() 37 | num_ins_stanford_studies = self.df_ins_stanford[RSNA_STUDY_COL].nunique() 38 | 39 | self.assertEqual(num_train_studies, 5095, "no studies in training set") 40 | self.assertEqual(num_val_studies, 1092, "no studies in val set") 41 | self.assertEqual(num_test_studies, 1092, "no studies in test set") 42 | self.assertEqual( 43 | num_ins_train_studies, 4620, "no studies in institutional training set" 44 | ) 45 | self.assertEqual( 46 | num_ins_val_studies, 990, "no studies in institutional val set" 47 | ) 48 | self.assertEqual( 49 | num_ins_test_studies, 991, "no studies in institutional test set" 50 | ) 51 | self.assertEqual( 52 | num_ins_stanford_studies, 678, "no studies in institutional stanford set" 53 | ) 54 | 55 | def test_no_overlap_study_in_split(self): 56 | 57 | train_studies = self.df_train[RSNA_STUDY_COL].unique() 58 | val_studies = self.df_val[RSNA_STUDY_COL].unique() 59 | test_studies = self.df_test[RSNA_STUDY_COL].unique() 60 | 61 | self.assertFalse( 62 | bool(set(train_studies) & set(val_studies)), "train and val overlap" 63 | ) 64 | self.assertFalse( 65 | bool(set(train_studies) & set(test_studies)), "train and test overlap" 66 | ) 67 | self.assertFalse( 68 | bool(set(val_studies) & set(test_studies)), "val and test overlap" 69 | ) 70 | 71 | def test_no_overlap_study_in_institution_split(self): 72 | 73 | train_studies = self.df_ins_train[RSNA_STUDY_COL].unique() 74 | val_studies = self.df_ins_val[RSNA_STUDY_COL].unique() 75 | test_studies = self.df_ins_test[RSNA_STUDY_COL].unique() 76 | stanford_studies = self.df_ins_stanford[RSNA_STUDY_COL].unique() 77 | 78 | self.assertFalse( 79 | bool(set(train_studies) & set(val_studies)), "train and val overlap" 80 | ) 81 | self.assertFalse( 82 | bool(set(train_studies) & set(test_studies)), "train and test overlap" 83 | ) 84 | self.assertFalse( 85 | bool(set(val_studies) & set(test_studies)), "val and test overlap" 86 | ) 87 | self.assertFalse( 88 | bool(set(stanford_studies) & set(test_studies)), "test and stanford overlap" 89 | ) 90 | self.assertFalse( 91 | bool(set(stanford_studies) & set(train_studies)), 92 | "train and stanford overlap", 93 | ) 94 | self.assertFalse( 95 | bool(set(stanford_studies) & set(val_studies)), "val and stanford overlap" 96 | ) 97 | 98 | 99 | if __name__ == "__main__": 100 | runner = unittest.TextTestRunner(verbosity=2) 101 | unittest.main(testRunner=runner) 102 | -------------------------------------------------------------------------------- /tests/test_dataloader_window.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | import torch 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | from pe_models import utils 10 | from pe_models.constants import * 11 | from pe_models.datasets import dataset_3d 12 | from torch.utils.data import DataLoader 13 | from omegaconf import OmegaConf 14 | 15 | 16 | class RSNAWindowDataLoaderTestCase(unittest.TestCase): 17 | def test_repeat_channel(self): 18 | config = { 19 | "data": { 20 | "use_hdf5": True, 21 | "dataset": "rsna", 22 | "type": "window", 23 | "targets": "rsna_pe_target", 24 | "channels": "repeat", 25 | "weighted_sample": True, 26 | "positive_only": True, 27 | "num_slices": 24, 28 | "min_abnormal_slice": 4, 29 | "sample_strategy": "random", 30 | "imsize": 256, 31 | }, 32 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 33 | } 34 | config = OmegaConf.create(config) 35 | dataset = dataset_3d.PEDatasetWindow(config, split="train") 36 | dataloader = DataLoader(dataset, batch_size=2, num_workers=0) 37 | 38 | x, y, ids = next(iter(dataloader)) 39 | 40 | self.assertEqual(x.shape[0], 2, "batch size inccorect") 41 | self.assertEqual(x.shape[1], 3, "number of channels incorrect") 42 | self.assertEqual(x.shape[2], 24, "slice number incorrect") 43 | self.assertEqual(x.shape[3], 224, "width incorrect") 44 | self.assertEqual(x.shape[4], 224, "height incorrect") 45 | self.assertTrue( 46 | torch.all(x[0, 0, :, :, :].eq(x[0, 1, :, :, :])), "channels not repeating" 47 | ) 48 | self.assertTrue( 49 | torch.all(x[0, 1, :, :, :].eq(x[0, 2, :, :, :])), "channels not repeating" 50 | ) 51 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 52 | 53 | utils.visualize_examples(x[0].numpy(), 10, "./data/test/rsna_window_repeat.png") 54 | 55 | def test_eq_read_from_dicom(self): 56 | hdf5_config = { 57 | "data": { 58 | "use_hdf5": True, 59 | "dataset": "rsna", 60 | "type": "window", 61 | "targets": "rsna_pe_target", 62 | "channels": "repeat", 63 | "weighted_sample": True, 64 | "positive_only": True, 65 | "num_slices": 24, 66 | "min_abnormal_slice": 4, 67 | "sample_strategy": "random", 68 | "imsize": 256, 69 | }, 70 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 71 | } 72 | hdf5_config = OmegaConf.create(hdf5_config) 73 | raw_config = { 74 | "data": { 75 | "use_hdf5": False, 76 | "dataset": "rsna", 77 | "type": "window", 78 | "targets": "rsna_pe_target", 79 | "channels": "repeat", 80 | "weighted_sample": True, 81 | "positive_only": True, 82 | "num_slices": 24, 83 | "min_abnormal_slice": 4, 84 | "sample_strategy": "random", 85 | "imsize": 256, 86 | }, 87 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 88 | } 89 | raw_config = OmegaConf.create(raw_config) 90 | hdf5_dataset = dataset_3d.PEDatasetWindow(hdf5_config, split="test") 91 | raw_dataset = dataset_3d.PEDatasetWindow(raw_config, split="test") 92 | hdf5_dataloader = DataLoader(hdf5_dataset, batch_size=4, num_workers=0, shuffle=False) 93 | raw_dataloader = DataLoader(raw_dataset, batch_size=4, num_workers=0, shuffle=False) 94 | 95 | hdf5_x, hdf5_y, hdf5_ids = next(iter(hdf5_dataloader)) 96 | raw_x, raw_y, raw_ids = next(iter(raw_dataloader)) 97 | 98 | self.assertTrue( 99 | torch.all(hdf5_x.eq(raw_x)), "different inputs values when reading from raw dicom vs hdf5" 100 | ) 101 | self.assertTrue( 102 | torch.all(hdf5_y.eq(raw_y)), "different target labels when reading from raw dicom vs hdf5" 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | runner = unittest.TextTestRunner(verbosity=2) 108 | unittest.main(testRunner=runner) 109 | -------------------------------------------------------------------------------- /pe_models/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BinaryFocalLoss(nn.Module): 8 | """Focal loss for binary classification. 9 | Adapted from: 10 | https://gist.github.com/AdrienLE/bf31dfe94569319f6e47b2de8df13416#file-focal_dice_1-py 11 | """ 12 | def __init__(self, gamma=2, size_average=True): 13 | super(BinaryFocalLoss, self).__init__() 14 | self.gamma = gamma 15 | self.take_mean = size_average 16 | 17 | def forward(self, logits, target): 18 | # Inspired by the implementation of binary_cross_entropy_with_logits 19 | if not (target.size() == logits.size()): 20 | raise ValueError("Target size ({}) must be the same as input size ({})" 21 | .format(target.size(), logits.size())) 22 | 23 | max_val = (-logits).clamp(min=0) 24 | loss = logits - logits * target + max_val + ((-max_val).exp() + (-logits - max_val).exp()).log() 25 | 26 | # This formula gives us the log sigmoid of 1-p if y is 0 and of p if y is 1 27 | inv_probs = F.logsigmoid(-logits * (target * 2 - 1)) 28 | loss = (inv_probs * self.gamma).exp() * loss 29 | 30 | if self.take_mean: 31 | loss = loss.mean() 32 | 33 | return loss 34 | 35 | 36 | class DINOLoss(nn.Module): 37 | """ 38 | DINO loss, adapted from: 39 | https://github.com/facebookresearch/dino/blob/main/main_dino.py 40 | """ 41 | def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, 42 | warmup_teacher_temp_epochs, nepochs, student_temp=0.1, 43 | center_momentum=0.9): 44 | super().__init__() 45 | self.student_temp = student_temp 46 | self.center_momentum = center_momentum 47 | self.ncrops = ncrops 48 | self.register_buffer("center", torch.zeros(1, out_dim)) 49 | # we apply a warm up for the teacher temperature because 50 | # a too high temperature makes the training instable at the beginning 51 | self.teacher_temp_schedule = np.concatenate(( 52 | np.linspace(warmup_teacher_temp, 53 | teacher_temp, warmup_teacher_temp_epochs), 54 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp 55 | )) 56 | 57 | def forward(self, student_output, teacher_output, epoch, verbose=True): 58 | """ 59 | Cross-entropy between softmax outputs of the teacher and student networks. 60 | """ 61 | student_out = student_output / self.student_temp 62 | student_out = student_out.chunk(self.ncrops) 63 | 64 | # teacher centering and sharpening 65 | temp = self.teacher_temp_schedule[epoch] 66 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 67 | teacher_out = teacher_out.detach().chunk(2) 68 | 69 | total_loss = 0 70 | n_loss_terms = 0 71 | for iq, q in enumerate(teacher_out): 72 | for v in range(len(student_out)): 73 | if v == iq: 74 | # we skip cases where student and teacher operate on the same view 75 | continue 76 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 77 | total_loss += loss.mean() 78 | n_loss_terms += 1 79 | # TODO 80 | if not verbose: 81 | print('------------- Loss Funcion -------------') 82 | print(f'loss = {loss}') 83 | print(f'total_loss = {total_loss}') 84 | print(f'n loss = {n_loss_terms}') 85 | print(f'q = {q}') 86 | print(f'{F.log_softmax(student_out[v], dim=-1)}') 87 | print(f'{student_out[v]}') 88 | import pdb; pdb.set_trace() 89 | total_loss /= n_loss_terms 90 | self.update_center(teacher_output) 91 | # TODO 92 | if not verbose: 93 | print(total_loss) 94 | print(n_loss_terms) 95 | 96 | return total_loss, n_loss_terms 97 | 98 | @torch.no_grad() 99 | def update_center(self, teacher_output): 100 | """ 101 | Update center used for teacher output. 102 | """ 103 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 104 | #dist.all_reduce(batch_center) 105 | #batch_center = batch_center / (len(teacher_output) * dist.get_world_size()) 106 | batch_center = batch_center / (len(teacher_output)) 107 | 108 | # ema update 109 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) -------------------------------------------------------------------------------- /tests/test_dataloader2d.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | import torch 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | from pe_models.constants import * 10 | from pe_models import utils 11 | from pe_models.datasets import dataset_2d, dataset_3d 12 | from torch.utils.data import DataLoader 13 | from omegaconf import OmegaConf 14 | 15 | 16 | class RSNA2DDataLoaderTestCase(unittest.TestCase): 17 | def test_repeat_channel(self): 18 | config = { 19 | "data": { 20 | "use_hdf5": False, 21 | "dataset": "rsna", 22 | "type": "2d", 23 | "targets": "rsna_pe_slice_target", 24 | "channels": "repeat", 25 | "weighted_sample": True, 26 | "positive_only": True, 27 | "imsize": 512, 28 | } 29 | } 30 | config = OmegaConf.create(config) 31 | dataset = dataset_2d.PEDataset2D(config, split="train") 32 | dataloader = DataLoader(dataset, batch_size=10, num_workers=8, shuffle=True) 33 | 34 | x, y, ids = next(iter(dataloader)) 35 | 36 | self.assertEqual(x.shape[0], 10, "batch size inccorect") 37 | self.assertEqual(x.shape[1], 3, "incorrect num channels") 38 | self.assertEqual(x.shape[2], 512, "incorrect height") 39 | self.assertEqual(x.shape[3], 512, "incorrect width") 40 | self.assertTrue( 41 | torch.all(x[0, 0, :, :].eq(x[0, 1, :, :])), "channels not repeating" 42 | ) 43 | self.assertTrue( 44 | torch.all(x[0, 1, :, :].eq(x[0, 2, :, :])), "channels not repeating" 45 | ) 46 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 47 | utils.visualize_examples(x.numpy(), 10, "./data/test/rsna_2d_repeat.png") 48 | 49 | def test_window_channel(self): 50 | config = { 51 | "data": { 52 | "use_hdf5": False, 53 | "dataset": "rsna", 54 | "type": "2d", 55 | "targets": "rsna_pe_slice_target", 56 | "channels": "window", 57 | "weighted_sample": True, 58 | "positive_only": True, 59 | "imsize": 512, 60 | } 61 | } 62 | config = OmegaConf.create(config) 63 | dataset = dataset_2d.PEDataset2D(config, split="train") 64 | dataloader = DataLoader(dataset, batch_size=10, num_workers=8, shuffle=True) 65 | 66 | x, y, ids = next(iter(dataloader)) 67 | 68 | self.assertEqual(x.shape[0], 10, "batch size inccorect") 69 | self.assertEqual(x.shape[1], 3, "incorrect num channels") 70 | self.assertEqual(x.shape[2], 512, "incorrect height") 71 | self.assertEqual(x.shape[3], 512, "incorrect width") 72 | self.assertFalse( 73 | torch.all(x[0, 0, :, :].eq(x[0, 1, :, :])), "channels not repeating" 74 | ) 75 | self.assertFalse( 76 | torch.all(x[0, 1, :, :].eq(x[0, 2, :, :])), "channels not repeating" 77 | ) 78 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 79 | utils.visualize_examples(x.numpy(), 10, "./data/test/rsna_2d_window.png") 80 | 81 | def test_neighbor_channel(self): 82 | config = { 83 | "data": { 84 | "use_hdf5": False, 85 | "dataset": "rsna", 86 | "type": "2d", 87 | "targets": "rsna_pe_slice_target", 88 | "channels": "neighbor", 89 | "weighted_sample": True, 90 | "positive_only": True, 91 | "imsize": 512, 92 | } 93 | } 94 | config = OmegaConf.create(config) 95 | dataset = dataset_2d.PEDataset2D(config, split="train") 96 | dataloader = DataLoader(dataset, batch_size=10, num_workers=8, shuffle=True) 97 | 98 | x, y, ids = next(iter(dataloader)) 99 | 100 | self.assertEqual(x.shape[0], 10, "batch size inccorect") 101 | self.assertEqual(x.shape[1], 3, "incorrect num channels") 102 | self.assertEqual(x.shape[2], 512, "incorrect height") 103 | self.assertEqual(x.shape[3], 512, "incorrect width") 104 | self.assertFalse( 105 | torch.all(x[0, 0, :, :].eq(x[0, 1, :, :])), "channels not repeating" 106 | ) 107 | self.assertFalse( 108 | torch.all(x[0, 1, :, :].eq(x[0, 2, :, :])), "channels not repeating" 109 | ) 110 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 111 | utils.visualize_examples(x.numpy(), 10, "./data/test/rsna_2d_neighbor.png") 112 | 113 | 114 | if __name__ == "__main__": 115 | runner = unittest.TextTestRunner(verbosity=2) 116 | unittest.main(testRunner=runner) 117 | -------------------------------------------------------------------------------- /pe_models/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | PROJECT_DATA_DIR = Path("") 4 | if not PROJECT_DATA_DIR.is_dir(): 5 | print( 6 | "\nProject data directory not specified. Please update data path in " 7 | + "PROJECT_DATA_DIR in pe_models/constants.py " 8 | ) 9 | PROJECT_DATA_DIR = Path("./data/") 10 | 11 | ## RSNA 12 | RSNA_DATA_DIR = PROJECT_DATA_DIR / "rsna" # check why isn't this just rsna 13 | if not RSNA_DATA_DIR.is_dir(): 14 | print( 15 | "\nPlease download the RSNA dataset from \n" 16 | + " https://www.kaggle.com/c/rsna-str-pulmonary-embolism-detection\n" 17 | + f"and place the downloaded dataset in {PROJECT_DATA_DIR}" 18 | ) 19 | RSNA_DATA_DIR = PROJECT_DATA_DIR / 'demo' 20 | 21 | RSNA_TRAIN_DIR = RSNA_DATA_DIR / "train" 22 | RSNA_ORIGINAL_TRAIN_CSV = RSNA_DATA_DIR / "train.csv" 23 | RSNA_TRAIN_CSV = RSNA_DATA_DIR / "rsna_train_master.csv" 24 | RSNA_TEST_DIR = RSNA_DATA_DIR / "test" 25 | RSNA_TEST_CSV = RSNA_DATA_DIR / "test.csv" 26 | RSNA_STUDY_HDF5 = RSNA_DATA_DIR / "rsna_study.hdf5" 27 | 28 | # Dataframe headers 29 | RSNA_STUDY_COL = "StudyInstanceUID" 30 | RSNA_SERIES_COL = "SeriesInstanceUID" 31 | RSNA_INSTANCE_COL = "SOPInstanceUID" 32 | RSNA_PREV_INSTANCE_COL = "PrevSOPInstanceUID" 33 | RSNA_NEXT_INSTANCE_COL = "NextSOPInstanceUID" 34 | RSNA_PE_TARGET_COL = "negative_exam_for_pe" 35 | RSNA_ORIGINAL_SPLIT_COL = "OriginalSplit" 36 | RSNA_SPLIT_COL = "Split" 37 | RSNA_INSTITUTION_SPLIT_COL = "InstitutionSplit" 38 | RSNA_INSTITUTION_COL = "Institution" 39 | RSNA_NUM_SLICES_COL = "NumSlices" 40 | RSNA_SLICE_ORDER_COL = "SliceOrder" 41 | RSNA_NUM_WINDOW_COL = "NumWindows" 42 | RSNA_INSTANCE_PATH_COL = "InstancePath" 43 | RSNA_INSTANCE_ORDER_COL = "InstanceOrder" 44 | RSNA_PE_SLICE_COL = "pe_present_on_image" 45 | RSNA_TARGET_COLS = [ 46 | "negative_exam_for_pe", 47 | "indeterminate", 48 | "rv_lv_ratio_gte_1", 49 | "rv_lv_ratio_lt_1", 50 | "leftsided_pe", 51 | "rightsided_pe", 52 | "chronic_pe", 53 | "acute_and_chronic_pe", 54 | "central_pe", 55 | ] 56 | RSNA_SLICE_TARGET_COLS = [ 57 | "pe_present_on_image", 58 | "indeterminate", 59 | "rv_lv_ratio_gte_1", 60 | "rv_lv_ratio_lt_1", 61 | "leftsided_pe", 62 | "rightsided_pe", 63 | "chronic_pe", 64 | "acute_and_chronic_pe", 65 | "central_pe", 66 | ] 67 | RSNA_LOCATION_TARGET_COLS = [ 68 | "pe_present_on_image", 69 | "leftsided_pe", 70 | "rightsided_pe", 71 | "central_pe", 72 | ] 73 | RSNA_PE_PROPERTY_TARGET_COLS = [ 74 | "pe_present_on_image", 75 | "leftsided_pe", 76 | "rightsided_pe", 77 | "central_pe", 78 | "chronic_pe", 79 | "acute_and_chronic_pe", 80 | ] 81 | 82 | RSNA_TARGET_TYPES = { 83 | "rsna_targets": RSNA_TARGET_COLS, 84 | "rsna_slice_targets": RSNA_SLICE_TARGET_COLS, 85 | "rsna_location_targets": RSNA_LOCATION_TARGET_COLS, 86 | "rsna_pe_slice_target": [RSNA_PE_SLICE_COL], 87 | "rsna_pe_target": [RSNA_PE_TARGET_COL], 88 | "rsna_pe_property_targets": RSNA_PE_PROPERTY_TARGET_COLS, 89 | } 90 | 91 | RSNA_LOSS_WEIGHT = { 92 | "negative_exam_for_pe": 0.0736196319, 93 | "indeterminate": 0.09202453988, 94 | "rv_lv_ratio_gte_1": 0.2346625767, 95 | "rv_lv_ratio_lt_1": 0.0782208589, 96 | "leftsided_pe": 0.06257668712, 97 | "rightsided_pe": 0.06257668712, 98 | "chronic_pe": 0.1042944785, 99 | "acute_and_chronic_pe": 0.1042944785, 100 | "central_pe": 0.1877300613, 101 | "pe_present_on_image": 0.07361963, 102 | } 103 | 104 | RSNA_DICOM_HEADERS = [ 105 | "SOPInstanceUID", 106 | "SeriesInstanceUID", 107 | "StudyInstanceUID", 108 | "InstanceNumber", 109 | "ImagePositionPatient", 110 | "ImageOrientationPatient", 111 | "PixelSpacing", 112 | "RescaleIntercept", 113 | "RescaleSlope", 114 | "WindowCenter", 115 | "WindowWidth", 116 | ] 117 | 118 | AIR_HU_VAL = -1000.0 119 | RANDOM_SEED = 2 120 | 121 | 122 | # Used to extract Stanford studies from RSNA 123 | # Requires token and mapping 124 | RSNA_MDAI_JSON = None 125 | RSNA_STANFORD_MAPPING = None 126 | MDAI_TOKEN = None 127 | MDAI_DOMAIN = None 128 | STANFORD_PE_METADATA = None 129 | 130 | 131 | # Stanford Cohort 132 | STANFORD_DATA_DIR = PROJECT_DATA_DIR / 'stanford' 133 | STANFORD_NO_RSNA_CSV = STANFORD_DATA_DIR / 'stanford_ctpe_no_rsna.csv' 134 | STANFORD_NO_RSNA_HDF5 = STANFORD_DATA_DIR / 'stanford_ctpe_no_rsna.hdf5' 135 | 136 | ## LIDC 137 | LIDC_DATA_DIR = PROJECT_DATA_DIR / "lidc" 138 | LIDC_TRAIN_CSV = LIDC_DATA_DIR / "lidc_train.csv" 139 | LIDC_STUDY_HDF5 = LIDC_DATA_DIR / "lidc_study.hdf5" 140 | LIDC_DICOM_CSV = LIDC_DATA_DIR / "lidc_2d.csv" 141 | 142 | LIDC_PATIENT_COL = "PatientID" 143 | LIDC_STUDY_COL = "StudyInstanceUID" 144 | LIDC_SERIES_COL = "SeriesInstanceUID" 145 | LIDC_INSTANCE_COL = "SOPInstanceUID" 146 | LIDC_NUM_WINDOW_COL = "NumWindows" 147 | LIDC_NUM_SLICES_COL = "NumSlices" 148 | LIDC_SPLIT_COL = "Split" 149 | LIDC_NOD_SLICE_COL = "nodule_present_on_image" 150 | LIDC_INSTANCE_ORDER_COL = "image_index" # or is this InstanceNumber 151 | -------------------------------------------------------------------------------- /pe_models/models/models_1d.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # TODO: need testing 3 | ################################################################################ 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | class Attention(nn.Module): 11 | """ 12 | Adapted from: 13 | https://github.com/GuanshuoXu/RSNA-STR-Pulmonary-Embolism-Detection/blob/main/trainall/2nd_level/seresnext101_192.py 14 | """ 15 | 16 | def __init__(self, feature_dim, step_dim, bias=True, **kwargs): 17 | super(Attention, self).__init__(**kwargs) 18 | print("=" * 80) 19 | print("Using attention") 20 | print("=" * 80) 21 | 22 | self.supports_masking = True 23 | self.bias = bias 24 | self.feature_dim = feature_dim 25 | self.step_dim = step_dim 26 | self.features_dim = 0 27 | 28 | weight = torch.zeros(feature_dim, 1) 29 | nn.init.xavier_uniform_(weight) 30 | self.weight = nn.Parameter(weight) 31 | 32 | if bias: 33 | self.b = nn.Parameter(torch.zeros(step_dim)) 34 | 35 | def forward(self, x, mask=None): 36 | feature_dim = self.feature_dim 37 | step_dim = self.step_dim 38 | 39 | eij = torch.mm(x.contiguous().view(-1, feature_dim), self.weight).view( 40 | -1, step_dim 41 | ) 42 | 43 | if self.bias: 44 | eij = eij + self.b 45 | 46 | eij = torch.tanh(eij) 47 | a = torch.exp(eij) 48 | 49 | if mask is not None: 50 | a = a * mask 51 | 52 | a = a / torch.sum(a, 1, keepdim=True) + 1e-10 53 | weighted_input = x * torch.unsqueeze(a, -1) 54 | 55 | return torch.sum(weighted_input, 1), self.weight 56 | 57 | 58 | class RNNSequentialEncoder(nn.Module): 59 | """Model to encode series of encoded 2D CT slices using RNN 60 | 61 | Args: 62 | feature_size (int): number of features for input feature vector 63 | rnn_type (str): either lstm or gru 64 | hidden_size (int): number of hidden units 65 | bidirectional (bool): use bidirectional rnn 66 | num_layers (int): number of rnn layers 67 | dropout_prob (float): dropout probability 68 | """ 69 | 70 | def __init__( 71 | self, 72 | feature_size: int, 73 | rnn_type: str = "lstm", 74 | hidden_size: int = 128, 75 | bidirectional: bool = True, 76 | num_layers: int = 1, 77 | dropout_prob: float = 0.0, 78 | ): 79 | 80 | super(RNNSequentialEncoder, self).__init__() 81 | 82 | self.feature_size = feature_size 83 | self.rnn_type = rnn_type 84 | self.hidden_size = hidden_size 85 | self.bidirectional = bidirectional 86 | self.dropout_prob = dropout_prob 87 | self.num_layers = num_layers 88 | 89 | if self.rnn_type not in ["LSTM", "GRU"]: 90 | raise Exception("RNN type has to be either LSTM or GRU") 91 | 92 | self.rnn = getattr(nn, rnn_type)( 93 | self.feature_size, 94 | self.hidden_size, 95 | batch_first=True, 96 | num_layers=self.num_layers, 97 | dropout=self.dropout_prob, 98 | bidirectional=bidirectional, 99 | ) 100 | 101 | def forward(self, x): 102 | x = x.transpose(0, 1) 103 | x, _ = self.rnn(x) # (Slice, Batch, Feature) 104 | x = x.transpose(0, 1) # (Batch, Slice, Feature) 105 | return x 106 | 107 | 108 | class PEModel1D(nn.Module): 109 | def __init__(self, cfg, num_classes=1): 110 | super(PEModel1D, self).__init__() 111 | 112 | # rnn input size 113 | seq_input_size = cfg.data.feature_size 114 | if cfg.data.contextualize_slice: 115 | seq_input_size = seq_input_size * 3 116 | 117 | # classifier input size 118 | cls_input_size = cfg.model.seq_encoder.hidden_size 119 | if cfg.model.seq_encoder.bidirectional: 120 | cls_input_size = cls_input_size * 2 121 | 122 | self.seq_encoder = RNNSequentialEncoder(seq_input_size, **cfg.model.seq_encoder) 123 | 124 | if "attention" in cfg.model.aggregation: 125 | self.attention = Attention(cls_input_size, cfg.data.num_slices) 126 | 127 | self.batch_norm_layer = torch.nn.BatchNorm1d(cls_input_size) 128 | self.classifier = nn.Linear(cls_input_size, num_classes) 129 | self.cfg = cfg 130 | 131 | def forward(self, x, get_features=False): 132 | x = self.seq_encoder(x) 133 | x, w = self.aggregate(x) 134 | x = self.batch_norm_layer(x) 135 | pred = self.classifier(x) 136 | return pred, x 137 | 138 | def aggregate(self, x): 139 | 140 | if self.cfg.model.aggregation == "attention": 141 | return self.attention(x) 142 | elif self.cfg.model.aggregation == "mean": 143 | x = torch.mean(x, 1) 144 | return x, None 145 | elif self.cfg.model.aggregation == "max": 146 | x, _ = torch.max(x, 1) 147 | return x, None 148 | else: 149 | raise Exception( 150 | "Aggregation method should be one of 'attention', 'mean' or 'max'" 151 | ) 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Video pretraining advances 3D deep learning on chest CTs 2 | 3 | This repository contains code to train and evaluate models on the RSNA PE dataset and the LIDC-IDRI dataset for our paper [Video pretraining advances 3D deep learning on chest CTs](https://arxiv.org/abs/2304.00546). 4 | 5 | ## Table of Contents 6 | 0. [System Requirements](#SystemRequirements) 7 | 0. [Installation](#Installation) 8 | 0. [Datasets](#Datasets) 9 | 0. [Usage](#Usage) 10 | 0. [Demo](#Demo) 11 | 12 | ## System Requirements 13 | 14 | ### Hardware requirements 15 | 16 | The data processing steps requires only a standard computer with enough RAM to support the in-memory operations. 17 | 18 | For training and testing models, a computer with sufficient GPU memory is recommended. 19 | 20 | ### Software requirements 21 | #### OS requirements 22 | All models have been trained and tested on a Linux system (Ubuntu 16.04) 23 | 24 | #### Python dependencies 25 | 26 | All dependencies can be found in **environment.yml** 27 | 28 | 29 | ## Installation 30 | 31 | 1. Please install [Anaconda](https://docs.conda.io/en/latest/miniconda.html) in order to create a Python environment. 32 | 2. Clone this repo (from the command-line: `git clone git@github.com:rajpurkarlab/2021-fall-chest-ct.git`). 33 | 3. Create the environment: `conda env create -f environment.yml`. 34 | 4. Activate the environment: `source activate pe_models`. 35 | 5. Install [PyTorch 1.7.1](https://pytorch.org/get-started/locally/) with the right CUDA version. 36 | 37 | Installation should take less than 10 minutes with stable internet. 38 | 39 | ## Datasets 40 | 41 | ### RSNA 42 | 43 | Download dataset from: [RSNA PE Dataset](https://www.kaggle.com/c/rsna-str-pulmonary-embolism-detection) 44 | 45 | Make sure to update **PROJECT_DATA_DIR** in `pe_models/constants.py` with path to the directory that contains the RSNA dataset. 46 | 47 | #### Preprocessing 48 | 49 | Please download the pre-processed label file that contains data split and DICOM header infomation using this [link](https://stanfordmedicine.box.com/s/nlatp1dgg47qry1g7hhr0n87mlavj887) and place it in the RSNA data directory. 50 | 51 | Alternatively, you can create the pre-processed file by running: 52 | ```bash 53 | $ python pe_models/preprocess/rsna.py 54 | ``` 55 | 56 | #### Test 57 | To ensure that the dataset is correct and that data are loading in the correct format, run the following unittest: 58 | 59 | ```bash 60 | $ python -W ignore -m unittest 61 | ``` 62 | 63 | Note that this might take a couple of minutes to complete. 64 | 65 | You can also visually inspect example inputs in `data/test/` after the unittest is complete. 66 | 67 | ### LIDC 68 | 69 | Download dataset from [TCIA Public Access](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) into a `PROJECT_DATA_DIR/lidc` folder. 70 | 71 | #### Preprocessing 72 | 73 | Install *pylidc* and set up your `~/.pylidcrc` file using the [official installation instructions](https://pylidc.github.io/install.html). 74 | 75 | You can then create all the necessary pre-processed files by running: 76 | 77 | ```bash 78 | $ python pe_models/preprocess/lidc.py 79 | ``` 80 | 81 | You can then set the `type` in an experiment YAML to `lidc-window` or `lidc-2d` to train on the LIDC dataset. 82 | 83 | ## Usage 84 | 85 | To train a model, run the following: 86 | 87 | ```bash 88 | python run.py --config --train 89 | ``` 90 | 91 | For more documentation, please run: 92 | 93 | ```bash 94 | python run.py --help 95 | ``` 96 | 97 | To test a model, use the `--test` flag, making sure that either the `--checkpoint` flag is specified or that the config YAML contains a **checkpoint** entry: 98 | 99 | ```bash 100 | python run.py --config --checkpoint --test 101 | ``` 102 | 103 | To featurize all studies in a dataset (to run a 1d model for example), use the `--test_split all` flag 104 | 105 | Example configs can be found in **./configs/** 106 | 107 | ### Run hyperparameter sweep with wandb 108 | 109 | Example hyperparameter sweep configs for each model can be found in **./configs/** 110 | 111 | ```bash 112 | wandb sweep 113 | wandb agent 114 | ``` 115 | ### Custom dataset: 116 | To train/test model on custom datasets: 117 | 1. Please ensure that your data adhere to the same format as the RSNA/LIDC dataset. (See [Example](https://stanfordmedicine.box.com/s/nlatp1dgg47qry1g7hhr0n87mlavj887)) 118 | 2. Create a dataloader similar to RSNA/LIDC in ./datasets and update ./datasets/__init__.py to include the name of your custom dataloader. 119 | 3. Make sure the *data.type* in your config file points to the name of your dataloader. 120 | 121 | ## Demo 122 | 123 | To run train/test script on a simulated demo dataset, use: 124 | 125 | ```bash 126 | python run.py --config ./data/demo/resnet18_demo.yaml --checkpoint --test 127 | ``` 128 | 129 | You should expect the following results: 130 | 131 | ``` 132 | {'test/mean_auprc': 0.9107142686843872, 133 | 'test/mean_auroc': 0.9166666865348816, 134 | 'test/negative_exam_for_pe_auprc': 0.9107142686843872, 135 | 'test/negative_exam_for_pe_auroc': 0.9166666865348816, 136 | 'test_loss': 0.6920164227485657, 137 | 'test_loss_epoch': 0.6920164227485657} 138 | ``` 139 | With a GPU, this should take less than 10 minutes to run. 140 | 141 | ## Citation 142 | 143 | If our work was useful in your research, please consider citing 144 | 145 | ```bibtex 146 | @article{ke2023video, 147 | title={Video Pretraining Advances 3D Deep Learning on Chest CT Tasks}, 148 | author={Alexander Ke and Shih-Cheng Huang and Chloe P O'Connell and Michal Klimont and Serena Yeung and Pranav Rajpurkar}, 149 | booktitle={Medical Imaging with Deep Learning}, 150 | year={2023}, 151 | eprint={2304.00546}, 152 | archivePrefix={arXiv}, 153 | primaryClass={eess.IV} 154 | } 155 | ``` -------------------------------------------------------------------------------- /pe_models/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import yaml 4 | import pandas as pd 5 | import requests 6 | import pydicom 7 | import cv2 8 | import torch 9 | import math 10 | 11 | from matplotlib import pyplot as plt 12 | from sklearn.metrics import average_precision_score, roc_auc_score 13 | 14 | 15 | def get_auroc(y, prob, keys): 16 | 17 | if type(y) == torch.Tensor: 18 | y = y.detach().cpu().numpy() 19 | if type(prob) == torch.Tensor: 20 | prob = prob.detach().cpu().numpy() 21 | 22 | auroc_dict = {} 23 | for i, k in enumerate(keys): 24 | y_cls = y[:, i] 25 | prob_cls = prob[:, i] 26 | 27 | if np.isnan(prob_cls).any(): 28 | auroc_dict[k] = 0 29 | elif len(set(y_cls)) == 1: 30 | auroc_dict[k] = 0 31 | else: 32 | auroc_dict[k] = roc_auc_score(y_cls, prob_cls) 33 | auroc_dict["mean"] = np.mean([v for _, v in auroc_dict.items()]) 34 | return auroc_dict 35 | 36 | 37 | def get_auprc(y, prob, keys): 38 | 39 | if type(y) == torch.Tensor: 40 | y = y.detach().cpu().numpy() 41 | if type(prob) == torch.Tensor: 42 | prob = prob.detach().cpu().numpy() 43 | 44 | auprc_dict = {} 45 | for i, k in enumerate(keys): 46 | y_cls = y[:, i] 47 | prob_cls = prob[:, i] 48 | 49 | if np.isnan(prob_cls).any(): 50 | auprc_dict[k] = 0 51 | elif len(set(y_cls)) == 1: 52 | auprc_dict[k] = 0 53 | else: 54 | auprc_dict[k] = average_precision_score(y_cls, prob_cls) 55 | auprc_dict["mean"] = np.mean([v for _, v in auprc_dict.items()]) 56 | return auprc_dict 57 | 58 | 59 | def flatten(d, parent_key="", sep="."): 60 | """flatten a nested dictionary""" 61 | items = [] 62 | for k, v in d.items(): 63 | new_key = parent_key + sep + k if parent_key else k 64 | if isinstance(v, collections.MutableMapping): 65 | items.extend(flatten(v, new_key, sep=sep).items()) 66 | else: 67 | items.append((new_key, v)) 68 | return dict(items) 69 | 70 | 71 | def get_best_ckpt_path(ckpt_paths, ascending=False): 72 | """get best ckpt path from a list of ckpt paths 73 | 74 | ckpt_paths: JSON file with ckpt path to metric pair 75 | ascending: sort paths based on ascending or descending metrics 76 | """ 77 | 78 | with open(ckpt_paths, "r") as stream: 79 | ckpts = yaml.safe_load(stream) 80 | 81 | ckpts_df = pd.DataFrame.from_dict(ckpts, orient="index").reset_index() 82 | ckpts_df.columns = ["path", "metric"] 83 | best_ckpt_path = ( 84 | ckpts_df.sort_values("metric", ascending=ascending).head(1)["path"].item() 85 | ) 86 | 87 | return best_ckpt_path 88 | 89 | 90 | def read_dicom(file_path: str, imsize): 91 | """TODO: repeated between dataset base """ 92 | 93 | # read dicom 94 | dcm = pydicom.dcmread(file_path) 95 | pixel_array = dcm.pixel_array 96 | 97 | # rescale 98 | intercept = dcm.RescaleIntercept 99 | slope = dcm.RescaleSlope 100 | pixel_array = pixel_array * slope + intercept 101 | 102 | # resize 103 | resize_shape = imsize 104 | pixel_array = cv2.resize( 105 | pixel_array, (resize_shape, resize_shape), interpolation=cv2.INTER_AREA 106 | ) 107 | 108 | return pixel_array 109 | 110 | def windowing(pixel_array: np.array, window_center: int, window_width: int): 111 | """TODO: repeated between dataset base """ 112 | 113 | lower = window_center - window_width // 2 114 | upper = window_center + window_width // 2 115 | pixel_array = np.clip(pixel_array.copy(), lower, upper) 116 | pixel_array = (pixel_array - lower) / (upper - lower) 117 | 118 | return pixel_array 119 | 120 | 121 | def download_file_from_google_drive(id, destination): 122 | def get_confirm_token(response): 123 | for key, value in response.cookies.items(): 124 | if key.startswith("download_warning"): 125 | return value 126 | 127 | return None 128 | 129 | def save_response_content(response, destination): 130 | CHUNK_SIZE = 32768 131 | 132 | with open(destination, "wb") as f: 133 | for chunk in response.iter_content(CHUNK_SIZE): 134 | if chunk: # filter out keep-alive new chunks 135 | f.write(chunk) 136 | 137 | URL = "https://docs.google.com/uc?export=download" 138 | 139 | session = requests.Session() 140 | 141 | response = session.get(URL, params={"id": id}, stream=True) 142 | token = get_confirm_token(response) 143 | 144 | if token: 145 | params = {"id": id, "confirm": token} 146 | response = session.get(URL, params=params, stream=True) 147 | 148 | save_response_content(response, destination) 149 | 150 | 151 | def visualize_examples(x, num_viz, save_dir): 152 | 153 | f, axarr = plt.subplots(num_viz, 3, figsize=(3, num_viz)) 154 | plt.subplots_adjust(wspace=0, hspace=0) 155 | 156 | # make sure channel first 157 | if x.shape[1] == 3: 158 | x = np.transpose(x, (1, 0, 2, 3)) 159 | factor = x.shape[1] // num_viz 160 | 161 | for i in range(num_viz): 162 | for j in range(3): 163 | image = x[j, i * factor, :, :] 164 | image = np.repeat(np.expand_dims(image, -1), 3, axis=-1) 165 | axarr[i][j].imshow(image) 166 | 167 | plt.setp(axarr, xticks=[], yticks=[]) 168 | plt.subplots_adjust( 169 | left=None, bottom=None, right=None, top=None, wspace=0.05, hspace=0.05 170 | ) 171 | plt.savefig(save_dir, dpi=100) 172 | print(f"Example images saved at: {save_dir}") 173 | 174 | 175 | def linear_warmup_then_cosine(last_iter, warmup, max_iter, delay=None): 176 | if delay is not None: 177 | last_iter = max(0, last_iter - delay) 178 | 179 | if last_iter < warmup: 180 | # Linear warmup period 181 | return float(last_iter) / warmup 182 | elif last_iter < max_iter: 183 | # Cosine annealing 184 | return (1 + math.cos(math.pi * (last_iter - warmup) / max_iter)) / 2 185 | else: 186 | # Done 187 | return 0. -------------------------------------------------------------------------------- /pe_models/models/backbones3d/penet_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/marshuang80/penet/blob/master/models/penet_classifier.py 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .layers.penet import * 10 | 11 | 12 | class PENetClassifier(nn.Module): 13 | """PENet stripped down for classification. 14 | The idea is to pre-train this network, then use the pre-trained 15 | weights for the encoder in a full PENet. 16 | """ 17 | 18 | def __init__(self, cfg, **kwargs): 19 | super(PENetClassifier, self).__init__() 20 | 21 | self.in_channels = 64 22 | self.model_depth = cfg.model.model_depth 23 | self.cardinality = cfg.model.cardinality 24 | self.num_channels = cfg.model.num_channels 25 | self.num_classes = cfg.model.num_classes 26 | 27 | self.in_conv = nn.Sequential(nn.Conv3d(self.num_channels, self.in_channels, kernel_size=7, 28 | stride=(1, 2, 2), padding=(3, 3, 3), bias=False), 29 | nn.GroupNorm(self.in_channels // 16, self.in_channels), 30 | nn.LeakyReLU(inplace=True)) 31 | self.max_pool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 32 | 33 | # Encoders 34 | if cfg.model.model_depth != 50: 35 | raise ValueError('Unsupported model depth: {}'.format(cfg.model.model_depth)) 36 | encoder_config = [3, 4, 6, 3] 37 | total_blocks = sum(encoder_config) 38 | block_idx = 0 39 | 40 | self.encoders = nn.ModuleList() 41 | for i, num_blocks in enumerate(encoder_config): 42 | out_channels = 2 ** i * 128 43 | stride = 1 if i == 0 else 2 44 | encoder = PENetEncoder(self.in_channels, out_channels, num_blocks, self.cardinality, 45 | block_idx, total_blocks, stride=stride) 46 | self.encoders.append(encoder) 47 | self.in_channels = out_channels * PENetBottleneck.expansion 48 | block_idx += num_blocks 49 | 50 | self.classifier = GAPLinear(self.in_channels, cfg.model.num_classes) 51 | 52 | if cfg.model.init_method is not None: 53 | self._initialize_weights(cfg.model.init_method, focal_pi=0.01) 54 | 55 | def _initialize_weights(self, init_method, gain=0.2, focal_pi=None): 56 | """Initialize all weights in the network.""" 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d) or isinstance(m, nn.Linear): 59 | if init_method == 'normal': 60 | nn.init.normal_(m.weight, mean=0, std=gain) 61 | elif init_method == 'xavier': 62 | nn.init.xavier_normal_(m.weight, gain=gain) 63 | elif init_method == 'kaiming': 64 | nn.init.kaiming_normal_(m.weight) 65 | else: 66 | raise NotImplementedError('Invalid initialization method: {}'.format(self.init_method)) 67 | if hasattr(m, 'bias') and m.bias is not None: 68 | if focal_pi is not None and hasattr(m, 'is_output_head') and m.is_output_head: 69 | # Focal loss prior (~0.01 prob for positive, see RetinaNet Section 4.1) 70 | nn.init.constant_(m.bias, -math.log((1 - focal_pi) / focal_pi)) 71 | else: 72 | nn.init.constant_(m.bias, 0) 73 | elif isinstance(m, nn.GroupNorm) and m.affine: 74 | # Gamma for last GroupNorm in each residual block gets set to 0 75 | init_gamma = 0 if hasattr(m, 'is_last_norm') and m.is_last_norm else 1 76 | nn.init.constant_(m.weight, init_gamma) 77 | nn.init.constant_(m.bias, 0) 78 | 79 | def forward(self, x): 80 | 81 | # Expand input (allows pre-training on RGB videos, fine-tuning on Hounsfield Units) 82 | if x.size(1) < self.num_channels: 83 | x = x.expand(-1, self.num_channels // x.size(1), -1, -1, -1) 84 | 85 | x = self.in_conv(x) 86 | x = self.max_pool(x) 87 | 88 | # Encoders 89 | for encoder in self.encoders: 90 | x = encoder(x) 91 | 92 | # Classifier 93 | x = self.classifier(x) 94 | 95 | return x 96 | 97 | def args_dict(self): 98 | """Get a dictionary of args that can be used to reconstruct this architecture. 99 | To use the returned dict, initialize the model with `PENet(**model_args)`. 100 | """ 101 | model_args = { 102 | 'model_depth': self.model_depth, 103 | 'cardinality': self.cardinality, 104 | 'num_classes': self.num_classes, 105 | 'num_channels': self.num_channels 106 | } 107 | 108 | return model_args 109 | 110 | def load_pretrained(self, ckpt_path): 111 | """Load parameters from a pre-trained PENetClassifier from checkpoint at ckpt_path. 112 | Args: 113 | ckpt_path: Path to checkpoint for PENetClassifier. 114 | Adapted from: 115 | https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2 116 | """ 117 | try: 118 | pretrained_dict = torch.load(ckpt_path)['model_state'] 119 | model_dict = self.state_dict() 120 | 121 | # Filter out unnecessary keys 122 | pretrained_dict = {k[len('module.'):]: v for k, v in pretrained_dict.items()} 123 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 124 | except: 125 | pretrained_dict = torch.load(ckpt_path)['state_dict'] 126 | model_dict = self.state_dict() 127 | pretrained_dict = {k[len('module.'):]: v for k, v in pretrained_dict.items()} 128 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ['classifier.fc.weight', 'classifier.fc.bias']} 129 | 130 | 131 | # Overwrite entries in the existing state dict 132 | model_dict.update(pretrained_dict) 133 | 134 | # Load the new state dict 135 | msg = self.load_state_dict(model_dict, strict=False) 136 | print(msg) 137 | 138 | -------------------------------------------------------------------------------- /pe_models/builder.py: -------------------------------------------------------------------------------- 1 | import omegaconf 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import albumentations as A 7 | import cv2 8 | 9 | from . import models 10 | from . import lightning 11 | from . import datasets 12 | from . import utils 13 | from .loss import BinaryFocalLoss 14 | from albumentations.pytorch import ToTensorV2 15 | from .constants import * 16 | from functools import partial 17 | from omegaconf import OmegaConf 18 | 19 | 20 | def build_data_module(cfg): 21 | return datasets.data_module.PEDataModule(cfg) 22 | 23 | 24 | def build_dataset(cfg): 25 | if cfg.data.type.lower() in datasets.ALL_DATASETS: 26 | return datasets.ALL_DATASETS[cfg.data.type.lower()] 27 | else: 28 | raise NotImplementedError( 29 | f"Dataset not implemented for {cfg.data.type.lower()}" 30 | ) 31 | 32 | 33 | def build_lightning_model(cfg): 34 | 35 | 36 | if cfg.data.type == 'window' or cfg.data.type == 'lidc-window': 37 | model = lightning.PEWindowClassificationLightningModel 38 | else: 39 | model = lightning.PEClassificationLightningModel 40 | 41 | # TODO: ugly logic 42 | # mae vit uses it's ownd data loading 43 | if OmegaConf.is_none(cfg, "checkpoint"): 44 | return model(cfg) 45 | else: 46 | checkpoint_path = cfg.checkpoint 47 | print('='*80) 48 | print(f'*** Using checkpoint: {checkpoint_path}') 49 | print('='*80) 50 | return model.load_from_checkpoint(checkpoint_path, cfg=cfg) 51 | 52 | 53 | def build_model(cfg): 54 | num_class = len(RSNA_TARGET_TYPES[cfg.data.targets]) 55 | return models.ALL_MODELS[cfg.model.type.lower()](cfg, num_class) 56 | 57 | 58 | def build_optimizer(cfg, model): 59 | # different scheduler for fine-tune and encoder 60 | if not OmegaConf.is_none(cfg.model, "boundary_layer_name"): 61 | print('Finetuning parameters') 62 | params = model.fine_tuning_parameters( 63 | cfg.model.boundary_layer_name, cfg.lightning.trainer.lr) 64 | else: 65 | params = [p for p in model.parameters() if p.requires_grad] 66 | # optimizer_name = cfg.train.optimizer.pop("name") 67 | # optimizer = getattr(torch.optim, optimizer_name) 68 | # return optimizer(params, lr=cfg.lightning.trainer.lr, **cfg.train.optimizer) 69 | optimizer_name = cfg.train.optimizer.pop("name") 70 | optimizer_fn = getattr(torch.optim, optimizer_name) 71 | optimizer = optimizer_fn(params, lr=cfg.lightning.trainer.lr, **cfg.train.optimizer) 72 | cfg.train.optimizer.name = optimizer_name 73 | return optimizer 74 | 75 | 76 | def build_scheduler(cfg, optimizer): 77 | 78 | if cfg.train.scheduler.name is not None: 79 | scheduler_name = cfg.train.scheduler.pop("name") 80 | monitor = cfg.train.scheduler.pop("monitor") 81 | interval = cfg.train.scheduler.pop("interval") 82 | frequency = cfg.train.scheduler.pop("frequency") 83 | 84 | if scheduler_name == 'CosineWarmup': 85 | # If pretrained, delay for warmup steps to allow randomly initialized head to settle down 86 | if len(optimizer.param_groups) > 1: 87 | ft_lambda_fn = partial( 88 | utils.linear_warmup_then_cosine, 89 | delay=cfg.train.scheduler.lr_warmup_steps, 90 | warmup=cfg.train.scheduler.lr_warmup_steps, 91 | max_iter=cfg.train.scheduler.lr_decay_step) 92 | reg_lambda_fn = partial( 93 | utils.linear_warmup_then_cosine, 94 | warmup=cfg.train.scheduler.lr_warmup_steps, 95 | max_iter=cfg.train.scheduler.lr_decay_step) 96 | lr_fns = [ft_lambda_fn, reg_lambda_fn] if cfg.model.pretrained else reg_lambda_fn 97 | else: 98 | lr_fns = partial( 99 | utils.linear_warmup_then_cosine, 100 | warmup=cfg.train.scheduler.lr_warmup_steps, 101 | max_iter=cfg.train.scheduler.lr_decay_step) 102 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_fns) 103 | else: 104 | scheduler_class = getattr(torch.optim.lr_scheduler, scheduler_name) 105 | scheduler = scheduler_class(optimizer, **cfg.train.scheduler) 106 | 107 | cfg.train.scheduler.name = scheduler_name 108 | cfg.train.scheduler.monitor = monitor 109 | cfg.train.scheduler.interval = interval 110 | cfg.train.scheduler.frequency = frequency 111 | 112 | else: 113 | scheduler = None 114 | monitor = None 115 | interval = None 116 | frequency = None 117 | 118 | scheduler = { 119 | "scheduler": scheduler, 120 | "monitor": monitor, 121 | "interval": interval, 122 | "frequency": frequency, 123 | } 124 | 125 | return scheduler 126 | 127 | 128 | def build_loss(cfg): 129 | # get loss function 130 | loss_fn_name = cfg.train.loss_fn.pop("name") 131 | if loss_fn_name == 'BinaryFocalLoss': 132 | loss_fn = BinaryFocalLoss 133 | else: 134 | loss_fn = getattr(nn, loss_fn_name) 135 | loss_function = loss_fn(**cfg.train.loss_fn) 136 | cfg.train.loss_fn.name = loss_fn_name 137 | return loss_function 138 | 139 | 140 | def build_transformation(cfg, split): 141 | 142 | if OmegaConf.is_none(cfg, 'transforms'): 143 | return None 144 | elif cfg.data.type in ["3d", "window", "lidc-window", "window_stanford"]: # another exception for lidc-window 145 | return None 146 | 147 | transforms = [] 148 | 149 | if split == "train": 150 | # handel shift scale rotate 151 | if "ShiftScaleRotate" in cfg.transforms: 152 | transforms.append( 153 | A.ShiftScaleRotate( 154 | border_mode=cv2.BORDER_CONSTANT, 155 | **cfg.transforms.ShiftScaleRotate 156 | ) 157 | ) 158 | for transform_name, arguments in cfg.transforms.items(): 159 | if transform_name == 'type': continue 160 | transforms.append(getattr(A, transform_name)(**arguments)) 161 | else: 162 | if "RandomCrop" in cfg.transforms: 163 | transforms.append(A.CenterCrop(**cfg.transforms.RandomCrop)) 164 | 165 | transforms += [ToTensorV2()] 166 | transforms = A.Compose(transforms) 167 | 168 | return transforms 169 | -------------------------------------------------------------------------------- /pe_models/datasets/dataset_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | import pydicom 4 | import numpy as np 5 | import pandas as pd 6 | import cv2 7 | import numpy.random as random 8 | 9 | from ..constants import * 10 | from torch.utils.data import Dataset 11 | from scipy.ndimage.interpolation import rotate 12 | from omegaconf import OmegaConf 13 | 14 | 15 | class PEDatasetBase(Dataset): 16 | def __init__(self, cfg, split="train", transform=None): 17 | 18 | self.cfg = cfg 19 | self.transform = transform 20 | self.split = split 21 | self.hdf5_dataset = None 22 | 23 | def __getitem__(self, index): 24 | raise NotImplementedError 25 | 26 | def __len__(self): 27 | raise NotImplementedError 28 | 29 | def read_dicom(self, file_path: str): 30 | 31 | # read dicom 32 | dcm = pydicom.dcmread(file_path) 33 | pixel_array = dcm.pixel_array 34 | 35 | # rescale 36 | intercept = dcm.RescaleIntercept 37 | slope = dcm.RescaleSlope 38 | pixel_array = pixel_array * slope + intercept 39 | 40 | # resize 41 | # TODO: maybe use augmentation resize instead 42 | resize_shape = self.cfg.data.imsize 43 | pixel_array = cv2.resize( 44 | pixel_array, (resize_shape, resize_shape), interpolation=cv2.INTER_AREA 45 | ) 46 | 47 | return pixel_array 48 | 49 | def windowing(self, pixel_array: np.array, window_center: int, window_width: int): 50 | 51 | lower = window_center - window_width // 2 52 | upper = window_center + window_width // 2 53 | pixel_array = np.clip(pixel_array.copy(), lower, upper) 54 | pixel_array = (pixel_array - lower) / (upper - lower) 55 | 56 | return pixel_array 57 | 58 | def process_slice(self, slice_info: pd.Series, train_dir:str = None): 59 | """process slice with windowing, resize and tranforms""" 60 | 61 | if train_dir is None: 62 | TRAIN_DIR = RSNA_TRAIN_DIR 63 | else: 64 | TRAIN_DIR = train_dir 65 | 66 | # window 67 | if self.cfg.data.channels == "repeat": 68 | slice_path = RSNA_TRAIN_DIR / slice_info[RSNA_INSTANCE_PATH_COL] 69 | slice_array = self.read_dicom(slice_path) 70 | ct_slice = self.windowing( 71 | slice_array, 400, 1000 72 | ) # use PE window by default 73 | # create 3 channels after converting to Tensor 74 | # using torch.repeat won't take up 3x memory 75 | 76 | elif self.cfg.data.channels == "neighbor": 77 | slice_paths = [ 78 | RSNA_TRAIN_DIR / slice_info[RSNA_PREV_INSTANCE_COL], 79 | RSNA_TRAIN_DIR / slice_info[RSNA_INSTANCE_PATH_COL], 80 | RSNA_TRAIN_DIR / slice_info[RSNA_NEXT_INSTANCE_COL], 81 | ] 82 | slice_arrays = np.array( 83 | [self.read_dicom(slice_path) for slice_path in slice_paths] 84 | ) 85 | ct_slice = self.windowing(slice_arrays, 400, 1000) 86 | ct_slice = np.stack(ct_slice) 87 | 88 | else: 89 | slice_path = RSNA_TRAIN_DIR / slice_info[RSNA_INSTANCE_PATH_COL] 90 | slice_array = self.read_dicom(slice_path) 91 | ct_slice = [ 92 | self.windowing(slice_array, -600, 1500), # LUNG window 93 | self.windowing(slice_array, 400, 1000), # PE window 94 | self.windowing(slice_array, 40, 400), # MEDIASTINAL window 95 | ] 96 | ct_slice = np.stack(ct_slice) 97 | 98 | return ct_slice 99 | 100 | def read_from_hdf5(self, key, slice_idx=None, hdf5_path = RSNA_STUDY_HDF5): 101 | if self.hdf5_dataset is None: 102 | self.hdf5_dataset = h5py.File(hdf5_path, 'r') 103 | 104 | if slice_idx is None: 105 | arr = self.hdf5_dataset[key][:] 106 | else: 107 | arr = self.hdf5_dataset[key][slice_idx] 108 | return arr 109 | 110 | def fix_slice_number(self, df: pd.DataFrame): 111 | 112 | num_slices = min(self.cfg.data.num_slices, df.shape[0]) 113 | if self.cfg.data.sample_strategy == "random": 114 | slice_idx = np.random.choice( 115 | np.arange(df.shape[0]), replace=False, size=num_slices 116 | ) 117 | slice_idx = list(np.sort(slice_idx)) 118 | df = df.iloc[slice_idx, :] 119 | elif self.cfg.data.sample_strategy == "fix": 120 | df = df.iloc[:num_slices, :] 121 | else: 122 | raise Exception("Sampling strategy either 'random' or 'fix'") 123 | return df 124 | 125 | def fix_series_slice_number(self, series): 126 | 127 | num_slices = min(self.cfg.data.num_slices, series.shape[0]) 128 | if self.cfg.data.sample_strategy == "random": 129 | slice_idx = np.random.choice( 130 | np.arange(series.shape[0]), replace=False, size=num_slices 131 | ) 132 | slice_idx = list(np.sort(slice_idx)) 133 | series = series[slice_idx, :] 134 | elif self.cfg.data.sample_strategy == "fix": 135 | series = series[:num_slices, :] 136 | else: 137 | raise Exception("Sampling strategy either 'random' or 'fix'") 138 | return series 139 | 140 | def fill_series_to_num_slicess(self, series, num_slices): 141 | x = torch.zeros(()).new_full((num_slices, *series.shape[1:]), 0.0) 142 | x[: series.shape[0]] = series 143 | return x 144 | 145 | def augment_series(self, series: np.array): 146 | """Series level augmentation""" 147 | 148 | if len(series.shape) == 3: 149 | series = np.expand_dims(series, 1) 150 | 151 | # crop volume slice-wise 152 | if self.cfg.transforms.RandomCrop is not None: 153 | # if not OmegaConf.is_none(self.cfg.transforms, "RandomCrop"): 154 | 155 | h = self.cfg.transforms.RandomCrop.height 156 | w = self.cfg.transforms.RandomCrop.width 157 | 158 | row_margin = max(0, series.shape[-2] - h) 159 | col_margin = max(0, series.shape[-1] - w) 160 | 161 | # Random crop during training, center crop during test inference 162 | row = ( 163 | random.randint(0, row_margin) 164 | if self.split == "train" 165 | else row_margin // 2 166 | ) 167 | col = ( 168 | random.randint(0, col_margin) 169 | if self.split == "train" 170 | else col_margin // 2 171 | ) 172 | series = series[:, :, col : col + h, row : row + w] 173 | 174 | # rotate 175 | if (self.cfg.transforms.Rotate is not None) and ( 176 | # if (not OmegaConf.is_none(self.cfg.transforms, "Rotate")) and ( 177 | self.split == "train" 178 | ): 179 | rotate_limit = self.cfg.transforms.Rotate.rotate_limit 180 | angle = random.randint(-rotate_limit, rotate_limit) 181 | 182 | series = rotate(series, angle, (-2, -1), reshape=False, cval=AIR_HU_VAL) 183 | 184 | return series 185 | -------------------------------------------------------------------------------- /tests/test_dataloader3d.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import unittest 3 | import torch 4 | import os 5 | import sys 6 | 7 | sys.path.append(os.getcwd()) 8 | 9 | from pe_models import utils 10 | from pe_models.constants import * 11 | from pe_models.datasets import dataset_3d 12 | from torch.utils.data import DataLoader 13 | from omegaconf import OmegaConf 14 | 15 | 16 | class RSNA3DDataLoaderTestCase(unittest.TestCase): 17 | def test_repeat_channel(self): 18 | config = { 19 | "data": { 20 | "use_hdf5": False, 21 | "dataset": "rsna", 22 | "type": "3d", 23 | "targets": "rsna_pe_target", 24 | "channels": "repeat", 25 | "weighted_sample": True, 26 | "positive_only": True, 27 | "num_slices": 150, 28 | "sample_strategy": "random", 29 | "imsize": 256, 30 | }, 31 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 32 | } 33 | config = OmegaConf.create(config) 34 | dataset = dataset_3d.PEDataset3D(config, split="train") 35 | dataloader = DataLoader(dataset, batch_size=2, num_workers=0) 36 | 37 | x, y, ids = next(iter(dataloader)) 38 | 39 | self.assertEqual(x.shape[0], 2, "batch size inccorect") 40 | self.assertEqual(x.shape[1], 3, "number of channels incorrect") 41 | self.assertEqual(x.shape[2], 150, "slice number incorrect") 42 | self.assertEqual(x.shape[3], 224, "width incorrect") 43 | self.assertEqual(x.shape[4], 224, "height incorrect") 44 | self.assertTrue( 45 | torch.all(x[0, 0, :, :, :].eq(x[0, 1, :, :, :])), "channels not repeating" 46 | ) 47 | self.assertTrue( 48 | torch.all(x[0, 1, :, :, :].eq(x[0, 2, :, :, :])), "channels not repeating" 49 | ) 50 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 51 | 52 | utils.visualize_examples(x[0].numpy(), 10, "./data/test/rsna_3d_repeat.png") 53 | 54 | def test_window_channel(self): 55 | config = { 56 | "data": { 57 | "use_hdf5": False, 58 | "dataset": "rsna", 59 | "type": "3d", 60 | "targets": "rsna_pe_target", 61 | "channels": "window", 62 | "weighted_sample": True, 63 | "positive_only": True, 64 | "num_slices": 150, 65 | "sample_strategy": "random", 66 | "imsize": 256, 67 | }, 68 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 69 | } 70 | config = OmegaConf.create(config) 71 | dataset = dataset_3d.PEDataset3D(config, split="train") 72 | dataloader = DataLoader(dataset, batch_size=2, num_workers=0) 73 | 74 | x, y, ids = next(iter(dataloader)) 75 | 76 | self.assertEqual(x.shape[0], 2, "batch size inccorect") 77 | self.assertEqual(x.shape[1], 3, "number of channels incorrect") 78 | self.assertEqual(x.shape[2], 150, "slice number incorrect") 79 | self.assertEqual(x.shape[3], 224, "width incorrect") 80 | self.assertEqual(x.shape[4], 224, "height incorrect") 81 | self.assertFalse( 82 | torch.all(x[0, 0, :, :].eq(x[0, 1, :, :])), "channels repeating" 83 | ) 84 | self.assertFalse( 85 | torch.all(x[0, 1, :, :].eq(x[0, 2, :, :])), "channels repeating" 86 | ) 87 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 88 | 89 | utils.visualize_examples(x[0].numpy(), 10, "./data/test/rsna_3d_window.png") 90 | 91 | def test_neighbor_channel(self): 92 | config = { 93 | "data": { 94 | "use_hdf5": False, 95 | "dataset": "rsna", 96 | "type": "3d", 97 | "targets": "rsna_pe_target", 98 | "channels": "neighbor", 99 | "weighted_sample": True, 100 | "positive_only": True, 101 | "num_slices": 150, 102 | "sample_strategy": "random", 103 | "imsize": 256, 104 | }, 105 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 106 | } 107 | config = OmegaConf.create(config) 108 | dataset = dataset_3d.PEDataset3D(config, split="train") 109 | dataloader = DataLoader(dataset, batch_size=2, num_workers=0, shuffle=False) 110 | 111 | x, y, ids = next(iter(dataloader)) 112 | 113 | self.assertEqual(x.shape[0], 2, "batch size inccorect") 114 | self.assertEqual(x.shape[1], 3, "number of channels incorrect") 115 | self.assertEqual(x.shape[2], 150, "slice number incorrect") 116 | self.assertEqual(x.shape[3], 224, "width incorrect") 117 | self.assertEqual(x.shape[4], 224, "height incorrect") 118 | self.assertFalse( 119 | torch.all(x[1, 0, :, :].eq(x[1, 1, :, :])), "channels repeating" 120 | ) 121 | self.assertFalse( 122 | torch.all(x[1, 1, :, :].eq(x[1, 2, :, :])), "channels repeating" 123 | ) 124 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 125 | utils.visualize_examples(x[0].numpy(), 10, "./data/test/rsna_3d_neighbor.png") 126 | 127 | def test_repeat_channel_from_hdf5(self): 128 | config = { 129 | "data": { 130 | "use_hdf5": True, 131 | "dataset": "rsna", 132 | "type": "3d", 133 | "targets": "rsna_pe_target", 134 | "channels": "repeat", 135 | "weighted_sample": True, 136 | "positive_only": True, 137 | "num_slices": 150, 138 | "sample_strategy": "random", 139 | "imsize": 256, 140 | }, 141 | "transforms": {"RandomCrop": {"height": 224, "width": 224}}, 142 | } 143 | config = OmegaConf.create(config) 144 | dataset = dataset_3d.PEDataset3D(config, split="train") 145 | dataloader = DataLoader(dataset, batch_size=2, num_workers=0) 146 | 147 | x, y, ids = next(iter(dataloader)) 148 | 149 | self.assertEqual(x.shape[0], 2, "batch size inccorect") 150 | self.assertEqual(x.shape[1], 3, "number of channels incorrect") 151 | self.assertEqual(x.shape[2], 150, "slice number incorrect") 152 | self.assertEqual(x.shape[3], 224, "width incorrect") 153 | self.assertEqual(x.shape[4], 224, "height incorrect") 154 | self.assertTrue( 155 | torch.all(x[0, 0, :, :, :].eq(x[0, 1, :, :, :])), "channels not repeating" 156 | ) 157 | self.assertTrue( 158 | torch.all(x[0, 1, :, :, :].eq(x[0, 2, :, :, :])), "channels not repeating" 159 | ) 160 | self.assertTrue((x.max() <= 1.0 and x.min() >= -1.0), "input normalized") 161 | 162 | utils.visualize_examples(x[0].numpy(), 10, "./data/test/rsna_3d_repeat_hdf5.png") 163 | 164 | 165 | if __name__ == "__main__": 166 | runner = unittest.TextTestRunner(verbosity=2) 167 | unittest.main(testRunner=runner) 168 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/resnet_2d3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | 7 | __all__ = [ 8 | 'ResNet2d3d', 'r2d3d50', 'r3d50' 9 | ] 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1, bias=False): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d( 14 | in_planes, 15 | out_planes, 16 | kernel_size=3, 17 | stride=stride, 18 | padding=1, 19 | bias=bias) 20 | 21 | def conv1x3x3(in_planes, out_planes, stride=1, bias=False): 22 | # 1x3x3 convolution with padding 23 | return nn.Conv3d( 24 | in_planes, 25 | out_planes, 26 | kernel_size=(1,3,3), 27 | stride=(1,stride,stride), 28 | padding=(0,1,1), 29 | bias=bias) 30 | 31 | 32 | def downsample_basic_block(x, planes, stride): 33 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 34 | zero_pads = torch.Tensor( 35 | out.size(0), planes - out.size(1), out.size(2), out.size(3), 36 | out.size(4)).zero_() 37 | if isinstance(out.data, torch.cuda.FloatTensor): 38 | zero_pads = zero_pads.cuda() 39 | 40 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 41 | 42 | return out 43 | 44 | 45 | 46 | class Bottleneck3d(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_final_relu=True): 50 | super(Bottleneck3d, self).__init__() 51 | bias = False 52 | self.use_final_relu = use_final_relu 53 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=(3,1,1), padding=(1,0,0), bias=bias) 54 | self.bn1 = nn.BatchNorm3d(planes) 55 | 56 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,stride,stride), padding=(0,1,1), bias=bias) 57 | self.bn2 = nn.BatchNorm3d(planes) 58 | 59 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 60 | self.bn3 = nn.BatchNorm3d(planes * 4) 61 | 62 | self.relu = nn.ReLU(inplace=True) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | if self.use_final_relu: out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | class Bottleneck2d(nn.Module): 90 | expansion = 4 91 | 92 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_final_relu=True): 93 | super(Bottleneck2d, self).__init__() 94 | bias = False 95 | self.use_final_relu = use_final_relu 96 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=bias) 97 | self.bn1 = nn.BatchNorm3d(planes) 98 | 99 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1,3,3), stride=(1,stride,stride), padding=(0,1,1), bias=bias) 100 | self.bn2 = nn.BatchNorm3d(planes) 101 | 102 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=bias) 103 | self.bn3 = nn.BatchNorm3d(planes * 4) 104 | 105 | self.relu = nn.ReLU(inplace=True) 106 | self.downsample = downsample 107 | self.stride = stride 108 | 109 | def forward(self, x): 110 | residual = x 111 | 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv2(out) 117 | out = self.bn2(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv3(out) 121 | out = self.bn3(out) 122 | 123 | if self.downsample is not None: 124 | residual = self.downsample(x) 125 | 126 | out += residual 127 | if self.use_final_relu: out = self.relu(out) 128 | 129 | return out 130 | 131 | 132 | 133 | class ResNet2d3d(nn.Module): 134 | def __init__(self, block, layers, input_channel=3): 135 | super(ResNet2d3d, self).__init__() 136 | self.inplanes = 64 137 | bias = False 138 | self.conv1 = nn.Conv3d(input_channel, 64, kernel_size=(5,7,7), stride=(2, 2, 2), padding=(2, 3, 3), bias=bias) 139 | self.bn1 = nn.BatchNorm3d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 142 | 143 | if not isinstance(block, list): 144 | block = [block] * 4 145 | 146 | self.layer1 = self._make_layer(block[0], 64, layers[0]) 147 | self.layer2 = self._make_layer(block[1], 128, layers[1], stride=(1,2,2)) 148 | self.layer3 = self._make_layer(block[2], 256, layers[2], stride=(1,2,2)) 149 | self.layer4 = self._make_layer(block[3], 512, layers[3], stride=(1,2,2), is_final=True) 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv3d): 152 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 153 | if m.bias is not None: m.bias.data.zero_() 154 | elif isinstance(m, nn.BatchNorm3d): 155 | m.weight.data.fill_(1) 156 | m.bias.data.zero_() 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1, is_final=False): 159 | downsample = None 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | # customized_stride to deal with 2d or 3d residual blocks 162 | if isinstance(stride, int): 163 | if (block == Bottleneck2d) or (block == BasicBlock2d): 164 | customized_stride = (1, stride, stride) 165 | else: 166 | customized_stride = stride 167 | elif isinstance(stride, tuple): 168 | customized_stride = stride 169 | stride = stride[-1] 170 | else: 171 | raise NotImplementedError 172 | 173 | downsample = nn.Sequential( 174 | nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=customized_stride, bias=False), 175 | nn.BatchNorm3d(planes * block.expansion) 176 | ) 177 | 178 | layers = [] 179 | layers.append(block(self.inplanes, planes, stride, downsample)) 180 | self.inplanes = planes * block.expansion 181 | if is_final: # if is final block, no ReLU in the final output 182 | for i in range(1, blocks-1): 183 | layers.append(block(self.inplanes, planes)) 184 | layers.append(block(self.inplanes, planes, use_final_relu=False)) 185 | else: 186 | for i in range(1, blocks): 187 | layers.append(block(self.inplanes, planes)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | 202 | return F.relu(x) 203 | 204 | 205 | # used by CVRL(https://arxiv.org/pdf/2008.03800.pdf) 206 | def r2d3d50(**kwargs): 207 | '''Constructs a ResNet-50 model. ''' 208 | model = ResNet2d3d([Bottleneck2d, Bottleneck2d, Bottleneck3d, Bottleneck3d], 209 | [3, 4, 6, 3], **kwargs) 210 | return model 211 | 212 | # full resnet3d 213 | def r3d50(**kwargs): 214 | '''Constructs a ResNet-50 model. ''' 215 | model = ResNet2d3d([Bottleneck3d, Bottleneck3d, Bottleneck3d, Bottleneck3d], 216 | [3, 4, 6, 3], **kwargs) 217 | return model 218 | 219 | 220 | 221 | if __name__ == '__main__': 222 | mymodel = r2d3d50() 223 | mydata = torch.FloatTensor(4, 3, 16, 224, 224) 224 | nn.init.normal_(mydata) 225 | import ipdb; ipdb.set_trace() 226 | mymodel(mydata) -------------------------------------------------------------------------------- /pe_models/models/backbones3d/resnet3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/kenshohara/3D-ResNets-PyTorch/blob/master/models/resnet.py 3 | """ 4 | 5 | import math 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | def get_inplanes(): 14 | return [64, 128, 256, 512] 15 | 16 | 17 | def conv3x3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv3d( 19 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 20 | ) 21 | 22 | 23 | def conv1x1x1(in_planes, out_planes, stride=1): 24 | return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, in_planes, planes, stride=1, downsample=None): 31 | super().__init__() 32 | 33 | self.conv1 = conv3x3x3(in_planes, planes, stride) 34 | self.bn1 = nn.BatchNorm3d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm3d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, in_planes, planes, stride=1, downsample=None): 64 | super().__init__() 65 | 66 | self.conv1 = conv1x1x1(in_planes, planes) 67 | self.bn1 = nn.BatchNorm3d(planes) 68 | self.conv2 = conv3x3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm3d(planes) 70 | self.conv3 = conv1x1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet3D(nn.Module): 100 | def __init__( 101 | self, 102 | block, 103 | layers, 104 | block_inplanes, 105 | n_input_channels=3, 106 | conv1_t_size=7, 107 | conv1_t_stride=1, 108 | no_max_pool=False, 109 | shortcut_type="B", 110 | widen_factor=1.0, 111 | n_classes=400, 112 | ): 113 | super().__init__() 114 | 115 | block_inplanes = [int(x * widen_factor) for x in block_inplanes] 116 | 117 | self.in_planes = block_inplanes[0] 118 | self.no_max_pool = no_max_pool 119 | 120 | self.conv1 = nn.Conv3d( 121 | n_input_channels, 122 | self.in_planes, 123 | kernel_size=(conv1_t_size, 7, 7), 124 | stride=(conv1_t_stride, 2, 2), 125 | padding=(conv1_t_size // 2, 3, 3), 126 | bias=False, 127 | ) 128 | self.bn1 = nn.BatchNorm3d(self.in_planes) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer( 132 | block, block_inplanes[0], layers[0], shortcut_type 133 | ) 134 | self.layer2 = self._make_layer( 135 | block, block_inplanes[1], layers[1], shortcut_type, stride=2 136 | ) 137 | self.layer3 = self._make_layer( 138 | block, block_inplanes[2], layers[2], shortcut_type, stride=2 139 | ) 140 | self.layer4 = self._make_layer( 141 | block, block_inplanes[3], layers[3], shortcut_type, stride=2 142 | ) 143 | 144 | self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) 145 | self.fc = nn.Linear(block_inplanes[3] * block.expansion, n_classes) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv3d): 149 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 150 | elif isinstance(m, nn.BatchNorm3d): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | 154 | def _downsample_basic_block(self, x, planes, stride): 155 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 156 | zero_pads = torch.zeros( 157 | out.size(0), planes - out.size(1), out.size(2), out.size(3), out.size(4) 158 | ) 159 | if isinstance(out.data, torch.cuda.FloatTensor): 160 | zero_pads = zero_pads.cuda() 161 | 162 | out = torch.cat([out.data, zero_pads], dim=1) 163 | 164 | return out 165 | 166 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 167 | downsample = None 168 | if stride != 1 or self.in_planes != planes * block.expansion: 169 | if shortcut_type == "A": 170 | downsample = partial( 171 | self._downsample_basic_block, 172 | planes=planes * block.expansion, 173 | stride=stride, 174 | ) 175 | else: 176 | downsample = nn.Sequential( 177 | conv1x1x1(self.in_planes, planes * block.expansion, stride), 178 | nn.BatchNorm3d(planes * block.expansion), 179 | ) 180 | 181 | layers = [] 182 | layers.append( 183 | block( 184 | in_planes=self.in_planes, 185 | planes=planes, 186 | stride=stride, 187 | downsample=downsample, 188 | ) 189 | ) 190 | self.in_planes = planes * block.expansion 191 | for i in range(1, blocks): 192 | layers.append(block(self.in_planes, planes)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def forward(self, x): 197 | x = self.conv1(x) 198 | x = self.bn1(x) 199 | x = self.relu(x) 200 | if not self.no_max_pool: 201 | x = self.maxpool(x) 202 | 203 | x = self.layer1(x) 204 | x = self.layer2(x) 205 | x = self.layer3(x) 206 | x = self.layer4(x) 207 | 208 | x = self.avgpool(x) 209 | 210 | x = x.view(x.size(0), -1) 211 | x = self.fc(x) 212 | 213 | return x 214 | 215 | 216 | def generate_model(model_depth, n_classes): 217 | assert model_depth in [10, 18, 34, 50, 101, 152, 200] 218 | 219 | kwargs = { 220 | "n_input_channels": 3, 221 | "conv1_t_size": 7, 222 | "conv1_t_stride": 1, 223 | "no_max_pool": False, 224 | "shortcut_type": "B", 225 | "widen_factor": 1.0, 226 | "n_classes": n_classes, 227 | } 228 | 229 | if model_depth == 10: 230 | model = ResNet3D(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs) 231 | elif model_depth == 18: 232 | model = ResNet3D(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs) 233 | elif model_depth == 34: 234 | model = ResNet3D(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs) 235 | elif model_depth == 50: 236 | model = ResNet3D(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs) 237 | elif model_depth == 101: 238 | model = ResNet3D(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs) 239 | elif model_depth == 152: 240 | model = ResNet3D(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs) 241 | elif model_depth == 200: 242 | model = ResNet3D(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs) 243 | 244 | return model 245 | -------------------------------------------------------------------------------- /pe_models/preprocess/lidc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import mdai 3 | import pydicom 4 | import sys 5 | import os 6 | import glob 7 | import tqdm 8 | import numpy as np 9 | import h5py 10 | 11 | sys.path.append(os.getcwd()) 12 | 13 | from collections import defaultdict 14 | from pe_models.constants import * 15 | from pe_models import utils 16 | from sklearn.model_selection import train_test_split 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | import cv2 20 | import pylidc as pl 21 | 22 | 23 | def process_window_df(df:pd.DataFrame, num_slices:int=24, min_abnormal_slice:int=4): 24 | 25 | window_labels_csv = LIDC_DATA_DIR / f"lidc_window_{num_slices}_min_abnormal_{min_abnormal_slice}.csv" 26 | 27 | # count number of windows per slice 28 | count_num_windows = lambda x: x // num_slices\ 29 | + (1 if x % num_slices > 0 else 0) 30 | df[LIDC_NUM_WINDOW_COL] = df[LIDC_NUM_SLICES_COL].apply( 31 | count_num_windows) 32 | 33 | # get windows list 34 | df_study = df.groupby([LIDC_STUDY_COL]).head(1) 35 | window_labels = defaultdict(list) 36 | 37 | # studies 38 | for _, row in tqdm.tqdm(df_study.iterrows(), total=df_study.shape[0]): 39 | study_name = row[LIDC_STUDY_COL] 40 | split = row[LIDC_SPLIT_COL] 41 | study_df = df[df[LIDC_STUDY_COL] == study_name] 42 | study_df = study_df.sort_values("ImagePositionPatient_2") 43 | 44 | # windows 45 | for idx in range(row[LIDC_NUM_WINDOW_COL]): 46 | start_idx = idx * num_slices 47 | end_idx = (idx+1) * num_slices 48 | 49 | window_df = study_df.iloc[start_idx: end_idx] 50 | num_positives_slices = window_df[LIDC_NOD_SLICE_COL].sum() 51 | label = 1 if num_positives_slices >= min_abnormal_slice else 0 52 | window_labels[LIDC_STUDY_COL].append(study_name) 53 | window_labels['index'].append(idx) 54 | window_labels['label'].append(label) 55 | window_labels[LIDC_SPLIT_COL].append(split) 56 | window_labels["ImagePositionPatient_2"].append(window_df["ImagePositionPatient_2"].tolist()) 57 | window_labels[LIDC_NOD_SLICE_COL].append(window_df[LIDC_NOD_SLICE_COL].tolist()) 58 | window_labels[LIDC_INSTANCE_COL].append(window_df[LIDC_INSTANCE_COL].tolist()) 59 | window_labels[LIDC_INSTANCE_ORDER_COL].append(window_df[LIDC_INSTANCE_ORDER_COL].tolist()) 60 | df = pd.DataFrame.from_dict(window_labels) 61 | df.to_csv(window_labels_csv) 62 | 63 | return df 64 | 65 | 66 | def process_dicom(dcm): 67 | pixel_array = dcm.pixel_array 68 | 69 | intercept = dcm.RescaleIntercept 70 | slope = dcm.RescaleSlope 71 | pixel_array = pixel_array * slope + intercept 72 | 73 | resize_shape = 256 # self.cfg.data.imsize 74 | pixel_array = cv2.resize( 75 | pixel_array, (resize_shape, resize_shape), interpolation=cv2.INTER_AREA 76 | ) 77 | return pixel_array 78 | 79 | 80 | def process_study_to_hdf5(): 81 | records = [] 82 | hdf5_fh = h5py.File(LIDC_STUDY_HDF5, "a") 83 | for scan in tqdm.tqdm(pl.query(pl.Scan).all()): 84 | dicoms = scan.load_all_dicom_images(verbose=False) 85 | if len(dicoms) == 0: 86 | continue 87 | series = np.stack([process_dicom(dcm) for dcm in dicoms]) 88 | hdf5_fh.create_dataset( 89 | scan.study_instance_uid, 90 | data=series, 91 | dtype="float32", 92 | chunks=True 93 | ) 94 | 95 | if len(scan.annotations) > 0: 96 | start_stop = np.stack([annot.bbox_matrix()[2] for annot in scan.annotations]) 97 | for i, dcm in enumerate(dicoms): 98 | records.append({ 99 | "PatientID": dcm.PatientID, 100 | "StudyInstanceUID": dcm.StudyInstanceUID, 101 | "SeriesInstanceUID": dcm.SeriesInstanceUID, 102 | "SOPInstanceUID": dcm.SOPInstanceUID, 103 | "image_index": i, 104 | "nodule_present_on_image": int( 105 | sum((start_stop[:, 0] <= i) & (i <= start_stop[:, 1])) >= 3 106 | ) if len(scan.annotations) > 0 else 0, 107 | "InstanceNumber": dcm.InstanceNumber, 108 | "ImagePositionPatient_0": dcm.ImagePositionPatient[0], 109 | "ImagePositionPatient_1": dcm.ImagePositionPatient[1], 110 | "ImagePositionPatient_2": dcm.ImagePositionPatient[2], 111 | "ImageOrientationPatient_0": dcm.ImageOrientationPatient[0], 112 | "ImageOrientationPatient_1": dcm.ImageOrientationPatient[1], 113 | "ImageOrientationPatient_2": dcm.ImageOrientationPatient[2], 114 | "ImageOrientationPatient_3": dcm.ImageOrientationPatient[3], 115 | "ImageOrientationPatient_4": dcm.ImageOrientationPatient[4], 116 | "ImageOrientationPatient_5": dcm.ImageOrientationPatient[5], 117 | "PixelSpacing_0": dcm.PixelSpacing[0], 118 | "PixelSpacing_1": dcm.PixelSpacing[1], 119 | "RescaleIntercept": dcm.RescaleIntercept, 120 | "RescaleSlope": dcm.RescaleSlope, 121 | "WindowCenter": dcm.get("WindowCenter"), 122 | "WindowWidth": dcm.get("WindowWidth"), 123 | }) 124 | hdf5_fh.close() 125 | 126 | pd.DataFrame.from_records(records).to_csv(LIDC_DICOM_CSV, index=False) 127 | 128 | 129 | def add_split_to_label_df( 130 | label_df: pd.DataFrame, 131 | val_size: float = 0.15, 132 | test_size: float = 0.15, 133 | patient_col: str = LIDC_PATIENT_COL, 134 | split_col: str = LIDC_SPLIT_COL, 135 | ): 136 | patients = label_df[patient_col].unique() 137 | 138 | # split between train and val+test 139 | split_ratio = val_size + test_size 140 | train_patients, test_val_patients = train_test_split( 141 | patients, test_size=split_ratio, random_state=RANDOM_SEED 142 | ) 143 | 144 | # split between val and test 145 | test_split_ratio = test_size / (val_size + test_size) 146 | val_patients, test_patients = train_test_split( 147 | test_val_patients, test_size=test_split_ratio, random_state=RANDOM_SEED 148 | ) 149 | train_rows = label_df[patient_col].isin(train_patients) 150 | label_df.loc[train_rows, split_col] = "train" 151 | val_rows = label_df[patient_col].isin(val_patients) 152 | label_df.loc[val_rows, split_col] = "valid" 153 | test_rows = label_df[patient_col].isin(test_patients) 154 | label_df.loc[test_rows, split_col] = "test" 155 | 156 | return label_df 157 | 158 | 159 | if __name__ == "__main__": 160 | if (not LIDC_STUDY_HDF5.is_file()) or (not LIDC_DICOM_CSV.is_file()): 161 | print('\n'+'='*80) 162 | print(f'\nParsing study HDF5 to {LIDC_STUDY_HDF5} and creating {LIDC_DICOM_CSV}') 163 | print('-'*80) 164 | process_study_to_hdf5() 165 | else: 166 | print('='*80) 167 | print(f'\n{LIDC_STUDY_HDF5} and {LIDC_DICOM_CSV} already existed and processed') 168 | print('-'*80) 169 | 170 | if not LIDC_TRAIN_CSV.is_file(): 171 | print('='*80) 172 | print(f'\nProcessing LIDC dataset metadata and save as {LIDC_TRAIN_CSV}') 173 | print('-'*80) 174 | 175 | lidc = pd.read_csv(LIDC_DICOM_CSV) 176 | 177 | # full dataset split 178 | lidc = add_split_to_label_df(lidc, split_col=LIDC_SPLIT_COL) 179 | for split in ["train", "valid", "test"]: 180 | print( 181 | f"Full split (patients) {split}: " 182 | + f"{lidc[lidc[LIDC_SPLIT_COL] == split][LIDC_PATIENT_COL].nunique()}" 183 | ) 184 | 185 | # get slice number 186 | unique_studies = pd.DataFrame(lidc[LIDC_STUDY_COL].value_counts()).reset_index() 187 | unique_studies.columns = [LIDC_STUDY_COL, LIDC_NUM_SLICES_COL] 188 | lidc = lidc.merge(unique_studies, on=LIDC_STUDY_COL) 189 | 190 | # create windowed dataset: default window_size=24, min_abnormal_slice: 4 191 | window_df = process_window_df(lidc) 192 | 193 | lidc = lidc.to_csv(LIDC_TRAIN_CSV, index=False) 194 | else: 195 | print('='*80) 196 | print(f'\n{LIDC_TRAIN_CSV} already existed and processed') 197 | print('-'*80) 198 | -------------------------------------------------------------------------------- /pe_models/lightning/classification_lightning_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import wandb 5 | import json 6 | import pandas as pd 7 | import pickle 8 | import os 9 | import h5py 10 | 11 | from .. import builder 12 | from .. import utils 13 | from ..constants import * 14 | from sklearn.metrics import average_precision_score, roc_auc_score 15 | from pytorch_lightning.core import LightningModule 16 | 17 | 18 | class PEClassificationLightningModel(LightningModule): 19 | """Pytorch-Lightning Module""" 20 | 21 | def __init__(self, cfg): 22 | """Pass in hyperparameters to the model""" 23 | # initalize superclass 24 | super().__init__() 25 | 26 | self.cfg = cfg 27 | self.model = builder.build_model(cfg) 28 | self.loss = builder.build_loss(cfg) 29 | self.target_names = RSNA_TARGET_TYPES[cfg.data.targets] 30 | 31 | if self.cfg.train.weighted_loss: 32 | print("Using weighted loss") 33 | self.loss_weights = torch.tensor( 34 | [RSNA_LOSS_WEIGHT[t] for t in self.target_names] 35 | ) 36 | else: 37 | self.loss_weights = None 38 | 39 | def configure_optimizers(self): 40 | optimizer = builder.build_optimizer(self.cfg, self.model) 41 | scheduler = builder.build_scheduler(self.cfg, optimizer) 42 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 43 | 44 | def training_step(self, batch, batch_idx): 45 | return self.shared_step(batch, "train") 46 | 47 | def validation_step(self, batch, batch_idx): 48 | return self.shared_step(batch, "val") 49 | 50 | def test_step(self, batch, batch_idx): 51 | return self.shared_step(batch, "test") 52 | 53 | def training_epoch_end(self, training_step_outputs): 54 | return self.shared_epoch_end(training_step_outputs, "train") 55 | 56 | def validation_epoch_end(self, validation_step_outputs): 57 | return self.shared_epoch_end(validation_step_outputs, "val") 58 | 59 | def test_epoch_end(self, test_step_outputs): 60 | return self.shared_epoch_end(test_step_outputs, "test") 61 | 62 | def shared_step(self, batch, split, extract_features=False): 63 | """Similar to traning step""" 64 | 65 | #x, y, instance_id, _ = batch 66 | x, y, instance_id = batch 67 | logit, features = self.model(x, get_features=True) 68 | 69 | if self.loss_weights is not None: 70 | weight = self.loss_weights.to(y.device) 71 | loss = self.loss(logit, y).mean(0) 72 | loss = torch.sum(weight * loss) 73 | else: 74 | loss = self.loss(logit, y) 75 | 76 | self.log( f"{split}_loss", loss, on_epoch=True, on_step=True, logger=False, prog_bar=True,) 77 | 78 | return_dict = {"y": y, "loss": loss, "logit": logit, "instance_id": instance_id, "features": features.cpu().detach()} 79 | return return_dict 80 | 81 | def shared_epoch_end(self, step_outputs, split): 82 | 83 | instance_id = [ids for x in step_outputs for ids in x["instance_id"]] 84 | logit = torch.cat([x["logit"] for x in step_outputs]) 85 | y = torch.cat([x["y"] for x in step_outputs]) 86 | prob = torch.sigmoid(logit) 87 | features = torch.cat([x["features"] for x in step_outputs]) 88 | 89 | if '-' in instance_id[0]: 90 | study_id = [i.split('-')[1] for i in instance_id] 91 | instance_id = [i.split('-')[0] for i in instance_id] 92 | 93 | # log auroc 94 | auroc_dict = utils.get_auroc(y, prob, self.target_names) 95 | for k, v in auroc_dict.items(): 96 | self.log(f"{split}/{k}_auroc", v, on_epoch=True, logger=True, prog_bar=True) 97 | 98 | # log auprc 99 | auprc_dict = utils.get_auprc(y, prob, self.target_names) 100 | for k, v in auprc_dict.items(): 101 | self.log(f"{split}/{k}_auprc", v, on_epoch=True, logger=True, prog_bar=True) 102 | 103 | if split == "test": 104 | # save results 105 | meta_dict = {"split": split} 106 | if not os.path.exists(self.cfg.output_dir): # I added this so it stops crashing 107 | os.makedirs(self.cfg.output_dir) 108 | results_csv = os.path.join(self.cfg.output_dir, "results.csv") 109 | auroc_dict = {f"{split}/{k}_auroc": [v] for k, v in auroc_dict.items()} 110 | auprc_dict = {f"{split}/{k}_auprc": [v] for k, v in auprc_dict.items()} 111 | results = {**meta_dict, **auroc_dict, **auprc_dict} 112 | 113 | results_df = pd.DataFrame.from_dict(results, orient="columns") 114 | if os.path.exists(results_csv): 115 | df = pd.read_csv(results_csv) 116 | results_df = pd.concat([df, results_df], ignore_index=True) 117 | results_df.to_csv(results_csv, index=False) 118 | print(f"\nResults saved at: {results_csv}") 119 | 120 | # save predictions 121 | y = [x[0] for x in list(y.cpu().detach().numpy())] 122 | prob = [x[0] for x in list(prob.cpu().detach().numpy())] 123 | prediction_dict = { 124 | "target": y, 125 | "prob": prob, 126 | "id": instance_id, 127 | } 128 | prediction_df = pd.DataFrame(prediction_dict) 129 | prediction_path = os.path.join(self.cfg.output_dir, "preds.csv") 130 | prediction_df.to_csv(prediction_path, index=False) 131 | print(f"\nPredictions saved at: {prediction_path}") 132 | 133 | # save features 134 | if self.cfg.data.type == '2d' or self.cfg.data.type == 'lidc-2d': 135 | _train_csv = RSNA_TRAIN_CSV if self.cfg.data.type == '2d' else LIDC_TRAIN_CSV 136 | _instance_col = RSNA_INSTANCE_COL if self.cfg.data.type == '2d' else LIDC_INSTANCE_COL 137 | _study_col = RSNA_STUDY_COL if self.cfg.data.type == '2d' else LIDC_STUDY_COL 138 | 139 | df_full = pd.read_csv(_train_csv) 140 | prediction_df.columns = ['target', 'prob', _instance_col] 141 | prediction_df = prediction_df.set_index(_instance_col) 142 | df_full = df_full.set_index(_instance_col) 143 | df = prediction_df.join(df_full).reset_index() 144 | 145 | print('Instance-level performance') 146 | print('--------------------------') 147 | #for split in ['train', 'valid', 'test']: 148 | for split in ['test']: 149 | print(f'{split} AUROC: {roc_auc_score(df[df.Split == split].target, df[df.Split == split].prob)}') 150 | print(f'{split} AUPRC: {average_precision_score(df[df.Split == split].target, df[df.Split == split].prob)}') 151 | 152 | print('\n\n') 153 | print('Study-level performance') 154 | print('-----------------------') 155 | study_df = df[[_study_col, 'target', 'prob', 'Split']].groupby(_study_col).max().reset_index() 156 | for split in ['train', 'valid', 'test']: 157 | print(f'{split} AUROC: {roc_auc_score(study_df[study_df.Split == split].target, study_df[study_df.Split == split].prob)}') 158 | print(f'{split} AUROC: {average_precision_score(study_df[study_df.Split == split].target, study_df[study_df.Split == split].prob)}') 159 | 160 | df = df.groupby(_study_col) 161 | 162 | id2feature = {k:v.numpy() for k,v in zip(instance_id, features)} 163 | 164 | features_path = os.path.join(self.cfg.output_dir, "features.hdf5") 165 | hdf5_fn = h5py.File(features_path, 'a') 166 | for study, grouped_df in df: 167 | grouped_df = grouped_df.sort_values(by=['InstanceOrder']) 168 | features = np.stack([id2feature[ids] for ids in grouped_df[_instance_col].tolist()]) 169 | hdf5_fn.create_dataset(study, data=features, dtype='float32', chunks=True) 170 | hdf5_fn.close() 171 | print(f"\nFeatures saved at: {features_path}") 172 | -------------------------------------------------------------------------------- /pe_models/models/backbones3d/s3dg.py: -------------------------------------------------------------------------------- 1 | # modified from https://raw.githubusercontent.com/qijiezhao/s3d.pytorch/master/S3DG_Pytorch.py 2 | import torch.nn as nn 3 | import torch 4 | 5 | ## pytorch default: torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 6 | ## tensorflow s3d code: torch.nn.BatchNorm3d(num_features, eps=1e-3, momentum=0.001, affine=True, track_running_stats=True) 7 | 8 | class BasicConv3d(nn.Module): 9 | def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): 10 | super(BasicConv3d, self).__init__() 11 | self.conv = nn.Conv3d(in_planes, out_planes, 12 | kernel_size=kernel_size, stride=stride, 13 | padding=padding, bias=False) 14 | 15 | # self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) 16 | self.bn = nn.BatchNorm3d(out_planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | # init 20 | self.conv.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std 21 | self.bn.weight.data.fill_(1) 22 | self.bn.bias.data.zero_() 23 | 24 | def forward(self, x): 25 | x = self.conv(x) 26 | x = self.bn(x) 27 | x = self.relu(x) 28 | return x 29 | 30 | class STConv3d(nn.Module): 31 | def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0): 32 | super(STConv3d, self).__init__() 33 | if isinstance(stride, tuple): 34 | t_stride = stride[0] 35 | stride = stride[-1] 36 | else: # int 37 | t_stride = stride 38 | 39 | self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=(1,kernel_size,kernel_size), 40 | stride=(1,stride,stride),padding=(0,padding,padding), bias=False) 41 | self.conv2 = nn.Conv3d(out_planes,out_planes,kernel_size=(kernel_size,1,1), 42 | stride=(t_stride,1,1),padding=(padding,0,0), bias=False) 43 | 44 | # self.bn1=nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) 45 | # self.bn2=nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True) 46 | self.bn1=nn.BatchNorm3d(out_planes) 47 | self.bn2=nn.BatchNorm3d(out_planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | 50 | # init 51 | self.conv1.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std 52 | self.conv2.weight.data.normal_(mean=0, std=0.01) # original s3d is truncated normal within 2 std 53 | self.bn1.weight.data.fill_(1) 54 | self.bn1.bias.data.zero_() 55 | self.bn2.weight.data.fill_(1) 56 | self.bn2.bias.data.zero_() 57 | 58 | def forward(self,x): 59 | x=self.conv1(x) 60 | x=self.bn1(x) 61 | x=self.relu(x) 62 | x=self.conv2(x) 63 | x=self.bn2(x) 64 | x=self.relu(x) 65 | return x 66 | 67 | 68 | class SelfGating(nn.Module): 69 | def __init__(self, input_dim): 70 | super(SelfGating, self).__init__() 71 | self.fc = nn.Linear(input_dim, input_dim) 72 | 73 | def forward(self, input_tensor): 74 | """Feature gating as used in S3D-G""" 75 | spatiotemporal_average = torch.mean(input_tensor, dim=[2, 3, 4]) 76 | weights = self.fc(spatiotemporal_average) 77 | weights = torch.sigmoid(weights) 78 | return weights[:, :, None, None, None] * input_tensor 79 | 80 | 81 | class SepInception(nn.Module): 82 | def __init__(self, in_planes, out_planes, gating=False): 83 | super(SepInception, self).__init__() 84 | 85 | assert len(out_planes) == 6 86 | assert isinstance(out_planes, list) 87 | 88 | [num_out_0_0a, 89 | num_out_1_0a, num_out_1_0b, 90 | num_out_2_0a, num_out_2_0b, 91 | num_out_3_0b] = out_planes 92 | 93 | self.branch0 = nn.Sequential( 94 | BasicConv3d(in_planes, num_out_0_0a, kernel_size=1, stride=1), 95 | ) 96 | self.branch1 = nn.Sequential( 97 | BasicConv3d(in_planes, num_out_1_0a, kernel_size=1, stride=1), 98 | STConv3d(num_out_1_0a, num_out_1_0b, kernel_size=3, stride=1, padding=1), 99 | ) 100 | self.branch2 = nn.Sequential( 101 | BasicConv3d(in_planes, num_out_2_0a, kernel_size=1, stride=1), 102 | STConv3d(num_out_2_0a, num_out_2_0b, kernel_size=3, stride=1, padding=1), 103 | ) 104 | self.branch3 = nn.Sequential( 105 | nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), 106 | BasicConv3d(in_planes, num_out_3_0b, kernel_size=1, stride=1), 107 | ) 108 | 109 | self.out_channels = sum([num_out_0_0a, num_out_1_0b, num_out_2_0b, num_out_3_0b]) 110 | 111 | self.gating = gating 112 | if gating: 113 | self.gating_b0 = SelfGating(num_out_0_0a) 114 | self.gating_b1 = SelfGating(num_out_1_0b) 115 | self.gating_b2 = SelfGating(num_out_2_0b) 116 | self.gating_b3 = SelfGating(num_out_3_0b) 117 | 118 | 119 | def forward(self, x): 120 | x0 = self.branch0(x) 121 | x1 = self.branch1(x) 122 | x2 = self.branch2(x) 123 | x3 = self.branch3(x) 124 | if self.gating: 125 | x0 = self.gating_b0(x0) 126 | x1 = self.gating_b1(x1) 127 | x2 = self.gating_b2(x2) 128 | x3 = self.gating_b3(x3) 129 | 130 | out = torch.cat((x0, x1, x2, x3), 1) 131 | 132 | return out 133 | 134 | 135 | class S3D(nn.Module): 136 | 137 | def __init__(self, input_channel=3, gating=False, slow=False): 138 | super(S3D, self).__init__() 139 | self.gating = gating 140 | self.slow = slow 141 | 142 | if slow: 143 | self.Conv_1a = STConv3d(input_channel, 64, kernel_size=7, stride=(1,2,2), padding=3) 144 | else: # normal 145 | self.Conv_1a = STConv3d(input_channel, 64, kernel_size=7, stride=2, padding=3) 146 | 147 | self.block1 = nn.Sequential(self.Conv_1a) # (64, 32, 112, 112) 148 | 149 | ################################### 150 | 151 | self.MaxPool_2a = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) 152 | self.Conv_2b = BasicConv3d(64, 64, kernel_size=1, stride=1) 153 | self.Conv_2c = STConv3d(64, 192, kernel_size=3, stride=1, padding=1) 154 | 155 | self.block2 = nn.Sequential( 156 | self.MaxPool_2a, # (64, 32, 56, 56) 157 | self.Conv_2b, # (64, 32, 56, 56) 158 | self.Conv_2c) # (192, 32, 56, 56) 159 | 160 | ################################### 161 | 162 | self.MaxPool_3a = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)) 163 | self.Mixed_3b = SepInception(in_planes=192, out_planes=[64, 96, 128, 16, 32, 32], gating=gating) 164 | self.Mixed_3c = SepInception(in_planes=256, out_planes=[128, 128, 192, 32, 96, 64], gating=gating) 165 | 166 | self.block3 = nn.Sequential( 167 | self.MaxPool_3a, # (192, 32, 28, 28) 168 | self.Mixed_3b, # (256, 32, 28, 28) 169 | self.Mixed_3c) # (480, 32, 28, 28) 170 | 171 | ################################### 172 | 173 | self.MaxPool_4a = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) 174 | self.Mixed_4b = SepInception(in_planes=480, out_planes=[192, 96, 208, 16, 48, 64], gating=gating) 175 | self.Mixed_4c = SepInception(in_planes=512, out_planes=[160, 112, 224, 24, 64, 64], gating=gating) 176 | self.Mixed_4d = SepInception(in_planes=512, out_planes=[128, 128, 256, 24, 64, 64], gating=gating) 177 | self.Mixed_4e = SepInception(in_planes=512, out_planes=[112, 144, 288, 32, 64, 64], gating=gating) 178 | self.Mixed_4f = SepInception(in_planes=528, out_planes=[256, 160, 320, 32, 128, 128], gating=gating) 179 | 180 | self.block4 = nn.Sequential( 181 | self.MaxPool_4a, # (480, 16, 14, 14) 182 | self.Mixed_4b, # (512, 16, 14, 14) 183 | self.Mixed_4c, # (512, 16, 14, 14) 184 | self.Mixed_4d, # (512, 16, 14, 14) 185 | self.Mixed_4e, # (528, 16, 14, 14) 186 | self.Mixed_4f) # (832, 16, 14, 14) 187 | 188 | ################################### 189 | 190 | self.MaxPool_5a = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)) 191 | self.Mixed_5b = SepInception(in_planes=832, out_planes=[256, 160, 320, 32, 128, 128], gating=gating) 192 | self.Mixed_5c = SepInception(in_planes=832, out_planes=[384, 192, 384, 48, 128, 128], gating=gating) 193 | 194 | self.block5 = nn.Sequential( 195 | self.MaxPool_5a, # (832, 8, 7, 7) 196 | self.Mixed_5b, # (832, 8, 7, 7) 197 | self.Mixed_5c) # (1024, 8, 7, 7) 198 | 199 | ################################### 200 | 201 | # self.AvgPool_0a = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) 202 | # self.Dropout_0b = nn.Dropout3d(dropout_keep_prob) 203 | # self.Conv_0c = nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True) 204 | 205 | # self.classifier = nn.Sequential( 206 | # self.AvgPool_0a, 207 | # self.Dropout_0b, 208 | # self.Conv_0c) 209 | 210 | 211 | def forward(self, x): 212 | x = self.block1(x) 213 | x = self.block2(x) 214 | x = self.block3(x) 215 | x = self.block4(x) 216 | x = self.block5(x) 217 | return x 218 | 219 | 220 | 221 | if __name__=='__main__': 222 | model=S3D(num_classes=400) -------------------------------------------------------------------------------- /pe_models/models/backbones3d/r21d.py: -------------------------------------------------------------------------------- 1 | # modified from https://raw.githubusercontent.com/leftthomas/R2Plus1D-C3D/master/models/R2Plus1D.py 2 | import math 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn.modules.utils import _triple 8 | 9 | 10 | class SpatioTemporalConv(nn.Module): 11 | """Applies a factored 3D convolution over an input signal composed of several input 12 | planes with distinct spatial and time axes, by performing a 2D convolution over the 13 | spatial axes to an intermediate subspace, followed by a 1D convolution over the time 14 | axis to produce the final output. 15 | Args: 16 | in_channels (int): Number of channels in the input tensor 17 | out_channels (int): Number of channels produced by the convolution 18 | kernel_size (int or tuple): Size of the convolving kernel 19 | stride (int or tuple, optional): Stride of the convolution. Default: 1 20 | padding (int or tuple, optional): Zero-padding added to the sides of the input during their respective convolutions. Default: 0 21 | bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, first_conv=False): 25 | super(SpatioTemporalConv, self).__init__() 26 | 27 | # if ints are entered, convert them to iterables, 1 -> [1, 1, 1] 28 | kernel_size = _triple(kernel_size) 29 | stride = _triple(stride) 30 | padding = _triple(padding) 31 | 32 | # decomposing the parameters into spatial and temporal components by 33 | # masking out the values with the defaults on the axis that 34 | # won't be convolved over. This is necessary to avoid unintentional 35 | # behavior such as padding being added twice 36 | spatial_kernel_size = (1, kernel_size[1], kernel_size[2]) 37 | spatial_stride = (1, stride[1], stride[2]) 38 | spatial_padding = (0, padding[1], padding[2]) 39 | 40 | temporal_kernel_size = (kernel_size[0], 1, 1) 41 | temporal_stride = (stride[0], 1, 1) 42 | temporal_padding = (padding[0], 0, 0) 43 | 44 | # compute the number of intermediary channels (M) using formula 45 | # from the paper section 3.5 46 | intermed_channels = int( 47 | math.floor((kernel_size[0] * kernel_size[1] * kernel_size[2] * in_channels * out_channels) / \ 48 | (kernel_size[1] * kernel_size[2] * in_channels + kernel_size[0] * out_channels))) 49 | # print(intermed_channels) 50 | 51 | # the spatial conv is effectively a 2D conv due to the 52 | # spatial_kernel_size, followed by batch_norm and ReLU 53 | self.spatial_conv = nn.Conv3d(in_channels, intermed_channels, spatial_kernel_size, 54 | stride=spatial_stride, padding=spatial_padding, bias=bias) 55 | self.bn = nn.BatchNorm3d(intermed_channels) 56 | self.relu = nn.ReLU() 57 | 58 | # the temporal conv is effectively a 1D conv, but has batch norm 59 | # and ReLU added inside the model constructor, not here. This is an 60 | # intentional design choice, to allow this module to externally act 61 | # identical to a standard Conv3D, so it can be reused easily in any 62 | # other codebase 63 | self.temporal_conv = nn.Conv3d(intermed_channels, out_channels, temporal_kernel_size, 64 | stride=temporal_stride, padding=temporal_padding, bias=bias) 65 | 66 | def forward(self, x): 67 | x = self.relu(self.bn(self.spatial_conv(x))) 68 | x = self.temporal_conv(x) 69 | return x 70 | 71 | 72 | class SpatioTemporalResBlock(nn.Module): 73 | r"""Single block for the ResNet network. Uses SpatioTemporalConv in 74 | the standard ResNet block layout (conv->batchnorm->ReLU->conv->batchnorm->sum->ReLU) 75 | Args: 76 | in_channels (int): Number of channels in the input tensor. 77 | out_channels (int): Number of channels in the output produced by the block. 78 | kernel_size (int or tuple): Size of the convolving kernels. 79 | downsample (bool, optional): If ``True``, the output size is to be smaller than the input. Default: ``False`` 80 | """ 81 | 82 | def __init__(self, in_channels, out_channels, kernel_size, downsample=False): 83 | super(SpatioTemporalResBlock, self).__init__() 84 | 85 | # If downsample == True, the first conv of the layer has stride = 2 86 | # to halve the residual output size, and the input x is passed 87 | # through a seperate 1x1x1 conv with stride = 2 to also halve it. 88 | 89 | # no pooling layers are used inside ResNet 90 | self.downsample = downsample 91 | 92 | # to allow for SAME padding 93 | padding = kernel_size // 2 94 | 95 | if self.downsample: 96 | # downsample with stride =2 the input x 97 | self.downsampleconv = SpatioTemporalConv(in_channels, out_channels, 1, stride=2) 98 | self.downsamplebn = nn.BatchNorm3d(out_channels) 99 | 100 | # downsample with stride = 2 when producing the residual 101 | self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding, stride=2) 102 | else: 103 | self.conv1 = SpatioTemporalConv(in_channels, out_channels, kernel_size, padding=padding) 104 | 105 | self.bn1 = nn.BatchNorm3d(out_channels) 106 | self.relu1 = nn.ReLU() 107 | 108 | # standard conv->batchnorm->ReLU 109 | self.conv2 = SpatioTemporalConv(out_channels, out_channels, kernel_size, padding=padding) 110 | self.bn2 = nn.BatchNorm3d(out_channels) 111 | self.outrelu = nn.ReLU() 112 | 113 | def forward(self, x): 114 | res = self.relu1(self.bn1(self.conv1(x))) 115 | res = self.bn2(self.conv2(res)) 116 | 117 | if self.downsample: 118 | x = self.downsamplebn(self.downsampleconv(x)) 119 | 120 | return self.outrelu(x + res) 121 | 122 | 123 | class SpatioTemporalResLayer(nn.Module): 124 | r"""Forms a single layer of the ResNet network, with a number of repeating 125 | blocks of same output size stacked on top of each other 126 | Args: 127 | in_channels (int): Number of channels in the input tensor. 128 | out_channels (int): Number of channels in the output produced by the layer. 129 | kernel_size (int or tuple): Size of the convolving kernels. 130 | layer_size (int): Number of blocks to be stacked to form the layer 131 | block_type (Module, optional): Type of block that is to be used to form the layer. Default: SpatioTemporalResBlock. 132 | downsample (bool, optional): If ``True``, the first block in layer will implement downsampling. Default: ``False`` 133 | """ 134 | 135 | def __init__(self, in_channels, out_channels, kernel_size, layer_size, block_type=SpatioTemporalResBlock, 136 | downsample=False): 137 | 138 | super(SpatioTemporalResLayer, self).__init__() 139 | 140 | # implement the first block 141 | self.block1 = block_type(in_channels, out_channels, kernel_size, downsample) 142 | 143 | # prepare module list to hold all (layer_size - 1) blocks 144 | self.blocks = nn.ModuleList([]) 145 | for i in range(layer_size - 1): 146 | # all these blocks are identical, and have downsample = False by default 147 | self.blocks += [block_type(out_channels, out_channels, kernel_size)] 148 | 149 | def forward(self, x): 150 | x = self.block1(x) 151 | for block in self.blocks: 152 | x = block(x) 153 | 154 | return x 155 | 156 | 157 | class R2Plus1DNet(nn.Module): 158 | r"""Forms the overall ResNet feature extractor by initializng 5 layers, with the number of blocks in 159 | each layer set by layer_sizes, and by performing a global average pool at the end producing a 160 | 512-dimensional vector for each element in the batch. 161 | Args: 162 | layer_sizes (tuple): An iterable containing the number of blocks in each layer 163 | block_type (Module, optional): Type of block that is to be used to form the layers. Default: SpatioTemporalResBlock. 164 | """ 165 | 166 | def __init__(self, layer_sizes=(1,1,1,1), block_type=SpatioTemporalResBlock, num_classes=4): 167 | super(R2Plus1DNet, self).__init__() 168 | # self.num_classes = num_classes 169 | 170 | # first conv, with stride 1x2x2 and kernel size 1x7x7 171 | self.conv1 = SpatioTemporalConv(3, 64, (3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3)) 172 | self.bn1 = nn.BatchNorm3d(64) 173 | self.relu1 = nn.ReLU() 174 | # output of conv2 is same size as of conv1, no downsampling needed. kernel_size 3x3x3 175 | self.conv2 = SpatioTemporalResLayer(64, 64, 3, layer_sizes[0], block_type=block_type) 176 | # each of the final three layers doubles num_channels, while performing downsampling 177 | # inside the first block 178 | self.conv3 = SpatioTemporalResLayer(64, 128, 3, layer_sizes[1], block_type=block_type, downsample=True) 179 | self.conv4 = SpatioTemporalResLayer(128, 256, 3, layer_sizes[2], block_type=block_type, downsample=True) 180 | self.conv5 = SpatioTemporalResLayer(256, 512, 3, layer_sizes[3], block_type=block_type, downsample=True) 181 | 182 | # global average pooling of the output 183 | self.pool = nn.AdaptiveAvgPool3d(1) 184 | 185 | # self.linear = nn.Linear(512, self.num_classes) 186 | 187 | def forward(self, x): 188 | x = self.relu1(self.bn1(self.conv1(x))) 189 | x = self.conv2(x) 190 | x = self.conv3(x) 191 | x = self.conv4(x) 192 | x = self.conv5(x) 193 | 194 | x = self.pool(x) 195 | # x = x.view(-1, 512) 196 | 197 | # x = self.linear(x) 198 | 199 | return x 200 | 201 | 202 | if __name__ == '__main__': 203 | r21d = R2Plus1DNet((1, 1, 1, 1)) -------------------------------------------------------------------------------- /pe_models/lightning/window_classification_lightning_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import wandb 5 | import json 6 | import pandas as pd 7 | import pickle 8 | import os 9 | import h5py 10 | 11 | from .. import builder 12 | from .. import utils 13 | from ..constants import * 14 | from sklearn.metrics import average_precision_score, roc_auc_score 15 | from pytorch_lightning.core import LightningModule 16 | from collections import defaultdict 17 | 18 | 19 | class PEWindowClassificationLightningModel(LightningModule): 20 | """Pytorch-Lightning Module""" 21 | 22 | def __init__(self, cfg): 23 | """Pass in hyperparameters to the model""" 24 | # initalize superclass 25 | super().__init__() 26 | 27 | self.cfg = cfg 28 | self.model = builder.build_model(cfg) 29 | self.loss = builder.build_loss(cfg) 30 | self.target_names = RSNA_TARGET_TYPES[cfg.data.targets] 31 | 32 | if self.cfg.train.weighted_loss: 33 | print("Using weighted loss") 34 | self.loss_weights = torch.tensor( 35 | [RSNA_LOSS_WEIGHT[t] for t in self.target_names] 36 | ) 37 | else: 38 | self.loss_weights = None 39 | 40 | # first layer: split 41 | # second layer: study_2_pred, study_2_label 42 | # third layer: study list 43 | self.results = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 44 | 45 | 46 | def configure_optimizers(self): 47 | optimizer = builder.build_optimizer(self.cfg, self.model) 48 | scheduler = builder.build_scheduler(self.cfg, optimizer) 49 | return {"optimizer": optimizer, "lr_scheduler": scheduler} 50 | 51 | def training_step(self, batch, batch_idx): 52 | return self.shared_step(batch, "train") 53 | 54 | def validation_step(self, batch, batch_idx): 55 | return self.shared_step(batch, "val") 56 | 57 | def test_step(self, batch, batch_idx): 58 | return self.shared_step(batch, "test") 59 | 60 | def training_epoch_end(self, training_step_outputs): 61 | return self.shared_epoch_end(training_step_outputs, "train") 62 | 63 | def validation_epoch_end(self, validation_step_outputs): 64 | return self.shared_epoch_end(validation_step_outputs, "val") 65 | 66 | def test_epoch_end(self, test_step_outputs): 67 | return self.shared_epoch_end(test_step_outputs, "test") 68 | 69 | def shared_step(self, batch, split, extract_features=False): 70 | """Similar to traning step""" 71 | 72 | x, y, instance_id = batch 73 | if self.cfg.data.type == "2d" or self.cfg.data.type == "lidc-2d": 74 | instance_id = [[i.split("-")[1] for i in instance_id], [i.split("-")[0] for i in instance_id]] 75 | logit, features = self.model(x, get_features=True) 76 | 77 | if self.loss_weights is not None: 78 | weight = self.loss_weights.to(y.device) 79 | loss = self.loss(logit, y).mean(0) 80 | loss = torch.sum(weight * loss) 81 | else: 82 | loss = self.loss(logit, y) 83 | 84 | # map window prediction to study 85 | batch_pred = torch.sigmoid(logit.clone().detach()).cpu().numpy() 86 | batch_label = y.clone().detach().cpu().numpy() 87 | # ALEX TO ADD HERE 88 | for s,p,l in zip(instance_id[0], batch_pred, batch_label): 89 | self.results[split]['study_2_pred'][s].append(p) 90 | self.results[split]['study_2_label'][s].append(l) 91 | 92 | self.log( 93 | f"{split}_loss", 94 | loss, 95 | on_epoch=True, 96 | on_step=True, 97 | logger=False, 98 | prog_bar=True, 99 | ) 100 | 101 | return_dict = {"y": y, "loss": loss, "logit": logit.detach(), "instance_id": instance_id, "features": features.cpu().detach()} 102 | return return_dict 103 | 104 | def shared_epoch_end(self, step_outputs, split): 105 | 106 | instance_id = [ids for x in step_outputs for ids in x["instance_id"]] 107 | logit = torch.cat([x["logit"] for x in step_outputs]) 108 | y = torch.cat([x["y"] for x in step_outputs]) 109 | prob = torch.sigmoid(logit) 110 | features = torch.cat([x["features"] for x in step_outputs]) 111 | 112 | # log window auroc 113 | auroc_dict = utils.get_auroc(y, prob, self.target_names) 114 | for k, v in auroc_dict.items(): 115 | self.log(f"{split}/window_{k}_auroc", v, on_epoch=True, logger=True, prog_bar=True) 116 | 117 | # log window auprc 118 | auprc_dict = utils.get_auprc(y, prob, self.target_names) 119 | for k, v in auprc_dict.items(): 120 | self.log(f"{split}/window_{k}_auprc", v, on_epoch=True, logger=True, prog_bar=True) 121 | 122 | # study level prediction (based on max pred) 123 | study_y, study_prob, study_ids= [], [], [] 124 | for study in self.results[split]['study_2_label'].keys(): 125 | label = 1 if 1 in self.results[split]['study_2_label'][study] else 0 126 | pred = max(self.results[split]['study_2_pred'][study]) 127 | study_y.append(label) 128 | study_prob.append(pred) 129 | study_ids.append(study) 130 | study_y = np.expand_dims(np.array(study_y), -1) 131 | study_prob = np.array(study_prob) 132 | 133 | # study level metrics 134 | auroc_dict = utils.get_auroc(study_y, study_prob, self.target_names) 135 | for k, v in auroc_dict.items(): 136 | self.log(f"{split}/{k}_auroc", v, on_epoch=True, logger=True, prog_bar=True) 137 | auprc_dict = utils.get_auprc(study_y, study_prob, self.target_names) 138 | for k, v in auprc_dict.items(): 139 | self.log(f"{split}/{k}_auprc", v, on_epoch=True, logger=True, prog_bar=True) 140 | 141 | # reset results 142 | self.results[split] = defaultdict(lambda: defaultdict(list)) 143 | 144 | if split == "test": 145 | # save results 146 | meta_dict = {"split": split} 147 | if not os.path.exists(self.cfg.output_dir): # I added this so it stops crashing 148 | os.makedirs(self.cfg.output_dir) 149 | results_csv = os.path.join(self.cfg.output_dir, "results.csv") 150 | auroc_dict = {f"{split}/{k}_auroc": [v] for k, v in auroc_dict.items()} 151 | auprc_dict = {f"{split}/{k}_auprc": [v] for k, v in auprc_dict.items()} 152 | results = {**meta_dict, **auroc_dict, **auprc_dict} 153 | 154 | results_df = pd.DataFrame.from_dict(results, orient="columns") 155 | if os.path.exists(results_csv): 156 | df = pd.read_csv(results_csv) 157 | results_df = pd.concat([df, results_df], ignore_index=True) 158 | results_df.to_csv(results_csv, index=False) 159 | print(f"\nResults saved at: {results_csv}") 160 | 161 | # save predictions 162 | y = [a[0] for a in list(y.cpu().detach().numpy())] 163 | prob = [a[0] for a in list(prob.cpu().detach().numpy())] 164 | study_id = [item for sublist in instance_id[::2] for item in sublist] 165 | if self.cfg.data.type == '2d' or self.cfg.data.type == 'lidc-2d': 166 | window_idx = [item for sublist in instance_id[1::2] for item in sublist] 167 | else: 168 | window_idx = [item for sublist in instance_id[1::2] for item in list(sublist.cpu().numpy())] 169 | 170 | prediction_dict = { 171 | "target": y, 172 | "prob": prob, 173 | "study_id": study_id, 174 | "window_idx": window_idx 175 | } 176 | prediction_path = os.path.join(self.cfg.output_dir, "preds.csv") 177 | #pickle.dump(prediction_dict, open(prediction_path, "wb")) 178 | pred_df = pd.DataFrame.from_dict(prediction_dict) 179 | pred_df.to_csv(prediction_path) 180 | print(f"\nPredictions saved at: {prediction_path}") 181 | 182 | 183 | # save features 184 | if self.cfg.data.type == '2d' or self.cfg.data.type == 'lidc-2d': 185 | _train_csv = RSNA_TRAIN_CSV if self.cfg.data.type == '2d' else LIDC_TRAIN_CSV 186 | _instance_col = RSNA_INSTANCE_COL if self.cfg.data.type == '2d' else LIDC_INSTANCE_COL 187 | _study_col = RSNA_STUDY_COL if self.cfg.data.type == '2d' else LIDC_STUDY_COL 188 | 189 | df_full = pd.read_csv(_train_csv) 190 | prediction_df = pred_df.copy() 191 | prediction_df = prediction_df.drop("study_id", axis=1) 192 | prediction_df.columns = ['target', 'prob', _instance_col] 193 | prediction_df = prediction_df.set_index(_instance_col) 194 | df_full = df_full.set_index(_instance_col) 195 | df = prediction_df.join(df_full).reset_index() 196 | 197 | print('Instance-level performance') 198 | print('--------------------------') 199 | for split in ['train', 'valid', 'test']: # these are not the appropriate split names 200 | print(f'{split} AUROC: {roc_auc_score(df[df.Split == split].target, df[df.Split == split].prob)}') 201 | print(f'{split} AUPRC: {average_precision_score(df[df.Split == split].target, df[df.Split == split].prob)}') 202 | 203 | print('\n\n') 204 | print('Study-level performance') 205 | print('-----------------------') 206 | study_df = df[[_study_col, 'target', 'prob', 'Split']].groupby(_study_col).max().reset_index() 207 | for split in ['train', 'valid', 'test']: 208 | print(f'{split} AUROC: {roc_auc_score(study_df[study_df.Split == split].target, study_df[study_df.Split == split].prob)}') 209 | print(f'{split} AUROC: {average_precision_score(study_df[study_df.Split == split].target, study_df[study_df.Split == split].prob)}') 210 | 211 | 212 | df = df.groupby(_study_col) 213 | instance_study_id = [f"{x}-{y}" for x, y in zip(window_idx, study_id)] 214 | id2feature = {k:v.numpy() for k,v in zip(instance_study_id, features)} 215 | 216 | features_path = os.path.join(self.cfg.output_dir, "features.hdf5") 217 | hdf5_fn = h5py.File(features_path, 'a') 218 | for study, grouped_df in df: 219 | _image_order_col = "InstanceOrder" if self.cfg.data.type == '2d' else "image_index" 220 | grouped_df = grouped_df.sort_values(by=[_image_order_col]) 221 | features = np.stack([ 222 | id2feature[ids] for ids in (grouped_df[_instance_col] + "-" + study).tolist() 223 | ]) 224 | hdf5_fn.create_dataset(study, data=features, dtype='float32', chunks=True) 225 | hdf5_fn.close() 226 | print(f"\nFeatures saved at: {features_path}") 227 | -------------------------------------------------------------------------------- /pe_models/datasets/dataset_1d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | import h5py 7 | import random 8 | 9 | from ..constants import * 10 | from .dataset_base import PEDatasetBase 11 | from omegaconf import OmegaConf 12 | 13 | 14 | class PEDataset1D(PEDatasetBase): 15 | def __init__(self, cfg, split="train", transform=None): 16 | super().__init__(cfg, split) 17 | 18 | self.cfg = cfg 19 | self.df = pd.read_csv(RSNA_TRAIN_CSV) 20 | 21 | if split != "all": 22 | self.df = self.df[self.df[RSNA_SPLIT_COL] == split] 23 | #self.df = self.df[self.df[RSNA_INSTITUTION_SPLIT_COL] == split] 24 | 25 | # hdf5 path 26 | self.hdf5_path = self.cfg.data.hdf5_path 27 | if self.hdf5_path is None: 28 | raise Exception("Encoded slice HDF5 required") 29 | 30 | with h5py.File(self.hdf5_path, 'r') as f: 31 | study = list(f.keys()) 32 | self.df = self.df[self.df[RSNA_STUDY_COL].isin(study)] 33 | 34 | if split == 'train': 35 | if not OmegaConf.is_none(cfg.data, 'sample_frac'): 36 | num_study = len(study) 37 | num_sample = int(num_study * cfg.data.sample_frac) 38 | sampled_study = np.random.choice(study, num_sample, replace=False) 39 | self.df = self.df[self.df[RSNA_STUDY_COL].isin(sampled_study)] 40 | 41 | # get studies 42 | grouped_df = self.df.groupby(RSNA_STUDY_COL).head(1)[ 43 | [RSNA_PE_TARGET_COL, RSNA_STUDY_COL] 44 | ] 45 | 46 | # use positive as 1 47 | self.pe_labels = [1 if t == 0 else 0 for t in grouped_df[RSNA_PE_TARGET_COL].tolist()] 48 | self.study = grouped_df[RSNA_STUDY_COL].tolist() 49 | 50 | 51 | def __getitem__(self, index): 52 | 53 | # read featurized series 54 | study = self.study[index] 55 | x = self.read_from_hdf5(study, hdf5_path=self.hdf5_path) 56 | 57 | # fix number of slices 58 | x = self.fix_series_slice_number(x) 59 | 60 | # contextualize slices 61 | if self.cfg.data.contextualize_slice: 62 | x = self.contextualize_slice(x) 63 | 64 | # create torch tensor 65 | x = torch.from_numpy(x).float() 66 | 67 | # fill 68 | x = self.fill_series_to_num_slicess(x, self.cfg.data.num_slices) 69 | 70 | # get traget 71 | y = self.pe_labels[index] 72 | y = torch.tensor(y).float().unsqueeze(-1) 73 | 74 | return x, y, study 75 | 76 | def __len__(self): 77 | return len(self.study) 78 | 79 | def contextualize_slice(self, arr): 80 | 81 | # make new empty array 82 | new_arr = np.zeros((arr.shape[0], arr.shape[1] * 3), dtype=np.float32) 83 | 84 | # fill first third of new array with original features 85 | for i in range(len(arr)): 86 | new_arr[i, : arr.shape[1]] = arr[i] 87 | 88 | # difference between previous neighbor 89 | new_arr[1:, arr.shape[1] : arr.shape[1] * 2] = ( 90 | new_arr[1:, : arr.shape[1]] - new_arr[:-1, : arr.shape[1]] 91 | ) 92 | 93 | # difference between next neighbor 94 | new_arr[:-1, arr.shape[1] * 2 :] = ( 95 | new_arr[:-1, : arr.shape[1]] - new_arr[1:, : arr.shape[1]] 96 | ) 97 | 98 | return new_arr 99 | 100 | def get_sampler(self): 101 | 102 | neg_class_count = (np.array(self.pe_labels) == 0).sum() 103 | pos_class_count = (np.array(self.pe_labels) == 1).sum() 104 | class_weight = [1 / neg_class_count, 1 / pos_class_count] 105 | weights = [class_weight[i] for i in self.pe_labels] 106 | 107 | weights = torch.Tensor(weights).double() 108 | sampler = torch.utils.data.sampler.WeightedRandomSampler( 109 | weights, num_samples=len(weights), replacement=True 110 | ) 111 | 112 | return sampler 113 | 114 | 115 | class LIDCDataset1D(PEDatasetBase): 116 | def __init__(self, cfg, split="train", transform=None): 117 | super().__init__(cfg, split) 118 | 119 | self.cfg = cfg 120 | self.df = pd.read_csv(LIDC_TRAIN_CSV) 121 | 122 | if split != "all": 123 | self.df = self.df[self.df[LIDC_SPLIT_COL] == split] 124 | 125 | # hdf5 path 126 | self.hdf5_path = self.cfg.data.hdf5_path 127 | if self.hdf5_path is None: 128 | raise Exception("Encoded slice HDF5 required") 129 | 130 | # TODO: 131 | with h5py.File(self.hdf5_path, 'r') as f: 132 | study = list(f.keys()) 133 | self.df = self.df[self.df[LIDC_STUDY_COL].isin(study)] 134 | 135 | if split == 'train': 136 | if not OmegaConf.is_none(cfg.data, 'sample_frac'): 137 | num_study = len(study) 138 | num_sample = int(num_study * cfg.data.sample_frac) 139 | sampled_study = np.random.choice(study, num_sample, replace=False) 140 | self.df = self.df[self.df[LIDC_STUDY_COL].isin(sampled_study)] 141 | 142 | # get studies 143 | grouped_df = self.df.groupby(LIDC_STUDY_COL)[LIDC_NOD_SLICE_COL].max() 144 | grouped_df = grouped_df.reset_index() 145 | grouped_df = grouped_df.rename({ 146 | LIDC_NOD_SLICE_COL: "nodule_present_in_study" 147 | }, axis=1) 148 | 149 | # use positive as 1 150 | self.pe_labels = grouped_df["nodule_present_in_study"].tolist() 151 | self.study = grouped_df[LIDC_STUDY_COL].tolist() 152 | print(len(self.study)) 153 | 154 | 155 | def __getitem__(self, index): 156 | 157 | # read featurized series 158 | study = self.study[index] 159 | x = self.read_from_hdf5(study, hdf5_path=self.hdf5_path) 160 | 161 | # fix number of slices 162 | x = self.fix_series_slice_number(x) 163 | 164 | # contextualize slices 165 | if self.cfg.data.contextualize_slice: 166 | x = self.contextualize_slice(x) 167 | 168 | # create torch tensor 169 | x = torch.from_numpy(x).float() 170 | 171 | # fill 172 | x = self.fill_series_to_num_slicess(x, self.cfg.data.num_slices) 173 | 174 | # get traget 175 | y = self.pe_labels[index] 176 | y = torch.tensor(y).float().unsqueeze(-1) 177 | 178 | return x, y, study 179 | 180 | def __len__(self): 181 | return len(self.study) 182 | 183 | def contextualize_slice(self, arr): 184 | 185 | # make new empty array 186 | new_arr = np.zeros((arr.shape[0], arr.shape[1] * 3), dtype=np.float32) 187 | 188 | # fill first third of new array with original features 189 | for i in range(len(arr)): 190 | new_arr[i, : arr.shape[1]] = arr[i] 191 | 192 | # difference between previous neighbor 193 | new_arr[1:, arr.shape[1] : arr.shape[1] * 2] = ( 194 | new_arr[1:, : arr.shape[1]] - new_arr[:-1, : arr.shape[1]] 195 | ) 196 | 197 | # difference between next neighbor 198 | new_arr[:-1, arr.shape[1] * 2 :] = ( 199 | new_arr[:-1, : arr.shape[1]] - new_arr[1:, : arr.shape[1]] 200 | ) 201 | 202 | return new_arr 203 | 204 | def get_sampler(self): 205 | 206 | neg_class_count = (np.array(self.pe_labels) == 0).sum() 207 | pos_class_count = (np.array(self.pe_labels) == 1).sum() 208 | class_weight = [1 / neg_class_count, 1 / pos_class_count] 209 | weights = [class_weight[i] for i in self.pe_labels] 210 | 211 | weights = torch.Tensor(weights).double() 212 | sampler = torch.utils.data.sampler.WeightedRandomSampler( 213 | weights, num_samples=len(weights), replacement=True 214 | ) 215 | 216 | return sampler 217 | 218 | 219 | 220 | class PEDataset1DStanford(PEDatasetBase): 221 | def __init__(self, cfg, split="train", transform=None): 222 | super().__init__(cfg, split) 223 | 224 | self.cfg = cfg 225 | self.df = pd.read_csv(STANFORD_NO_RSNA_CSV) 226 | 227 | if split != "all": 228 | self.df = self.df[self.df["Split"] == split] 229 | 230 | # hdf5 path 231 | self.hdf5_path = self.cfg.data.hdf5_path 232 | if self.hdf5_path is None: 233 | raise Exception("Encoded slice HDF5 required") 234 | 235 | # TODO: 236 | with h5py.File(self.hdf5_path, 'r') as f: 237 | study = list(f.keys()) 238 | self.df = self.df[self.df['StudyInstanceUID'].isin(study)] 239 | 240 | if split == 'train': 241 | if not OmegaConf.is_none(cfg.data, 'sample_frac'): 242 | num_study = len(study) 243 | num_sample = int(num_study * cfg.data.sample_frac) 244 | sampled_study = np.random.choice(study, num_sample, replace=False) 245 | self.df = self.df[self.df['StudyInstanceUID'].isin(sampled_study)] 246 | 247 | # get studies 248 | grouped_df = self.df.groupby('StudyInstanceUID').head(1)[ 249 | ['negative_exam_for_pe', 'StudyInstanceUID'] 250 | ] 251 | 252 | # use positive as 1 253 | self.pe_labels = grouped_df['negative_exam_for_pe'].tolist() 254 | self.study = grouped_df['StudyInstanceUID'].tolist() 255 | 256 | def __getitem__(self, index): 257 | 258 | # read featurized series 259 | study = self.study[index] 260 | x = self.read_from_hdf5(study, hdf5_path=self.hdf5_path) 261 | 262 | # fix number of slices 263 | x = self.fix_series_slice_number(x) 264 | 265 | # contextualize slices 266 | if self.cfg.data.contextualize_slice: 267 | x = self.contextualize_slice(x) 268 | 269 | # create torch tensor 270 | x = torch.from_numpy(x).float() 271 | 272 | # fill 273 | x = self.fill_series_to_num_slicess(x, self.cfg.data.num_slices) 274 | 275 | # get traget 276 | y = self.pe_labels[index] 277 | y = torch.tensor(y).float().unsqueeze(-1) 278 | 279 | return x, y, study 280 | 281 | def __len__(self): 282 | return len(self.study) 283 | 284 | def contextualize_slice(self, arr): 285 | 286 | # make new empty array 287 | new_arr = np.zeros((arr.shape[0], arr.shape[1] * 3), dtype=np.float32) 288 | 289 | # fill first third of new array with original features 290 | for i in range(len(arr)): 291 | new_arr[i, : arr.shape[1]] = arr[i] 292 | 293 | # difference between previous neighbor 294 | new_arr[1:, arr.shape[1] : arr.shape[1] * 2] = ( 295 | new_arr[1:, : arr.shape[1]] - new_arr[:-1, : arr.shape[1]] 296 | ) 297 | 298 | # difference between next neighbor 299 | new_arr[:-1, arr.shape[1] * 2 :] = ( 300 | new_arr[:-1, : arr.shape[1]] - new_arr[1:, : arr.shape[1]] 301 | ) 302 | 303 | return new_arr 304 | 305 | def get_sampler(self): 306 | 307 | neg_class_count = (np.array(self.pe_labels) == 0).sum() 308 | pos_class_count = (np.array(self.pe_labels) == 1).sum() 309 | class_weight = [1 / neg_class_count, 1 / pos_class_count] 310 | weights = [class_weight[i] for i in self.pe_labels] 311 | 312 | weights = torch.Tensor(weights).double() 313 | sampler = torch.utils.data.sampler.WeightedRandomSampler( 314 | weights, num_samples=len(weights), replacement=True 315 | ) 316 | 317 | return sampler 318 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /pe_models/preprocess/rsna.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import mdai 3 | import pydicom 4 | import sys 5 | import os 6 | import glob 7 | import tqdm 8 | import numpy as np 9 | import h5py 10 | 11 | sys.path.append(os.getcwd()) 12 | 13 | from collections import defaultdict 14 | from pe_models.constants import * 15 | from pe_models import utils 16 | from sklearn.model_selection import train_test_split 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | 20 | def process_window_df(df:pd.DataFrame, num_slices:int=24, min_abnormal_slice:int=4, stride:int=None): 21 | 22 | window_labels_csv = RSNA_DATA_DIR / f"rsna_window_{num_slices}_min_abnormal_{min_abnormal_slice}_stride_{stride}.csv" 23 | 24 | # count number of windows per slice 25 | #count_num_windows = lambda x: x // num_slices\ 26 | # + (1 if x % num_slices > 0 else 0) 27 | count_num_windows = lambda x: (x - num_slices) // stride 28 | df[RSNA_NUM_WINDOW_COL] = df[RSNA_NUM_SLICES_COL].apply( 29 | count_num_windows) 30 | 31 | # get windows list 32 | df_study = df.groupby([RSNA_STUDY_COL]).head(1) 33 | window_labels = defaultdict(list) 34 | 35 | # studies 36 | for _, row in tqdm.tqdm(df_study.iterrows(), total=df_study.shape[0]): 37 | study_name = row[RSNA_STUDY_COL] 38 | split = row[RSNA_SPLIT_COL] 39 | 40 | if RSNA_INSTITUTION_SPLIT_COL in df_study.columns: 41 | institution_split = row[RSNA_INSTITUTION_SPLIT_COL] 42 | else: 43 | institution_split = None 44 | study_df = df[df[RSNA_STUDY_COL] == study_name] 45 | study_df = study_df.sort_values("ImagePositionPatient_2") 46 | 47 | # windows 48 | for idx in range(row[RSNA_NUM_WINDOW_COL]): 49 | start_idx = idx * stride 50 | end_idx = (idx * stride) + num_slices 51 | 52 | window_df = study_df.iloc[start_idx: end_idx] 53 | num_positives_slices = window_df[RSNA_PE_SLICE_COL].sum() 54 | label = 1 if num_positives_slices >= min_abnormal_slice else 0 55 | window_labels[RSNA_STUDY_COL].append(study_name) 56 | window_labels['index'].append(idx) 57 | window_labels['label'].append(label) 58 | window_labels[RSNA_SPLIT_COL].append(split) 59 | window_labels[RSNA_INSTITUTION_SPLIT_COL].append(institution_split) 60 | window_labels[RSNA_INSTANCE_PATH_COL].append(window_df[RSNA_INSTANCE_PATH_COL].tolist()) 61 | window_labels["ImagePositionPatient_2"].append(window_df["ImagePositionPatient_2"].tolist()) 62 | window_labels[RSNA_PE_SLICE_COL].append(window_df[RSNA_PE_SLICE_COL].tolist()) 63 | window_labels[RSNA_INSTANCE_COL].append(window_df[RSNA_INSTANCE_COL].tolist()) 64 | window_labels[RSNA_INSTANCE_ORDER_COL].append(window_df[RSNA_INSTANCE_ORDER_COL].tolist()) 65 | df = pd.DataFrame.from_dict(window_labels) 66 | df.to_csv(window_labels_csv) 67 | 68 | return df 69 | 70 | 71 | def process_study_to_hdf5(csv_path:str = RSNA_TRAIN_CSV): 72 | 73 | df = pd.read_csv(csv_path) 74 | 75 | # indicate neiborig slices based on patient position 76 | hdf5_fh = h5py.File(RSNA_STUDY_HDF5, 'a') 77 | for study_name in tqdm.tqdm( 78 | df[RSNA_STUDY_COL].unique(), total=df[RSNA_STUDY_COL].nunique() 79 | ): 80 | study_df = df[df[RSNA_STUDY_COL] == study_name].copy() 81 | 82 | # order study instances 83 | study_df = study_df.sort_values("ImagePositionPatient_2") 84 | 85 | # save paths to hdf5 86 | instance_paths = study_df[RSNA_INSTANCE_PATH_COL].tolist() 87 | series = np.stack([utils.read_dicom(RSNA_TRAIN_DIR / path, 256) for path in instance_paths]) 88 | hdf5_fh.create_dataset(study_name, data=series, dtype='float32', chunks=True) 89 | 90 | # clean up 91 | hdf5_fh.close() 92 | 93 | 94 | 95 | class Metadata(Dataset): 96 | def __init__(self, df: pd.DataFrame): 97 | self.df = df 98 | 99 | def __len__(self): 100 | return len(self.df) 101 | 102 | def __getitem__(self, idx): 103 | 104 | row = self.df.iloc[idx] 105 | dcm = pydicom.dcmread( 106 | RSNA_TRAIN_DIR / row[RSNA_INSTANCE_PATH_COL], stop_before_pixels=True 107 | ) 108 | 109 | metadata = {} 110 | for k in RSNA_DICOM_HEADERS: 111 | try: 112 | att = getattr(dcm, k) 113 | if k in ["InstanceNumber", "RescaleSlope", "RescaleIntercept"]: 114 | metadata[k] = float(att) 115 | elif k in [ 116 | "PixelSpacing", 117 | "ImagePositionPatient", 118 | "ImageOrientationPatient", 119 | ]: 120 | for ind, coord in enumerate(att): 121 | metadata[f"{k}_{ind}"] = float(coord) 122 | else: 123 | metadata[k] = str(att) 124 | except Exception as e: 125 | print(e) 126 | 127 | return pd.DataFrame(metadata, index=[0]) 128 | 129 | 130 | def add_split_to_label_df( 131 | label_df: pd.DataFrame, 132 | val_size: float = 0.15, 133 | test_size: float = 0.15, 134 | study_col: str = RSNA_STUDY_COL, 135 | split_col: str = RSNA_SPLIT_COL, 136 | ): 137 | patients = label_df[study_col].unique() 138 | 139 | # split between train and val+test 140 | split_ratio = val_size + test_size 141 | train_patients, test_val_patients = train_test_split( 142 | patients, test_size=split_ratio, random_state=RANDOM_SEED 143 | ) 144 | 145 | # split between val and test 146 | test_split_ratio = test_size / (val_size + test_size) 147 | val_patients, test_patients = train_test_split( 148 | test_val_patients, test_size=test_split_ratio, random_state=RANDOM_SEED 149 | ) 150 | train_rows = label_df[study_col].isin(train_patients) 151 | label_df.loc[train_rows, split_col] = "train" 152 | val_rows = label_df[study_col].isin(val_patients) 153 | label_df.loc[val_rows, split_col] = "valid" 154 | test_rows = label_df[study_col].isin(test_patients) 155 | label_df.loc[test_rows, split_col] = "test" 156 | 157 | return label_df 158 | 159 | 160 | if __name__ == "__main__": 161 | 162 | if not RSNA_TRAIN_CSV.is_file(): 163 | print('='*80) 164 | print('\nProcessing RSNA dataset metadata and save as {RSNA_TRAIN_CSV}') 165 | print('-'*80) 166 | 167 | # applying train/val/test split only on train csv (test.csv does not contain labels) 168 | rsna = pd.read_csv(RSNA_ORIGINAL_TRAIN_CSV) 169 | 170 | # full dataset split 171 | rsna[RSNA_ORIGINAL_SPLIT_COL] = "train" 172 | rsna = add_split_to_label_df(rsna, split_col=RSNA_SPLIT_COL) 173 | for split in ["train", "valid", "test"]: 174 | print( 175 | f"Full split {split}: " 176 | + f"{rsna[rsna.Split == split][RSNA_STUDY_COL].nunique()}" 177 | ) 178 | 179 | # if raw rsna annotations are availible - extract stanford studies 180 | if RSNA_MDAI_JSON is not None: 181 | 182 | # read RSNA annotations 183 | mdai_client = mdai.Client(domain=MDAI_DOMAIN, access_token=MDAI_TOKEN) 184 | results = mdai.common_utils.json_to_dataframe(RSNA_MDAI_JSON) 185 | rsna_anno = results["annotations"] 186 | rsna_studies = rsna_anno[RSNA_STUDY_COL].unique() 187 | 188 | # extract overlapping studies 189 | stanford_df = pd.read_csv(STANFORD_PE_METADATA) 190 | stanford_overlapping_df = stanford_df[ 191 | stanford_df.StudyInstanceUID.isin(rsna_studies) 192 | ] 193 | anon_stanford_studies = list(stanford_overlapping_df[RSNA_STUDY_COL].unique()) 194 | rsna_stanford_df = rsna_anno[ 195 | rsna_anno[RSNA_STUDY_COL].isin(anon_stanford_studies) 196 | ] 197 | 198 | # create mapping from stanford study UID to rsna hash 199 | rsna_stanford_mapping_df = pd.read_csv(RSNA_STANFORD_MAPPING) 200 | stanford_2_rsna = dict( 201 | zip( 202 | rsna_stanford_mapping_df.SOPInstanceUID, 203 | rsna_stanford_mapping_df.SOPInstanceUID_hash, 204 | ) 205 | ) 206 | rsna_stanford_df["anon_instance"] = rsna_stanford_df["SOPInstanceUID"].apply( 207 | lambda x: stanford_2_rsna[x] 208 | if (x in stanford_2_rsna and x is not None) 209 | else None 210 | ) 211 | overlap_anon_instance = rsna_stanford_df["anon_instance"].unique() 212 | 213 | # split stanford vs other institutions 214 | stanford_studies = rsna[rsna[RSNA_INSTANCE_COL].isin(overlap_anon_instance)][ 215 | RSNA_STUDY_COL 216 | ].unique() 217 | rsna_external_df = rsna[~rsna[RSNA_STUDY_COL].isin(stanford_studies)] 218 | rsna_stanford_df = rsna[rsna[RSNA_STUDY_COL].isin(stanford_studies)] 219 | rsna_stanford_df.loc[:, RSNA_INSTITUTION_COL] = "Stanford" 220 | rsna_external_df.loc[:, RSNA_INSTITUTION_COL] = "Other" 221 | 222 | # create split for data from other institutions 223 | rsna_external_df = add_split_to_label_df( 224 | rsna_external_df, split_col=RSNA_INSTITUTION_SPLIT_COL 225 | ) 226 | rsna_stanford_df.loc[:, RSNA_INSTITUTION_SPLIT_COL] = "stanford_test" 227 | rsna = pd.concat([rsna_external_df, rsna_stanford_df]) 228 | 229 | for split in ["train", "valid", "test", "stanford_test"]: 230 | print( 231 | f"Institution split {split}: " 232 | + f"{rsna[rsna[RSNA_INSTITUTION_SPLIT_COL] == split][RSNA_STUDY_COL].nunique()}" 233 | ) 234 | 235 | # add instance path 236 | rsna[RSNA_INSTANCE_PATH_COL] = rsna.apply( 237 | lambda x: f"{x[RSNA_STUDY_COL]}/{x[RSNA_SERIES_COL]}/{x[RSNA_INSTANCE_COL]}.dcm", 238 | axis=1, 239 | ) 240 | 241 | # create dataset and loader to extract metadata 242 | dataset = Metadata(rsna) 243 | loader = DataLoader( 244 | dataset, batch_size=1, shuffle=False, num_workers=12, collate_fn=lambda x: x 245 | ) 246 | 247 | # get metadata 248 | meta = [] 249 | for data in tqdm.tqdm(loader, total=len(loader)): 250 | meta += [data[0]] 251 | meta_df = pd.concat(meta, axis=0, ignore_index=True) 252 | 253 | # get slice number 254 | unique_studies = pd.DataFrame(meta_df[RSNA_STUDY_COL].value_counts()).reset_index() 255 | unique_studies.columns = [RSNA_STUDY_COL, RSNA_NUM_SLICES_COL] 256 | meta_df = meta_df.merge(unique_studies, on=RSNA_STUDY_COL) 257 | 258 | # join with metadata with labels 259 | rsna = rsna.set_index([RSNA_STUDY_COL, RSNA_SERIES_COL, RSNA_INSTANCE_COL]) 260 | meta_df = meta_df.set_index([RSNA_STUDY_COL, RSNA_SERIES_COL, RSNA_INSTANCE_COL]) 261 | rsna = rsna.join(meta_df, how="left").reset_index() 262 | 263 | # indicate neiborig slices based on patient position 264 | study_dfs = [] 265 | for study_name in tqdm.tqdm( 266 | rsna[RSNA_STUDY_COL].unique(), total=rsna[RSNA_STUDY_COL].nunique() 267 | ): 268 | study_df = rsna[rsna[RSNA_STUDY_COL] == study_name].copy() 269 | 270 | # order study instances 271 | study_df = study_df.sort_values("ImagePositionPatient_2") 272 | study_df[RSNA_INSTANCE_ORDER_COL] = np.arange(len(study_df)) 273 | 274 | # get neighbors paths 275 | instance_paths = study_df[RSNA_INSTANCE_PATH_COL].tolist() 276 | instance_paths = [instance_paths[0]] + instance_paths + [instance_paths[-1]] 277 | study_df[RSNA_PREV_INSTANCE_COL] = instance_paths[:-2] 278 | study_df[RSNA_NEXT_INSTANCE_COL] = instance_paths[2:] 279 | 280 | study_dfs.append(study_df) 281 | 282 | rsna = pd.concat(study_dfs, axis=0, ignore_index=True) 283 | rsna.to_csv(RSNA_TRAIN_CSV, index=False) 284 | 285 | # create windowed dataset: default window_size=24, min_abnormal_slice: 4 286 | window_df = process_window_df(rsna) 287 | else: 288 | print('='*80) 289 | print(f'\n{RSNA_TRAIN_CSV} already existed and processed') 290 | print('-'*80) 291 | 292 | if not RSNA_STUDY_HDF5.is_file(): 293 | print('\n'+'='*80) 294 | print(f'\nParsing study HDF5 to {RSNA_STUDY_HDF5}') 295 | print('-'*80) 296 | process_study_to_hdf5() 297 | else: 298 | print('='*80) 299 | print(f'\n{RSNA_STUDY_HDF5} already existed and processed') 300 | print('-'*80) 301 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import pe_models 4 | import datetime 5 | import os 6 | import numpy as np 7 | import pandas as pd 8 | import wandb 9 | import yaml 10 | 11 | from collections import defaultdict 12 | from pathlib import Path 13 | from dateutil import tz 14 | from omegaconf import OmegaConf 15 | from pytorch_lightning import seed_everything 16 | from pytorch_lightning import loggers as pl_loggers 17 | from pytorch_lightning.trainer import Trainer 18 | from pytorch_lightning.callbacks import ( 19 | ModelCheckpoint, 20 | EarlyStopping, 21 | LearningRateMonitor, 22 | ) 23 | 24 | 25 | seed_everything(23) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | 30 | def parse_configs(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--config", type=str, help="paths to base config") 33 | parser.add_argument( 34 | "--train", action="store_true", default=False, help="specify to train model") 35 | parser.add_argument( 36 | "--debug", action="store_true", default=False, help="specify to debug model") 37 | parser.add_argument( 38 | "--test", 39 | action="store_true", 40 | default=False, 41 | help="specify to test model" 42 | "By default run.py trains a model based on config file",) 43 | parser.add_argument( 44 | "--checkpoint", type=str, default=None) 45 | parser.add_argument( 46 | "--cv_split", type=int, default=1, help="Cross validation split") 47 | parser.add_argument( 48 | "--top_n", type=int, default=1, help="test based on top N best ckpt") 49 | parser.add_argument( 50 | "--sample_frac", type=float, default=None, help="test based on top N best ckpt") 51 | parser.add_argument( 52 | "--batch_size", type=int, default=None, help="test based on top N best ckpt") 53 | parser.add_argument( 54 | "--test_split", type=str, default='test', help="test based on top N best ckpt") 55 | 56 | parser = Trainer.add_argparse_args(parser) 57 | 58 | args, unknown = parser.parse_known_args() 59 | cli = [u.strip("--") for u in unknown] # remove strings leading to flag 60 | 61 | # add command line argments to config 62 | cfg = OmegaConf.load(args.config) 63 | cli = OmegaConf.from_dotlist(cli) 64 | cli_flat = pe_models.utils.flatten(cli) 65 | cfg.hyperparameters = cli_flat # hyperparameter defaults 66 | cfg.data.cv_split = args.cv_split 67 | if args.gpus is not None: 68 | cfg.lightning.trainer.gpus = str(args.gpus) 69 | 70 | 71 | cfg.data.sample_frac = args.sample_frac 72 | if args.checkpoint is not None: 73 | cfg.checkpoint = args.checkpoint 74 | if args.batch_size is not None: 75 | cfg.train.batch_size = args.batch_size 76 | cfg.test_split = args.test_split 77 | 78 | # edit experiment name 79 | if not OmegaConf.is_none(cfg, "trial_name"): 80 | cfg.experiment_name = f"{cfg.experiment_name}_{cfg.trial_name}" 81 | if not OmegaConf.is_none(cfg.data, 'sample_frac'): 82 | cfg.experiment_name = f"{cfg.experiment_name}_frac{args.sample_frac}" 83 | if not OmegaConf.is_none(cfg.data, 'positive_only'): 84 | if cfg.data.positive_only: 85 | cfg.experiment_name = f"{cfg.experiment_name}_positive_only" 86 | if cfg.data.weighted_sample: 87 | cfg.experiment_name = f"{cfg.experiment_name}_weighted_sample" 88 | 89 | 90 | if args.sample_frac is not None: 91 | cfg.data.sample_frac = args.sample_frac 92 | cfg.experiment_name = f"{cfg.experiment_name}_frac{args.sample_frac}" 93 | 94 | # get current time 95 | now = datetime.datetime.now(tz.tzlocal()) 96 | timestamp = now.strftime("%Y_%m_%d_%H_%M_%S") 97 | cfg.extension = timestamp 98 | 99 | # add checkpoint 100 | if args.checkpoint is not None: 101 | cfg.checkpoint = args.checkpoint 102 | 103 | # check debug 104 | if args.debug: 105 | cfg.train.num_workers = 0 106 | cfg.lightning.trainer.gpus = 1 107 | cfg.lightning.trainer.distributed_backend = None 108 | 109 | return cfg, args 110 | 111 | 112 | def create_directories(cfg): 113 | 114 | # set directory names 115 | if cfg.phase == "pretrain": 116 | cfg.output_dir = f"./data/output/{cfg.experiment_name}/{cfg/deep/u/marshuang80/data.cv_split}/{cfg.extension}/pretrain" 117 | cfg.lightning.logger.name = ( 118 | f"{cfg.experiment_name}/{cfg.data.cv_split}/{cfg.extension}" 119 | ) 120 | cfg.lightning.checkpoint_callback.dirpath = f"./data/ckpt/{cfg.experiment_name}/{cfg/deep/u/marshuang80/data.cv_split}/{cfg.extension}/pretrain" 121 | else: 122 | cfg.output_dir = ( 123 | f"./data/output/{cfg.experiment_name}/{cfg.data.cv_split}/{cfg.extension}" 124 | ) 125 | cfg.lightning.logger.name = ( 126 | f"{cfg.experiment_name}/{cfg.data.cv_split}/{cfg.extension}" 127 | ) 128 | cfg.lightning.checkpoint_callback.dirpath = ( 129 | f"./data/ckpt/{cfg.experiment_name}/{cfg.data.cv_split}/{cfg.extension}" 130 | ) 131 | 132 | 133 | # create directories 134 | if not os.path.exists(cfg.lightning.logger.save_dir): 135 | os.makedirs(cfg.lightning.logger.save_dir) 136 | if not os.path.exists(cfg.lightning.checkpoint_callback.dirpath): 137 | print(cfg.lightning.checkpoint_callback.dirpath) 138 | os.makedirs(cfg.lightning.checkpoint_callback.dirpath) 139 | if not os.path.exists(cfg.output_dir): 140 | os.makedirs(cfg.output_dir) 141 | 142 | return cfg 143 | 144 | 145 | def setup(cfg, test_split=False): 146 | 147 | # create output, logger and ckpt directories if split != test 148 | if not test_split: 149 | cfg = create_directories(cfg) 150 | 151 | # logging 152 | loggers = [pl_loggers.csv_logs.CSVLogger(cfg.output_dir)] 153 | if "logger" in cfg.lightning: 154 | 155 | logger_type = cfg.lightning.logger.pop("logger_type") 156 | logger_class = getattr(pl_loggers, logger_type) 157 | logger = logger_class(**cfg.lightning.logger) 158 | loggers.append(logger) 159 | cfg.lightning.logger.logger_type = logger_type 160 | 161 | """ 162 | if logger_type == "WandbLogger": 163 | # set sweep defaults 164 | hyperparameter_defaults = cfg.hyperparameters 165 | run = logger.experiment 166 | run.config.setdefaults(hyperparameter_defaults) 167 | 168 | # update cfg with new sweep parameters 169 | run_config = [f"{k}={v}" for k, v in run.config.items()] 170 | run_config = OmegaConf.from_dotlist(run_config) 171 | cfg = OmegaConf.merge(cfg, run_config) # update defaults to CLI 172 | 173 | # set best metric 174 | if cfg.lightning.checkpoint_callback.mode == "max": 175 | goal = "maximize" 176 | else: 177 | goal = "minimize" 178 | metric = cfg.lightning.checkpoint_callback.monitor 179 | wandb.define_metric(f"{metric}", summary="best", goal=goal) 180 | """ 181 | 182 | # callbacks 183 | callbacks = [LearningRateMonitor(logging_interval="step")] 184 | if "checkpoint_callback" in cfg.lightning: 185 | checkpoint_callback = ModelCheckpoint(**cfg.lightning.checkpoint_callback) 186 | callbacks.append(checkpoint_callback) 187 | if "early_stopping_callback" in cfg.lightning: 188 | early_stopping_callback = EarlyStopping( 189 | **cfg.lightning.early_stopping_callback 190 | ) 191 | callbacks.append(early_stopping_callback) 192 | 193 | # save config 194 | config_path = os.path.join(cfg.output_dir, "config.yaml") 195 | config_path_ckpt = os.path.join( 196 | cfg.lightning.checkpoint_callback.dirpath, "config.yaml" 197 | ) 198 | with open(config_path, "w") as fp: 199 | OmegaConf.save(config=cfg, f=fp.name) 200 | with open(config_path_ckpt, "w") as fp: 201 | OmegaConf.save(config=cfg, f=fp.name) 202 | 203 | else: 204 | loggers = [] 205 | callbacks = [] 206 | checkpoint_callback = None 207 | 208 | # get datamodule 209 | dm = pe_models.builder.build_data_module(cfg) 210 | cfg.data.num_batches = len(dm.train_dataloader()) 211 | 212 | # define lightning module 213 | model = pe_models.builder.build_lightning_model(cfg) 214 | 215 | # setup pytorch-lightning trainer 216 | trainer_args = argparse.Namespace(**cfg.lightning.trainer) 217 | trainer = Trainer.from_argparse_args( 218 | args=trainer_args, deterministic=True, callbacks=callbacks, logger=loggers 219 | ) 220 | 221 | # auto learning rate finder 222 | if trainer_args.auto_lr_find is not False: 223 | lr_finder = trainer.tuner.lr_find(model, datamodule=dm) 224 | new_lr = lr_finder.suggestion() 225 | model.lr = new_lr 226 | print(f"learning rate updated to {new_lr}") 227 | 228 | return trainer, model, dm, checkpoint_callback 229 | 230 | 231 | def save_best_checkpoints(checkpoint_callback, cfg, return_best=True): 232 | ckpt_paths = os.path.join( 233 | cfg.lightning.checkpoint_callback.dirpath, "best_ckpts.yaml" 234 | ) 235 | checkpoint_callback.to_yaml(filepath=ckpt_paths) 236 | if return_best: 237 | ascending = cfg.lightning.checkpoint_callback.mode == "min" 238 | best_ckpt_path = pe_models.utils.get_best_ckpt_path(ckpt_paths, ascending) 239 | return best_ckpt_path 240 | 241 | 242 | def find_best_ckpt(cfg, top_n): 243 | """Finding best ckpt for a wandb hyperparameter sweep""" 244 | 245 | output_dir = f"./data/output/{cfg.experiment_name}/{cfg.data.cv_split}" 246 | results = defaultdict(list) 247 | sweep_path = Path(output_dir) 248 | experiment_dirs = [p for p in sweep_path.iterdir() if p.is_dir()] 249 | 250 | for p in experiment_dirs: 251 | 252 | metrics_csv = p / "default" / "version_0" / "metrics.csv" 253 | config_file = p / "config.yaml" 254 | 255 | # run errored 256 | if not metrics_csv.exists(): 257 | continue 258 | 259 | df = pd.read_csv(metrics_csv) 260 | try: 261 | curr_best_epoch = int( 262 | df.sort_values("val/mean_auroc", ascending=False).head(1)["epoch"].values[0] 263 | ) 264 | curr_best_step = int( 265 | df.sort_values("val/mean_auroc", ascending=False).head(1)["step"].values[0] 266 | ) 267 | curr_best_auroc = ( 268 | df.sort_values("val/mean_auroc", ascending=False) 269 | .head(1)["val/mean_auroc"] 270 | .values[0] 271 | ) 272 | curr_train_auroc = df[ 273 | (df.epoch == curr_best_epoch) & (~df["train/mean_auroc"].isna()) 274 | ].iloc[0]["train/mean_auroc"] 275 | except: 276 | # failed runs 277 | print(f"Unable to read {metrics_csv}") 278 | continue 279 | 280 | experiment_name = str(p).split("output")[1][1:] # remote leading '/' 281 | results["experiment_name"].append(experiment_name) 282 | results["val_auroc"].append(curr_best_auroc) 283 | results["train_auroc"].append(curr_train_auroc) 284 | results["config_file"].append(config_file) 285 | results["results_path"].append(metrics_csv) 286 | results["best_ckpt"].append( 287 | str( 288 | Path("./data/ckpt") 289 | / experiment_name 290 | / f"epoch={curr_best_epoch}-step={curr_best_step}.ckpt" 291 | ) 292 | ) 293 | 294 | df = pd.DataFrame.from_dict(results) 295 | df = df.sort_values("val_auroc", ascending=False) 296 | loc = top_n - 1 297 | best = df.iloc[loc] 298 | best_train_auroc = best["train_auroc"] 299 | best_auroc = best["val_auroc"] 300 | best_ckpt = best["best_ckpt"] 301 | best_config_file = best["config_file"] 302 | best_result_path = best["results_path"] 303 | 304 | print(f"\nBest training AUROC: {best_train_auroc: .3f}") 305 | print(f"Best validation AUROC: {best_auroc: .3f}") 306 | print(f"Using checkpoint: {str(best_ckpt)}\n") 307 | best_cfg = OmegaConf.load(best_config_file) 308 | return str(best_ckpt), best_cfg.model, str(best_result_path) 309 | 310 | 311 | if __name__ == "__main__": 312 | cfg, args = parse_configs() 313 | 314 | if args.train: 315 | trainer, model, dm, checkpoint_callback = setup(cfg) 316 | trainer.fit(model, dm) 317 | best_ckpt = save_best_checkpoints(checkpoint_callback, cfg, return_best=True) 318 | cfg.checkpoint = best_ckpt 319 | print(f"Best checkpoint path: {best_ckpt}") 320 | 321 | if args.test: 322 | if OmegaConf.is_none(cfg, "checkpoint"): 323 | cfg.checkpoint, cfg.model, cfg.save_dir = find_best_ckpt(cfg, args.top_n) 324 | 325 | print("="*80) 326 | print(cfg.checkpoint) 327 | print("="*80) 328 | cfg.output_dir = '/'.join(cfg.checkpoint.split('/')[:-1]).replace('ckpt','output') 329 | print(f'Output dir: {cfg.output_dir}') 330 | trainer, model, dm, checkpoint_callback = setup(cfg, test_split=True) 331 | trainer.test(model=model, datamodule=dm) 332 | --------------------------------------------------------------------------------