├── train.sh ├── requirement.txt ├── save_backbone.py ├── main.py ├── dataset ├── nih_dataset.py ├── vin_dataset.py ├── transforms.py ├── chexpert_dataset.py ├── cxr_datamodule.py └── cxr_dataset.py ├── LICENSE ├── callbacks ├── nih_pseudo_callback.py ├── chexpert_pseudo_callback.py ├── vinbig_pseudo_callback.py ├── fusion_submit_callback.py └── submit_callback.py ├── README.md ├── model ├── loss.py ├── layers.py ├── cxr_model.py └── ml_decoder.py ├── .gitignore └── config.yaml /train.sh: -------------------------------------------------------------------------------- 1 | nohup python3 main.py fit --config config.yaml >train.log & 2 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | lightning 2 | pytorch-lightning[extra] 3 | timm 4 | jpeg4py 5 | albumentations 6 | iterative-stratification 7 | ffcv 8 | neptune-client 9 | positional-encodings[pytorch] -------------------------------------------------------------------------------- /save_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import yaml 3 | from model.cxr_model import CxrModel 4 | 5 | config = yaml.load(open('config.yaml', "r"), Loader=yaml.FullLoader) 6 | model = CxrModel.load_from_checkpoint( 7 | config['ckpt_path'], **config['model'] 8 | ) 9 | 10 | torch.save(model.backbone.model.state_dict(), 'model.pth') 11 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lightning.pytorch.cli import LightningCLI 3 | from lightning.pytorch.loggers import NeptuneLogger 4 | from model.cxr_model import CxrModel 5 | from dataset.cxr_datamodule import CxrDataModule 6 | 7 | class MyLightningCLI(LightningCLI): 8 | def before_fit(self): 9 | if isinstance(self.trainer.logger, NeptuneLogger): 10 | self.trainer.logger.experiment["train/config"].upload('config.yaml') 11 | 12 | def cli_main(): 13 | torch.set_float32_matmul_precision('high') 14 | cli = MyLightningCLI(CxrModel, CxrDataModule, save_config_callback=None) 15 | 16 | if __name__ == "__main__": 17 | cli_main() 18 | -------------------------------------------------------------------------------- /dataset/nih_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch import from_numpy 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class NihDataset(Dataset): 9 | def __init__(self, cfg, df, transform=None): 10 | self.cfg = cfg 11 | self.df = df 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.df) 16 | 17 | def __getitem__(self, index): 18 | label = self.df.iloc[index][self.cfg['classes']].to_numpy().astype(np.float32) 19 | path = os.path.join(self.cfg['data_dir'], "nih/images_001/images", self.df.iloc[index]["id"]) 20 | 21 | img = cv2.imread(path) 22 | assert img.shape == (self.cfg['size'], self.cfg['size'], 3), f"{img.shape}" 23 | 24 | if self.transform: 25 | transformed = self.transform(image=img) 26 | img = transformed['image'] 27 | img = np.moveaxis(img, -1, 0) 28 | 29 | return img, label -------------------------------------------------------------------------------- /dataset/vin_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from torch import from_numpy 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class VinDataset(Dataset): 9 | def __init__(self, cfg, df, transform=None): 10 | self.cfg = cfg 11 | self.df = df 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.df) 16 | 17 | def __getitem__(self, index): 18 | label = self.df.iloc[index][self.cfg['classes']].to_numpy().astype(np.float32) 19 | path = os.path.join(self.cfg['data_dir'], "vinbig/train", self.df.iloc[index]["image_id"]+".png") 20 | 21 | img = cv2.imread(path) 22 | assert img.shape == (self.cfg['size'], self.cfg['size'], 3), f"{img.shape}" 23 | 24 | if self.transform: 25 | transformed = self.transform(image=img) 26 | img = transformed['image'] 27 | img = np.moveaxis(img, -1, 0) 28 | 29 | return img, label -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import albumentations as A 3 | 4 | def get_transforms(size): 5 | transforms_train = A.Compose([ 6 | A.RandomResizedCrop(size,size, scale=(0.9, 1), p=1, interpolation=cv2.INTER_LANCZOS4), 7 | A.HorizontalFlip(p=0.5), 8 | A.ShiftScaleRotate(p=0.5), 9 | A.RandomBrightnessContrast(p=0.5), 10 | A.OneOf([ 11 | A.OpticalDistortion(), 12 | A.GridDistortion(), 13 | A.ElasticTransform(), 14 | ], p=0.2), 15 | A.OneOf([ 16 | A.GaussNoise(), 17 | A.GaussianBlur(), 18 | A.MotionBlur(), 19 | A.MedianBlur(), 20 | ], p=0.2), 21 | A.Resize(size,size, interpolation=cv2.INTER_LANCZOS4), 22 | A.Normalize(), 23 | ]) 24 | 25 | transforms_val = A.Compose([ 26 | A.Resize(size,size, interpolation=cv2.INTER_LANCZOS4), 27 | A.Normalize() 28 | ]) 29 | return transforms_train, transforms_val 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DK 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /callbacks/nih_pseudo_callback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | from lightning.pytorch.callbacks import BasePredictionWriter 5 | 6 | 7 | class NihWriter(BasePredictionWriter): 8 | def __init__(self, nih_train_df_path, nih_pseudo_train_df_path, write_interval): 9 | super().__init__(write_interval) 10 | self.nih_train_df_path = nih_train_df_path 11 | self.nih_pseudo_train_df_path = nih_pseudo_train_df_path 12 | 13 | def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): 14 | predictions = torch.cat(predictions, dim=0) 15 | preds = predictions.float().squeeze(0).detach().cpu().numpy() 16 | 17 | np.save("nih_preds.npy", preds) 18 | 19 | nih_train_df = pd.read_csv(self.nih_train_df_path) 20 | org = np.array(nih_train_df.iloc[:, -26:].values).astype(np.float32) 21 | 22 | # Replace the original labels with the pseudo labels only if org value is -1 23 | idx = np.where(org == -1) 24 | org[idx] = preds[idx] 25 | 26 | nih_train_df.iloc[:, -26:] = org 27 | nih_train_df.to_csv(self.nih_pseudo_train_df_path, index=False) 28 | 29 | print(f"NIH pseudo labels saved to {self.nih_pseudo_train_df_path}") -------------------------------------------------------------------------------- /callbacks/chexpert_pseudo_callback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | from lightning.pytorch.callbacks import BasePredictionWriter 5 | 6 | 7 | class ChexpertWriter(BasePredictionWriter): 8 | def __init__(self, chexpert_train_df_path, chexpert_pseudo_train_df_path, write_interval): 9 | super().__init__(write_interval) 10 | self.chexpert_train_df_path = chexpert_train_df_path 11 | self.chexpert_pseudo_train_df_path = chexpert_pseudo_train_df_path 12 | 13 | def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): 14 | predictions = torch.cat(predictions, dim=0) 15 | preds = predictions.float().squeeze(0).detach().cpu().numpy() 16 | 17 | np.save("chexpert_preds.npy", preds) 18 | 19 | chexpert_train_df = pd.read_csv(self.chexpert_train_df_path) 20 | org = np.array(chexpert_train_df.iloc[:, -26:].values).astype(np.float32) 21 | 22 | # Replace the original labels with the pseudo labels only if org value is -1 23 | idx = np.where(org == -1) 24 | org[idx] = preds[idx] 25 | 26 | chexpert_train_df.iloc[:, -26:] = org 27 | chexpert_train_df.to_csv(self.chexpert_pseudo_train_df_path, index=False) 28 | 29 | print(f"Chexpert pseudo labels saved to {self.chexpert_pseudo_train_df_path}") -------------------------------------------------------------------------------- /dataset/chexpert_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import jpeg4py as jpeg 5 | from torch import from_numpy 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class ChexpertDataset(Dataset): 10 | def __init__(self, cfg, df, transform=None): 11 | self.cfg = cfg 12 | self.df = df 13 | self.transform = transform 14 | 15 | def __len__(self): 16 | return len(self.df) 17 | 18 | def __getitem__(self, index): 19 | label = self.df.iloc[index][self.cfg['classes']].to_numpy().astype(np.float32) 20 | path = os.path.join(self.cfg['data_dir'], 'chexpert/chexpertchestxrays-u20210408', self.df.iloc[index]["Path"]) 21 | resized_path = path.replace(".jpg", f"_resized_{self.cfg['size']}.jpg") 22 | 23 | if os.path.exists(resized_path): 24 | img = jpeg.JPEG(resized_path).decode() 25 | if os.path.exists(path): 26 | os.remove(path) 27 | assert img.shape == (self.cfg['size'], self.cfg['size'], 3) 28 | else: 29 | img = jpeg.JPEG(path).decode() 30 | img = cv2.resize(img, (self.cfg['size'], self.cfg['size']), interpolation=cv2.INTER_LANCZOS4) 31 | cv2.imwrite(resized_path, img) 32 | 33 | if self.transform: 34 | transformed = self.transform(image=img) 35 | img = transformed['image'] 36 | img = np.moveaxis(img, -1, 0) 37 | 38 | return img, label -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CheXFusion 2 | 3 | This is the official PyTorch Implementation of [**"CheXFusion: Effective Fusion of Multi-View Features using Transformers for Long-Tailed Chest X-Ray Classification"**](https://openaccess.thecvf.com/content/ICCV2023W/CVAMD/papers/Kim_CheXFusion_Effective_Fusion_of_Multi-View_Features_Using_Transformers_for_Long-Tailed_ICCVW_2023_paper.pdf) 4 | 5 | [Presentation Video](https://www.youtube.com/watch?v=E00Nv28o8a8&t=18s) 6 | 7 | [Slides](https://docs.google.com/presentation/d/1NhVnBgYEJGTDUae4eSgjvA9BlIB_ZieuNmbQOUwWQxk/edit#slide=id.g280a08a6593_2_115) 8 | 9 | [Poster](https://drive.google.com/file/d/1vAupRI-ElfDAvT9OgaBs9ya15cmZhAqw/view?usp=sharing) 10 | 11 | :trophy: **1st place solution** :trophy: for the ICCV CVAMD 2023 Shared Task: [CXR-LT: Multi-Label Long-Tailed Classification on Chest X-Rays](https://bionlplab.github.io/2023_ICCV_CVAMD/). 12 | 13 | 14 | Paper published in ICCV proceedings with oral presentation. 15 | 16 | ![Imgur](https://i.imgur.com/wsC9vQP.png) 17 | 18 | ### Update June 7: 19 | 20 | A co-authored [paper](https://www.sciencedirect.com/science/article/pii/S136184152400149X) discussing the challenge was published in Medical Image Analysis! 21 | 22 | ChexFusion remained 1st place, considerably outperforming other teams, even under various distribution shifts. 23 | 24 | 25 | ## Results 26 | 27 | ChexFusion showed 1st place results in all metrics. [Public Leaderboard](https://codalab.lisn.upsaclay.fr/competitions/12599#results) 28 | 29 | ![Imgur](https://imgur.com/fRv7HoF.png) 30 | -------------------------------------------------------------------------------- /callbacks/vinbig_pseudo_callback.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | from lightning.pytorch.callbacks import BasePredictionWriter 5 | 6 | 7 | class VinBigWriter(BasePredictionWriter): 8 | def __init__(self, vinbig_train_df_path, vinbig_pseudo_train_df_path, write_interval): 9 | super().__init__(write_interval) 10 | self.vinbig_train_df_path = vinbig_train_df_path 11 | self.vinbig_pseudo_train_df_path = vinbig_pseudo_train_df_path 12 | 13 | def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): 14 | predictions = torch.cat(predictions, dim=0) 15 | preds = predictions.float().squeeze(0).detach().cpu().numpy() 16 | 17 | vinbig_train_df = pd.read_csv(self.vinbig_train_df_path) 18 | org = np.array(vinbig_train_df.iloc[:, -26:].values).astype(np.float32) 19 | 20 | # Replace the original labels with the pseudo labels only if org value is -1 21 | idx = np.where(org == -1) 22 | org[idx] = preds[idx] 23 | 24 | # If both column nodule and mass is 1, then replace the one with lower pred with the pred value 25 | both_ones_indices = np.where((org[:, 13] == 1) & (org[:, 15] == 1))[0] 26 | 27 | for index in both_ones_indices: 28 | if preds[index, 13] < preds[index, 15]: 29 | org[index, 13] = preds[index, 13] 30 | else: 31 | org[index, 15] = preds[index, 15] 32 | 33 | vinbig_train_df.iloc[:, -26:] = org 34 | vinbig_train_df.to_csv(self.vinbig_pseudo_train_df_path, index=False) 35 | 36 | print(f"VinBig pseudo labels saved to {self.vinbig_pseudo_train_df_path}") 37 | 38 | -------------------------------------------------------------------------------- /callbacks/fusion_submit_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import torch 4 | import pandas as pd 5 | from lightning.pytorch.callbacks import BasePredictionWriter 6 | 7 | 8 | class FusionSubmissonWriter(BasePredictionWriter): 9 | def __init__(self, sample_submission_path, submission_path, submission_zip_path, submission_code_dir, pred_df_path, write_interval): 10 | super().__init__(write_interval) 11 | self.sample_submission_path = sample_submission_path 12 | self.submission_path = submission_path 13 | self.submission_zip_path = submission_zip_path 14 | self.submission_code_dir = submission_code_dir 15 | self.pred_df_path = pred_df_path 16 | 17 | def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): 18 | # Add predictions 19 | predictions = torch.cat(predictions, dim=0) 20 | torch.save(predictions, "predictions.pt") 21 | 22 | submit_df = pd.read_csv(self.sample_submission_path) 23 | pred_df = pd.read_csv(self.pred_df_path) 24 | submit_df['study_id'] = pred_df['study_id'] 25 | 26 | temp_df = pd.DataFrame(predictions, columns=submit_df.columns[-27:-1]) 27 | temp_df['study_id'] = list(pred_df.groupby('study_id').groups.keys()) 28 | submit_df = submit_df.merge(temp_df, on='study_id', how='left', suffixes=('_x', '')) 29 | 30 | # Remove _x columns 31 | submit_df = submit_df.loc[:, ~submit_df.columns.str.endswith('_x')] 32 | submit_df.drop(columns=['study_id'], inplace=True) 33 | 34 | # Save submission 35 | submit_df.to_csv(self.submission_path, index=False) 36 | with zipfile.ZipFile(self.submission_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: 37 | # Add the folder and its contents to the zip 38 | for root, _, files in os.walk(self.submission_code_dir): 39 | for file in files: 40 | file_path = os.path.join(root, file) 41 | zipf.write(file_path, os.path.join('code',os.path.relpath(file_path, self.submission_code_dir))) 42 | 43 | # Add the file to the zip 44 | zipf.write(self.submission_path, os.path.basename(self.submission_path)) 45 | 46 | print(f"Submission saved!") 47 | 48 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def get_loss(type, class_instance_nums, total_instance_num): 6 | if type == 'bce': 7 | return nn.BCEWithLogitsLoss() 8 | elif type == 'wbce': 9 | return BCEwithClassWeights(class_instance_nums, total_instance_num) 10 | elif type == 'asl': 11 | return ASLwithClassWeight(class_instance_nums, total_instance_num) 12 | else: 13 | raise ValueError(f'Unknown loss type: {type}') 14 | 15 | 16 | class BCEwithClassWeights(nn.Module): 17 | def __init__(self, class_instance_nums, total_instance_num): 18 | super(BCEwithClassWeights, self).__init__() 19 | class_instance_nums = torch.tensor(class_instance_nums, dtype=torch.float32) 20 | p = class_instance_nums / total_instance_num 21 | self.pos_weights = torch.exp(1-p) 22 | self.neg_weights = torch.exp(p) 23 | 24 | def forward(self, pred, label): 25 | # https://www.cse.sc.edu/~songwang/document/cvpr21d.pdf (equation 4) 26 | weight = label * self.pos_weights.cuda() + (1 - label) * self.neg_weights.cuda() 27 | loss = nn.functional.binary_cross_entropy_with_logits(pred, label, weight=weight) 28 | return loss 29 | 30 | 31 | class ASLwithClassWeight(nn.Module): 32 | def __init__(self, class_instance_nums, total_instance_num, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8): 33 | super(ASLwithClassWeight, self).__init__() 34 | class_instance_nums = torch.tensor(class_instance_nums, dtype=torch.float32) 35 | p = class_instance_nums / total_instance_num 36 | self.pos_weights = torch.exp(1-p) 37 | self.neg_weights = torch.exp(p) 38 | self.gamma_neg = gamma_neg 39 | self.gamma_pos = gamma_pos 40 | self.clip = clip 41 | self.eps = eps 42 | 43 | def forward(self, pred, label): 44 | weight = label * self.pos_weights.cuda() + (1 - label) * self.neg_weights.cuda() 45 | 46 | # Calculating Probabilities 47 | xs_pos = torch.sigmoid(pred) 48 | xs_neg = 1.0 - xs_pos 49 | 50 | # Asymmetric Clipping 51 | if self.clip is not None and self.clip > 0: 52 | xs_neg.add_(self.clip).clamp_(max=1) 53 | 54 | # Basic CE calculation 55 | los_pos = label * torch.log(xs_pos.clamp(min=self.eps)) 56 | los_neg = (1 - label) * torch.log(xs_neg.clamp(min=self.eps)) 57 | loss = los_pos + los_neg 58 | loss *= weight 59 | 60 | # Asymmetric Focusing 61 | if self.gamma_neg > 0 or self.gamma_pos > 0: 62 | pt0 = xs_pos * label 63 | pt1 = xs_neg * (1 - label) # pt = p if t > 0 else 1-p 64 | pt = pt0 + pt1 65 | one_sided_gamma = self.gamma_pos * label + self.gamma_neg * (1 - label) 66 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 67 | loss *= one_sided_w 68 | 69 | return -loss.mean() 70 | -------------------------------------------------------------------------------- /callbacks/submit_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import torch 4 | import pandas as pd 5 | from lightning.pytorch.callbacks import BasePredictionWriter 6 | 7 | 8 | class SubmissonWriter(BasePredictionWriter): 9 | def __init__(self, sample_submission_path, submission_path, submission_zip_path, submission_code_dir, pred_df_path, write_interval): 10 | super().__init__(write_interval) 11 | self.sample_submission_path = sample_submission_path 12 | self.submission_path = submission_path 13 | self.submission_zip_path = submission_zip_path 14 | self.submission_code_dir = submission_code_dir 15 | self.pred_df_path = pred_df_path 16 | 17 | def postprocess(self, submit_df): 18 | pred_df = pd.read_csv(self.pred_df_path) 19 | submit_df['study_id'] = pred_df['study_id'] 20 | 21 | submit_df['weight'] = 5 22 | submit_df.loc[pred_df['ViewPosition'].isin(['PA', 'AP']), 'weight'] = 7 23 | submit_df.loc[pred_df['ViewPosition'].isin(['LL', 'LATERAL']), 'weight'] = 3 24 | 25 | # weighted average values across images and fix that value for all images in that study 26 | classes = submit_df.columns[1:-2] 27 | submit_df[classes] = submit_df[classes].mul(submit_df['weight'], axis=0) 28 | submit_df[classes] = submit_df.groupby('study_id')[classes].transform('sum') 29 | submit_df['weight'] = submit_df.groupby('study_id')['weight'].transform('sum') 30 | submit_df[classes] = submit_df[classes].div(submit_df['weight'], axis=0) 31 | 32 | submit_df.drop(columns=['weight'], inplace=True) 33 | submit_df.drop(columns=['study_id'], inplace=True) 34 | 35 | return submit_df 36 | 37 | def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): 38 | # Add predictions 39 | predictions = torch.cat(predictions, dim=0) 40 | # torch.save(predictions, "predictions.pt") 41 | 42 | submit_df = pd.read_csv(self.sample_submission_path) 43 | submit_df.iloc[..., 1:] = predictions.float().squeeze( 44 | 0).detach().cpu().numpy() 45 | 46 | # Post processing 47 | submit_df = self.postprocess(submit_df) 48 | 49 | # Save submission 50 | submit_df.to_csv(self.submission_path, index=False) 51 | with zipfile.ZipFile(self.submission_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: 52 | # Add the folder and its contents to the zip 53 | for root, _, files in os.walk(self.submission_code_dir): 54 | for file in files: 55 | file_path = os.path.join(root, file) 56 | zipf.write(file_path, os.path.join('code',os.path.relpath(file_path, self.submission_code_dir))) 57 | 58 | # Add the file to the zip 59 | zipf.write(self.submission_path, os.path.basename(self.submission_path)) 60 | 61 | print(f"Submission saved!") 62 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import timm 3 | import torch.nn as nn 4 | import copy 5 | import einops 6 | from model.ml_decoder import MLDecoder 7 | from positional_encodings.torch_encodings import PositionalEncoding2D, Summer 8 | 9 | 10 | class Backbone(nn.Module): 11 | def __init__(self, timm_init_args): 12 | super().__init__() 13 | self.model = timm.create_model(**timm_init_args) 14 | self.model.head = nn.Identity() 15 | self.pos_encoding = Summer(PositionalEncoding2D(768)) 16 | self.head = MLDecoder(num_classes=26, initial_num_features=768) 17 | 18 | def forward(self, x): 19 | x = self.model(x) 20 | x = self.pos_encoding(x) 21 | x = self.head(x) 22 | return x 23 | 24 | 25 | class FusionBackbone(nn.Module): 26 | def __init__(self, timm_init_args, pretrained_path=None): 27 | super().__init__() 28 | self.model = timm.create_model(**timm_init_args) 29 | self.model.head = MLDecoder(num_classes=26, initial_num_features=768) 30 | if pretrained_path is not None: 31 | self.model.load_state_dict(torch.load(pretrained_path)) 32 | self.model.head = nn.Identity() 33 | self.conv2d = nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1) 34 | self.pos_encoding = Summer(PositionalEncoding2D(768)) 35 | self.padding_token = nn.Parameter(torch.randn(1, 768, 1, 1)) 36 | self.segment_embedding = nn.Parameter(torch.randn(4, 768, 1, 1)) 37 | 38 | self.head = MLDecoder(num_classes=26, initial_num_features=768) 39 | self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=768, nhead=8), num_layers=2) 40 | 41 | def forward(self, x): 42 | b, s, _, _, _ = x.shape 43 | 44 | x = einops.rearrange(x, 'b s c h w -> (b s) c h w') 45 | no_pad = torch.nonzero(x.sum(dim=(1, 2, 3)) != 0).squeeze(1) 46 | x = x[no_pad] 47 | 48 | with torch.no_grad(): 49 | x = self.model(x).detach() 50 | 51 | x = self.conv2d(x) 52 | x = self.pos_encoding(x) 53 | 54 | pad_tokens = einops.repeat(self.padding_token, '1 c 1 1 -> (b s) c h w', b=b, s=s, h=x.shape[2], w=x.shape[3]).type_as(x) 55 | segment_embedding = einops.repeat(self.segment_embedding, 's c 1 1 -> (b s) c h w', b=b, h=x.shape[2], w=x.shape[3]).type_as(x) 56 | pad_tokens[no_pad] = x + segment_embedding[no_pad] 57 | x = pad_tokens 58 | 59 | x = einops.rearrange(x, '(b s) c h w -> b (s h w) c', b=b, s=s, h=x.shape[2], w=x.shape[3]) 60 | mask =(x.sum(dim=-1) == 0).transpose(0, 1) 61 | x = self.transformer_encoder(x, src_key_padding_mask=mask) 62 | x = self.head(x, mask) 63 | 64 | return x 65 | 66 | 67 | class PretrainedBackbone(nn.Module): 68 | def __init__(self, timm_init_args, pretrained_path): 69 | super().__init__() 70 | self.model = timm.create_model(**timm_init_args) 71 | self.new_head = copy.deepcopy(self.model.head) 72 | self.model.load_state_dict(torch.load(pretrained_path)) 73 | self.model.head = nn.Identity() 74 | 75 | def forward(self, x): 76 | with torch.no_grad(): 77 | x = self.model(x) 78 | x = self.new_head(x.detach()) 79 | return x 80 | -------------------------------------------------------------------------------- /model/cxr_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lightning.pytorch as pl 3 | from torch.optim import AdamW 4 | from torchmetrics import AveragePrecision, AUROC 5 | from transformers import get_cosine_schedule_with_warmup 6 | from model.layers import Backbone, FusionBackbone 7 | from model.loss import get_loss 8 | 9 | 10 | class CxrModel(pl.LightningModule): 11 | def __init__(self, lr, classes, loss_init_args, timm_init_args): 12 | super(CxrModel, self).__init__() 13 | self.lr = lr 14 | self.classes = classes 15 | self.backbone = FusionBackbone(timm_init_args, 'model.pth') 16 | # self.backbone = Backbone(timm_init_args) 17 | self.validation_step_outputs = [] 18 | self.val_ap = AveragePrecision(task='binary') 19 | self.val_auc = AUROC(task="binary") 20 | 21 | self.criterion_cls = get_loss(**loss_init_args) 22 | 23 | def forward(self, image): 24 | return self.backbone(image) 25 | 26 | def shared_step(self, batch, batch_idx): 27 | image, label = batch 28 | pred = self(image) 29 | 30 | loss = self.criterion_cls(pred, label) 31 | 32 | pred=torch.sigmoid(pred).detach() 33 | 34 | return dict( 35 | loss=loss, 36 | pred=pred, 37 | label=label, 38 | ) 39 | 40 | def training_step(self, batch, batch_idx): 41 | res = self.shared_step(batch, batch_idx) 42 | self.log_dict({'loss': res['loss'].detach()}, prog_bar=True) 43 | self.log_dict({'train_loss': res['loss'].detach()}, prog_bar=True, on_step=False, on_epoch=True) 44 | return res['loss'] 45 | 46 | def validation_step(self, batch, batch_idx): 47 | res = self.shared_step(batch, batch_idx) 48 | self.log_dict({'val_loss': res['loss'].detach()}, prog_bar=True) 49 | self.validation_step_outputs.append(res) 50 | 51 | def on_validation_epoch_end(self): 52 | preds = torch.cat([x['pred'] for x in self.validation_step_outputs]) 53 | labels = torch.cat([x['label'] for x in self.validation_step_outputs]) 54 | 55 | val_ap = [] 56 | val_auroc = [] 57 | for i in range(26): 58 | ap = self.val_ap(preds[:, i], labels[:, i].long()) 59 | auroc = self.val_auc(preds[:, i], labels[:, i].long()) 60 | val_ap.append(ap) 61 | val_auroc.append(auroc) 62 | print(f'{self.classes[i]}_ap: {ap}') 63 | 64 | head_idx = [0, 2, 4, 12, 14, 16, 20, 24] 65 | medium_idx = [1, 3, 5, 6, 8, 9, 10, 13, 15, 22] 66 | tail_idx = [7, 11, 17, 18, 19, 21, 23, 25] 67 | 68 | self.log_dict({'val_ap': sum(val_ap)/26}, prog_bar=True) 69 | self.log_dict({'val_auroc': sum(val_auroc)/26}, prog_bar=True) 70 | self.log_dict({'val_head_ap': sum([val_ap[i] for i in head_idx]) / len(head_idx)}, prog_bar=True) 71 | self.log_dict({'val_medium_ap': sum([val_ap[i] for i in medium_idx]) / len(medium_idx)}, prog_bar=True) 72 | self.log_dict({'val_tail_ap': sum([val_ap[i] for i in tail_idx]) / len(tail_idx)}, prog_bar=True) 73 | self.validation_step_outputs = [] 74 | 75 | def predict_step(self, batch, batch_idx): 76 | pred = self.shared_step(batch, batch_idx)['pred'] 77 | image, label = batch 78 | batch_1 = (image.flip(-1), label) 79 | pred_1 = self.shared_step(batch_1, batch_idx)['pred'] 80 | pred = (pred + pred_1) / 2 81 | return pred 82 | 83 | def configure_optimizers(self): 84 | optimizer = AdamW(self.backbone.parameters(), lr=self.lr) 85 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, 250000) 86 | return [optimizer], [scheduler] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .neptune 2 | .vscode 3 | data 4 | lightning_logs 5 | save 6 | model.pth 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /dataset/cxr_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import lightning.pytorch as pl 5 | from torch.utils.data import DataLoader, ConcatDataset 6 | from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit 7 | from dataset.cxr_dataset import CxrDataset, CxrBalancedDataset, CxrStudyIdDataset 8 | from dataset.vin_dataset import VinDataset 9 | from dataset.nih_dataset import NihDataset 10 | from dataset.chexpert_dataset import ChexpertDataset 11 | from dataset.transforms import get_transforms 12 | 13 | 14 | class CxrDataModule(pl.LightningDataModule): 15 | def __init__(self, datamodule_cfg, dataloader_init_args): 16 | super(CxrDataModule, self).__init__() 17 | self.cfg = datamodule_cfg 18 | self.df = pd.read_csv(self.cfg["train_df_path"]) 19 | self.dataloader_init_args = dataloader_init_args 20 | if self.cfg["use_pseudo_label"]: 21 | print("Using pseudo label") 22 | self.vin_df = pd.read_csv(self.cfg["vinbig_pseudo_train_df_path"]) 23 | self.nih_df = pd.read_csv(self.cfg["nih_pseudo_train_df_path"]) 24 | self.chexpert_df = pd.read_csv(self.cfg["chexpert_pseudo_train_df_path"]) 25 | 26 | def setup(self, stage): 27 | transforms_train, transforms_val = get_transforms(self.cfg["size"]) 28 | if stage in ('fit', 'validate'): 29 | # split train/val 30 | msss = MultilabelStratifiedShuffleSplit( 31 | n_splits=1, test_size=self.cfg["val_split"], random_state=self.cfg["seed"]) 32 | train_idx, val_idx = next(msss.split( 33 | self.df, self.df[self.cfg["classes"]].values)) 34 | train_df = self.df.iloc[train_idx] 35 | val_df = self.df.iloc[val_idx] 36 | 37 | self.train_dataset = CxrStudyIdDataset(self.cfg, train_df, transforms_train) 38 | self.val_dataset = CxrStudyIdDataset(self.cfg, val_df, transforms_val) 39 | 40 | if self.cfg["use_pseudo_label"]: 41 | vin_dataset = VinDataset(self.cfg, self.vin_df, transforms_train) 42 | nih_dataset = NihDataset(self.cfg, self.nih_df, transforms_train) 43 | chexpert_dataset = ChexpertDataset(self.cfg, self.chexpert_df, transforms_train) 44 | print(f"vin len: {len(vin_dataset)}") 45 | print(f"nih len: {len(nih_dataset)}") 46 | print(f"chexpert len: {len(chexpert_dataset)}") 47 | self.train_dataset = ConcatDataset([self.train_dataset, vin_dataset, nih_dataset, chexpert_dataset]) 48 | 49 | print(f"train len: {len(self.train_dataset)}") 50 | print(f"val len: {len(self.val_dataset)}") 51 | 52 | elif stage == 'predict': 53 | if self.cfg["predict_pseudo_label"] == "vinbig": 54 | print("predicting with vinbig dataset") 55 | pred_df = pd.read_csv(self.cfg["vinbig_train_df_path"]) 56 | self.pred_dataset = VinDataset(self.cfg, pred_df, transforms_val) 57 | elif self.cfg["predict_pseudo_label"] == "nih": 58 | print("predicting with nih dataset") 59 | pred_df = pd.read_csv(self.cfg["nih_train_df_path"]) 60 | self.pred_dataset = NihDataset(self.cfg, pred_df, transforms_val) 61 | elif self.cfg["predict_pseudo_label"] == "chexpert": 62 | print("predicting with chexpert dataset") 63 | pred_df = pd.read_csv(self.cfg["chexpert_train_df_path"]) 64 | self.pred_dataset = ChexpertDataset(self.cfg, pred_df, transforms_val) 65 | else: 66 | pred_df = pd.read_csv(self.cfg["pred_df_path"]) 67 | self.pred_dataset = CxrStudyIdDataset(self.cfg, pred_df, transforms_val) 68 | 69 | def train_dataloader(self): 70 | return DataLoader(self.train_dataset, **self.dataloader_init_args, shuffle=True) 71 | 72 | def val_dataloader(self): 73 | return DataLoader(self.val_dataset, **self.dataloader_init_args, shuffle=False) 74 | 75 | def predict_dataloader(self): 76 | return DataLoader(self.pred_dataset, **self.dataloader_init_args, shuffle=False) -------------------------------------------------------------------------------- /dataset/cxr_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import pandas as pd 4 | import numpy as np 5 | import jpeg4py as jpeg 6 | from torch import from_numpy 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class CxrDataset(Dataset): 11 | def __init__(self, cfg, df, transform=None): 12 | self.cfg = cfg 13 | self.df = df 14 | self.transform = transform 15 | 16 | def __len__(self): 17 | return len(self.df) 18 | 19 | def __getitem__(self, index): 20 | if all([c in self.df.columns for c in self.cfg['classes']]): 21 | label = self.df.iloc[index][self.cfg['classes']].to_numpy().astype(np.float32) 22 | else: 23 | label = np.zeros(len(self.cfg['classes'])) 24 | 25 | path = self.df.iloc[index]["path"] 26 | path = os.path.join(self.cfg['data_dir'], path) 27 | resized_path = path.replace(".jpg", f"_resized_{self.cfg['size']}.jpg") 28 | 29 | if os.path.exists(resized_path): 30 | img = jpeg.JPEG(resized_path).decode() 31 | if os.path.exists(path): 32 | os.remove(path) 33 | assert img.shape == (self.cfg['size'], self.cfg['size'], 3) 34 | else: 35 | img = jpeg.JPEG(path).decode() 36 | img = cv2.resize(img, (self.cfg['size'], self.cfg['size']), interpolation=cv2.INTER_LANCZOS4) 37 | cv2.imwrite(resized_path, img) 38 | 39 | if self.transform: 40 | transformed = self.transform(image=img) 41 | img = transformed['image'] 42 | img = np.moveaxis(img, -1, 0) 43 | 44 | return img, label 45 | 46 | 47 | class CxrBalancedDataset(Dataset): 48 | def __init__(self, cfg, df, transform=None): 49 | self.cfg = cfg 50 | self.df = df 51 | self.transform = transform 52 | 53 | def __len__(self): 54 | return len(self.df) 55 | 56 | def __getitem__(self, index): 57 | class_name = self.cfg['classes'][index%len(self.cfg['classes'])] 58 | df = self.df[self.df[class_name] == 1].sample(1).iloc[0] 59 | 60 | label = df[self.cfg['classes']].to_numpy().astype(np.float32) 61 | 62 | path = df["path"] 63 | path = os.path.join(self.cfg['data_dir'], path) 64 | resized_path = path.replace(".jpg", f"_resized_{self.cfg['size']}.jpg") 65 | 66 | if os.path.exists(resized_path): 67 | img = jpeg.JPEG(resized_path).decode() 68 | assert img.shape == (self.cfg['size'], self.cfg['size'], 3) 69 | else: 70 | img = jpeg.JPEG(path).decode() 71 | img = cv2.resize(img, (self.cfg['size'], self.cfg['size']), interpolation=cv2.INTER_LANCZOS4) 72 | cv2.imwrite(resized_path, img) 73 | 74 | if self.transform: 75 | transformed = self.transform(image=img) 76 | img = transformed['image'] 77 | img = np.moveaxis(img, -1, 0) 78 | 79 | return img, label 80 | 81 | 82 | class CxrStudyIdDataset(Dataset): 83 | def __init__(self, cfg, df, transform=None): 84 | self.cfg = cfg 85 | self.df = df.groupby("study_id") 86 | self.study_ids = list(self.df.groups.keys()) 87 | self.transform = transform 88 | 89 | def __len__(self): 90 | return len(self.df) 91 | 92 | def __getitem__(self, index): 93 | df = self.df.get_group(self.study_ids[index]) 94 | if len(df) > 4: 95 | df = df.sample(4) 96 | 97 | if all([c in df.columns for c in self.cfg['classes']]): 98 | label = df[self.cfg['classes']].iloc[0].to_numpy().astype(np.float32) 99 | else: 100 | label = np.zeros(len(self.cfg['classes'])) 101 | 102 | imgs = [] 103 | for i in range(len(df)): 104 | path = df.iloc[i]["path"] 105 | path = os.path.join(self.cfg['data_dir'], path) 106 | resized_path = path.replace(".jpg", f"_resized_{self.cfg['size']}.jpg") 107 | if os.path.exists(resized_path): 108 | img = jpeg.JPEG(resized_path).decode() 109 | if os.path.exists(path): 110 | os.remove(path) 111 | assert img.shape == (self.cfg['size'], self.cfg['size'], 3) 112 | else: 113 | img = jpeg.JPEG(path).decode() 114 | img = cv2.resize(img, (self.cfg['size'], self.cfg['size']), interpolation=cv2.INTER_LANCZOS4) 115 | cv2.imwrite(resized_path, img) 116 | 117 | if self.transform: 118 | transformed = self.transform(image=img) 119 | img = transformed['image'] 120 | img = np.moveaxis(img, -1, 0) 121 | imgs.append(img) 122 | 123 | img = np.stack(imgs, axis=0) 124 | img = np.concatenate([img, np.zeros((4-len(df), 3, self.cfg['size'], self.cfg['size']))], axis=0).astype(np.float32) 125 | return img, label 126 | 127 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | trainer: 3 | accelerator: auto 4 | strategy: auto 5 | devices: auto 6 | num_nodes: 1 7 | precision: 16-mixed 8 | logger: 9 | - class_path: lightning.pytorch.loggers.NeptuneLogger 10 | init_args: 11 | project: dongkyuk/CXR 12 | name: 384, tiny, baseline 13 | log_model_checkpoints: false 14 | callbacks: 15 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 16 | init_args: 17 | dirpath: save/transformer-fusion-padding-segembedding 18 | filename: '{epoch:02d}-{val_loss:.4f}-{val_ap:.5f}' 19 | save_top_k: 10 20 | monitor: val_loss 21 | mode: min 22 | # - class_path: callbacks.submit_callback.SubmissonWriter 23 | # init_args: 24 | # write_interval: epoch 25 | # sample_submission_path: data/sample_submission.csv 26 | # submission_path: save/submission.csv 27 | # submission_zip_path: save/submission.zip 28 | # submission_code_dir: data/code 29 | # pred_df_path: data/development.csv 30 | - class_path: callbacks.fusion_submit_callback.FusionSubmissonWriter 31 | init_args: 32 | write_interval: epoch 33 | sample_submission_path: data/sample_submission.csv 34 | submission_path: save/submission.csv 35 | submission_zip_path: save/submission.zip 36 | submission_code_dir: data/code 37 | pred_df_path: data/development.csv 38 | # - class_path: callbacks.vinbig_pseudo_callback.VinBigWriter 39 | # init_args: 40 | # write_interval: epoch 41 | # vinbig_train_df_path: data/vinbig/train_processed.csv 42 | # vinbig_pseudo_train_df_path: data/vinbig/train_processed_pseudo_labeled.csv 43 | # - class_path: callbacks.nih_pseudo_callback.NihWriter 44 | # init_args: 45 | # write_interval: epoch 46 | # nih_train_df_path: data/nih/train_processed.csv 47 | # nih_pseudo_train_df_path: data/nih/train_processed_pseudo_labeled.csv 48 | # - class_path: callbacks.chexpert_pseudo_callback.ChexpertWriter 49 | # init_args: 50 | # write_interval: epoch 51 | # chexpert_train_df_path: data/chexpert/train_processed.csv 52 | # chexpert_pseudo_train_df_path: data/chexpert/train_processed_pseudo_labeled.csv 53 | # - class_path: lightning.pytorch.callbacks.StochasticWeightAveraging 54 | # init_args: 55 | # swa_lrs: 1e-2 56 | # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping 57 | # init_args: 58 | # monitor: val_ap 59 | # mode: max 60 | # patience: 10 61 | fast_dev_run: false 62 | overfit_batches: 0.0 63 | val_check_interval: 0.25 64 | num_sanity_val_steps: 2 65 | accumulate_grad_batches: 1 66 | gradient_clip_val: null 67 | deterministic: null 68 | benchmark: true 69 | max_epochs: 1000 70 | model: 71 | lr: 3e-5 72 | classes: 73 | - Atelectasis 74 | - Calcification of the Aorta 75 | - Cardiomegaly 76 | - Consolidation 77 | - Edema 78 | - Emphysema 79 | - Enlarged Cardiomediastinum 80 | - Fibrosis 81 | - Fracture 82 | - Hernia 83 | - Infiltration 84 | - Lung Lesion 85 | - Lung Opacity 86 | - Mass 87 | - No Finding 88 | - Nodule 89 | - Pleural Effusion 90 | - Pleural Other 91 | - Pleural Thickening 92 | - Pneumomediastinum 93 | - Pneumonia 94 | - Pneumoperitoneum 95 | - Pneumothorax 96 | - Subcutaneous Emphysema 97 | - Support Devices 98 | - Tortuous Aorta 99 | loss_init_args: 100 | type: asl 101 | class_instance_nums: 102 | - 67597 103 | - 4361 104 | - 76900 105 | - 16038 106 | - 38574 107 | - 4255 108 | - 30119 109 | - 1158 110 | - 11883 111 | - 4049 112 | - 10218 113 | - 2533 114 | - 79931 115 | - 5529 116 | - 41869 117 | - 7663 118 | - 69240 119 | - 675 120 | - 3369 121 | - 788 122 | - 48093 123 | - 543 124 | - 14983 125 | - 2453 126 | - 89140 127 | - 3499 128 | total_instance_num: 264849 129 | timm_init_args: 130 | num_classes: 26 131 | model_name: convnext_small.fb_in22k_ft_in1k 132 | pretrained: true 133 | data: 134 | dataloader_init_args: 135 | batch_size: 8 136 | num_workers: 8 137 | pin_memory: true 138 | persistent_workers: true 139 | datamodule_cfg: 140 | predict_pseudo_label: null #chexpert 141 | use_pseudo_label: false #true 142 | data_dir: data 143 | train_df_path: data/train.csv 144 | pred_df_path: data/development.csv 145 | vinbig_train_df_path: data/vinbig/train_processed.csv 146 | vinbig_pseudo_train_df_path: data/vinbig/train_processed_pseudo_labeled.csv 147 | nih_train_df_path: data/nih/train_processed.csv 148 | nih_pseudo_train_df_path: data/nih/train_processed_pseudo_labeled.csv 149 | chexpert_train_df_path: data/chexpert/train_processed.csv 150 | chexpert_pseudo_train_df_path: data/chexpert/train_processed_pseudo_labeled.csv 151 | save_dir: save 152 | val_split: 0.1 153 | seed: 42 154 | size: 1024 155 | classes: 156 | - Atelectasis 157 | - Calcification of the Aorta 158 | - Cardiomegaly 159 | - Consolidation 160 | - Edema 161 | - Emphysema 162 | - Enlarged Cardiomediastinum 163 | - Fibrosis 164 | - Fracture 165 | - Hernia 166 | - Infiltration 167 | - Lung Lesion 168 | - Lung Opacity 169 | - Mass 170 | - No Finding 171 | - Nodule 172 | - Pleural Effusion 173 | - Pleural Other 174 | - Pleural Thickening 175 | - Pneumomediastinum 176 | - Pneumonia 177 | - Pneumoperitoneum 178 | - Pneumothorax 179 | - Subcutaneous Emphysema 180 | - Support Devices 181 | - Tortuous Aorta 182 | 183 | ckpt_path: /home/paperspace/Desktop/CXR-LT/save/transformer-fusion-padding-segembedding/epoch=07-val_loss=0.0600-val_ap=0.36752.ckpt #/home/paperspace/Desktop/CXR-LT/save/transformer-fusion-padding-segembedding/epoch=06-val_loss=0.0600-val_ap=0.36612.ckpt #/home/paperspace/Desktop/CXR-LT/save/transformer-fusion-new/epoch=04-val_loss=0.0601-val_ap=0.36637.ckpt #/home/paperspace/Desktop/CXR-LT/save/pseudo-all-recent/epoch=05-val_loss=0.0594-val_ap=0.36696.ckpt #/home/paperspace/Desktop/CXR-LT/save/pseudo-all/epoch=06-val_loss=0.0597-val_ap=0.36819.ckpt #null #/home/paperspace/Desktop/CXR-LT/save/asl-mldecoder/epoch=06-val_loss=0.0600-val_ap=0.35800.ckpt # 184 | -------------------------------------------------------------------------------- /model/ml_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from torch import nn, Tensor 4 | from torch.nn.modules.transformer import _get_activation_fn 5 | 6 | 7 | def add_ml_decoder_head(model, num_classes=-1, num_of_groups=-1, decoder_embedding=768, zsl=0): 8 | if num_classes == -1: 9 | num_classes = model.num_classes 10 | num_features = model.num_features 11 | if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 12 | model.global_pool = nn.Identity() 13 | del model.fc 14 | model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features, num_of_groups=num_of_groups, 15 | decoder_embedding=decoder_embedding, zsl=zsl) 16 | elif hasattr(model, 'head'): # tresnet 17 | if hasattr(model, 'global_pool'): 18 | model.global_pool = nn.Identity() 19 | del model.head 20 | model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features, num_of_groups=num_of_groups, 21 | decoder_embedding=decoder_embedding, zsl=zsl) 22 | else: 23 | print("model is not suited for ml-decoder") 24 | exit(-1) 25 | 26 | return model 27 | 28 | 29 | class TransformerDecoderLayerOptimal(nn.Module): 30 | def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", 31 | layer_norm_eps=1e-5) -> None: 32 | super(TransformerDecoderLayerOptimal, self).__init__() 33 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 34 | self.dropout = nn.Dropout(dropout) 35 | self.dropout1 = nn.Dropout(dropout) 36 | self.dropout2 = nn.Dropout(dropout) 37 | self.dropout3 = nn.Dropout(dropout) 38 | 39 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 40 | 41 | # Implementation of Feedforward model 42 | self.linear1 = nn.Linear(d_model, dim_feedforward) 43 | self.linear2 = nn.Linear(dim_feedforward, d_model) 44 | 45 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 46 | self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) 47 | 48 | self.activation = _get_activation_fn(activation) 49 | 50 | def __setstate__(self, state): 51 | if 'activation' not in state: 52 | state['activation'] = torch.nn.functional.relu 53 | super(TransformerDecoderLayerOptimal, self).__setstate__(state) 54 | 55 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 56 | memory_mask: Optional[Tensor] = None, 57 | tgt_key_padding_mask: Optional[Tensor] = None, 58 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 59 | tgt = tgt + self.dropout1(tgt) 60 | tgt = self.norm1(tgt) 61 | tgt2 = self.multihead_attn(tgt, memory, memory)[0] 62 | tgt = tgt + self.dropout2(tgt2) 63 | tgt = self.norm2(tgt) 64 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 65 | tgt = tgt + self.dropout3(tgt2) 66 | tgt = self.norm3(tgt) 67 | return tgt 68 | 69 | 70 | # @torch.jit.script 71 | # class ExtrapClasses(object): 72 | # def __init__(self, num_queries: int, group_size: int): 73 | # self.num_queries = num_queries 74 | # self.group_size = group_size 75 | # 76 | # def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: 77 | # torch.Tensor): 78 | # # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) 79 | # h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) 80 | # w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) 81 | # out = (h * w).sum(dim=2) + class_embed_b 82 | # out = out.view((h.shape[0], self.group_size * self.num_queries)) 83 | # return out 84 | 85 | @torch.jit.script 86 | class GroupFC(object): 87 | def __init__(self, embed_len_decoder: int): 88 | self.embed_len_decoder = embed_len_decoder 89 | 90 | def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): 91 | for i in range(h.shape[1]): 92 | h_i = h[:, i, :] 93 | if len(duplicate_pooling.shape)==3: 94 | w_i = duplicate_pooling[i, :, :] 95 | else: 96 | w_i = duplicate_pooling 97 | out_extrap[:, i, :] = torch.matmul(h_i, w_i) 98 | 99 | 100 | class MLDecoder(nn.Module): 101 | def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, 102 | initial_num_features=2048, zsl=0): 103 | super(MLDecoder, self).__init__() 104 | embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups 105 | if embed_len_decoder > num_classes: 106 | embed_len_decoder = num_classes 107 | 108 | # switching to 768 initial embeddings 109 | decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding 110 | embed_standart = nn.Linear(initial_num_features, decoder_embedding) 111 | 112 | # non-learnable queries 113 | if not zsl: 114 | query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) 115 | query_embed.requires_grad_(False) 116 | else: 117 | query_embed = None 118 | 119 | # decoder 120 | decoder_dropout = 0.1 121 | num_layers_decoder = 1 122 | dim_feedforward = 2048 123 | layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, 124 | dim_feedforward=dim_feedforward, dropout=decoder_dropout) 125 | self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) 126 | self.decoder.embed_standart = embed_standart 127 | self.decoder.query_embed = query_embed 128 | self.zsl = zsl 129 | 130 | if self.zsl: 131 | if decoder_embedding != 300: 132 | self.wordvec_proj = nn.Linear(300, decoder_embedding) 133 | else: 134 | self.wordvec_proj = nn.Identity() 135 | self.decoder.duplicate_pooling = torch.nn.Parameter(torch.Tensor(decoder_embedding, 1)) 136 | self.decoder.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(1)) 137 | self.decoder.duplicate_factor = 1 138 | else: 139 | # group fully-connected 140 | self.decoder.num_classes = num_classes 141 | self.decoder.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) 142 | self.decoder.duplicate_pooling = torch.nn.Parameter( 143 | torch.Tensor(embed_len_decoder, decoder_embedding, self.decoder.duplicate_factor)) 144 | self.decoder.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) 145 | torch.nn.init.xavier_normal_(self.decoder.duplicate_pooling) 146 | torch.nn.init.constant_(self.decoder.duplicate_pooling_bias, 0) 147 | self.decoder.group_fc = GroupFC(embed_len_decoder) 148 | self.train_wordvecs = None 149 | self.test_wordvecs = None 150 | 151 | def forward(self, x, mask=None): 152 | if len(x.shape) == 4: # [bs,2048, 7,7] 153 | embedding_spatial = x.flatten(2).transpose(1, 2) 154 | else: # [bs, 197,468] 155 | embedding_spatial = x 156 | embedding_spatial_786 = self.decoder.embed_standart(embedding_spatial) 157 | embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True) 158 | 159 | bs = embedding_spatial_786.shape[0] 160 | if self.zsl: 161 | query_embed = torch.nn.functional.relu(self.wordvec_proj(self.decoder.query_embed)) 162 | else: 163 | query_embed = self.decoder.query_embed.weight 164 | # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) 165 | tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand 166 | h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1), memory_key_padding_mask=mask) # [embed_len_decoder, batch, 768] 167 | h = h.transpose(0, 1) 168 | 169 | out_extrap = torch.zeros(h.shape[0], h.shape[1], self.decoder.duplicate_factor, device=h.device, dtype=h.dtype) 170 | self.decoder.group_fc(h, self.decoder.duplicate_pooling, out_extrap) 171 | if not self.zsl: 172 | h_out = out_extrap.flatten(1)[:, :self.decoder.num_classes] 173 | else: 174 | h_out = out_extrap.flatten(1) 175 | h_out += self.decoder.duplicate_pooling_bias 176 | logits = h_out 177 | return logits 178 | --------------------------------------------------------------------------------