├── Pre-Training ├── Contrastive_Learning │ ├── pl_bolts │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── object_detection.py │ │ │ └── rl.py │ │ ├── transforms │ │ │ ├── __init__.py │ │ │ ├── self_supervised │ │ │ │ ├── __init__.py │ │ │ │ └── ssl_transforms.py │ │ │ └── dataset_normalizations.py │ │ ├── callbacks │ │ │ ├── vision │ │ │ │ ├── __init__.py │ │ │ │ ├── sr_image_logger.py │ │ │ │ └── image_generation.py │ │ │ ├── verification │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ ├── torch_ort.py │ │ │ ├── byol_updates.py │ │ │ └── variational.py │ │ ├── models │ │ │ ├── gans │ │ │ │ ├── basic │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── components.py │ │ │ │ ├── dcgan │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── components.py │ │ │ │ ├── srgan │ │ │ │ │ └── __init__.py │ │ │ │ ├── pix2pix │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── pix2pix_module.py │ │ │ │ └── __init__.py │ │ │ ├── rl │ │ │ │ ├── common │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── cli.py │ │ │ │ │ └── distributions.py │ │ │ │ ├── __init__.py │ │ │ │ ├── dueling_dqn_model.py │ │ │ │ ├── double_dqn_model.py │ │ │ │ └── noisy_dqn_model.py │ │ │ ├── detection │ │ │ │ ├── yolo │ │ │ │ │ └── __init__.py │ │ │ │ ├── components │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── _supported_models.py │ │ │ │ │ └── torchvision_backbones.py │ │ │ │ ├── retinanet │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── backbones.py │ │ │ │ ├── faster_rcnn │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── backbones.py │ │ │ │ └── __init__.py │ │ │ ├── vision │ │ │ │ ├── image_gpt │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── gpt2.py │ │ │ │ ├── __init__.py │ │ │ │ ├── pixel_cnn.py │ │ │ │ ├── unet.py │ │ │ │ └── segmentation.py │ │ │ ├── self_supervised │ │ │ │ ├── byol │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── logo_data_set.py │ │ │ │ │ ├── models.py │ │ │ │ │ └── custom_wandb_logger.py │ │ │ │ ├── moco │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── moco_data_set.py │ │ │ │ │ ├── callbacks.py │ │ │ │ │ └── custom_wandb_logger.py │ │ │ │ ├── swav │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── PT_Dataset.py │ │ │ │ ├── evaluator.py │ │ │ │ └── __init__.py │ │ │ ├── regression │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ ├── autoencoders │ │ │ │ ├── __init__.py │ │ │ │ ├── basic_vae │ │ │ │ │ └── __init__.py │ │ │ │ └── basic_ae │ │ │ │ │ └── __init__.py │ │ │ └── mnist_module.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ ├── aggregation.py │ │ │ └── object_detection.py │ │ ├── optimizers │ │ │ └── __init__.py │ │ ├── utils │ │ │ ├── shaping.py │ │ │ ├── self_supervised.py │ │ │ ├── pretrained_weights.py │ │ │ ├── warnings.py │ │ │ ├── __init__.py │ │ │ └── semi_supervised.py │ │ ├── datasets │ │ │ ├── concat_dataset.py │ │ │ ├── sr_mnist_dataset.py │ │ │ ├── sr_stl10_dataset.py │ │ │ ├── sr_celeba_dataset.py │ │ │ ├── emnist_dataset.py │ │ │ ├── __init__.py │ │ │ ├── base_dataset.py │ │ │ ├── utils.py │ │ │ ├── sr_dataset_mixin.py │ │ │ ├── mnist_dataset.py │ │ │ ├── kitti_dataset.py │ │ │ └── ssl_amdim_datasets.py │ │ ├── __init__.py │ │ ├── __about__.py │ │ ├── datamodules │ │ │ ├── __init__.py │ │ │ ├── sr_datamodule.py │ │ │ ├── binary_emnist_datamodule.py │ │ │ ├── mnist_datamodule.py │ │ │ ├── binary_mnist_datamodule.py │ │ │ └── fashion_mnist_datamodule.py │ │ └── setup_tools.py │ ├── requirements.txt │ └── README.md ├── Data_Preprocessing │ ├── requirements.txt │ └── README.md ├── Masked_Autoencoder │ ├── requirements.txt │ ├── LICENSE │ ├── main.sh │ ├── models │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── custom.py │ ├── README.md │ ├── utils │ │ ├── lr_control.py │ │ ├── imagenet.py │ │ └── arg_util.py │ ├── sampler.py │ ├── decoder.py │ ├── launch.py │ └── dist.py └── README.md └── Downstream ├── requirements.txt └── README.md /Pre-Training/Contrastive_Learning/pl_bolts/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/vision/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/basic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/dcgan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/srgan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/verification/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/yolo/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/pix2pix/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/vision/image_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/byol/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Pre-Training/Data_Preprocessing/requirements.txt: -------------------------------------------------------------------------------- 1 | pydicom~=1.4.2 2 | numpy~=1.21.2 3 | opencv-python~=4.5.5 4 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/transforms/self_supervised/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.transforms.self_supervised.ssl_transforms import Patchify, RandomTranslateWithReflect # noqa: F401 2 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.metrics.aggregation import accuracy, mean, precision_at_k 2 | 3 | __all__ = [ 4 | "accuracy", 5 | "mean", 6 | "precision_at_k", 7 | ] 8 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/components/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection.components.torchvision_backbones import create_torchvision_backbone 2 | 3 | __all__ = ["create_torchvision_backbone"] 4 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | torchvision==0.11.0 4 | torchaudio==0.10.0 5 | Pillow 6 | typed-argument-parser 7 | timm==0.5.4 8 | setuptools==59.5.0 9 | tensorboard 10 | tensorboardx 11 | -------------------------------------------------------------------------------- /Downstream/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning~=1.7.7 2 | matplotlib~=3.5.2 3 | seaborn~=0.11.2 4 | wandb~=0.13.4 5 | numpy~=1.22.3 6 | scikit-learn~=1.1.1 7 | pillow~=9.2.0 8 | torchmetrics~=0.10.0 9 | medmnist~=2.1.0 10 | monai~=1.0.0 11 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/retinanet/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection.retinanet.backbones import create_retinanet_backbone 2 | from pl_bolts.models.detection.retinanet.retinanet_module import RetinaNet 3 | 4 | __all__ = ["create_retinanet_backbone", "RetinaNet"] 5 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/regression/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.regression.linear_regression import LinearRegression 2 | from pl_bolts.models.regression.logistic_regression import LogisticRegression 3 | 4 | __all__ = [ 5 | "LinearRegression", 6 | "LogisticRegression", 7 | ] 8 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/faster_rcnn/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection.faster_rcnn.backbones import create_fasterrcnn_backbone 2 | from pl_bolts.models.detection.faster_rcnn.faster_rcnn_module import FasterRCNN 3 | 4 | __all__ = ["create_fasterrcnn_backbone", "FasterRCNN"] 5 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.optimizers.lars import LARS 2 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR, linear_warmup_decay 3 | 4 | __all__ = [ 5 | "LARS", 6 | "LinearWarmupCosineAnnealingLR", 7 | "linear_warmup_decay", 8 | ] 9 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/moco/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.self_supervised.moco.transforms import ( # noqa: F401 2 | Moco2EvalCIFAR10Transforms, 3 | Moco2EvalImagenetTransforms, 4 | Moco2EvalSTL10Transforms, 5 | Moco2TrainCIFAR10Transforms, 6 | Moco2TrainImagenetTransforms, 7 | Moco2TrainSTL10Transforms, 8 | ) 9 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/requirements.txt: -------------------------------------------------------------------------------- 1 | #torch>=1.7.1 2 | torchmetrics>=0.4.1 3 | pytorch-lightning~=1.6.0 4 | packaging~=21.3 5 | lightning-bolts~=0.5.0 6 | #dataclasses~=0.8 7 | numpy~=1.21.2 8 | #torchvision~=0.12.0 9 | pillow~=9.0.1 10 | wandb~=0.12.11 11 | #setuptools~=58.0.4 12 | pydicom~=1.4.2 13 | nibabel~=3.0.2 14 | opencv-python 15 | scikit-image~=0.19.3 16 | 17 | 18 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.vision.image_gpt.gpt2 import GPT2 2 | from pl_bolts.models.vision.image_gpt.igpt_module import ImageGPT 3 | from pl_bolts.models.vision.pixel_cnn import PixelCNN 4 | from pl_bolts.models.vision.segmentation import SemSegment 5 | from pl_bolts.models.vision.unet import UNet 6 | 7 | __all__ = [ 8 | "GPT2", 9 | "ImageGPT", 10 | "PixelCNN", 11 | "SemSegment", 12 | "UNet", 13 | ] 14 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/utils/shaping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import Tensor 4 | 5 | 6 | def tile(a: Tensor, dim: int, n_tile: int) -> Tensor: 7 | init_dim = a.size(dim) 8 | repeat_idx = [1] * a.dim() 9 | repeat_idx[dim] = n_tile 10 | a = a.repeat(*repeat_idx) 11 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) 12 | return torch.index_select(a, dim, order_index) 13 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.gans.basic.basic_gan_module import GAN 2 | from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN 3 | from pl_bolts.models.gans.pix2pix.pix2pix_module import Pix2Pix 4 | from pl_bolts.models.gans.srgan.srgan_module import SRGAN 5 | from pl_bolts.models.gans.srgan.srresnet_module import SRResNet 6 | 7 | __all__ = [ 8 | "GAN", 9 | "DCGAN", 10 | "Pix2Pix", 11 | "SRGAN", 12 | "SRResNet", 13 | ] 14 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of PyTorchLightning models.""" 2 | from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE 3 | from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE 4 | from pl_bolts.models.mnist_module import LitMNIST 5 | from pl_bolts.models.regression import LinearRegression, LogisticRegression 6 | 7 | __all__ = [ 8 | "AE", 9 | "VAE", 10 | "LitMNIST", 11 | "LinearRegression", 12 | "LogisticRegression", 13 | ] 14 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.detection import components 2 | from pl_bolts.models.detection.faster_rcnn import FasterRCNN 3 | from pl_bolts.models.detection.retinanet import RetinaNet 4 | from pl_bolts.models.detection.yolo.yolo_config import YOLOConfiguration 5 | from pl_bolts.models.detection.yolo.yolo_module import YOLO 6 | 7 | __all__ = [ 8 | "components", 9 | "FasterRCNN", 10 | "YOLOConfiguration", 11 | "YOLO", 12 | "RetinaNet", 13 | ] 14 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class ConcatDataset(Dataset): 5 | def __init__(self, *datasets): 6 | self.datasets = datasets 7 | 8 | def __getitem__(self, i): 9 | result = [] 10 | for dataset in self.datasets: 11 | cycled_i = i % len(dataset) 12 | result.append(dataset[cycled_i]) 13 | 14 | return tuple(result) 15 | 16 | def __len__(self): 17 | return max(len(d) for d in self.datasets) 18 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/utils/self_supervised.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | from pl_bolts.models.self_supervised import resnets 4 | from pl_bolts.utils.semi_supervised import Identity 5 | 6 | 7 | def torchvision_ssl_encoder( 8 | name: str, 9 | pretrained: bool = False, 10 | return_all_feature_maps: bool = False, 11 | ) -> Module: 12 | pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps) 13 | pretrained_model.fc = Identity() 14 | return pretrained_model 15 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/swav/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.self_supervised.swav.swav_module_cifar import SwAV 2 | from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50 3 | from pl_bolts.models.self_supervised.swav.transforms import ( 4 | SwAVEvalDataTransform, 5 | SwAVFinetuneTransform, 6 | SwAVTrainDataTransform, 7 | ) 8 | 9 | __all__ = [ 10 | "SwAV", 11 | "resnet18", 12 | "resnet50", 13 | "SwAVEvalDataTransform", 14 | "SwAVFinetuneTransform", 15 | "SwAVTrainDataTransform", 16 | ] 17 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/autoencoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Here are a VAE and GAN.""" 2 | 3 | from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE 4 | from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE 5 | from pl_bolts.models.autoencoders.components import ( 6 | resnet18_decoder, 7 | resnet18_encoder, 8 | resnet50_decoder, 9 | resnet50_encoder, 10 | ) 11 | 12 | __all__ = [ 13 | "AE", 14 | "VAE", 15 | "resnet18_decoder", 16 | "resnet18_encoder", 17 | "resnet50_decoder", 18 | "resnet50_encoder", 19 | ] 20 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/__init__.py: -------------------------------------------------------------------------------- 1 | """Root package crossroad.""" 2 | 3 | import os 4 | 5 | from pl_bolts.__about__ import * # noqa: F401, F403 6 | 7 | _PACKAGE_ROOT = os.path.dirname(__file__) 8 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) 9 | _HTTPS_AWS_HUB = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" 10 | 11 | from pl_bolts import ( # noqa: E402 12 | callbacks, 13 | datamodules, 14 | datasets, 15 | losses, 16 | metrics, 17 | models, 18 | optimizers, 19 | transforms, 20 | utils, 21 | ) 22 | 23 | __all__ = [ 24 | "callbacks", 25 | "datamodules", 26 | "datasets", 27 | "losses", 28 | "metrics", 29 | "models", 30 | "optimizers", 31 | "transforms", 32 | "utils", 33 | ] 34 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic 2 | from pl_bolts.models.rl.double_dqn_model import DoubleDQN 3 | from pl_bolts.models.rl.dqn_model import DQN 4 | from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN 5 | from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN 6 | from pl_bolts.models.rl.per_dqn_model import PERDQN 7 | from pl_bolts.models.rl.reinforce_model import Reinforce 8 | from pl_bolts.models.rl.sac_model import SAC 9 | from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient 10 | 11 | __all__ = [ 12 | "AdvantageActorCritic", 13 | "DoubleDQN", 14 | "DQN", 15 | "DuelingDQN", 16 | "NoisyDQN", 17 | "PERDQN", 18 | "Reinforce", 19 | "SAC", 20 | "VanillaPolicyGradient", 21 | ] 22 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/utils/pretrained_weights.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from pytorch_lightning import LightningModule 4 | 5 | vae_imagenet2012 = ( 6 | "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt" 7 | ) 8 | 9 | cpcv2_resnet18 = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com/" "cpc/resnet18-v6/epoch%3D85.ckpt" 10 | urls = {"vae-imagenet2012": vae_imagenet2012, "CPC_v2-resnet18": cpcv2_resnet18} 11 | 12 | 13 | def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no cover 14 | if class_name is None: 15 | class_name = model.__class__.__name__ 16 | ckpt_url = urls[class_name] 17 | weights_model = model.__class__.load_from_checkpoint(ckpt_url) 18 | model.load_state_dict(weights_model.state_dict()) 19 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/autoencoders/basic_vae/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | VAE Template 3 | ============ 4 | 5 | This is a basic template for implementing a Variational Autoencoder in PyTorch Lightning. 6 | 7 | A default encoder and decoder have been provided but can easily be replaced by custom models. 8 | 9 | This template uses the CIFAR10 dataset but image data of any dimension can be fed in as long as the image 10 | width and image height are even values. For other types of data, such as sound, it will be necessary 11 | to change the Encoder and Decoder. 12 | 13 | The default encoder is a resnet18 backbone followed by linear layers which map representations 14 | to mu and var. The default decoder mirrors the encoder architecture and is similar to an inverted 15 | resnet18. The model also assumes a Gaussian prior and a Gaussian approximate posterior distribution. 16 | """ 17 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/moco/moco_data_set.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import torchvision.datasets as datasets 3 | 4 | 5 | logger = getLogger() 6 | 7 | 8 | class MoCoDataset(datasets.ImageFolder): 9 | def __init__( 10 | self, 11 | data_path, 12 | tau_g, 13 | ): 14 | super(MoCoDataset, self).__init__(data_path) 15 | self.tau_g = tau_g # set of global transforms 16 | 17 | def __len__(self): 18 | return len(self.samples) 19 | 20 | def __str__(self): 21 | return f"LoGoDataset with {self.__len__()} images" 22 | 23 | def __getitem__(self, idx): 24 | path, _ = self.samples[idx] 25 | image = self.loader(path) 26 | 27 | global_crops = list(map(lambda transform: transform(image), self.tau_g)) 28 | 29 | return (global_crops[0], global_crops[1]), 0 30 | -------------------------------------------------------------------------------- /Pre-Training/Data_Preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Prprocessing 2 | 3 | Download the LIDC-IDRI Dataset from here: [https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254) 4 | 5 | The data comes as DICOM images. We save each slice of each CT volume as a png file. We do not applay any windwoing since we want to use all data for pre-training. 6 | 7 | ### Start: 8 | If you are using Conda on Linux, here is how to get started: 9 | 1. Open your terminal and follow these steps: 10 | 1. conda create --name SSL_Preprocessing python==3.10 11 | 2. conda activate SSL_Preprocessing 12 | 4. cd ...SSL-MedicalImagining-CL-MAE/Pre-Training/Data_Preprocessing 13 | 5. pip install -r requirements.txt 14 | 2. Open LIDC_3DDICOM_to_2Dpng.py and adjust the folder pathes in the main method 15 | 16 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/autoencoders/basic_ae/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AE Template 3 | ============ 4 | 5 | This is a basic template for implementing an Autoencoder in PyTorch Lightning. 6 | 7 | A default encoder and decoder have been provided but can easily be replaced by custom models. 8 | 9 | This template uses the CIFAR10 dataset but image data of any dimension can be fed in as long as the image 10 | width and image height are even values. For other types of data, such as sound, it will be necessary 11 | to change the Encoder and Decoder. 12 | 13 | The default encoder is a resnet18 backbone followed by linear layers which map representations to latent space. 14 | The default decoder mirrors the encoder architecture and is similar to an inverted resnet18. 15 | 16 | .. code-block:: python 17 | 18 | from pl_bolts.models.autoencoders import AE 19 | 20 | model = AE() 21 | trainer = pl.Trainer() 22 | trainer.fit(model) 23 | """ 24 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/moco/callbacks.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from pytorch_lightning import Callback 4 | 5 | 6 | class MocoLRScheduler(Callback): 7 | def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200): 8 | super().__init__() 9 | self.lr = initial_lr 10 | self.use_cosine_scheduler = use_cosine_scheduler 11 | self.schedule = schedule 12 | self.max_epochs = max_epochs 13 | 14 | def on_epoch_start(self, trainer, pl_module): 15 | epoch = trainer.current_epoch 16 | lr = self.lr 17 | 18 | if self.use_cosine_scheduler: # cosine lr schedule 19 | lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / self.max_epochs)) 20 | else: # stepwise lr schedule 21 | for milestone in self.schedule: 22 | lr *= 0.1 if epoch >= milestone else 1.0 23 | 24 | optimizer = trainer.optimizers[0] 25 | for param_group in optimizer.param_groups: 26 | param_group["lr"] = lr 27 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/sr_mnist_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pl_bolts.datasets.mnist_dataset import MNIST 4 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin 5 | from pl_bolts.utils import _PIL_AVAILABLE 6 | from pl_bolts.utils.warnings import warn_missing_pkg 7 | 8 | if _PIL_AVAILABLE: 9 | from PIL import Image 10 | else: # pragma: no cover 11 | warn_missing_pkg("PIL", pypi_name="Pillow") 12 | 13 | 14 | class SRMNIST(SRDatasetMixin, MNIST): 15 | """MNIST dataset that can be used to train Super Resolution models. 16 | 17 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. 18 | """ 19 | 20 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: 21 | hr_image_size = 28 22 | lr_image_size = hr_image_size // scale_factor 23 | self.image_channels = 1 24 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) 25 | 26 | def _get_image(self, index: int): 27 | return Image.fromarray(self.data[index].numpy(), mode="L") 28 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Keyu Tian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/metrics/aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mean(res, key): 5 | # recursive mean for multilevel dicts 6 | return torch.stack([x[key] if isinstance(x, dict) else mean(x, key) for x in res]).mean() 7 | 8 | 9 | def accuracy(preds, labels): 10 | preds = preds.float() 11 | max_lgt = torch.max(preds, 1)[1] 12 | num_correct = (max_lgt == labels).sum().item() 13 | num_correct = torch.tensor(num_correct).float() 14 | acc = num_correct / len(labels) 15 | 16 | return acc 17 | 18 | 19 | def precision_at_k(output, target, top_k=(1,)): 20 | """Computes the accuracy over the k top predictions for the specified values of k.""" 21 | with torch.no_grad(): 22 | maxk = max(top_k) 23 | batch_size = target.size(0) 24 | 25 | _, pred = output.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in top_k: 31 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/byol/logo_data_set.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision.datasets as datasets 6 | import torchvision.transforms as transforms 7 | from PIL import Image, ImageFilter 8 | 9 | logger = getLogger() 10 | 11 | 12 | class LoGoDataset(datasets.ImageFolder): 13 | def __init__( 14 | self, 15 | data_path, 16 | tau_g, 17 | tau_l, 18 | ): 19 | super(LoGoDataset, self).__init__(data_path) 20 | self.tau_g = tau_g # set of global transforms 21 | self.tau_l = tau_l # set of local transforms 22 | 23 | def __len__(self): 24 | return len(self.samples) 25 | 26 | def __str__(self): 27 | return f"LoGoDataset with {self.__len__()} images" 28 | 29 | def __getitem__(self, idx): 30 | path, _ = self.samples[idx] 31 | image = self.loader(path) 32 | 33 | global_crops = list(map(lambda transform: transform(image), self.tau_g)) 34 | local_crops = list(map(lambda transform: transform(image), self.tau_l)) 35 | 36 | # breakpoint() 37 | 38 | return (global_crops[0], global_crops[1]), 0 39 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/components/_supported_models.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 2 | from pl_bolts.utils.warnings import warn_missing_pkg 3 | 4 | if _TORCHVISION_AVAILABLE: 5 | import torchvision 6 | 7 | TORCHVISION_MODEL_ZOO = { 8 | "vgg11": torchvision.models.vgg11, 9 | "vgg13": torchvision.models.vgg13, 10 | "vgg16": torchvision.models.vgg16, 11 | "vgg19": torchvision.models.vgg19, 12 | "resnet18": torchvision.models.resnet18, 13 | "resnet34": torchvision.models.resnet34, 14 | "resnet50": torchvision.models.resnet50, 15 | "resnet101": torchvision.models.resnet101, 16 | "resnet152": torchvision.models.resnet152, 17 | "resnext50_32x4d": torchvision.models.resnext50_32x4d, 18 | "resnext50_32x8d": torchvision.models.resnext101_32x8d, 19 | "mnasnet0_5": torchvision.models.mnasnet0_5, 20 | "mnasnet0_75": torchvision.models.mnasnet0_75, 21 | "mnasnet1_0": torchvision.models.mnasnet1_0, 22 | "mnasnet1_3": torchvision.models.mnasnet1_3, 23 | "mobilenet_v2": torchvision.models.mobilenet_v2, 24 | } 25 | 26 | else: # pragma: no cover 27 | warn_missing_pkg("torchvision") 28 | TORCHVISION_MODEL_ZOO = {} 29 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/evaluator.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SSLEvaluator(nn.Module): 5 | def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): 6 | super().__init__() 7 | self.n_input = n_input 8 | self.n_classes = n_classes 9 | self.n_hidden = n_hidden 10 | if n_hidden is None: 11 | # use linear classifier 12 | self.block_forward = nn.Sequential(Flatten(), nn.Dropout(p=p), nn.Linear(n_input, n_classes, bias=True)) 13 | else: 14 | # use simple MLP classifier 15 | self.block_forward = nn.Sequential( 16 | Flatten(), 17 | nn.Dropout(p=p), 18 | nn.Linear(n_input, n_hidden, bias=False), 19 | nn.BatchNorm1d(n_hidden), 20 | nn.ReLU(inplace=True), 21 | nn.Dropout(p=p), 22 | nn.Linear(n_hidden, n_classes, bias=True), 23 | ) 24 | 25 | def forward(self, x): 26 | logits = self.block_forward(x) 27 | return logits 28 | 29 | 30 | class Flatten(nn.Module): 31 | def __init__(self): 32 | super().__init__() 33 | 34 | def forward(self, input_tensor): 35 | return input_tensor.view(input_tensor.size(0), -1) 36 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/sr_stl10_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | 5 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin 6 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _PIL_AVAILABLE: 10 | import PIL 11 | else: # pragma: no cover 12 | warn_missing_pkg("PIL", pypi_name="Pillow") 13 | 14 | if _TORCHVISION_AVAILABLE: 15 | from torchvision.datasets import STL10 16 | else: # pragma: no cover 17 | warn_missing_pkg("torchvision") 18 | STL10 = object 19 | 20 | 21 | class SRSTL10(SRDatasetMixin, STL10): 22 | """STL10 dataset that can be used to train Super Resolution models. 23 | 24 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. 25 | """ 26 | 27 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: 28 | hr_image_size = 96 29 | lr_image_size = hr_image_size // scale_factor 30 | self.image_channels = 3 31 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) 32 | 33 | def _get_image(self, index: int): 34 | return PIL.Image.fromarray(np.transpose(self.data[index], (1, 2, 0))) 35 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/sr_celeba_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any 3 | 4 | from pl_bolts.datasets.sr_dataset_mixin import SRDatasetMixin 5 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 6 | from pl_bolts.utils.warnings import warn_missing_pkg 7 | 8 | if _PIL_AVAILABLE: 9 | from PIL import Image 10 | else: # pragma: no cover 11 | warn_missing_pkg("PIL", pypi_name="Pillow") 12 | 13 | if _TORCHVISION_AVAILABLE: 14 | from torchvision.datasets import CelebA 15 | else: # pragma: no cover 16 | warn_missing_pkg("torchvision") 17 | CelebA = object 18 | 19 | 20 | class SRCelebA(SRDatasetMixin, CelebA): 21 | """CelebA dataset that can be used to train Super Resolution models. 22 | 23 | Function __getitem__ (implemented in SRDatasetMixin) returns tuple of high and low resolution image. 24 | """ 25 | 26 | def __init__(self, scale_factor: int, *args: Any, **kwargs: Any) -> None: 27 | hr_image_size = 128 28 | lr_image_size = hr_image_size // scale_factor 29 | self.image_channels = 3 30 | super().__init__(hr_image_size, lr_image_size, self.image_channels, *args, **kwargs) 31 | 32 | def _get_image(self, index: int): 33 | return Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 34 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/__about__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.0dev" 2 | __author__ = "PyTorchLightning et al." 3 | __author_email__ = "name@pytorchlightning.ai" 4 | __license__ = "Apache-2.0" 5 | __copyright__ = f"Copyright (c) 2020-2021, {__author__}" 6 | __homepage__ = "https://github.com/PyTorchLightning/lightning-bolts" 7 | __docs__ = "PyTorch Lightning Bolts is a community contribution for ML researchers." 8 | __long_doc__ = """ 9 | What is it? 10 | ----------- 11 | Bolts is a collection of useful models and templates to bootstrap your DL research even faster. 12 | It's designed to work with PyTorch Lightning 13 | 14 | Subclass Example 15 | ---------------- 16 | Use `pl_bolts` models to remove boilerplate for common approaches and architectures. 17 | Because it uses LightningModules under the hood, you just need to overwrite 18 | the relevant parts to your research. 19 | 20 | How to add a model 21 | ------------------ 22 | This repository is meant for model contributions from the community. 23 | To add a model, you can start with the MNIST template (or any other model in the repo). 24 | Please organize the functions of your lightning module. 25 | """ 26 | 27 | __all__ = [ 28 | "__author__", 29 | "__author_email__", 30 | "__copyright__", 31 | "__docs__", 32 | "__homepage__", 33 | "__license__", 34 | "__version__", 35 | ] 36 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of PyTorchLightning callbacks.""" 2 | from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate 3 | from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor 4 | from pl_bolts.callbacks.printing import PrintTableMetricsCallback 5 | from pl_bolts.callbacks.sparseml import SparseMLCallback 6 | from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator 7 | from pl_bolts.callbacks.torch_ort import ORTCallback 8 | from pl_bolts.callbacks.variational import LatentDimInterpolator 9 | from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerificationCallback # type: ignore 10 | from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback 11 | from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler 12 | from pl_bolts.callbacks.vision.sr_image_logger import SRImageLoggerCallback 13 | 14 | __all__ = [ 15 | "BatchGradientVerificationCallback", 16 | "BYOLMAWeightUpdate", 17 | "ModuleDataMonitor", 18 | "TrainingDataMonitor", 19 | "PrintTableMetricsCallback", 20 | "SSLOnlineEvaluator", 21 | "LatentDimInterpolator", 22 | "ConfusedLogitCallback", 23 | "TensorboardGenerativeModelImageSampler", 24 | "SRImageLoggerCallback", 25 | "ORTCallback", 26 | "SparseMLCallback", 27 | ] 28 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/byol/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from pl_bolts.utils.self_supervised import torchvision_ssl_encoder 4 | 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, input_dim=2048, hidden_size=4096, output_dim=256): 8 | super().__init__() 9 | self.output_dim = output_dim 10 | self.input_dim = input_dim 11 | self.model = nn.Sequential( 12 | nn.Linear(input_dim, hidden_size, bias=False), 13 | nn.BatchNorm1d(hidden_size), 14 | nn.ReLU(inplace=True), 15 | nn.Linear(hidden_size, output_dim, bias=True), 16 | ) 17 | 18 | def forward(self, x): 19 | x = self.model(x) 20 | return x 21 | 22 | 23 | class SiameseArm(nn.Module): 24 | def __init__(self, encoder="resnet50", encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): 25 | super().__init__() 26 | 27 | if isinstance(encoder, str): 28 | encoder = torchvision_ssl_encoder(encoder) 29 | # Encoder 30 | self.encoder = encoder 31 | # Projector 32 | self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim) 33 | # Predictor 34 | self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim) 35 | 36 | def forward(self, x): 37 | y = self.encoder(x)[0] 38 | z = self.projector(y) 39 | h = self.predictor(z) 40 | return y, z, h 41 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/main.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # an example to do pre-training: (not that `/path/to/imagenet` should contain directories named `train` and `val`) 4 | # > cd /path/to/SparK 5 | # > bash ./main.sh experiment_name --num_nodes=1 --ngpu_per_node=8 --node_rank=0 --master_address=128.0.0.0 --master_port=30000 --data_path=/path/to/imagenet --model=convnext_small --ep=400 --wp_ep=10 6 | 7 | ####### template begins ####### 8 | SCRIPTS_DIR=$(cd $(dirname $0); pwd) 9 | cd "${SCRIPTS_DIR}" 10 | SPARK_DIR=$(pwd) 11 | echo "SPARK_DIR=${SPARK_DIR}" 12 | 13 | shopt -s expand_aliases 14 | alias python=python3 15 | alias to_scripts_dir='cd "${SCRIPTS_DIR}"' 16 | alias to_spark_dir='cd "${SPARK_DIR}"' 17 | alias print='echo "$(date +"[%m-%d %H:%M:%S]") (exp.sh)=>"' 18 | function mkd() { 19 | mkdir -p "$1" >/dev/null 2>&1 20 | } 21 | ####### template ends ####### 22 | 23 | 24 | EXP_NAME=$1 25 | 26 | EXP_DIR="${SPARK_DIR}/output_${EXP_NAME}" 27 | 28 | 29 | print "===================== Args =====================" 30 | print "EXP_NAME: ${EXP_NAME}" 31 | print "[other_args sent to launch.py]: ${*:2}" 32 | print "================================================" 33 | print "" 34 | 35 | 36 | print "============== Pretraining starts ==============" 37 | to_spark_dir 38 | touch ~/wait1 39 | python launch.py \ 40 | --main_py_relpath main.py \ 41 | --exp_name "${EXP_NAME}" \ 42 | --exp_dir "${EXP_DIR}" \ 43 | "${*:2}" 44 | print "============== Pretraining ends ==============" 45 | rm ~/wait1 46 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/emnist_dataset.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 2 | from pl_bolts.utils.warnings import warn_missing_pkg 3 | 4 | if _TORCHVISION_AVAILABLE: 5 | from torchvision.datasets import EMNIST 6 | else: # pragma: no cover 7 | warn_missing_pkg("torchvision") 8 | EMNIST = object 9 | 10 | if _PIL_AVAILABLE: 11 | from PIL import Image 12 | else: # pragma: no cover 13 | warn_missing_pkg("PIL", pypi_name="Pillow") 14 | 15 | 16 | class BinaryEMNIST(EMNIST): 17 | def __getitem__(self, idx): 18 | """ 19 | Args: 20 | index: Index 21 | 22 | Returns: 23 | tuple: (image, target) where target is index of the target class. 24 | """ 25 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 26 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 27 | 28 | img, target = self.data[idx], int(self.targets[idx]) 29 | 30 | # doing this so that it is consistent with all other datasets 31 | # to return a PIL Image 32 | img = Image.fromarray(img.numpy(), mode="L") 33 | 34 | if self.transform is not None: 35 | img = self.transform(img) 36 | 37 | if self.target_transform is not None: 38 | target = self.target_transform(target) 39 | 40 | # binary 41 | img[img < 0.5] = 0.0 42 | img[img >= 0.5] = 1.0 43 | 44 | return img, target 45 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/utils/warnings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Callable, Dict, Optional 4 | 5 | MISSING_PACKAGE_WARNINGS: Dict[str, int] = {} 6 | 7 | WARN_MISSING_PACKAGE = int(os.environ.get("WARN_MISSING_PACKAGE", False)) 8 | 9 | 10 | def warn_missing_pkg( 11 | pkg_name: str, 12 | pypi_name: Optional[str] = None, 13 | extra_text: Optional[str] = None, 14 | stdout_func: Callable = warnings.warn, 15 | ) -> int: 16 | """Template for warning on missing packages, show them just once. 17 | 18 | Args: 19 | pkg_name: Name of missing package 20 | pypi_name: In case that package name differ from PyPI name 21 | extra_text: Additional text after the base warning 22 | stdout_func: Define used function for streaming warning, use ``warnings.warn`` or ``logging.warning`` 23 | 24 | Returns: 25 | Number of warning calls 26 | """ 27 | if not WARN_MISSING_PACKAGE: 28 | return -1 29 | 30 | if pkg_name not in MISSING_PACKAGE_WARNINGS: 31 | extra_text = os.linesep + extra_text if extra_text else "" 32 | if not pypi_name: 33 | pypi_name = pkg_name 34 | stdout_func( 35 | f"You want to use `{pkg_name}` which is not installed yet," 36 | f" install it with `pip install {pypi_name}`." + extra_text 37 | ) 38 | MISSING_PACKAGE_WARNINGS[pkg_name] = 1 39 | else: 40 | MISSING_PACKAGE_WARNINGS[pkg_name] += 1 41 | 42 | return MISSING_PACKAGE_WARNINGS[pkg_name] 43 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | 3 | from pl_bolts.datasets.base_dataset import LightDataset 4 | from pl_bolts.datasets.cifar10_dataset import CIFAR10, TrialCIFAR10 5 | from pl_bolts.datasets.concat_dataset import ConcatDataset 6 | from pl_bolts.datasets.dummy_dataset import ( 7 | DummyDataset, 8 | DummyDetectionDataset, 9 | RandomDataset, 10 | RandomDictDataset, 11 | RandomDictStringDataset, 12 | ) 13 | from pl_bolts.datasets.emnist_dataset import BinaryEMNIST 14 | from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet, extract_archive, parse_devkit_archive 15 | from pl_bolts.datasets.kitti_dataset import KittiDataset 16 | from pl_bolts.datasets.mnist_dataset import MNIST, BinaryMNIST 17 | from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed, SSLDatasetMixin 18 | 19 | __all__ = [ 20 | "LightDataset", 21 | "CIFAR10", 22 | "TrialCIFAR10", 23 | "ConcatDataset", 24 | "DummyDataset", 25 | "DummyDetectionDataset", 26 | "MNIST", 27 | "RandomDataset", 28 | "RandomDictDataset", 29 | "RandomDictStringDataset", 30 | "extract_archive", 31 | "parse_devkit_archive", 32 | "UnlabeledImagenet", 33 | "KittiDataset", 34 | "BinaryMNIST", 35 | "CIFAR10Mixed", 36 | "SSLDatasetMixin", 37 | "BinaryEMNIST", 38 | ] 39 | 40 | # TorchVision hotfix https://github.com/pytorch/vision/issues/1938 41 | opener = urllib.request.build_opener() 42 | opener.addheaders = [("User-agent", "Mozilla/5.0")] 43 | urllib.request.install_opener(opener) 44 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/__init__.py: -------------------------------------------------------------------------------- 1 | """These models have been pre-trained using self-supervised learning. The models can also be used without pre- 2 | training and overwritten for your own research. 3 | 4 | Here's an example for using these as pretrained models. 5 | 6 | .. code-block :: 7 | 8 | from pl_bolts.models.self_supervised import CPC_v2 9 | 10 | images = get_imagenet_batch() 11 | 12 | # extract unsupervised representations 13 | pretrained = CPC_v2(pretrained=True) 14 | representations = pretrained(images) 15 | 16 | # use these in classification or any downstream task 17 | classifications = classifier(representations) 18 | """ 19 | from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM 20 | from pl_bolts.models.self_supervised.byol.byol_module import BYOL 21 | from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2 22 | from pl_bolts.models.self_supervised.evaluator import SSLEvaluator 23 | from pl_bolts.models.self_supervised.moco.moco2_module import Moco_v2 24 | from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR 25 | from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam 26 | from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner 27 | from pl_bolts.models.self_supervised.swav.swav_module_cifar import SwAV 28 | 29 | __all__ = [ 30 | "AMDIM", 31 | "BYOL", 32 | "CPC_v2", 33 | "SSLEvaluator", 34 | "Moco_v2", 35 | "SimCLR", 36 | "SimSiam", 37 | "SSLFineTuner", 38 | "SwAV", 39 | ] 40 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/byol/custom_wandb_logger.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities import rank_zero_only 2 | 3 | 4 | @rank_zero_only 5 | def log_hyperparameters(object_dict: dict) -> None: 6 | """Controls which config parts are saved by lightning loggers. 7 | Additionally saves: 8 | - Number of model parameters 9 | """ 10 | 11 | hparams = {} 12 | 13 | cfg = object_dict["cfg"] 14 | model = object_dict["model"] 15 | trainer = object_dict["trainer"] 16 | 17 | if not trainer.logger: 18 | log.warning("Logger not found! Skipping hyperparameter logging...") 19 | return 20 | 21 | # hparams["model"] = cfg["model"] 22 | 23 | # save number of model parameters 24 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 25 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 26 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 27 | 28 | # hparams["datamodule"] = cfg["datamodule"] 29 | # hparams["trainer"] = cfg["trainer"] 30 | 31 | # hparams["callbacks"] = cfg.get("callbacks") 32 | # hparams["extras"] = cfg.get("extras") 33 | 34 | # TODO: 35 | # hparams["task_name"] = cfg.get("task_name") 36 | hparams["tags"] = cfg.get("tags") 37 | # hparams["ckpt_path"] = cfg.get("ckpt_path") 38 | # hparams["seed"] = cfg.get("seed") 39 | 40 | hparams["cfg"] = cfg 41 | # send hparams to all loggers 42 | trainer.logger.log_hyperparams(hparams) 43 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/moco/custom_wandb_logger.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.utilities import rank_zero_only 2 | 3 | 4 | @rank_zero_only 5 | def log_hyperparameters(object_dict: dict) -> None: 6 | """Controls which config parts are saved by lightning loggers. 7 | Additionally saves: 8 | - Number of model parameters 9 | """ 10 | 11 | hparams = {} 12 | 13 | cfg = object_dict["cfg"] 14 | model = object_dict["model"] 15 | trainer = object_dict["trainer"] 16 | 17 | if not trainer.logger: 18 | log.warning("Logger not found! Skipping hyperparameter logging...") 19 | return 20 | 21 | # hparams["model"] = cfg["model"] 22 | 23 | # save number of model parameters 24 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 25 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 26 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 27 | 28 | # hparams["datamodule"] = cfg["datamodule"] 29 | # hparams["trainer"] = cfg["trainer"] 30 | 31 | # hparams["callbacks"] = cfg.get("callbacks") 32 | # hparams["extras"] = cfg.get("extras") 33 | 34 | # TODO: 35 | # hparams["task_name"] = cfg.get("task_name") 36 | hparams["tags"] = cfg.get("tags") 37 | # hparams["ckpt_path"] = cfg.get("ckpt_path") 38 | # hparams["seed"] = cfg.get("seed") 39 | 40 | hparams["cfg"] = cfg 41 | # send hparams to all loggers 42 | trainer.logger.log_hyperparams(hparams) 43 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/dueling_dqn_model.py: -------------------------------------------------------------------------------- 1 | """Dueling DQN.""" 2 | import argparse 3 | 4 | from pytorch_lightning import Trainer 5 | 6 | from pl_bolts.models.rl.common.networks import DuelingCNN 7 | from pl_bolts.models.rl.dqn_model import DQN 8 | 9 | 10 | class DuelingDQN(DQN): 11 | """PyTorch Lightning implementation of `Dueling DQN `_ 12 | 13 | Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas 14 | 15 | Model implemented by: 16 | 17 | - `Donal Byrne ` 18 | 19 | Example: 20 | 21 | >>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN 22 | ... 23 | >>> model = DuelingDQN("PongNoFrameskip-v4") 24 | 25 | Train:: 26 | 27 | trainer = Trainer() 28 | trainer.fit(model) 29 | 30 | .. note:: Currently only supports CPU and single GPU training with `accelerator=dp` 31 | """ 32 | 33 | def build_networks(self) -> None: 34 | """Initializes the Dueling DQN train and target networks.""" 35 | self.net = DuelingCNN(self.obs_shape, self.n_actions) 36 | self.target_net = DuelingCNN(self.obs_shape, self.n_actions) 37 | 38 | 39 | def cli_main(): 40 | parser = argparse.ArgumentParser(add_help=False) 41 | 42 | # trainer args 43 | parser = Trainer.add_argparse_args(parser) 44 | 45 | # model args 46 | parser = DuelingDQN.add_model_specific_args(parser) 47 | args = parser.parse_args() 48 | 49 | model = DuelingDQN(**args.__dict__) 50 | 51 | trainer = Trainer.from_argparse_args(args) 52 | trainer.fit(model) 53 | 54 | 55 | if __name__ == "__main__": 56 | cli_main() 57 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/faster_rcnn/backbones.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch.nn as nn 4 | 5 | from pl_bolts.models.detection.components import create_torchvision_backbone 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _TORCHVISION_AVAILABLE: 10 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 11 | else: # pragma: no cover 12 | warn_missing_pkg("torchvision") 13 | 14 | 15 | def create_fasterrcnn_backbone( 16 | backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any 17 | ) -> nn.Module: 18 | """ 19 | Args: 20 | backbone: 21 | Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", 22 | "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", 23 | as resnets with fpn backbones. 24 | Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", 25 | "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", 26 | fpn: If True then constructs fpn as well. 27 | pretrained: If None creates imagenet weights backbone. 28 | trainable_backbone_layers: number of trainable resnet layers starting from final block. 29 | """ 30 | 31 | if fpn: 32 | # Creates a torchvision resnet model with fpn added. 33 | backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) 34 | else: 35 | # This does not create fpn backbone, it is supported for all models 36 | backbone, _ = create_torchvision_backbone(backbone, pretrained) 37 | return backbone 38 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/retinanet/backbones.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch.nn as nn 4 | 5 | from pl_bolts.models.detection.components import create_torchvision_backbone 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _TORCHVISION_AVAILABLE: 10 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone 11 | else: # pragma: no cover 12 | warn_missing_pkg("torchvision") 13 | 14 | 15 | def create_retinanet_backbone( 16 | backbone: str, fpn: bool = True, pretrained: Optional[str] = None, trainable_backbone_layers: int = 3, **kwargs: Any 17 | ) -> nn.Module: 18 | """ 19 | Args: 20 | backbone: 21 | Supported backones are: "resnet18", "resnet34","resnet50", "resnet101", "resnet152", 22 | "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", 23 | as resnets with fpn backbones. 24 | Without fpn backbones supported are: "resnet18", "resnet34", "resnet50","resnet101", 25 | "resnet152", "resnext101_32x8d", "mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19", 26 | fpn: If True then constructs fpn as well. 27 | pretrained: If None creates imagenet weights backbone. 28 | trainable_backbone_layers: number of trainable resnet layers starting from final block. 29 | """ 30 | 31 | if fpn: 32 | # Creates a torchvision resnet model with fpn added. 33 | backbone = resnet_fpn_backbone(backbone, pretrained=True, trainable_layers=trainable_backbone_layers, **kwargs) 34 | else: 35 | # This does not create fpn backbone, it is supported for all models 36 | backbone, _ = create_torchvision_backbone(backbone, pretrained) 37 | return backbone 38 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/basic/components.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class Generator(nn.Module): 8 | def __init__(self, latent_dim, img_shape, hidden_dim=256): 9 | super().__init__() 10 | feats = int(np.prod(img_shape)) 11 | self.img_shape = img_shape 12 | self.fc1 = nn.Linear(latent_dim, hidden_dim) 13 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2) 14 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2) 15 | self.fc4 = nn.Linear(self.fc3.out_features, feats) 16 | 17 | # forward method 18 | def forward(self, z): 19 | z = F.leaky_relu(self.fc1(z), 0.2) 20 | z = F.leaky_relu(self.fc2(z), 0.2) 21 | z = F.leaky_relu(self.fc3(z), 0.2) 22 | img = torch.tanh(self.fc4(z)) 23 | img = img.view(img.size(0), *self.img_shape) 24 | return img 25 | 26 | 27 | class Discriminator(nn.Module): 28 | def __init__(self, img_shape, hidden_dim=1024): 29 | super().__init__() 30 | in_dim = int(np.prod(img_shape)) 31 | self.fc1 = nn.Linear(in_dim, hidden_dim) 32 | self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2) 33 | self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2) 34 | self.fc4 = nn.Linear(self.fc3.out_features, 1) 35 | 36 | # forward method 37 | def forward(self, img): 38 | x = img.view(img.size(0), -1) 39 | x = F.leaky_relu(self.fc1(x), 0.2) 40 | x = F.dropout(x, 0.3) 41 | x = F.leaky_relu(self.fc2(x), 0.2) 42 | x = F.dropout(x, 0.3) 43 | x = F.leaky_relu(self.fc3(x), 0.2) 44 | x = F.dropout(x, 0.3) 45 | return torch.sigmoid(self.fc4(x)) 46 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/common/cli.py: -------------------------------------------------------------------------------- 1 | """Contains generic arguments used for all models.""" 2 | 3 | import argparse 4 | 5 | 6 | def add_base_args(parent) -> argparse.ArgumentParser: 7 | """Adds arguments for DQN model. 8 | 9 | Note: 10 | These params are fine tuned for Pong env. 11 | 12 | Args: 13 | parent 14 | """ 15 | arg_parser = argparse.ArgumentParser(parents=[parent]) 16 | 17 | arg_parser.add_argument("--algo", type=str, default="dqn", help="algorithm to use for training") 18 | arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") 19 | arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 20 | 21 | arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") 22 | arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") 23 | 24 | arg_parser.add_argument("--episode_length", type=int, default=500, help="max length of an episode") 25 | arg_parser.add_argument("--max_episode_reward", type=int, default=18, help="max episode reward in the environment") 26 | arg_parser.add_argument( 27 | "--n_steps", 28 | type=int, 29 | default=4, 30 | help="how many steps to unroll for each update", 31 | ) 32 | arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") 33 | arg_parser.add_argument("--epoch_len", type=int, default=1000, help="how many batches per epoch") 34 | arg_parser.add_argument("--num_envs", type=int, default=1, help="number of environments to run at once") 35 | arg_parser.add_argument( 36 | "--avg_reward_len", type=int, default=100, help="how many episodes to include in avg reward" 37 | ) 38 | 39 | arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") 40 | return arg_parser 41 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import urllib.request 4 | from abc import ABC 5 | from typing import Sequence, Tuple 6 | from urllib.error import HTTPError 7 | 8 | from torch import Tensor 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class LightDataset(ABC, Dataset): 13 | 14 | data: Tensor 15 | targets: Tensor 16 | normalize: tuple 17 | dir_path: str 18 | cache_folder_name: str 19 | DATASET_NAME = "light" 20 | 21 | def __len__(self) -> int: 22 | return len(self.data) 23 | 24 | @property 25 | def cached_folder_path(self) -> str: 26 | return os.path.join(self.dir_path, self.DATASET_NAME, self.cache_folder_name) 27 | 28 | @staticmethod 29 | def _prepare_subset( 30 | full_data: Tensor, 31 | full_targets: Tensor, 32 | num_samples: int, 33 | labels: Sequence, 34 | ) -> Tuple[Tensor, Tensor]: 35 | """Prepare a subset of a common dataset.""" 36 | classes = {d: 0 for d in labels} 37 | indexes = [] 38 | for idx, target in enumerate(full_targets): 39 | label = target.item() 40 | if classes.get(label, float("inf")) >= num_samples: 41 | continue 42 | indexes.append(idx) 43 | classes[label] += 1 44 | if all(classes[k] >= num_samples for k in classes): 45 | break 46 | data = full_data[indexes] 47 | targets = full_targets[indexes] 48 | return data, targets 49 | 50 | def _download_from_url(self, base_url: str, data_folder: str, file_name: str): 51 | url = os.path.join(base_url, file_name) 52 | logging.info(f"Downloading {url}") 53 | fpath = os.path.join(data_folder, file_name) 54 | try: 55 | urllib.request.urlretrieve(url, fpath) 56 | except HTTPError as err: 57 | raise RuntimeError(f"Failed download from {url}") from err 58 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import random_split 2 | 3 | from pl_bolts.datasets.sr_celeba_dataset import SRCelebA 4 | from pl_bolts.datasets.sr_mnist_dataset import SRMNIST 5 | from pl_bolts.datasets.sr_stl10_dataset import SRSTL10 6 | 7 | 8 | def prepare_sr_datasets(dataset: str, scale_factor: int, data_dir: str): 9 | """Creates train, val, and test datasets for training a Super Resolution GAN. 10 | 11 | Args: 12 | dataset: string indicating which dataset class to use (celeba, mnist, or stl10). 13 | scale_factor: scale factor between low- and high resolution images. 14 | data_dir: root dir of dataset. 15 | 16 | Returns: 17 | sr_datasets: tuple containing train, val, and test dataset. 18 | """ 19 | assert dataset in ["celeba", "mnist", "stl10"] 20 | 21 | if dataset == "celeba": 22 | dataset_cls = SRCelebA 23 | dataset_train = dataset_cls(scale_factor, root=data_dir, split="train", download=True) 24 | dataset_val = dataset_cls(scale_factor, root=data_dir, split="valid", download=True) 25 | dataset_test = dataset_cls(scale_factor, root=data_dir, split="test", download=True) 26 | 27 | elif dataset == "mnist": 28 | dataset_cls = SRMNIST 29 | dataset_dev = dataset_cls(scale_factor, root=data_dir, train=True, download=True) 30 | dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000]) 31 | dataset_test = dataset_cls(scale_factor, root=data_dir, train=False, download=True) 32 | 33 | elif dataset == "stl10": 34 | dataset_cls = SRSTL10 35 | dataset_dev = dataset_cls(scale_factor, root=data_dir, split="train", download=True) 36 | dataset_train, dataset_val = random_split(dataset_dev, lengths=[4_500, 500]) 37 | dataset_test = dataset_cls(scale_factor, root=data_dir, split="test", download=True) 38 | 39 | return (dataset_train, dataset_val, dataset_test) 40 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/sr_dataset_mixin.py: -------------------------------------------------------------------------------- 1 | """Adapted from: https://github.com/https-deeplearning-ai/GANs-Public.""" 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | 6 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _PIL_AVAILABLE: 10 | from PIL import Image 11 | else: # pragma: no cover 12 | warn_missing_pkg("PIL", pypi_name="Pillow") 13 | 14 | if _TORCHVISION_AVAILABLE: 15 | from torchvision import transforms as transform_lib 16 | else: # pragma: no cover 17 | warn_missing_pkg("torchvision") 18 | 19 | 20 | class SRDatasetMixin: 21 | """Mixin for Super Resolution datasets. 22 | 23 | Scales range of high resolution images to [-1, 1] and range or low resolution images to [0, 1]. 24 | """ 25 | 26 | def __init__(self, hr_image_size: int, lr_image_size: int, image_channels: int, *args: Any, **kwargs: Any) -> None: 27 | super().__init__(*args, **kwargs) 28 | 29 | self.hr_transforms = transform_lib.Compose( 30 | [ 31 | transform_lib.RandomCrop(hr_image_size), 32 | transform_lib.ToTensor(), 33 | transform_lib.Normalize(mean=(0.5,) * image_channels, std=(0.5,) * image_channels), 34 | ] 35 | ) 36 | 37 | self.lr_transforms = transform_lib.Compose( 38 | [ 39 | transform_lib.Normalize(mean=(-1.0,) * image_channels, std=(2.0,) * image_channels), 40 | transform_lib.ToPILImage(), 41 | transform_lib.Resize(lr_image_size, Image.BICUBIC), 42 | transform_lib.ToTensor(), 43 | ] 44 | ) 45 | 46 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: 47 | image = self._get_image(index) 48 | 49 | hr_image = self.hr_transforms(image) 50 | lr_image = self.lr_transforms(hr_image) 51 | 52 | return hr_image, lr_image 53 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.datamodules.async_dataloader import AsynchronousLoader 2 | from pl_bolts.datamodules.binary_emnist_datamodule import BinaryEMNISTDataModule 3 | from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule 4 | from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule 5 | from pl_bolts.datamodules.cityscapes_datamodule import CityscapesDataModule 6 | from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule 7 | from pl_bolts.datamodules.experience_source import DiscountedExperienceSource, ExperienceSource, ExperienceSourceDataset 8 | from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule 9 | from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule 10 | from pl_bolts.datamodules.kitti_datamodule import KittiDataModule 11 | from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule 12 | from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule, SklearnDataset, TensorDataset 13 | from pl_bolts.datamodules.sr_datamodule import TVTDataModule 14 | from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule 15 | from pl_bolts.datamodules.stl10_datamodule import STL10DataModule 16 | from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule 17 | from pl_bolts.datasets.kitti_dataset import KittiDataset 18 | 19 | __all__ = [ 20 | "AsynchronousLoader", 21 | "BinaryMNISTDataModule", 22 | "CIFAR10DataModule", 23 | "TinyCIFAR10DataModule", 24 | "CityscapesDataModule", 25 | "DiscountedExperienceSource", 26 | "ExperienceSource", 27 | "ExperienceSourceDataset", 28 | "FashionMNISTDataModule", 29 | "ImagenetDataModule", 30 | "KittiDataModule", 31 | "MNISTDataModule", 32 | "SklearnDataModule", 33 | "SklearnDataset", 34 | "TensorDataset", 35 | "TVTDataModule", 36 | "SSLImagenetDataModule", 37 | "STL10DataModule", 38 | "VOCDetectionDataModule", 39 | "KittiDataset", 40 | "EMNISTDataModule", 41 | "BinaryEMNISTDataModule", 42 | ] 43 | -------------------------------------------------------------------------------- /Pre-Training/README.md: -------------------------------------------------------------------------------- 1 | # Pre-Training 2 | 3 | ## 1) Data Preprocessing 4 | We use the [LIDC-IDRI](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=1966254) dataset for self-supervised pre-training. Only the CT images are used without any labels or other information. \ 5 | We perform the pre-training on the 2D slices of the CT volumes. 6 | 7 | Go to the folder [Data_Preprocessing](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing) for the preprocessing code and further explanations. 8 | 9 | ## 2) Pre-Training 10 | We compare two types of self-supervised pre-training: Contrastive Learning and Masked Autoencoder. 11 | 12 | ### a) Contrastive Learning 13 | Three state-of-the-art and best-performing contrastive learning methods on convolutional networks are: 14 | [SwAV](https://proceedings.neurips.cc/paper/2020/hash/70feb62b69f16e0238f741fab228fec2-Abstract.html), [MoCo](https://openaccess.thecvf.com/content_CVPR_2020/html/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.html), and [BYOL](https://proceedings.neurips.cc/paper_files/paper/2020/file/f3ada80d5c4ee70142b17b8192b2958e-Paper.pdf) 15 | 16 | Go to the folder [Contrastive_Learning](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Contrastive_Learning) for the Contrastive Learning code and further explanations. 17 | 18 | ### b) Masked Autoencoder 19 | In a recent study published at the eleventh International Conference on Learning Representations 2023, Tian et al. demonstrate that Masked Autoencoder can be adapted for convolutional models using sparse convolutions. Their new approach, called [SparK](https://proceedings.neurips.cc/paper/2020/hash/70feb62b69f16e0238f741fab228fec2-Abstract.html), outperforms all state-of-the-art contrastive methods on a convolutional model, using natural images from ImageNet for self-supervised pre-training. We apply and investigate the Spark pre-training method to CT images. 20 | 21 | Go to the folder [Masked_Autoencoder](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Masked_Autoencoder) for the SparK code and further explanations. 22 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/losses/object_detection.py: -------------------------------------------------------------------------------- 1 | """Loss functions for Object Detection task.""" 2 | 3 | from torch import Tensor 4 | 5 | from pl_bolts.metrics.object_detection import giou, iou 6 | 7 | 8 | def iou_loss(preds: Tensor, target: Tensor) -> Tensor: 9 | """Calculates the intersection over union loss. 10 | 11 | Args: 12 | preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 13 | target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 14 | 15 | Example: 16 | 17 | >>> import torch 18 | >>> from pl_bolts.losses.object_detection import iou_loss 19 | >>> preds = torch.tensor([[100, 100, 200, 200]]) 20 | >>> target = torch.tensor([[150, 150, 250, 250]]) 21 | >>> iou_loss(preds, target) 22 | tensor([[0.8571]]) 23 | 24 | Returns: 25 | IoU loss 26 | """ 27 | loss = 1 - iou(preds, target) 28 | return loss 29 | 30 | 31 | def giou_loss(preds: Tensor, target: Tensor) -> Tensor: 32 | """Calculates the generalized intersection over union loss. 33 | 34 | It has been proposed in `Generalized Intersection over Union: A Metric and A 35 | Loss for Bounding Box Regression `_. 36 | 37 | Args: 38 | preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 39 | target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 40 | 41 | Example: 42 | 43 | >>> import torch 44 | >>> from pl_bolts.losses.object_detection import giou_loss 45 | >>> preds = torch.tensor([[100, 100, 200, 200]]) 46 | >>> target = torch.tensor([[150, 150, 250, 250]]) 47 | >>> giou_loss(preds, target) 48 | tensor([[1.0794]]) 49 | 50 | Returns: 51 | GIoU loss in an NxM tensor containing the pairwise GIoU loss for every element in preds and target, 52 | where N is the number of prediction bounding boxes and M is the number of target bounding boxes 53 | """ 54 | loss = 1 - giou(preds, target) 55 | return loss 56 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/vision/pixel_cnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | PixelCNN 3 | Implemented by: William Falcon 4 | Reference: https://arxiv.org/pdf/1905.09272.pdf (page 15) 5 | Accessed: May 14, 2020 6 | """ 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | class PixelCNN(nn.Module): 12 | """Implementation of `Pixel CNN `_. 13 | 14 | Paper authors: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, 15 | Koray Kavukcuoglu 16 | 17 | Implemented by: 18 | 19 | - William Falcon 20 | 21 | Example:: 22 | 23 | >>> from pl_bolts.models.vision import PixelCNN 24 | >>> import torch 25 | ... 26 | >>> model = PixelCNN(input_channels=3) 27 | >>> x = torch.rand(5, 3, 64, 64) 28 | >>> out = model(x) 29 | ... 30 | >>> out.shape 31 | torch.Size([5, 3, 64, 64]) 32 | """ 33 | 34 | def __init__(self, input_channels: int, hidden_channels: int = 256, num_blocks=5): 35 | super().__init__() 36 | self.input_channels = input_channels 37 | self.hidden_channels = hidden_channels 38 | 39 | self.blocks = nn.ModuleList([self.conv_block(input_channels) for _ in range(num_blocks)]) 40 | 41 | def conv_block(self, input_channels): 42 | c1 = nn.Conv2d(in_channels=input_channels, out_channels=self.hidden_channels, kernel_size=(1, 1)) 43 | act1 = nn.ReLU() 44 | c2 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=(1, 3)) 45 | pad = nn.ConstantPad2d((0, 0, 1, 0, 0, 0, 0, 0), 1) 46 | c3 = nn.Conv2d( 47 | in_channels=self.hidden_channels, out_channels=self.hidden_channels, kernel_size=(2, 1), padding=(0, 1) 48 | ) 49 | act2 = nn.ReLU() 50 | c4 = nn.Conv2d(in_channels=self.hidden_channels, out_channels=input_channels, kernel_size=(1, 1)) 51 | 52 | block = nn.Sequential(c1, act1, c2, pad, c3, act2, c4) 53 | return block 54 | 55 | def forward(self, z): 56 | c = z 57 | for conv_block in self.blocks: 58 | c = c + conv_block(c) 59 | 60 | c = F.relu(c) 61 | return c 62 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import operator 3 | from typing import Callable 4 | 5 | import torch 6 | from packaging.version import Version 7 | from pkg_resources import DistributionNotFound 8 | from pytorch_lightning.utilities import _module_available 9 | 10 | from pl_bolts.callbacks.verification.batch_gradient import BatchGradientVerification # type: ignore 11 | 12 | 13 | # Ported from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py 14 | def _compare_version(package: str, op: Callable, version: str) -> bool: 15 | """Compare package version with some requirements. 16 | 17 | >>> _compare_version("torch", operator.ge, "0.1") 18 | True 19 | """ 20 | try: 21 | pkg = importlib.import_module(package) 22 | except (ModuleNotFoundError, DistributionNotFound): 23 | return False 24 | try: 25 | pkg_version = Version(pkg.__version__) 26 | except TypeError: 27 | # this is mock by sphinx, so it shall return True ro generate all summaries 28 | return True 29 | return op(pkg_version, Version(version)) 30 | 31 | 32 | _NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") 33 | 34 | _TORCHVISION_AVAILABLE: bool = _module_available("torchvision") 35 | _GYM_AVAILABLE: bool = _module_available("gym") 36 | _SKLEARN_AVAILABLE: bool = _module_available("sklearn") 37 | _PIL_AVAILABLE: bool = _module_available("PIL") 38 | _OPENCV_AVAILABLE: bool = _module_available("cv2") 39 | _WANDB_AVAILABLE: bool = _module_available("wandb") 40 | _MATPLOTLIB_AVAILABLE: bool = _module_available("matplotlib") 41 | _TORCHVISION_LESS_THAN_0_9_1: bool = _compare_version("torchvision", operator.lt, "0.9.1") 42 | _PL_GREATER_EQUAL_1_4 = _compare_version("pytorch_lightning", operator.ge, "1.4.0") 43 | _PL_GREATER_EQUAL_1_4_5 = _compare_version("pytorch_lightning", operator.ge, "1.4.5") 44 | _TORCH_ORT_AVAILABLE = _module_available("torch_ort") 45 | _TORCH_MAX_VERSION_SPARSEML = _compare_version("torch", operator.lt, "1.10.0") 46 | _SPARSEML_AVAILABLE = _module_available("sparseml") and _PL_GREATER_EQUAL_1_4_5 and _TORCH_MAX_VERSION_SPARSEML 47 | 48 | __all__ = ["BatchGradientVerification"] 49 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/transforms/dataset_normalizations.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 2 | from pl_bolts.utils.warnings import warn_missing_pkg 3 | 4 | if _TORCHVISION_AVAILABLE: 5 | from torchvision import transforms 6 | else: # pragma: no cover 7 | warn_missing_pkg("torchvision") 8 | 9 | 10 | def imagenet_normalization(): 11 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 12 | raise ModuleNotFoundError( 13 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 14 | ) 15 | 16 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 17 | return normalize 18 | 19 | 20 | def cifar10_normalization(): 21 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 22 | raise ModuleNotFoundError( 23 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 24 | ) 25 | 26 | normalize = transforms.Normalize( 27 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 28 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]], 29 | ) 30 | return normalize 31 | 32 | 33 | def stl10_normalization(): 34 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 35 | raise ModuleNotFoundError( 36 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 37 | ) 38 | 39 | normalize = transforms.Normalize(mean=(0.43, 0.42, 0.39), std=(0.27, 0.26, 0.27)) 40 | return normalize 41 | 42 | 43 | def emnist_normalization(split: str): 44 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 45 | raise ModuleNotFoundError( 46 | "You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`." 47 | ) 48 | 49 | # `stats` contains mean and std for each `split`. 50 | stats = { 51 | "balanced": (0.175, 0.333), 52 | "byclass": (0.174, 0.332), 53 | "bymerge": (0.174, 0.332), 54 | "digits": (0.173, 0.332), 55 | "letters": (0.172, 0.331), 56 | "mnist": (0.173, 0.332), 57 | } 58 | 59 | return transforms.Normalize(mean=stats[split][0], std=stats[split][1]) 60 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from timm import create_model 9 | from timm.loss import SoftTargetCrossEntropy 10 | from timm.models.layers import drop 11 | 12 | 13 | from models.convnext import ConvNeXt 14 | from models.resnet import ResNet 15 | from models.custom import YourConvNet 16 | _import_resnets_for_timm_registration = (ResNet,) 17 | 18 | 19 | # log more 20 | def _ex_repr(self): 21 | return ', '.join( 22 | f'{k}=' + (f'{v:g}' if isinstance(v, float) else str(v)) 23 | for k, v in vars(self).items() 24 | if not k.startswith('_') and k != 'training' 25 | and not isinstance(v, (torch.nn.Module, torch.Tensor)) 26 | ) 27 | for clz in (torch.nn.CrossEntropyLoss, SoftTargetCrossEntropy, drop.DropPath): 28 | if hasattr(clz, 'extra_repr'): 29 | clz.extra_repr = _ex_repr 30 | else: 31 | clz.__repr__ = lambda self: f'{type(self).__name__}({_ex_repr(self)})' 32 | 33 | 34 | pretrain_default_model_kwargs = { 35 | 'your_convnet': dict(), 36 | 'resnet50': dict(drop_path_rate=0.05), 37 | 'resnet101': dict(drop_path_rate=0.08), 38 | 'resnet152': dict(drop_path_rate=0.10), 39 | 'resnet200': dict(drop_path_rate=0.15), 40 | 'convnext_small': dict(sparse=True, drop_path_rate=0.2), 41 | 'convnext_base': dict(sparse=True, drop_path_rate=0.3), 42 | 'convnext_large': dict(sparse=True, drop_path_rate=0.4), 43 | } 44 | for kw in pretrain_default_model_kwargs.values(): 45 | kw['pretrained'] = False 46 | kw['num_classes'] = 0 47 | kw['global_pool'] = '' 48 | 49 | 50 | def build_sparse_encoder(name: str, input_size: int, sbn=False, drop_path_rate=0.0, verbose=False): 51 | from encoder import SparseEncoder 52 | 53 | kwargs = pretrain_default_model_kwargs[name] 54 | if drop_path_rate != 0: 55 | kwargs['drop_path_rate'] = drop_path_rate 56 | print(f'[build_sparse_encoder] model kwargs={kwargs}') 57 | cnn = create_model(name, **kwargs) 58 | 59 | return SparseEncoder(cnn, input_size=input_size, sbn=sbn, verbose=verbose) 60 | 61 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/torch_ort.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from pytorch_lightning import Callback, LightningModule, Trainer 15 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 16 | 17 | from pl_bolts.utils import _TORCH_ORT_AVAILABLE 18 | 19 | if _TORCH_ORT_AVAILABLE: 20 | from torch_ort import ORTModule 21 | 22 | 23 | class ORTCallback(Callback): 24 | """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime. 25 | 26 | Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for 27 | training and inference. 28 | 29 | Usage: 30 | 31 | # via Transformer Tasks 32 | model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True) 33 | 34 | # or via the trainer 35 | trainer = flash.Trainer(callbacks=ORTCallback()) 36 | """ 37 | 38 | def __init__(self) -> None: 39 | if not _TORCH_ORT_AVAILABLE: 40 | raise MisconfigurationException( 41 | "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort" 42 | ) 43 | 44 | def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None: 45 | if not hasattr(pl_module, "model"): 46 | raise MisconfigurationException( 47 | "Torch ORT requires to wrap a single model that defines a forward function " 48 | "assigned as `model` inside the `LightningModule`." 49 | ) 50 | if not isinstance(pl_module.model, ORTModule): 51 | pl_module.model = ORTModule(pl_module.model) 52 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/README.md: -------------------------------------------------------------------------------- 1 | # Pre-Training with SparK 2 | 3 | SparK is the first successful adaptation of masked autoencoder self-supervised pre-training to convolutional neural networks (CNNs). 4 | 5 | This is code from the official implementation of SparK [https://github.com/keyu-tian/SparK) (MIT license) 6 | 7 | ### How to Start: 8 | 1. Download the LIDC data and run the preprocessing script as explained here: [https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing) 9 | 2. Change the folder structure of the preprocessed data to: (Take part of the images as validation) 10 | ```bash 11 | LIDC-Data 12 | / \ 13 | train val 14 | / \ 15 | 1 1 16 | ``` 17 | 2. Open your terminal and follow these steps: 18 | 1. conda create --name SSL_Masked_Autoencoder python==3.8 19 | 2. conda activate SSL_Masked_Autoencoder 20 | 3. conda install pytorch==1.10.0 torchvision==0.11.0 torchaudio==0.10.0 cudatoolkit=11.3 -c pytorch 21 | 4. cd .../SSL-MedicalImagining-CL-MAE/Pre-Training/Masked_Autoencoder/ 22 | 5. pip install -r requirements.txt 23 | 4. Start the pre-training with a bash script: 24 | ```bash 25 | #!/bin/bash 26 | 27 | python ./main.py \ 28 | --exp_name=ResNet50_1 \ 29 | --data_path=/path/to/LIDC-Data \ 30 | --model=resnet50 \ 31 | --bs=32 \ 32 | --exp_dir=/path/to/where/results/should/be/saved \ 33 | --ep=1600 \ 34 | ``` 35 | For further information and other setting please refere to the SparK github: [https://github.com/keyu-tian/SparK](https://github.com/keyu-tian/SparK4) 36 | 37 | 38 | ### SparK Paper 39 | Please also cite the SparK paper: 40 | 41 | ```latex 42 | @inproceedings{ 43 | tian2023designing, 44 | title={Designing {BERT} for Convolutional Networks: Sparse and Hierarchical Masked Modeling}, 45 | author={Keyu Tian and Yi Jiang and qishuai diao and Chen Lin and Liwei Wang and Zehuan Yuan}, 46 | booktitle={The Eleventh International Conference on Learning Representations }, 47 | year={2023}, 48 | url={https://openreview.net/forum?id=NRxydtWup1S} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/common/distributions.py: -------------------------------------------------------------------------------- 1 | """Distributions used in some continuous RL algorithms.""" 2 | import torch 3 | 4 | 5 | class TanhMultivariateNormal(torch.distributions.MultivariateNormal): 6 | """The distribution of X is an affine of tanh applied on a normal distribution. 7 | 8 | X = action_scale * tanh(Z) + action_bias 9 | Z ~ Normal(mean, variance) 10 | """ 11 | 12 | def __init__(self, action_bias, action_scale, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.action_bias = action_bias 16 | self.action_scale = action_scale 17 | 18 | def rsample_with_z(self, sample_shape=torch.Size()): 19 | """Samples X using reparametrization trick with the intermediate variable Z. 20 | 21 | Returns: 22 | Sampled X and Z 23 | """ 24 | z = super().rsample() 25 | return self.action_scale * torch.tanh(z) + self.action_bias, z 26 | 27 | def log_prob_with_z(self, value, z): 28 | """Computes the log probability of a sampled X. 29 | 30 | Refer to the original paper of SAC for more details in equation (20), (21) 31 | 32 | Args: 33 | value: the value of X 34 | z: the value of Z 35 | Returns: 36 | Log probability of the sample 37 | """ 38 | value = (value - self.action_bias) / self.action_scale 39 | z_logprob = super().log_prob(z) 40 | correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) 41 | return z_logprob - correction 42 | 43 | def rsample_and_log_prob(self, sample_shape=torch.Size()): 44 | """Samples X and computes the log probability of the sample. 45 | 46 | Returns: 47 | Sampled X and log probability 48 | """ 49 | z = super().rsample() 50 | z_logprob = super().log_prob(z) 51 | value = torch.tanh(z) 52 | correction = torch.log(self.action_scale * (1 - value ** 2) + 1e-7).sum(1) 53 | return self.action_scale * value + self.action_bias, z_logprob - correction 54 | 55 | def rsample(self, sample_shape=torch.Size()): 56 | fz, z = self.rsample_with_z(sample_shape) 57 | return fz 58 | 59 | def log_prob(self, value): 60 | value = (value - self.action_bias) / self.action_scale 61 | z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2 62 | return self.log_prob_with_z(value, z) 63 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/utils/lr_control.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from pprint import pformat 9 | 10 | 11 | def lr_wd_annealing(optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it): 12 | wp_it = round(wp_it) 13 | if cur_it < wp_it: 14 | cur_lr = 0.005 * peak_lr + 0.995 * peak_lr * cur_it / wp_it 15 | else: 16 | ratio = (cur_it - wp_it) / (max_it - 1 - wp_it) 17 | cur_lr = 0.001 * peak_lr + 0.999 * peak_lr * (0.5 + 0.5 * math.cos(math.pi * ratio)) 18 | 19 | ratio = cur_it / (max_it - 1) 20 | cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * ratio)) 21 | 22 | min_lr, max_lr = cur_lr, cur_lr 23 | min_wd, max_wd = cur_wd, cur_wd 24 | for param_group in optimizer.param_groups: 25 | scaled_lr = param_group['lr'] = cur_lr * param_group.get('lr_scale', 1) # 'lr_scale' could be assigned 26 | min_lr, max_lr = min(min_lr, scaled_lr), max(max_lr, scaled_lr) 27 | scaled_wd = param_group['weight_decay'] = cur_wd * param_group.get('weight_decay_scale', 1) # 'weight_decay_scale' could be assigned 28 | min_wd, max_wd = min(min_wd, scaled_wd), max(max_wd, scaled_wd) 29 | return min_lr, max_lr, min_wd, max_wd 30 | 31 | 32 | def get_param_groups(model, nowd_keys=()): 33 | para_groups, para_groups_dbg = {}, {} 34 | 35 | for name, para in model.named_parameters(): 36 | if not para.requires_grad: 37 | continue # frozen weights 38 | if len(para.shape) == 1 or name.endswith('.bias') or any(k in name for k in nowd_keys): 39 | wd_scale, group_name = 0., 'no_decay' 40 | else: 41 | wd_scale, group_name = 1., 'decay' 42 | 43 | if group_name not in para_groups: 44 | para_groups[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.} 45 | para_groups_dbg[group_name] = {'params': [], 'weight_decay_scale': wd_scale, 'lr_scale': 1.} 46 | para_groups[group_name]['params'].append(para) 47 | para_groups_dbg[group_name]['params'].append(name) 48 | 49 | for g in para_groups_dbg.values(): 50 | g['params'] = pformat(', '.join(g['params']), width=200) 51 | 52 | print(f'[get_ft_param_groups] param groups = \n{pformat(para_groups_dbg, indent=2, width=250)}\n') 53 | return list(para_groups.values()) 54 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.utils import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE, _TORCHVISION_LESS_THAN_0_9_1 2 | from pl_bolts.utils.warnings import warn_missing_pkg 3 | 4 | if _TORCHVISION_AVAILABLE: 5 | from torchvision.datasets import MNIST 6 | else: # pragma: no cover 7 | warn_missing_pkg("torchvision") 8 | MNIST = object 9 | 10 | if _PIL_AVAILABLE: 11 | from PIL import Image 12 | else: # pragma: no cover 13 | warn_missing_pkg("PIL", pypi_name="Pillow") 14 | 15 | # TODO(akihironitta): This is needed to avoid 503 error when downloading MNIST dataset 16 | # from http://yann.lecun.com/exdb/mnist/ and can be removed after `torchvision==0.9.1`. 17 | # See https://github.com/pytorch/vision/issues/3549 for details. 18 | if _TORCHVISION_AVAILABLE and _TORCHVISION_LESS_THAN_0_9_1: 19 | MNIST.resources = [ 20 | ( 21 | "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", 22 | "f68b3c2dcbeaaa9fbdd348bbdeb94873", 23 | ), 24 | ( 25 | "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", 26 | "d53e105ee54ea40749a09fcbcd1e9432", 27 | ), 28 | ( 29 | "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", 30 | "9fb629c4189551a2d022fa330f9573f3", 31 | ), 32 | ( 33 | "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", 34 | "ec29112dd5afa0611ce80d1b7f02629c", 35 | ), 36 | ] 37 | 38 | 39 | class BinaryMNIST(MNIST): 40 | def __getitem__(self, idx): 41 | """ 42 | Args: 43 | index (int): Index 44 | Returns: 45 | tuple: (image, target) where target is index of the target class. 46 | """ 47 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 48 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 49 | 50 | img, target = self.data[idx], int(self.targets[idx]) 51 | 52 | # doing this so that it is consistent with all other datasets 53 | # to return a PIL Image 54 | img = Image.fromarray(img.numpy(), mode="L") 55 | 56 | if self.transform is not None: 57 | img = self.transform(img) 58 | 59 | if self.target_transform is not None: 60 | target = self.target_transform(target) 61 | 62 | # binary 63 | img[img < 0.5] = 0.0 64 | img[img >= 0.5] = 1.0 65 | 66 | return img, target 67 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/byol_updates.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Sequence, Union 3 | 4 | from pytorch_lightning import Callback, LightningModule, Trainer 5 | from torch import Tensor 6 | from torch.nn import Module 7 | 8 | 9 | class BYOLMAWeightUpdate(Callback): 10 | """Weight update rule from BYOL. 11 | 12 | Your model should have: 13 | 14 | - ``self.online_network`` 15 | - ``self.target_network`` 16 | 17 | Updates the target_network params using an exponential moving average update rule weighted by tau. 18 | BYOL claims this keeps the online_network from collapsing. 19 | 20 | .. note:: Automatically increases tau from ``initial_tau`` to 1.0 with every training step 21 | 22 | Example:: 23 | 24 | # model must have 2 attributes 25 | model = Model() 26 | model.online_network = ... 27 | model.target_network = ... 28 | 29 | trainer = Trainer(callbacks=[BYOLMAWeightUpdate()]) 30 | """ 31 | 32 | def __init__(self, initial_tau: float = 0.996): 33 | """ 34 | Args: 35 | initial_tau: starting tau. Auto-updates with every training step 36 | """ 37 | super().__init__() 38 | self.initial_tau = initial_tau 39 | self.current_tau = initial_tau 40 | 41 | def on_train_batch_end( 42 | self, 43 | trainer: Trainer, 44 | pl_module: LightningModule, 45 | outputs: Sequence, 46 | batch: Sequence, 47 | batch_idx: int, 48 | dataloader_idx: int, 49 | ) -> None: 50 | # get networks 51 | online_net = pl_module.online_network 52 | target_net = pl_module.target_network 53 | 54 | # update weights 55 | self.update_weights(online_net, target_net) 56 | 57 | # update tau after 58 | self.current_tau = self.update_tau(pl_module, trainer) 59 | 60 | def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float: 61 | max_steps = len(trainer.train_dataloader) * trainer.max_epochs 62 | tau = 1 - (1 - self.initial_tau) * (math.cos(math.pi * pl_module.global_step / max_steps) + 1) / 2 63 | return tau 64 | 65 | def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None: 66 | # apply MA weight update 67 | for (name, online_p), (_, target_p) in zip( 68 | online_net.named_parameters(), 69 | target_net.named_parameters(), 70 | ): 71 | target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data 72 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/vision/sr_image_logger.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn.functional as F 6 | from pytorch_lightning import Callback 7 | 8 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 9 | from pl_bolts.utils.warnings import warn_missing_pkg 10 | 11 | if _TORCHVISION_AVAILABLE: 12 | from torchvision.utils import make_grid 13 | else: # pragma: no cover 14 | warn_missing_pkg("torchvision") 15 | 16 | 17 | class SRImageLoggerCallback(Callback): 18 | """Logs low-res, generated high-res, and ground truth high-res images to TensorBoard Your model must implement 19 | the ``forward`` function for generation. 20 | 21 | Requirements:: 22 | 23 | # model forward must work generating high-res from low-res image 24 | hr_fake = pl_module(lr_image) 25 | 26 | Example:: 27 | 28 | from pl_bolts.callbacks import SRImageLoggerCallback 29 | 30 | trainer = Trainer(callbacks=[SRImageLoggerCallback()]) 31 | """ 32 | 33 | def __init__(self, log_interval: int = 1000, scale_factor: int = 4, num_samples: int = 5) -> None: 34 | """ 35 | Args: 36 | log_interval: Number of steps between logging. Default: ``1000``. 37 | scale_factor: Scale factor used for downsampling the high-res images. Default: ``4``. 38 | num_samples: Number of images of displayed in the grid. Default: ``5``. 39 | """ 40 | super().__init__() 41 | self.log_interval = log_interval 42 | self.scale_factor = scale_factor 43 | self.num_samples = num_samples 44 | 45 | def on_train_batch_end( 46 | self, 47 | trainer: pl.Trainer, 48 | pl_module: pl.LightningModule, 49 | outputs: torch.Tensor, 50 | batch: Tuple[torch.Tensor, torch.Tensor], 51 | batch_idx: int, 52 | dataloader_idx: int, 53 | ) -> None: 54 | global_step = trainer.global_step 55 | if global_step % self.log_interval == 0: 56 | hr_image, lr_image = batch 57 | hr_image, lr_image = hr_image.to(pl_module.device), lr_image.to(pl_module.device) 58 | hr_fake = pl_module(lr_image) 59 | lr_image = F.interpolate(lr_image, scale_factor=self.scale_factor) 60 | 61 | lr_image_grid = make_grid(lr_image[: self.num_samples], nrow=1, normalize=True) 62 | hr_fake_grid = make_grid(hr_fake[: self.num_samples], nrow=1, normalize=True) 63 | hr_image_grid = make_grid(hr_image[: self.num_samples], nrow=1, normalize=True) 64 | 65 | grid = torch.cat((lr_image_grid, hr_fake_grid, hr_image_grid), -1) 66 | title = "sr_images" 67 | trainer.logger.experiment.add_image(title, grid, global_step=global_step) 68 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datamodules/sr_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pytorch_lightning import LightningDataModule 4 | from torch.utils.data import DataLoader, Dataset 5 | 6 | 7 | class TVTDataModule(LightningDataModule): 8 | """Simple DataModule creating train, val, and test dataloaders from given train, val, and test dataset. 9 | 10 | Example:: 11 | from pl_bolts.datamodules import TVTDataModule 12 | from pl_bolts.datasets.sr_mnist_dataset import SRMNIST 13 | 14 | dataset_dev = SRMNIST(scale_factor=4, root=".", train=True) 15 | dataset_train, dataset_val = random_split(dataset_dev, lengths=[55_000, 5_000]) 16 | dataset_test = SRMNIST(scale_factor=4, root=".", train=True) 17 | dm = TVTDataModule(dataset_train, dataset_val, dataset_test) 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dataset_train: Dataset, 23 | dataset_val: Dataset, 24 | dataset_test: Dataset, 25 | batch_size: int = 16, 26 | shuffle: bool = True, 27 | num_workers: int = 8, 28 | pin_memory: bool = True, 29 | drop_last: bool = True, 30 | *args: Any, 31 | **kwargs: Any, 32 | ) -> None: 33 | """ 34 | Args: 35 | dataset_train: Train dataset 36 | dataset_val: Val dataset 37 | dataset_test: Test dataset 38 | batch_size: How many samples per batch to load 39 | num_workers: How many workers to use for loading data 40 | shuffle: If true shuffles the train data every epoch 41 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 42 | returning them 43 | drop_last: If true drops the last incomplete batch 44 | """ 45 | super().__init__() 46 | 47 | self.dataset_train = dataset_train 48 | self.dataset_val = dataset_val 49 | self.dataset_test = dataset_test 50 | self.num_workers = num_workers 51 | self.batch_size = batch_size 52 | self.shuffle = shuffle 53 | self.pin_memory = pin_memory 54 | self.drop_last = drop_last 55 | 56 | def train_dataloader(self) -> DataLoader: 57 | return self._dataloader(self.dataset_train, shuffle=self.shuffle) 58 | 59 | def val_dataloader(self) -> DataLoader: 60 | return self._dataloader(self.dataset_val, shuffle=False) 61 | 62 | def test_dataloader(self) -> DataLoader: 63 | return self._dataloader(self.dataset_test, shuffle=False) 64 | 65 | def _dataloader(self, dataset: Dataset, shuffle: bool = True) -> DataLoader: 66 | return DataLoader( 67 | dataset, 68 | batch_size=self.batch_size, 69 | shuffle=shuffle, 70 | num_workers=self.num_workers, 71 | drop_last=self.drop_last, 72 | pin_memory=self.pin_memory, 73 | ) 74 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | def worker_init_fn(worker_id): 15 | # https://pytorch.org/docs/stable/notes/randomness.html#dataloader 16 | worker_seed = torch.initial_seed() % 2 ** 32 17 | np.random.seed(worker_seed) 18 | random.seed(worker_seed) 19 | 20 | 21 | class DistInfiniteBatchSampler(Sampler): 22 | def __init__(self, world_size, rank, dataset_len, glb_batch_size, seed=1, filling=False, shuffle=True): 23 | assert glb_batch_size % world_size == 0 24 | self.world_size, self.rank = world_size, rank 25 | self.dataset_len = dataset_len 26 | self.glb_batch_size = glb_batch_size 27 | self.batch_size = glb_batch_size // world_size 28 | 29 | self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size 30 | self.filling = filling 31 | self.shuffle = shuffle 32 | self.epoch = 0 33 | self.seed = seed 34 | self.indices = self.gener_indices() 35 | 36 | def gener_indices(self): 37 | global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 38 | if self.shuffle: 39 | g = torch.Generator() 40 | g.manual_seed(self.epoch + self.seed) 41 | global_indices = torch.randperm(self.dataset_len, generator=g) 42 | else: 43 | global_indices = torch.arange(self.dataset_len) 44 | filling = global_max_p - global_indices.shape[0] 45 | if filling > 0 and self.filling: 46 | global_indices = torch.cat((global_indices, global_indices[:filling])) 47 | global_indices = tuple(global_indices.numpy().tolist()) 48 | 49 | seps = torch.linspace(0, len(global_indices), self.world_size + 1, dtype=torch.int) 50 | local_indices = global_indices[seps[self.rank]:seps[self.rank + 1]] 51 | self.max_p = len(local_indices) 52 | return local_indices 53 | 54 | def __iter__(self): 55 | self.epoch = 0 56 | while True: 57 | self.epoch += 1 58 | p, q = 0, 0 59 | while p < self.max_p: 60 | q = p + self.batch_size 61 | yield self.indices[p:q] 62 | p = q 63 | if self.shuffle: 64 | self.indices = self.gener_indices() 65 | 66 | def __len__(self): 67 | return self.iters_per_ep 68 | 69 | 70 | if __name__ == '__main__': 71 | W = 16 72 | for rk in range(W): 73 | ind = DistInfiniteBatchSampler(W, rk, 5024, 5024).gener_indices() 74 | print(rk, len(ind)) 75 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import List 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from timm.models.resnet import ResNet 11 | 12 | 13 | # hack: inject the `get_downsample_ratio` function into `timm.models.resnet.ResNet` 14 | def get_downsample_ratio(self: ResNet) -> int: 15 | return 32 16 | 17 | 18 | # hack: inject the `get_feature_map_channels` function into `timm.models.resnet.ResNet` 19 | def get_feature_map_channels(self: ResNet) -> List[int]: 20 | # `self.feature_info` is maintained by `timm` 21 | return [info['num_chs'] for info in self.feature_info[1:]] 22 | 23 | 24 | # hack: override the forward function of `timm.models.resnet.ResNet` 25 | def forward(self, x, hierarchical=False): 26 | """ this forward function is a modified version of `timm.models.resnet.ResNet.forward` 27 | >>> ResNet.forward 28 | """ 29 | x = self.conv1(x) 30 | x = self.bn1(x) 31 | x = self.act1(x) 32 | x = self.maxpool(x) 33 | 34 | if hierarchical: 35 | ls = [] 36 | x = self.layer1(x); ls.append(x) 37 | x = self.layer2(x); ls.append(x) 38 | x = self.layer3(x); ls.append(x) 39 | x = self.layer4(x); ls.append(x) 40 | return ls 41 | else: 42 | x = self.global_pool(x) 43 | if self.drop_rate: 44 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 45 | x = self.fc(x) 46 | return x 47 | 48 | 49 | ResNet.get_downsample_ratio = get_downsample_ratio 50 | ResNet.get_feature_map_channels = get_feature_map_channels 51 | ResNet.forward = forward 52 | 53 | 54 | @torch.no_grad() 55 | def convnet_test(): 56 | from timm.models import create_model 57 | cnn = create_model('resnet50') 58 | print('get_downsample_ratio:', cnn.get_downsample_ratio()) 59 | print('get_feature_map_channels:', cnn.get_feature_map_channels()) 60 | 61 | downsample_ratio = cnn.get_downsample_ratio() 62 | feature_map_channels = cnn.get_feature_map_channels() 63 | 64 | # check the forward function 65 | B, C, H, W = 4, 3, 224, 224 66 | inp = torch.rand(B, C, H, W) 67 | feats = cnn(inp, hierarchical=True) 68 | assert isinstance(feats, list) 69 | assert len(feats) == len(feature_map_channels) 70 | print([tuple(t.shape) for t in feats]) 71 | 72 | # check the downsample ratio 73 | feats = cnn(inp, hierarchical=True) 74 | assert feats[-1].shape[-2] == H // downsample_ratio 75 | assert feats[-1].shape[-1] == W // downsample_ratio 76 | 77 | # check the channel number 78 | for feat, ch in zip(feats, feature_map_channels): 79 | assert feat.ndim == 4 80 | assert feat.shape[1] == ch 81 | 82 | 83 | if __name__ == '__main__': 84 | convnet_test() 85 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/swav/PT_Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import re 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class TorchDataset(Dataset): 11 | """ 12 | Loading the Datasets 13 | """ 14 | 15 | def __init__(self, directory, augmentations=False): 16 | self.directory = directory # 1) 17 | self.augmentations = augmentations 18 | 19 | self.images = os.listdir(directory) # 2) Liste von allen Files in directory (alle Images) 20 | 21 | 22 | # def augment_gaussian_noise(self, data_sample, noise_variance=(0.001, 0.05)): 23 | # # https://github.com/MIC-DKFZ/batchgenerators 24 | # if noise_variance[0] == noise_variance[1]: 25 | # variance = noise_variance[0] 26 | # else: 27 | # variance = random.uniform(noise_variance[0], noise_variance[1]) 28 | # data_sample = data_sample + np.random.normal(0.0, variance, size=data_sample.shape) 29 | # return data_sample 30 | 31 | def __len__(self): 32 | return len(os.listdir(self.directory)) # 3) Anzahl an Files in directory (Anzahl Bilder) 33 | 34 | def __getitem__(self, idx): # idx = Anzahl Aufrufe (Iteration 0,1,2,3,...) 35 | if torch.is_tensor(idx): 36 | idx = idx.tolist() 37 | 38 | # load image + lable 39 | name = self.images[idx] # 5) image: Liste von allen Files in directory (alle Images) [idx: geht Liste durch, eins Laden pro Aufruf] 40 | file = torch.load(os.path.join(self.directory, name)) # 6) Läd ein File pro Aufruf 41 | 42 | 43 | # Image / Lable trennen 44 | image = file["vol"] 45 | lable = file["class"] 46 | 47 | 48 | # Falls nicht als float sondern als Torch Tensor abgespeichert: 49 | image = image.to(torch.float32) 50 | 51 | 52 | # do augmentations 53 | if self.augmentations: 54 | random_number = random.randint(1, 10) 55 | image = image.numpy() 56 | if random_number >= 7: 57 | # do for each layer 58 | image = self.augment_gaussian_noise(image) 59 | image = torch.from_numpy(image) 60 | 61 | # falls Daten nicht schon mit einer Dim mehr gespeichert !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! (Lymphome (Channels) / Patienten) 62 | #image = image.unsqueeze(0) # [1,512,512] 63 | #print(image.shape) 64 | 65 | # Falls nicht als float sondern als Torch Tensor abgespeichert: 66 | image = image.float() # Falls nicht schon als float gespeichert 67 | #print("float", image.shape) 68 | #print("lable", lable) 69 | 70 | return image, lable, name #+++++++++++++++++++++++++++++++ 71 | 72 | 73 | 74 | # Nur zum Testen: 75 | 76 | if __name__ == '__main__': 77 | dataset = TorchDataset("/home/wolfda/Clinic_Data/Challenge/Challenge_COVID-19-20_v2/Train_tensor_slices_filter", augmentations=True) 78 | img, mask = dataset[1] 79 | img, mask = dataset[2] 80 | 81 | from batchviewer import view_batch 82 | 83 | view_batch(img, mask, width=512, height=512) 84 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/pix2pix/pix2pix_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import LightningModule 3 | from torch import nn 4 | 5 | from pl_bolts.models.gans.pix2pix.components import Generator, PatchGAN 6 | 7 | 8 | def _weights_init(m): 9 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 10 | torch.nn.init.normal_(m.weight, 0.0, 0.02) 11 | if isinstance(m, nn.BatchNorm2d): 12 | torch.nn.init.normal_(m.weight, 0.0, 0.02) 13 | torch.nn.init.constant_(m.bias, 0) 14 | 15 | 16 | class Pix2Pix(LightningModule): 17 | def __init__(self, in_channels, out_channels, learning_rate=0.0002, lambda_recon=200): 18 | 19 | super().__init__() 20 | self.save_hyperparameters() 21 | 22 | self.gen = Generator(in_channels, out_channels) 23 | self.patch_gan = PatchGAN(in_channels + out_channels) 24 | 25 | # intializing weights 26 | self.gen = self.gen.apply(_weights_init) 27 | self.patch_gan = self.patch_gan.apply(_weights_init) 28 | 29 | self.adversarial_criterion = nn.BCEWithLogitsLoss() 30 | self.recon_criterion = nn.L1Loss() 31 | 32 | def _gen_step(self, real_images, conditioned_images): 33 | # Pix2Pix has adversarial and a reconstruction loss 34 | # First calculate the adversarial loss 35 | fake_images = self.gen(conditioned_images) 36 | disc_logits = self.patch_gan(fake_images, conditioned_images) 37 | adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits)) 38 | 39 | # calculate reconstruction loss 40 | recon_loss = self.recon_criterion(fake_images, real_images) 41 | lambda_recon = self.hparams.lambda_recon 42 | 43 | return adversarial_loss + lambda_recon * recon_loss 44 | 45 | def _disc_step(self, real_images, conditioned_images): 46 | fake_images = self.gen(conditioned_images).detach() 47 | fake_logits = self.patch_gan(fake_images, conditioned_images) 48 | 49 | real_logits = self.patch_gan(real_images, conditioned_images) 50 | 51 | fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits)) 52 | real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits)) 53 | return (real_loss + fake_loss) / 2 54 | 55 | def configure_optimizers(self): 56 | lr = self.hparams.learning_rate 57 | gen_opt = torch.optim.Adam(self.gen.parameters(), lr=lr) 58 | disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=lr) 59 | return disc_opt, gen_opt 60 | 61 | def training_step(self, batch, batch_idx, optimizer_idx): 62 | real, condition = batch 63 | 64 | loss = None 65 | if optimizer_idx == 0: 66 | loss = self._disc_step(real, condition) 67 | self.log("PatchGAN Loss", loss) 68 | elif optimizer_idx == 1: 69 | loss = self._gen_step(real, condition) 70 | self.log("Generator Loss", loss) 71 | 72 | return loss 73 | 74 | 75 | if __name__ == "__main__": 76 | pix2pix = Pix2Pix(3, 3) 77 | print(pix2pix(torch.randn(1, 3, 256, 256)).shape) 78 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/transforms/self_supervised/ssl_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import functional as F 3 | 4 | from pl_bolts.utils import _PIL_AVAILABLE 5 | from pl_bolts.utils.warnings import warn_missing_pkg 6 | 7 | if _PIL_AVAILABLE: 8 | from PIL import Image 9 | else: # pragma: no cover 10 | warn_missing_pkg("PIL", pypi_name="Pillow") 11 | 12 | 13 | class RandomTranslateWithReflect: 14 | """Translate image randomly Translate vertically and horizontally by n pixels where n is integer drawn 15 | uniformly independently for each axis from [-max_translation, max_translation]. 16 | 17 | Fill the uncovered blank area with reflect padding. 18 | """ 19 | 20 | def __init__(self, max_translation): 21 | if not _PIL_AVAILABLE: # pragma: no cover 22 | raise ModuleNotFoundError("You want to use `Pillow` which is not installed yet.") 23 | 24 | self.max_translation = max_translation 25 | 26 | def __call__(self, old_image): 27 | xtranslation, ytranslation = np.random.randint(-self.max_translation, self.max_translation + 1, size=2) 28 | xpad, ypad = abs(xtranslation), abs(ytranslation) 29 | xsize, ysize = old_image.size 30 | 31 | flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT) 32 | flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM) 33 | flipped_both = old_image.transpose(Image.ROTATE_180) 34 | 35 | new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad)) 36 | 37 | new_image.paste(old_image, (xpad, ypad)) 38 | 39 | new_image.paste(flipped_lr, (xpad + xsize - 1, ypad)) 40 | new_image.paste(flipped_lr, (xpad - xsize + 1, ypad)) 41 | 42 | new_image.paste(flipped_tb, (xpad, ypad + ysize - 1)) 43 | new_image.paste(flipped_tb, (xpad, ypad - ysize + 1)) 44 | 45 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1)) 46 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1)) 47 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1)) 48 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1)) 49 | 50 | new_image = new_image.crop( 51 | (xpad - xtranslation, ypad - ytranslation, xpad + xsize - xtranslation, ypad + ysize - ytranslation) 52 | ) 53 | return new_image 54 | 55 | 56 | class Patchify: 57 | def __init__(self, patch_size, overlap_size): 58 | self.patch_size = patch_size 59 | self.overlap_size = self.patch_size - overlap_size 60 | 61 | def __call__(self, x): 62 | x = x.unsqueeze(0) 63 | b, c, h, w = x.size() 64 | 65 | # patch up the images 66 | # (b, c, h, w) -> (b, c*patch_size, L) 67 | x = F.unfold(x, kernel_size=self.patch_size, stride=self.overlap_size) 68 | 69 | # (b, c*patch_size, L) -> (b, nb_patches, width, height) 70 | x = x.transpose(2, 1).contiguous().view(b, -1, self.patch_size, self.patch_size) 71 | 72 | # reshape to have (b x patches, c, h, w) 73 | x = x.view(-1, c, self.patch_size, self.patch_size) 74 | 75 | x = x.squeeze(0) 76 | 77 | return x 78 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/double_dqn_model.py: -------------------------------------------------------------------------------- 1 | """Double DQN.""" 2 | import argparse 3 | from collections import OrderedDict 4 | from typing import Tuple 5 | 6 | from pytorch_lightning import Trainer 7 | from torch import Tensor 8 | 9 | from pl_bolts.losses.rl import double_dqn_loss 10 | from pl_bolts.models.rl.dqn_model import DQN 11 | 12 | 13 | class DoubleDQN(DQN): 14 | """Double Deep Q-network (DDQN) PyTorch Lightning implementation of `Double DQN`_. 15 | 16 | Paper authors: Hado van Hasselt, Arthur Guez, David Silver 17 | 18 | Model implemented by: 19 | 20 | - `Donal Byrne ` 21 | 22 | Example: 23 | 24 | >>> from pl_bolts.models.rl.double_dqn_model import DoubleDQN 25 | ... 26 | >>> model = DoubleDQN("PongNoFrameskip-v4") 27 | 28 | Train:: 29 | 30 | trainer = Trainer() 31 | trainer.fit(model) 32 | 33 | Note: 34 | This example is based on 35 | https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter08/03_dqn_double.py 36 | 37 | Note: 38 | Currently only supports CPU and single GPU training with `accelerator=dp` 39 | 40 | .. _`Double DQN`: https://arxiv.org/pdf/1509.06461.pdf 41 | """ 42 | 43 | def training_step(self, batch: Tuple[Tensor, Tensor], _) -> OrderedDict: 44 | """Carries out a single step through the environment to update the replay buffer. Then calculates loss 45 | based on the minibatch recieved. 46 | 47 | Args: 48 | batch: current mini batch of replay data 49 | _: batch number, not used 50 | 51 | Returns: 52 | Training loss and log metrics 53 | """ 54 | 55 | # calculates training loss 56 | loss = double_dqn_loss(batch, self.net, self.target_net, self.gamma) 57 | 58 | if self._use_dp_or_ddp2(self.trainer): 59 | loss = loss.unsqueeze(0) 60 | 61 | # Soft update of target network 62 | if self.global_step % self.sync_rate == 0: 63 | self.target_net.load_state_dict(self.net.state_dict()) 64 | 65 | self.log_dict( 66 | { 67 | "total_reward": self.total_rewards[-1], 68 | "avg_reward": self.avg_rewards, 69 | "train_loss": loss, 70 | # "episodes": self.total_episode_steps, 71 | } 72 | ) 73 | 74 | return OrderedDict( 75 | { 76 | "loss": loss, 77 | "avg_reward": self.avg_rewards, 78 | } 79 | ) 80 | 81 | 82 | def cli_main(): 83 | parser = argparse.ArgumentParser(add_help=False) 84 | 85 | # trainer args 86 | parser = Trainer.add_argparse_args(parser) 87 | 88 | # model args 89 | parser = DoubleDQN.add_model_specific_args(parser) 90 | args = parser.parse_args() 91 | 92 | model = DoubleDQN(**args.__dict__) 93 | 94 | trainer = Trainer.from_argparse_args(args) 95 | trainer.fit(model) 96 | 97 | 98 | if __name__ == "__main__": 99 | cli_main() 100 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import List 9 | 10 | import torch 11 | import torch.nn as nn 12 | from timm.models.layers import trunc_normal_ 13 | 14 | from utils.misc import is_pow2n 15 | 16 | 17 | class UNetBlock(nn.Module): 18 | def __init__(self, cin, cout, bn2d): 19 | """ 20 | a UNet block with 2x up sampling 21 | """ 22 | super().__init__() 23 | self.up_sample = nn.ConvTranspose2d(cin, cin, kernel_size=4, stride=2, padding=1, bias=True) 24 | self.conv = nn.Sequential( 25 | nn.Conv2d(cin, cin, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cin), nn.ReLU6(inplace=True), 26 | nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), bn2d(cout), 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.up_sample(x) 31 | return self.conv(x) 32 | 33 | 34 | class LightDecoder(nn.Module): 35 | def __init__(self, up_sample_ratio, width=768, sbn=True): # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule 36 | super().__init__() 37 | self.width = width 38 | assert is_pow2n(up_sample_ratio) 39 | n = round(math.log2(up_sample_ratio)) 40 | channels = [self.width // 2 ** i for i in range(n + 1)] # todo: the decoder's width follows a simple halfing rule; you can change it to any other rule 41 | bn2d = nn.SyncBatchNorm if sbn else nn.BatchNorm2d 42 | self.dec = nn.ModuleList([UNetBlock(cin, cout, bn2d) for (cin, cout) in zip(channels[:-1], channels[1:])]) 43 | self.proj = nn.Conv2d(channels[-1], 3, kernel_size=1, stride=1, bias=True) 44 | 45 | self.initialize() 46 | 47 | def forward(self, to_dec: List[torch.Tensor]): 48 | x = 0 49 | for i, d in enumerate(self.dec): 50 | if i < len(to_dec) and to_dec[i] is not None: 51 | x = x + to_dec[i] 52 | x = self.dec[i](x) 53 | return self.proj(x) 54 | 55 | def extra_repr(self) -> str: 56 | return f'width={self.width}' 57 | 58 | def initialize(self): 59 | for m in self.modules(): 60 | if isinstance(m, nn.Linear): 61 | trunc_normal_(m.weight, std=.02) 62 | if m.bias is not None: 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, nn.Conv2d): 65 | trunc_normal_(m.weight, std=.02) 66 | if m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 69 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 70 | if m.bias is not None: 71 | nn.init.constant_(m.bias, 0.) 72 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.SyncBatchNorm)): 73 | nn.init.constant_(m.bias, 0) 74 | nn.init.constant_(m.weight, 1.0) 75 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/utils/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Any, Callable, Optional, Tuple 9 | 10 | import PIL.Image as PImage 11 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 12 | from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS 13 | from torchvision.transforms import transforms 14 | from torch.utils.data import Dataset 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | interpolation = InterpolationMode.BICUBIC 19 | except: 20 | import PIL 21 | interpolation = PIL.Image.BICUBIC 22 | 23 | 24 | def pil_loader(path): 25 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 26 | with open(path, 'rb') as f: img: PImage.Image = PImage.open(f).convert('RGB') 27 | return img 28 | 29 | 30 | class ImageNetDataset(DatasetFolder): 31 | def __init__( 32 | self, 33 | imagenet_folder: str, 34 | train: bool, 35 | transform: Callable, 36 | is_valid_file: Optional[Callable[[str], bool]] = None, 37 | ): 38 | imagenet_folder = os.path.join(imagenet_folder, 'train' if train else 'val') 39 | super(ImageNetDataset, self).__init__( 40 | imagenet_folder, 41 | loader=pil_loader, 42 | extensions=IMG_EXTENSIONS if is_valid_file is None else None, 43 | transform=transform, target_transform=None, is_valid_file=is_valid_file 44 | ) 45 | 46 | self.samples = tuple(self.samples) 47 | self.targets = tuple([s[1] for s in self.samples]) 48 | 49 | def __getitem__(self, index: int) -> Tuple[Any, int]: 50 | path, target = self.samples[index] 51 | return self.transform(self.loader(path)), target 52 | 53 | 54 | def build_dataset_to_pretrain(dataset_path, input_size) -> Dataset: 55 | """ 56 | You may need to modify this function to fit your own dataset. 57 | :param dataset_path: the folder of dataset 58 | :param input_size: the input size (image resolution) 59 | :return: the dataset used for pretraining 60 | """ 61 | trans_train = transforms.Compose([ 62 | transforms.RandomResizedCrop(input_size, scale=(0.67, 1.0), interpolation=interpolation), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 66 | ]) 67 | 68 | dataset_path = os.path.abspath(dataset_path) 69 | for postfix in ('train', 'val'): 70 | if dataset_path.endswith(postfix): 71 | dataset_path = dataset_path[:-len(postfix)] 72 | 73 | dataset_train = ImageNetDataset(imagenet_folder=dataset_path, transform=trans_train, train=True) 74 | print_transform(trans_train, '[pre-train]') 75 | return dataset_train 76 | 77 | 78 | def print_transform(transform, s): 79 | print(f'Transform {s} = ') 80 | for t in transform.transforms: 81 | print(t) 82 | print('---------------------------\n') 83 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | # Copyright (c) ByteDance, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import functools 11 | import os 12 | import socket 13 | import subprocess 14 | import sys 15 | from typing import List 16 | 17 | os_system = functools.partial(subprocess.call, shell=True) 18 | echo = lambda info: os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') 19 | 20 | 21 | def __find_free_port(): 22 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 23 | sock.bind(("", 0)) 24 | port = sock.getsockname()[1] 25 | sock.close() 26 | return port 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser(description='PyTorch Distributed Launcher') 31 | parser.add_argument('--main_py_relpath', type=str, default='main.py', 32 | help='specify launcher script.') 33 | 34 | # distributed environment 35 | parser.add_argument('--num_nodes', type=int, default=1) 36 | parser.add_argument('--ngpu_per_node', type=int, default=1) 37 | parser.add_argument('--node_rank', type=int, default=0, 38 | help='node rank, ranged from 0 to [dist_num_nodes]-1') 39 | parser.add_argument('--master_address', type=str, default='128.0.0.0', 40 | help='master address for distributed communication') 41 | parser.add_argument('--master_port', type=int, default=30001, 42 | help='master port for distributed communication') 43 | 44 | args_for_this, args_for_python = parser.parse_known_args() 45 | args_for_python: List[str] 46 | 47 | echo(f'[initial args_for_python]: {args_for_python}') 48 | # auto-complete: update args like `--sbn` to `--sbn=1` 49 | kwargs = args_for_python[-1] 50 | kwargs = '='.join(map(str.strip, kwargs.split('='))) 51 | kwargs = kwargs.split(' ') 52 | for i, a in enumerate(kwargs): 53 | if len(a) and '=' not in a: 54 | kwargs[i] = f'{a}=1' 55 | args_for_python[-1] = ' '.join(kwargs) 56 | echo(f'[final args_for_python]: {args_for_python}') 57 | 58 | if args_for_this.num_nodes > 1: # distributed 59 | os.environ['NPROC_PER_NODE'] = str(args_for_this.ngpu_per_node) 60 | cmd = ( 61 | f'python3 -m torch.distributed.launch' 62 | f' --nnodes={args_for_this.num_nodes}' 63 | f' --nproc_per_node={args_for_this.ngpu_per_node}' 64 | f' --node_rank={args_for_this.node_rank}' 65 | f' --master_addr={args_for_this.master_address}' 66 | f' --master_port={args_for_this.master_port}' 67 | f' {args_for_this.main_py_relpath}' 68 | f' {" ".join(args_for_python)}' 69 | ) 70 | else: # single machine with multiple GPUs 71 | cmd = ( 72 | f'python3 -m torch.distributed.launch' 73 | f' --nproc_per_node={args_for_this.ngpu_per_node}' 74 | f' --master_port={__find_free_port()}' 75 | f' {args_for_this.main_py_relpath}' 76 | f' {" ".join(args_for_python)}' 77 | ) 78 | 79 | exit_code = subprocess.call(cmd, shell=True) 80 | sys.exit(exit_code) 81 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datamodules/binary_emnist_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union 2 | 3 | from pl_bolts.datamodules.emnist_datamodule import EMNISTDataModule 4 | from pl_bolts.datasets import BinaryEMNIST 5 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 6 | 7 | 8 | class BinaryEMNISTDataModule(EMNISTDataModule): 9 | """ 10 | .. figure:: https://user-images.githubusercontent.com/4632336/123210742-4d6b3380-d477-11eb-80da-3e9a74a18a07.png 11 | :width: 400 12 | :alt: EMNIST 13 | 14 | Please see :class:`~pl_bolts.datamodules.emnist_datamodule.EMNISTDataModule` for more details. 15 | 16 | Example:: 17 | 18 | from pl_bolts.datamodules import BinaryEMNISTDataModule 19 | dm = BinaryEMNISTDataModule('.') 20 | model = LitModel() 21 | Trainer().fit(model, datamodule=dm) 22 | """ 23 | 24 | name = "binary_emnist" 25 | dataset_cls = BinaryEMNIST 26 | dims = (1, 28, 28) 27 | 28 | def __init__( 29 | self, 30 | data_dir: Optional[str] = None, 31 | split: str = "mnist", 32 | val_split: Union[int, float] = 0.2, 33 | num_workers: int = 0, 34 | normalize: bool = False, 35 | batch_size: int = 32, 36 | seed: int = 42, 37 | shuffle: bool = True, 38 | pin_memory: bool = True, 39 | drop_last: bool = False, 40 | strict_val_split: bool = False, 41 | *args: Any, 42 | **kwargs: Any, 43 | ) -> None: 44 | """ 45 | Args: 46 | data_dir: Where to save/load the data. 47 | split: The dataset has 6 different splits: ``byclass``, ``bymerge``, 48 | ``balanced``, ``letters``, ``digits`` and ``mnist``. 49 | This argument is passed to :class:`torchvision.datasets.EMNIST`. 50 | val_split: Percent (float) or number (int) of samples 51 | to use for the validation split. 52 | num_workers: How many workers to use for loading data 53 | normalize: If ``True``, applies image normalize. 54 | batch_size: How many samples per batch to load. 55 | seed: Random seed to be used for train/val/test splits. 56 | shuffle: If ``True``, shuffles the train data every epoch. 57 | pin_memory: If ``True``, the data loader will copy Tensors into 58 | CUDA pinned memory before returning them. 59 | drop_last: If ``True``, drops the last incomplete batch. 60 | strict_val_split: If ``True``, uses the validation split defined in the paper and ignores ``val_split``. 61 | Note that it only works with ``"balanced"``, ``"digits"``, ``"letters"``, ``"mnist"`` splits. 62 | """ 63 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 64 | raise ModuleNotFoundError( 65 | "You want to use EMNIST dataset loaded from `torchvision` which is not installed yet." 66 | ) 67 | 68 | super().__init__( # type: ignore[misc] 69 | data_dir=data_dir, 70 | split=split, 71 | val_split=val_split, 72 | num_workers=num_workers, 73 | normalize=normalize, 74 | batch_size=batch_size, 75 | seed=seed, 76 | shuffle=shuffle, 77 | pin_memory=pin_memory, 78 | drop_last=drop_last, 79 | strict_val_split=strict_val_split, 80 | *args, 81 | **kwargs, 82 | ) 83 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/detection/components/torchvision_backbones.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch.nn as nn 4 | 5 | from pl_bolts.models.detection.components._supported_models import TORCHVISION_MODEL_ZOO 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE # noqa: F401 7 | from pl_bolts.utils.warnings import warn_missing_pkg # noqa: F401 8 | 9 | 10 | def _create_backbone_generic(model: nn.Module, out_channels: int) -> nn.Module: 11 | """Generic Backbone creater. It removes the last linear layer. 12 | 13 | Args: 14 | model: torch.nn model 15 | out_channels: Number of out_channels in last layer. 16 | """ 17 | modules_total = list(model.children()) 18 | modules = modules_total[:-1] 19 | ft_backbone = nn.Sequential(*modules) 20 | ft_backbone.out_channels = out_channels 21 | return ft_backbone 22 | 23 | 24 | # Use this when you have Adaptive Pooling layer in End. 25 | # When Model.features is not applicable. 26 | def _create_backbone_adaptive(model: nn.Module, out_channels: Optional[int] = None) -> nn.Module: 27 | """Creates backbone by removing linear after Adaptive Pooling layer. 28 | 29 | Args: 30 | model: torch.nn model with adaptive pooling layer 31 | out_channels: Number of out_channels in last layer 32 | """ 33 | if out_channels is None: 34 | modules_total = list(model.children()) 35 | out_channels = modules_total[-1].in_features 36 | return _create_backbone_generic(model, out_channels=out_channels) 37 | 38 | 39 | def _create_backbone_features(model: nn.Module, out_channels: int) -> nn.Module: 40 | """Creates backbone from feature sequential block. 41 | 42 | Args: 43 | model: torch.nn model with features as sequential block. 44 | out_channels: Number of out_channels in last layer. 45 | """ 46 | ft_backbone = model.features 47 | ft_backbone.out_channels = out_channels 48 | return ft_backbone 49 | 50 | 51 | def create_torchvision_backbone(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: 52 | """Creates CNN backbone from Torchvision. 53 | 54 | Args: 55 | model_name: Name of the model. E.g. resnet18 56 | pretrained: Pretrained weights dataset "imagenet", etc 57 | """ 58 | 59 | model_selected = TORCHVISION_MODEL_ZOO[model_name] 60 | net = model_selected(pretrained=pretrained) 61 | 62 | if model_name == "mobilenet_v2": 63 | out_channels = 1280 64 | ft_backbone = _create_backbone_features(net, 1280) 65 | return ft_backbone, out_channels 66 | 67 | if model_name in ["vgg11", "vgg13", "vgg16", "vgg19"]: 68 | out_channels = 512 69 | ft_backbone = _create_backbone_features(net, out_channels) 70 | return ft_backbone, out_channels 71 | 72 | if model_name in ["resnet18", "resnet34"]: 73 | out_channels = 512 74 | ft_backbone = _create_backbone_adaptive(net, out_channels) 75 | return ft_backbone, out_channels 76 | 77 | if model_name in [ 78 | "resnet50", 79 | "resnet101", 80 | "resnet152", 81 | "resnext50_32x4d", 82 | "resnext101_32x8d", 83 | ]: 84 | out_channels = 2048 85 | ft_backbone = _create_backbone_adaptive(net, out_channels) 86 | return ft_backbone, out_channels 87 | 88 | if model_name in ["mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"]: 89 | out_channels = 1280 90 | ft_backbone = _create_backbone_adaptive(net, out_channels) 91 | return ft_backbone, out_channels 92 | raise ValueError(f"Unsupported model: '{model_name}'") 93 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import List 9 | from typing import Union 10 | 11 | import sys 12 | import torch 13 | import torch.distributed as tdist 14 | import torch.multiprocessing as mp 15 | 16 | __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cpu' 17 | __initialized = False 18 | 19 | 20 | def initialized(): 21 | return __initialized 22 | 23 | 24 | def initialize(backend='nccl'): 25 | global __device 26 | if not torch.cuda.is_available(): 27 | print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) 28 | return 29 | elif 'RANK' not in os.environ: 30 | __device = torch.empty(1).cuda().device 31 | print(f'[dist initialize] RANK is not set, use 1 GPU instead', file=sys.stderr) 32 | return 33 | 34 | # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 35 | if mp.get_start_method(allow_none=True) is None: 36 | mp.set_start_method('spawn') 37 | global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() 38 | local_rank = global_rank % num_gpus 39 | torch.cuda.set_device(local_rank) 40 | tdist.init_process_group(backend=backend) 41 | 42 | global __rank, __local_rank, __world_size, __initialized 43 | __local_rank = local_rank 44 | __rank, __world_size = tdist.get_rank(), tdist.get_world_size() 45 | __device = torch.empty(1).cuda().device 46 | __initialized = True 47 | 48 | assert tdist.is_initialized(), 'torch.distributed is not initialized!' 49 | 50 | 51 | def get_rank(): 52 | return __rank 53 | 54 | 55 | def get_local_rank(): 56 | return __local_rank 57 | 58 | 59 | def get_world_size(): 60 | return __world_size 61 | 62 | 63 | def get_device(): 64 | return __device 65 | 66 | 67 | def is_master(): 68 | return __rank == 0 69 | 70 | 71 | def is_local_master(): 72 | return __local_rank == 0 73 | 74 | 75 | def barrier(): 76 | if __initialized: 77 | tdist.barrier() 78 | 79 | 80 | def parallelize(net, syncbn=False): 81 | if syncbn: 82 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 83 | net = net.cuda() 84 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[get_local_rank()], find_unused_parameters=False, broadcast_buffers=False) 85 | return net 86 | 87 | 88 | def allreduce(t: torch.Tensor) -> None: 89 | if __initialized: 90 | if not t.is_cuda: 91 | cu = t.detach().cuda() 92 | tdist.all_reduce(cu) 93 | t.copy_(cu.cpu()) 94 | else: 95 | tdist.all_reduce(t) 96 | 97 | 98 | def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: 99 | if __initialized: 100 | if not t.is_cuda: 101 | t = t.cuda() 102 | ls = [torch.empty_like(t) for _ in range(__world_size)] 103 | tdist.all_gather(ls, t) 104 | else: 105 | ls = [t] 106 | if cat: 107 | ls = torch.cat(ls, dim=0) 108 | return ls 109 | 110 | 111 | def broadcast(t: torch.Tensor, src_rank) -> None: 112 | if __initialized: 113 | if not t.is_cuda: 114 | cu = t.detach().cuda() 115 | tdist.broadcast(cu, src=src_rank) 116 | t.copy_(cu.cpu()) 117 | else: 118 | tdist.broadcast(t, src=src_rank) 119 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/metrics/object_detection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def iou(preds: Tensor, target: Tensor) -> Tensor: 6 | """Calculates the intersection over union. 7 | 8 | Args: 9 | preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 10 | target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 11 | 12 | Example: 13 | 14 | >>> import torch 15 | >>> from pl_bolts.metrics.object_detection import iou 16 | >>> preds = torch.tensor([[100, 100, 200, 200]]) 17 | >>> target = torch.tensor([[150, 150, 250, 250]]) 18 | >>> iou(preds, target) 19 | tensor([[0.1429]]) 20 | 21 | Returns: 22 | IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target, 23 | where N is the number of prediction bounding boxes and M is the number of target bounding boxes 24 | """ 25 | x_min = torch.max(preds[:, None, 0], target[:, 0]) 26 | y_min = torch.max(preds[:, None, 1], target[:, 1]) 27 | x_max = torch.min(preds[:, None, 2], target[:, 2]) 28 | y_max = torch.min(preds[:, None, 3], target[:, 3]) 29 | intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) 30 | pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) 31 | target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) 32 | union = pred_area[:, None] + target_area - intersection 33 | iou = torch.true_divide(intersection, union) 34 | return iou 35 | 36 | 37 | def giou(preds: Tensor, target: Tensor) -> Tensor: 38 | """Calculates the generalized intersection over union. 39 | 40 | It has been proposed in `Generalized Intersection over Union: A Metric and A 41 | Loss for Bounding Box Regression `_. 42 | 43 | Args: 44 | preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 45 | target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]`` 46 | 47 | Example: 48 | 49 | >>> import torch 50 | >>> from pl_bolts.metrics.object_detection import giou 51 | >>> preds = torch.tensor([[100, 100, 200, 200]]) 52 | >>> target = torch.tensor([[150, 150, 250, 250]]) 53 | >>> giou(preds, target) 54 | tensor([[-0.0794]]) 55 | 56 | Returns: 57 | GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target, 58 | where N is the number of prediction bounding boxes and M is the number of target bounding boxes 59 | """ 60 | x_min = torch.max(preds[:, None, 0], target[:, 0]) 61 | y_min = torch.max(preds[:, None, 1], target[:, 1]) 62 | x_max = torch.min(preds[:, None, 2], target[:, 2]) 63 | y_max = torch.min(preds[:, None, 3], target[:, 3]) 64 | intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0) 65 | pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1]) 66 | target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1]) 67 | union = pred_area[:, None] + target_area - intersection 68 | C_x_min = torch.min(preds[:, None, 0], target[:, 0]) 69 | C_y_min = torch.min(preds[:, None, 1], target[:, 1]) 70 | C_x_max = torch.max(preds[:, None, 2], target[:, 2]) 71 | C_y_max = torch.max(preds[:, None, 3], target[:, 3]) 72 | C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0) 73 | iou = torch.true_divide(intersection, union) 74 | giou = iou - torch.true_divide((C_area - union), C_area) 75 | return giou 76 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/setup_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright The PyTorch Lightning team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | import re 17 | from typing import List 18 | 19 | _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) 20 | 21 | 22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]: 23 | """Load requirements from a file. 24 | 25 | >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE 26 | ['torch...', 'pytorch-lightning...'...] 27 | """ 28 | with open(os.path.join(path_dir, file_name)) as file: 29 | lines = [ln.strip() for ln in file.readlines()] 30 | reqs = [] 31 | for ln in lines: 32 | # filer all comments 33 | if comment_char in ln: 34 | ln = ln[: ln.index(comment_char)].strip() 35 | # skip directly installed dependencies 36 | if ln.startswith("http"): 37 | continue 38 | if ln: # if requirement is not empty 39 | reqs.append(ln) 40 | return reqs 41 | 42 | 43 | def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str: 44 | """Load readme as decribtion. 45 | 46 | >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE 47 | '
...' 48 | """ 49 | path_readme = os.path.join(path_dir, "README.md") 50 | text = open(path_readme, encoding="utf-8").read() 51 | 52 | # drop images from readme 53 | text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "") 54 | 55 | # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_png 56 | github_source_url = os.path.join(homepage, "raw", ver) 57 | # replace relative repository path to absolute link to the release 58 | # do not replace all "docs" as in the readme we reger some other sources with particular path to docs 59 | text = text.replace("docs/source/_images/", f"{os.path.join(github_source_url, 'docs/source/_images/')}") 60 | 61 | # readthedocs badge 62 | text = text.replace("badge/?version=stable", f"badge/?version={ver}") 63 | text = text.replace("lightning-bolts.readthedocs.io/en/stable/", f"lightning-bolts.readthedocs.io/en/{ver}") 64 | # codecov badge 65 | text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg") 66 | # replace github badges for release ones 67 | text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}") 68 | 69 | skip_begin = r"" 70 | skip_end = r"" 71 | # todo: wrap content as commented description 72 | text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL) 73 | 74 | # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png 75 | # github_release_url = os.path.join(homepage, "releases", "download", ver) 76 | # # download badge and replace url with local file 77 | # text = _parse_for_badge(text, github_release_url) 78 | return text 79 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/gans/dcgan/components.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py 2 | from torch import Tensor, nn 3 | 4 | 5 | class DCGANGenerator(nn.Module): 6 | def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None: 7 | """ 8 | Args: 9 | latent_dim: Dimension of the latent space 10 | feature_maps: Number of feature maps to use 11 | image_channels: Number of channels of the images from the dataset 12 | """ 13 | super().__init__() 14 | self.gen = nn.Sequential( 15 | self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0), 16 | self._make_gen_block(feature_maps * 8, feature_maps * 4), 17 | self._make_gen_block(feature_maps * 4, feature_maps * 2), 18 | self._make_gen_block(feature_maps * 2, feature_maps), 19 | self._make_gen_block(feature_maps, image_channels, last_block=True), 20 | ) 21 | 22 | @staticmethod 23 | def _make_gen_block( 24 | in_channels: int, 25 | out_channels: int, 26 | kernel_size: int = 4, 27 | stride: int = 2, 28 | padding: int = 1, 29 | bias: bool = False, 30 | last_block: bool = False, 31 | ) -> nn.Sequential: 32 | if not last_block: 33 | gen_block = nn.Sequential( 34 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), 35 | nn.BatchNorm2d(out_channels), 36 | nn.ReLU(True), 37 | ) 38 | else: 39 | gen_block = nn.Sequential( 40 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), 41 | nn.Tanh(), 42 | ) 43 | 44 | return gen_block 45 | 46 | def forward(self, noise: Tensor) -> Tensor: 47 | return self.gen(noise) 48 | 49 | 50 | class DCGANDiscriminator(nn.Module): 51 | def __init__(self, feature_maps: int, image_channels: int) -> None: 52 | """ 53 | Args: 54 | feature_maps: Number of feature maps to use 55 | image_channels: Number of channels of the images from the dataset 56 | """ 57 | super().__init__() 58 | self.disc = nn.Sequential( 59 | self._make_disc_block(image_channels, feature_maps, batch_norm=False), 60 | self._make_disc_block(feature_maps, feature_maps * 2), 61 | self._make_disc_block(feature_maps * 2, feature_maps * 4), 62 | self._make_disc_block(feature_maps * 4, feature_maps * 8), 63 | self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True), 64 | ) 65 | 66 | @staticmethod 67 | def _make_disc_block( 68 | in_channels: int, 69 | out_channels: int, 70 | kernel_size: int = 4, 71 | stride: int = 2, 72 | padding: int = 1, 73 | bias: bool = False, 74 | batch_norm: bool = True, 75 | last_block: bool = False, 76 | ) -> nn.Sequential: 77 | if not last_block: 78 | disc_block = nn.Sequential( 79 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), 80 | nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(), 81 | nn.LeakyReLU(0.2, inplace=True), 82 | ) 83 | else: 84 | disc_block = nn.Sequential( 85 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias), 86 | nn.Sigmoid(), 87 | ) 88 | 89 | return disc_block 90 | 91 | def forward(self, x: Tensor) -> Tensor: 92 | return self.disc(x).view(-1, 1).squeeze(1) 93 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/variational.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | from pytorch_lightning import LightningModule, Trainer 6 | from pytorch_lightning.callbacks import Callback 7 | from torch import Tensor 8 | 9 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 10 | from pl_bolts.utils.warnings import warn_missing_pkg 11 | 12 | if _TORCHVISION_AVAILABLE: 13 | import torchvision 14 | else: # pragma: no cover 15 | warn_missing_pkg("torchvision") 16 | 17 | 18 | class LatentDimInterpolator(Callback): 19 | """Interpolates the latent space for a model by setting all dims to zero and stepping through the first two 20 | dims increasing one unit at a time. 21 | 22 | Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5) 23 | 24 | Example:: 25 | 26 | from pl_bolts.callbacks import LatentDimInterpolator 27 | 28 | Trainer(callbacks=[LatentDimInterpolator()]) 29 | """ 30 | 31 | def __init__( 32 | self, 33 | interpolate_epoch_interval: int = 20, 34 | range_start: int = -5, 35 | range_end: int = 5, 36 | steps: int = 11, 37 | num_samples: int = 2, 38 | normalize: bool = True, 39 | ): 40 | """ 41 | Args: 42 | interpolate_epoch_interval: default 20 43 | range_start: default -5 44 | range_end: default 5 45 | steps: number of step between start and end 46 | num_samples: default 2 47 | normalize: default True (change image to (0, 1) range) 48 | """ 49 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 50 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 51 | 52 | super().__init__() 53 | self.interpolate_epoch_interval = interpolate_epoch_interval 54 | self.range_start = range_start 55 | self.range_end = range_end 56 | self.num_samples = num_samples 57 | self.normalize = normalize 58 | self.steps = steps 59 | 60 | def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 61 | if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: 62 | images = self.interpolate_latent_space(pl_module, latent_dim=pl_module.hparams.latent_dim) 63 | images = torch.cat(images, dim=0) 64 | 65 | num_rows = self.steps 66 | grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize) 67 | str_title = f"{pl_module.__class__.__name__}_latent_space" 68 | trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) 69 | 70 | def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int) -> List[Tensor]: 71 | images = [] 72 | with torch.no_grad(): 73 | pl_module.eval() 74 | for z1 in np.linspace(self.range_start, self.range_end, self.steps): 75 | for z2 in np.linspace(self.range_start, self.range_end, self.steps): 76 | # set all dims to zero 77 | z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device) 78 | 79 | # set the fist 2 dims to the value 80 | z[:, 0] = torch.tensor(z1) 81 | z[:, 1] = torch.tensor(z2) 82 | 83 | # sample 84 | # generate images 85 | img = pl_module(z) 86 | 87 | if len(img.size()) == 2: 88 | img = img.view(self.num_samples, *pl_module.img_dim) 89 | 90 | img = img[0] 91 | img = img.unsqueeze(0) 92 | images.append(img) 93 | 94 | pl_module.train() 95 | return images 96 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/models/custom.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from typing import List 10 | from timm.models.registry import register_model 11 | 12 | 13 | class YourConvNet(nn.Module): 14 | """ 15 | This is a template for your custom ConvNet. 16 | It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`. 17 | You can refer to the implementations in `pretrain\models\resnet.py` for an example. 18 | """ 19 | 20 | def get_downsample_ratio(self) -> int: 21 | """ 22 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). 23 | 24 | :return: the TOTAL downsample ratio of the ConvNet. 25 | E.g., for a ResNet-50, this should return 32. 26 | """ 27 | raise NotImplementedError 28 | 29 | def get_feature_map_channels(self) -> List[int]: 30 | """ 31 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`). 32 | 33 | :return: a list of the number of channels of each feature map. 34 | E.g., for a ResNet-50, this should return [256, 512, 1024, 2048]. 35 | """ 36 | raise NotImplementedError 37 | 38 | def forward(self, inp_bchw: torch.Tensor, hierarchical=False): 39 | """ 40 | The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`). 41 | 42 | :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width). 43 | :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical). 44 | :return: 45 | - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes). 46 | - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`. 47 | E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map]. 48 | for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)] 49 | """ 50 | raise NotImplementedError 51 | 52 | 53 | @register_model 54 | def your_convnet_small(pretrained=False, **kwargs): 55 | raise NotImplementedError 56 | return YourConvNet(**kwargs) 57 | 58 | 59 | @torch.no_grad() 60 | def convnet_test(): 61 | from timm.models import create_model 62 | cnn = create_model('your_convnet_small') 63 | print('get_downsample_ratio:', cnn.get_downsample_ratio()) 64 | print('get_feature_map_channels:', cnn.get_feature_map_channels()) 65 | 66 | downsample_ratio = cnn.get_downsample_ratio() 67 | feature_map_channels = cnn.get_feature_map_channels() 68 | 69 | # check the forward function 70 | B, C, H, W = 4, 3, 224, 224 71 | inp = torch.rand(B, C, H, W) 72 | feats = cnn(inp, hierarchical=True) 73 | assert isinstance(feats, list) 74 | assert len(feats) == len(feature_map_channels) 75 | print([tuple(t.shape) for t in feats]) 76 | 77 | # check the downsample ratio 78 | feats = cnn(inp, hierarchical=True) 79 | assert feats[-1].shape[-2] == H // downsample_ratio 80 | assert feats[-1].shape[-1] == W // downsample_ratio 81 | 82 | # check the channel number 83 | for feat, ch in zip(feats, feature_map_channels): 84 | assert feat.ndim == 4 85 | assert feat.shape[1] == ch 86 | 87 | 88 | if __name__ == '__main__': 89 | convnet_test() 90 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datamodules/mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Union 2 | 3 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule 4 | from pl_bolts.datasets import MNIST 5 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 6 | from pl_bolts.utils.warnings import warn_missing_pkg 7 | 8 | if _TORCHVISION_AVAILABLE: 9 | from torchvision import transforms as transform_lib 10 | else: # pragma: no cover 11 | warn_missing_pkg("torchvision") 12 | 13 | 14 | class MNISTDataModule(VisionDataModule): 15 | """ 16 | .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png 17 | :width: 400 18 | :alt: MNIST 19 | 20 | Specs: 21 | - 10 classes (1 per digit) 22 | - Each image is (1 x 28 x 28) 23 | 24 | Standard MNIST, train, val, test splits and transforms 25 | 26 | Transforms:: 27 | 28 | mnist_transforms = transform_lib.Compose([ 29 | transform_lib.ToTensor() 30 | ]) 31 | 32 | Example:: 33 | 34 | from pl_bolts.datamodules import MNISTDataModule 35 | 36 | dm = MNISTDataModule('.') 37 | model = LitModel() 38 | 39 | Trainer().fit(model, datamodule=dm) 40 | """ 41 | 42 | name = "mnist" 43 | dataset_cls = MNIST 44 | dims = (1, 28, 28) 45 | 46 | def __init__( 47 | self, 48 | data_dir: Optional[str] = None, 49 | val_split: Union[int, float] = 0.2, 50 | num_workers: int = 0, 51 | normalize: bool = False, 52 | batch_size: int = 32, 53 | seed: int = 42, 54 | shuffle: bool = True, 55 | pin_memory: bool = True, 56 | drop_last: bool = False, 57 | *args: Any, 58 | **kwargs: Any, 59 | ) -> None: 60 | """ 61 | Args: 62 | data_dir: Where to save/load the data 63 | val_split: Percent (float) or number (int) of samples to use for the validation split 64 | num_workers: How many workers to use for loading data 65 | normalize: If true applies image normalize 66 | batch_size: How many samples per batch to load 67 | seed: Random seed to be used for train/val/test splits 68 | shuffle: If true shuffles the train data every epoch 69 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 70 | returning them 71 | drop_last: If true drops the last incomplete batch 72 | """ 73 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 74 | raise ModuleNotFoundError( 75 | "You want to use MNIST dataset loaded from `torchvision` which is not installed yet." 76 | ) 77 | 78 | super().__init__( # type: ignore[misc] 79 | data_dir=data_dir, 80 | val_split=val_split, 81 | num_workers=num_workers, 82 | normalize=normalize, 83 | batch_size=batch_size, 84 | seed=seed, 85 | shuffle=shuffle, 86 | pin_memory=pin_memory, 87 | drop_last=drop_last, 88 | *args, 89 | **kwargs, 90 | ) 91 | 92 | @property 93 | def num_classes(self) -> int: 94 | """ 95 | Return: 96 | 10 97 | """ 98 | return 10 99 | 100 | def default_transforms(self) -> Callable: 101 | if self.normalize: 102 | mnist_transforms = transform_lib.Compose( 103 | [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] 104 | ) 105 | else: 106 | mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) 107 | 108 | return mnist_transforms 109 | -------------------------------------------------------------------------------- /Downstream/README.md: -------------------------------------------------------------------------------- 1 | # Downstream Tasks 2 | 3 | We tested our pre-training on three CT classification tasks: 4 | - **COVID-19**: Covid classification on lung CT scans (From Grand Challenge [https://covid-ct.grand-challenge.org/](https://covid-ct.grand-challenge.org/) or 5 | [https://doi.org/10.48550/arXiv.2003.13865](https://doi.org/10.48550/arXiv.2003.13865)) 6 | - **OrgMNIST**: Multi-class classification of 11 body organs on patches cropped around organs from abdominal CT scans (From MedMNIST Challenges [https://medmnist.com/](https://medmnist.com/) or [https://doi.org/10.1038/s41597-022-01721-8](https://doi.org/10.1038/s41597-022-01721-8)) 7 | - **Brain**: Brain hemorrhage classification on brain CT scans on an internal dataset of the Ulm Univerity Medical Center 8 | 9 | We gradually reduced the training dataset size for all three tasks to evaluate which pre-training method is best when only small annotated datasets are available. 10 | 11 | Here are our results: 12 | ![Results](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/assets/75016933/83df9ede-bbf9-4eea-816c-f9de718ee764) 13 | 14 | 15 | 16 | ### How to Start: 17 | We have jupyther notebooks with PyTorch Lightning and Moani for the three Downstream Tasks. \ 18 | If you are using Conda on Linux, here is how to get started: 19 | 1. Open your terminal and follow these steps: 20 | 1. conda create --name SSL_Downstream python==3.10 21 | 2. conda activate SSL_Downstream 22 | 3. *CUDA 10.2:* conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch\ 23 | *CUDA 11.3:* conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch\ 24 | *CUDA 11.6:* conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge \ 25 | (The newest PyTorch should also work [https://pytorch.org/](https://pytorch.org/)) 26 | 4. cd ...SSL-MedicalImagining-CL-MAE/Downstream/ 27 | 5. pip install -r requirements.txt 28 | 6. Download Jupyter: conda install -c anaconda jupyter 29 | 3. Login to Wandb (or create an account [https://wandb.ai/](https://wandb.ai/)) 30 | 4. Open "OrgMNIST.ipynb" or "COVID_19.ipynb" or "Brain.ipynb" in Jupyter Notebook or Jupyter Lab 31 | 1. Fill out the first cell with your preferences (Here you have to add the path to the downloaded pre-training checkpoints from the main README.md) 32 | 2. Run all cells 33 | 34 | 35 | ### Start Notebooks from Bash: 36 | This is not necessary, you can run everything directly in Jupyter Notebook or Jupyter Lab. However this might be useful 37 | 1. Open the notebook in Jupyter Lab 38 | 2. Click in the first code cell (This cell has all the parameters that needs to be specified) 39 | 1. On the left click on the two gear wheels 40 | 2. Add a cell tag with the name "parameters" \ 41 | ![Parameters](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/assets/75016933/afcd9342-a6a7-4921-a25a-c1fdcc827cd6) 42 | 3. Download papermill conda install -c conda-forge papermill 43 | 4. Creat a bash file (e.g. "file.sh"). All variables from the first code cell are parameters and can be specified in the bash file with -p ... 44 | 45 | ```bash 46 | # COVID-19 47 | papermill COVID-19.ipynb COVID-19.ipynb \ 48 | -p root_dir "path/where/results/should/be/saved" \ 49 | -p Run "WandB_Name_of_Run" \ 50 | -p pretrained_weights "/path/to/the/downloaded/checkpoints/SparK.pth" \ 51 | -p pre_train "SparK" \ 52 | 53 | # OrgMNIST 54 | papermill OrgMNIST.ipynb OrgMNIST.ipynb \ 55 | -p root_dir "path/where/results/should/be/saved" \ 56 | -p Run "WandB_Name_of_Run" \ 57 | -p pretrained_weights "/path/to/the/downloaded/SwAV.ckpt" \ 58 | -p pre_train "SwAV" \ 59 | 60 | ``` 61 | 5. Run the bash file (this will start the notebook) 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datamodules/binary_mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Union 2 | 3 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule 4 | from pl_bolts.datasets import BinaryMNIST 5 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 6 | from pl_bolts.utils.warnings import warn_missing_pkg 7 | 8 | if _TORCHVISION_AVAILABLE: 9 | from torchvision import transforms as transform_lib 10 | else: # pragma: no cover 11 | warn_missing_pkg("torchvision") 12 | 13 | 14 | class BinaryMNISTDataModule(VisionDataModule): 15 | """ 16 | .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png 17 | :width: 400 18 | :alt: MNIST 19 | 20 | Specs: 21 | - 10 classes (1 per digit) 22 | - Each image is (1 x 28 x 28) 23 | 24 | Binary MNIST, train, val, test splits and transforms 25 | 26 | Transforms:: 27 | 28 | mnist_transforms = transform_lib.Compose([ 29 | transform_lib.ToTensor() 30 | ]) 31 | 32 | Example:: 33 | 34 | from pl_bolts.datamodules import BinaryMNISTDataModule 35 | 36 | dm = BinaryMNISTDataModule('.') 37 | model = LitModel() 38 | 39 | Trainer().fit(model, datamodule=dm) 40 | """ 41 | 42 | name = "binary_mnist" 43 | dataset_cls = BinaryMNIST 44 | dims = (1, 28, 28) 45 | 46 | def __init__( 47 | self, 48 | data_dir: Optional[str] = None, 49 | val_split: Union[int, float] = 0.2, 50 | num_workers: int = 0, 51 | normalize: bool = False, 52 | batch_size: int = 32, 53 | seed: int = 42, 54 | shuffle: bool = True, 55 | pin_memory: bool = True, 56 | drop_last: bool = False, 57 | *args: Any, 58 | **kwargs: Any, 59 | ) -> None: 60 | """ 61 | Args: 62 | data_dir: Where to save/load the data 63 | val_split: Percent (float) or number (int) of samples to use for the validation split 64 | num_workers: How many workers to use for loading data 65 | normalize: If true applies image normalize 66 | batch_size: How many samples per batch to load 67 | seed: Random seed to be used for train/val/test splits 68 | shuffle: If true shuffles the train data every epoch 69 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 70 | returning them 71 | drop_last: If true drops the last incomplete batch 72 | """ 73 | 74 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 75 | raise ModuleNotFoundError( 76 | "You want to use transforms loaded from `torchvision` which is not installed yet." 77 | ) 78 | 79 | super().__init__( # type: ignore[misc] 80 | data_dir=data_dir, 81 | val_split=val_split, 82 | num_workers=num_workers, 83 | normalize=normalize, 84 | batch_size=batch_size, 85 | seed=seed, 86 | shuffle=shuffle, 87 | pin_memory=pin_memory, 88 | drop_last=drop_last, 89 | *args, 90 | **kwargs, 91 | ) 92 | 93 | @property 94 | def num_classes(self) -> int: 95 | """ 96 | Return: 97 | 10 98 | """ 99 | return 10 100 | 101 | def default_transforms(self) -> Callable: 102 | if self.normalize: 103 | mnist_transforms = transform_lib.Compose( 104 | [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] 105 | ) 106 | else: 107 | mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) 108 | 109 | return mnist_transforms 110 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/callbacks/vision/image_generation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from pytorch_lightning import Callback, LightningModule, Trainer 5 | 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _TORCHVISION_AVAILABLE: 10 | import torchvision 11 | else: # pragma: no cover 12 | warn_missing_pkg("torchvision") 13 | 14 | 15 | class TensorboardGenerativeModelImageSampler(Callback): 16 | """Generates images and logs to tensorboard. Your model must implement the ``forward`` function for generation. 17 | 18 | Requirements:: 19 | 20 | # model must have img_dim arg 21 | model.img_dim = (1, 28, 28) 22 | 23 | # model forward must work for sampling 24 | z = torch.rand(batch_size, latent_dim) 25 | img_samples = your_model(z) 26 | 27 | Example:: 28 | 29 | from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler 30 | 31 | trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) 32 | """ 33 | 34 | def __init__( 35 | self, 36 | num_samples: int = 3, 37 | nrow: int = 8, 38 | padding: int = 2, 39 | normalize: bool = False, 40 | norm_range: Optional[Tuple[int, int]] = None, 41 | scale_each: bool = False, 42 | pad_value: int = 0, 43 | ) -> None: 44 | """ 45 | Args: 46 | num_samples: Number of images displayed in the grid. Default: ``3``. 47 | nrow: Number of images displayed in each row of the grid. 48 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 49 | padding: Amount of padding. Default: ``2``. 50 | normalize: If ``True``, shift the image to the range (0, 1), 51 | by the min and max values specified by :attr:`range`. Default: ``False``. 52 | norm_range: Tuple (min, max) where min and max are numbers, 53 | then these numbers are used to normalize the image. By default, min and max 54 | are computed from the tensor. 55 | scale_each: If ``True``, scale each image in the batch of 56 | images separately rather than the (min, max) over all images. Default: ``False``. 57 | pad_value: Value for the padded pixels. Default: ``0``. 58 | """ 59 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 60 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 61 | 62 | super().__init__() 63 | self.num_samples = num_samples 64 | self.nrow = nrow 65 | self.padding = padding 66 | self.normalize = normalize 67 | self.norm_range = norm_range 68 | self.scale_each = scale_each 69 | self.pad_value = pad_value 70 | 71 | def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 72 | dim = (self.num_samples, pl_module.hparams.latent_dim) 73 | z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device) 74 | 75 | # generate images 76 | with torch.no_grad(): 77 | pl_module.eval() 78 | images = pl_module(z) 79 | pl_module.train() 80 | 81 | if len(images.size()) == 2: 82 | img_dim = pl_module.img_dim 83 | images = images.view(self.num_samples, *img_dim) 84 | 85 | grid = torchvision.utils.make_grid( 86 | tensor=images, 87 | nrow=self.nrow, 88 | padding=self.padding, 89 | normalize=self.normalize, 90 | range=self.norm_range, 91 | scale_each=self.scale_each, 92 | pad_value=self.pad_value, 93 | ) 94 | str_title = f"{pl_module.__class__.__name__}_images" 95 | trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step) 96 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datamodules/fashion_mnist_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional, Union 2 | 3 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule 4 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 5 | from pl_bolts.utils.warnings import warn_missing_pkg 6 | 7 | if _TORCHVISION_AVAILABLE: 8 | from torchvision import transforms as transform_lib 9 | from torchvision.datasets import FashionMNIST 10 | else: # pragma: no cover 11 | warn_missing_pkg("torchvision") 12 | FashionMNIST = None 13 | 14 | 15 | class FashionMNISTDataModule(VisionDataModule): 16 | """ 17 | .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ 18 | wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png 19 | :width: 400 20 | :alt: Fashion MNIST 21 | 22 | Specs: 23 | - 10 classes (1 per type) 24 | - Each image is (1 x 28 x 28) 25 | 26 | Standard FashionMNIST, train, val, test splits and transforms 27 | 28 | Transforms:: 29 | 30 | mnist_transforms = transform_lib.Compose([ 31 | transform_lib.ToTensor() 32 | ]) 33 | 34 | Example:: 35 | 36 | from pl_bolts.datamodules import FashionMNISTDataModule 37 | 38 | dm = FashionMNISTDataModule('.') 39 | model = LitModel() 40 | 41 | Trainer().fit(model, datamodule=dm) 42 | """ 43 | 44 | name = "fashion_mnist" 45 | dataset_cls = FashionMNIST 46 | dims = (1, 28, 28) 47 | 48 | def __init__( 49 | self, 50 | data_dir: Optional[str] = None, 51 | val_split: Union[int, float] = 0.2, 52 | num_workers: int = 0, 53 | normalize: bool = False, 54 | batch_size: int = 32, 55 | seed: int = 42, 56 | shuffle: bool = True, 57 | pin_memory: bool = True, 58 | drop_last: bool = False, 59 | *args: Any, 60 | **kwargs: Any, 61 | ) -> None: 62 | """ 63 | Args: 64 | data_dir: Where to save/load the data 65 | val_split: Percent (float) or number (int) of samples to use for the validation split 66 | num_workers: How many workers to use for loading data 67 | normalize: If true applies image normalize 68 | batch_size: How many samples per batch to load 69 | seed: Random seed to be used for train/val/test splits 70 | shuffle: If true shuffles the train data every epoch 71 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 72 | returning them 73 | drop_last: If true drops the last incomplete batch 74 | """ 75 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 76 | raise ModuleNotFoundError( 77 | "You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet." 78 | ) 79 | 80 | super().__init__( # type: ignore[misc] 81 | data_dir=data_dir, 82 | val_split=val_split, 83 | num_workers=num_workers, 84 | normalize=normalize, 85 | batch_size=batch_size, 86 | seed=seed, 87 | shuffle=shuffle, 88 | pin_memory=pin_memory, 89 | drop_last=drop_last, 90 | *args, 91 | **kwargs, 92 | ) 93 | 94 | @property 95 | def num_classes(self) -> int: 96 | """ 97 | Return: 98 | 10 99 | """ 100 | return 10 101 | 102 | def default_transforms(self) -> Callable: 103 | if self.normalize: 104 | mnist_transforms = transform_lib.Compose( 105 | [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] 106 | ) 107 | else: 108 | mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) 109 | 110 | return mnist_transforms 111 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/rl/noisy_dqn_model.py: -------------------------------------------------------------------------------- 1 | """Noisy DQN.""" 2 | import argparse 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | from pytorch_lightning import Trainer 7 | from torch import Tensor 8 | 9 | from pl_bolts.datamodules.experience_source import Experience 10 | from pl_bolts.models.rl.common.networks import NoisyCNN 11 | from pl_bolts.models.rl.dqn_model import DQN 12 | 13 | 14 | class NoisyDQN(DQN): 15 | """PyTorch Lightning implementation of `Noisy DQN `_ 16 | 17 | Paper authors: Meire Fortunato, Mohammad Gheshlaghi Azar, Bilal Piot, Jacob Menick, Ian Osband, Alex Graves, 18 | Vlad Mnih, Remi Munos, Demis Hassabis, Olivier Pietquin, Charles Blundell, Shane Legg 19 | 20 | Model implemented by: 21 | 22 | - `Donal Byrne ` 23 | 24 | Example: 25 | >>> from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN 26 | ... 27 | >>> model = NoisyDQN("PongNoFrameskip-v4") 28 | 29 | Train:: 30 | 31 | trainer = Trainer() 32 | trainer.fit(model) 33 | 34 | .. note:: Currently only supports CPU and single GPU training with `accelerator=dp` 35 | """ 36 | 37 | def build_networks(self) -> None: 38 | """Initializes the Noisy DQN train and target networks.""" 39 | self.net = NoisyCNN(self.obs_shape, self.n_actions) 40 | self.target_net = NoisyCNN(self.obs_shape, self.n_actions) 41 | 42 | def on_train_start(self) -> None: 43 | """Set the agents epsilon to 0 as the exploration comes from the network.""" 44 | self.agent.epsilon = 0.0 45 | 46 | def train_batch( 47 | self, 48 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: 49 | """Contains the logic for generating a new batch of data to be passed to the DataLoader. This is the same 50 | function as the standard DQN except that we dont update epsilon as it is always 0. The exploration comes 51 | from the noisy network. 52 | 53 | Returns: 54 | yields a Experience tuple containing the state, action, reward, done and next_state. 55 | """ 56 | episode_reward = 0 57 | episode_steps = 0 58 | 59 | while True: 60 | self.total_steps += 1 61 | action = self.agent(self.state, self.device) 62 | 63 | next_state, r, is_done, _ = self.env.step(action[0]) 64 | 65 | episode_reward += r 66 | episode_steps += 1 67 | 68 | exp = Experience(state=self.state, action=action[0], reward=r, done=is_done, new_state=next_state) 69 | 70 | self.buffer.append(exp) 71 | self.state = next_state 72 | 73 | if is_done: 74 | self.done_episodes += 1 75 | self.total_rewards.append(episode_reward) 76 | self.total_episode_steps.append(episode_steps) 77 | self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :])) 78 | self.state = self.env.reset() 79 | episode_steps = 0 80 | episode_reward = 0 81 | 82 | states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size) 83 | 84 | for idx, _ in enumerate(dones): 85 | yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx] 86 | 87 | # Simulates epochs 88 | if self.total_steps % self.batches_per_epoch == 0: 89 | break 90 | 91 | 92 | def cli_main(): 93 | parser = argparse.ArgumentParser(add_help=False) 94 | 95 | # trainer args 96 | parser = Trainer.add_argparse_args(parser) 97 | 98 | # model args 99 | parser = NoisyDQN.add_model_specific_args(parser) 100 | args = parser.parse_args() 101 | 102 | model = NoisyDQN(**args.__dict__) 103 | 104 | trainer = Trainer.from_argparse_args(args) 105 | trainer.fit(model) 106 | 107 | 108 | if __name__ == "__main__": 109 | cli_main() 110 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/vision/image_gpt/gpt2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning import LightningModule 3 | from torch import nn 4 | 5 | 6 | class Block(nn.Module): 7 | def __init__(self, embed_dim, heads): 8 | super().__init__() 9 | self.ln_1 = nn.LayerNorm(embed_dim) 10 | self.ln_2 = nn.LayerNorm(embed_dim) 11 | self.attn = nn.MultiheadAttention(embed_dim, heads) 12 | self.mlp = nn.Sequential( 13 | nn.Linear(embed_dim, embed_dim * 4), 14 | nn.GELU(), 15 | nn.Linear(embed_dim * 4, embed_dim), 16 | ) 17 | 18 | def forward(self, x): 19 | attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype) 20 | attn_mask = torch.triu(attn_mask, diagonal=1) 21 | 22 | x = self.ln_1(x) 23 | a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False) 24 | x = x + a 25 | m = self.mlp(self.ln_2(x)) 26 | x = x + m 27 | return x 28 | 29 | 30 | class GPT2(LightningModule): 31 | """GPT-2 from `language Models are Unsupervised Multitask Learners `_ 33 | 34 | Paper by: Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever 35 | 36 | Implementation contributed by: 37 | 38 | - `Teddy Koker `_ 39 | 40 | Example:: 41 | 42 | from pl_bolts.models.vision import GPT2 43 | 44 | seq_len = 17 45 | batch_size = 32 46 | vocab_size = 16 47 | x = torch.randint(0, vocab_size, (seq_len, batch_size)) 48 | model = GPT2(embed_dim=32, heads=2, layers=2, num_positions=seq_len, vocab_size=vocab_size, num_classes=4) 49 | results = model(x) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | embed_dim: int, 55 | heads: int, 56 | layers: int, 57 | num_positions: int, 58 | vocab_size: int, 59 | num_classes: int, 60 | ): 61 | super().__init__() 62 | self.save_hyperparameters() 63 | 64 | self._init_sos_token() 65 | self._init_embeddings() 66 | self._init_layers() 67 | 68 | def _init_sos_token(self): 69 | self.sos = torch.nn.Parameter(torch.zeros(self.hparams.embed_dim)) 70 | nn.init.normal_(self.sos) 71 | 72 | def _init_embeddings(self): 73 | self.token_embeddings = nn.Embedding(self.hparams.vocab_size, self.hparams.embed_dim) 74 | self.position_embeddings = nn.Embedding(self.hparams.num_positions, self.hparams.embed_dim) 75 | 76 | def _init_layers(self): 77 | self.layers = nn.ModuleList() 78 | for _ in range(self.hparams.layers): 79 | self.layers.append(Block(self.hparams.embed_dim, self.hparams.heads)) 80 | 81 | self.ln_f = nn.LayerNorm(self.hparams.embed_dim) 82 | self.head = nn.Linear(self.hparams.embed_dim, self.hparams.vocab_size, bias=False) 83 | self.clf_head = nn.Linear(self.hparams.embed_dim, self.hparams.num_classes) 84 | 85 | def forward(self, x, classify=False): 86 | """Expect input as shape [sequence len, batch] If classify, return classification logits.""" 87 | length, batch = x.shape 88 | 89 | h = self.token_embeddings(x.long()) 90 | 91 | # prepend sos token 92 | sos = torch.ones(1, batch, self.hparams.embed_dim, device=x.device, dtype=x.dtype) * self.sos 93 | h = torch.cat([sos, h[:-1, :, :]], axis=0) 94 | 95 | # add positional embeddings 96 | positions = torch.arange(length, device=x.device).unsqueeze(-1) 97 | h = h + self.position_embeddings(positions).expand_as(h) 98 | 99 | # transformer 100 | for layer in self.layers: 101 | h = layer(h) 102 | 103 | if not classify: 104 | # return logits 105 | return self.head(h) 106 | 107 | h = torch.mean(h, dim=0) # average pool over sequence 108 | return self.clf_head(h) # return classification logits 109 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/README.md: -------------------------------------------------------------------------------- 1 | # Pre-Training with SwAV, MoCoV2, BYOL 2 | 3 | We used the implementation of PyTorch Lightning Bolds [https://lightning.ai/docs/pytorch/stable/ecosystem/bolts.html](https://lightning.ai/docs/pytorch/stable/ecosystem/bolts.html) 4 | 5 | ### How to Start: 6 | 1. Download the LIDC data and run the preprocessing script as explained here: [https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing) 7 | 8 | #### Option 1: Use the latest PyTorch Lightning Bolts implementation 9 | You can use the implementation of PyTorch Lightning Bolts. You only have to change the data loading. 10 | - SwAV: [https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/swav/swav_module.py](https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/swav/swav_module.py) 11 | - MoCoV2: [https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/moco/moco_module.py](https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/moco/moco_module.py) 12 | - BYOL: [https://github.com/Lightning-Universe/lightning-bolts/tree/master/src/pl_bolts/models/self_supervised/byol](https://github.com/Lightning-Universe/lightning-bolts/tree/master/src/pl_bolts/models/self_supervised/byol) 13 | 14 | #### Option 2: Use our PyTorch Lightning Bolts adapion 15 | 2. Change the folder structure of the preprocessed data to: 16 | ```bash 17 | LIDC-Data 18 | / 19 | train 20 | ``` 21 | 2. Open your terminal and follow these steps: 22 | 1. conda create --name SSL_Contrastive python==3.10 23 | 2. conda activate SSL_Contrastive 24 | 3. conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia 25 | 4. cd .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/ 26 | 5. pip install -r requirements.txt 27 | 4. Start the pre-training with a bash script: \ 28 | SwAV: 29 | ```bash 30 | #!/bin/bash 31 | 32 | wandb login your_login_id 33 | python .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/swav/swav_module_lidc.py \ 34 | --save_path /path/where/results/should/be/saved \ 35 | --data_dir /path/to/the/LIDC-Data \ 36 | --model Some_Name_for_WandB \ 37 | --test Some_Name_for_WandB \ 38 | --project WandB_project_name \ 39 | --batch_size 128 \ 40 | --group Bs_128 \ 41 | --tags ["500Proto_Color2x04-2x02-Blur-Crop"] \ 42 | --learning_rate 0.15 \ 43 | --final_lr 0.00015 \ 44 | --start_lr 0.3 \ 45 | --freeze_prototypes_epochs 313 \ 46 | --accumulate_grad_batches 1 \ 47 | --optimizer lars \ 48 | ``` 49 | 50 | MoCo V2: 51 | ```bash 52 | #!/bin/bash 53 | 54 | wandb login your_login_id 55 | python .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/moco/moco2_module.py \ 56 | --dataset=medical \ 57 | --batch_size=128 \ 58 | --data_dir=/path/to/the/LIDC-Data \ 59 | --savepath=/path/where/results/should/be/saved \ 60 | --wandb_group=LIDC \ 61 | --wandb_job_type=MoCo \ 62 | --lambda_ 0.05 \ 63 | --base_encoder=resnet50 \ 64 | --max_epochs=800 \ 65 | --num_workers=12 \ 66 | --tags resnet50 LIDC MoCo \ 67 | ``` 68 | 69 | BYOL: 70 | ```bash 71 | #!/bin/bash 72 | 73 | wandb login your_login_id 74 | python .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/byol/byol_module.py --gpus 1 \ 75 | --data_dir /path/to/the/LIDC-Data \ 76 | --batch_size 64 \ 77 | --savepath /path/where/results/should/be/saved \ 78 | --group BYOL \ 79 | --name WandB_name \ 80 | ``` 81 | For further information and other setting please refere to the PyTorch Lightning Bolds github: [https://github.com/Lightning-Universe/lightning-bolts/tree/master](https://github.com/Lightning-Universe/lightning-bolts/tree/master) 82 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/mnist_module.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from pytorch_lightning import LightningModule, Trainer 5 | from torch.nn import functional as F 6 | from torch.utils.data import DataLoader, random_split 7 | 8 | from pl_bolts.datasets import MNIST 9 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 10 | from pl_bolts.utils.warnings import warn_missing_pkg 11 | 12 | if _TORCHVISION_AVAILABLE: 13 | from torchvision import transforms 14 | else: # pragma: no cover 15 | warn_missing_pkg("torchvision") 16 | 17 | 18 | class LitMNIST(LightningModule): 19 | def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir="", **kwargs): 20 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 21 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 22 | 23 | super().__init__() 24 | self.save_hyperparameters() 25 | 26 | self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) 27 | self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) 28 | 29 | self.mnist_train = None 30 | self.mnist_val = None 31 | 32 | def forward(self, x): 33 | x = x.view(x.size(0), -1) 34 | x = torch.relu(self.l1(x)) 35 | x = torch.relu(self.l2(x)) 36 | return x 37 | 38 | def training_step(self, batch, batch_idx): 39 | x, y = batch 40 | y_hat = self(x) 41 | loss = F.cross_entropy(y_hat, y) 42 | self.log("train_loss", loss) 43 | return loss 44 | 45 | def validation_step(self, batch, batch_idx): 46 | x, y = batch 47 | y_hat = self(x) 48 | loss = F.cross_entropy(y_hat, y) 49 | self.log("val_loss", loss) 50 | 51 | def test_step(self, batch, batch_idx): 52 | x, y = batch 53 | y_hat = self(x) 54 | loss = F.cross_entropy(y_hat, y) 55 | self.log("test_loss", loss) 56 | 57 | def configure_optimizers(self): 58 | return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) 59 | 60 | def prepare_data(self): 61 | MNIST(self.hparams.data_dir, train=True, download=True, transform=transforms.ToTensor()) 62 | 63 | def train_dataloader(self): 64 | dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor()) 65 | mnist_train, _ = random_split(dataset, [55000, 5000]) 66 | loader = DataLoader(mnist_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) 67 | return loader 68 | 69 | def val_dataloader(self): 70 | dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor()) 71 | _, mnist_val = random_split(dataset, [55000, 5000]) 72 | loader = DataLoader(mnist_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) 73 | return loader 74 | 75 | def test_dataloader(self): 76 | test_dataset = MNIST(self.hparams.data_dir, train=False, download=True, transform=transforms.ToTensor()) 77 | loader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) 78 | return loader 79 | 80 | @staticmethod 81 | def add_model_specific_args(parent_parser): 82 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 83 | parser.add_argument("--batch_size", type=int, default=32) 84 | parser.add_argument("--num_workers", type=int, default=4) 85 | parser.add_argument("--hidden_dim", type=int, default=128) 86 | parser.add_argument("--data_dir", type=str, default="") 87 | parser.add_argument("--learning_rate", type=float, default=0.0001) 88 | return parser 89 | 90 | 91 | def cli_main(): 92 | # args 93 | parser = ArgumentParser() 94 | parser = Trainer.add_argparse_args(parser) 95 | parser = LitMNIST.add_model_specific_args(parser) 96 | args = parser.parse_args() 97 | 98 | # model 99 | model = LitMNIST(**vars(args)) 100 | 101 | # training 102 | trainer = Trainer.from_argparse_args(args) 103 | trainer.fit(model) 104 | 105 | 106 | if __name__ == "__main__": # pragma: no cover 107 | cli_main() 108 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | from pl_bolts.utils import _PIL_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _PIL_AVAILABLE: 10 | from PIL import Image 11 | else: # pragma: no cover 12 | warn_missing_pkg("PIL", pypi_name="Pillow") 13 | 14 | DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) 15 | DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) 16 | 17 | 18 | class KittiDataset(Dataset): 19 | """ 20 | Note: 21 | You need to have downloaded the Kitti dataset first and provide the path to where it is saved. 22 | You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 23 | 24 | There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These 25 | useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored 26 | in `valid_labels`. 27 | 28 | The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` 29 | (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and 30 | `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by 31 | the loss function when comparing with the output. 32 | """ 33 | 34 | IMAGE_PATH = os.path.join("training", "image_2") 35 | MASK_PATH = os.path.join("training", "semantic") 36 | 37 | def __init__( 38 | self, 39 | data_dir: str, 40 | img_size: tuple = (1242, 376), 41 | void_labels: list = DEFAULT_VOID_LABELS, 42 | valid_labels: list = DEFAULT_VALID_LABELS, 43 | transform=None, 44 | ): 45 | """ 46 | Args: 47 | data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' 48 | img_size: image dimensions (width, height) 49 | void_labels: useless classes to be excluded from training 50 | valid_labels: useful classes to include 51 | """ 52 | if not _PIL_AVAILABLE: # pragma: no cover 53 | raise ModuleNotFoundError("You want to use `PIL` which is not installed yet.") 54 | 55 | self.img_size = img_size 56 | self.void_labels = void_labels 57 | self.valid_labels = valid_labels 58 | self.ignore_index = 250 59 | self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) 60 | self.transform = transform 61 | 62 | self.data_dir = data_dir 63 | self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH) 64 | self.mask_path = os.path.join(self.data_dir, self.MASK_PATH) 65 | self.img_list = self.get_filenames(self.img_path) 66 | self.mask_list = self.get_filenames(self.mask_path) 67 | 68 | def __len__(self): 69 | return len(self.img_list) 70 | 71 | def __getitem__(self, idx): 72 | img = Image.open(self.img_list[idx]) 73 | img = img.resize(self.img_size) 74 | img = np.array(img) 75 | 76 | mask = Image.open(self.mask_list[idx]).convert("L") 77 | mask = mask.resize(self.img_size) 78 | mask = np.array(mask) 79 | mask = self.encode_segmap(mask) 80 | 81 | if self.transform: 82 | img = self.transform(img) 83 | 84 | return img, mask 85 | 86 | def encode_segmap(self, mask): 87 | """Sets void classes to zero so they won't be considered for training.""" 88 | for voidc in self.void_labels: 89 | mask[mask == voidc] = self.ignore_index 90 | for validc in self.valid_labels: 91 | mask[mask == validc] = self.class_map[validc] 92 | # remove extra idxs from updated dataset 93 | mask[mask > 18] = self.ignore_index 94 | return mask 95 | 96 | def get_filenames(self, path): 97 | """Returns a list of absolute paths to images inside given `path`""" 98 | files_list = list() 99 | for filename in os.listdir(path): 100 | files_list.append(os.path.join(path, filename)) 101 | return files_list 102 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/vision/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class UNet(nn.Module): 7 | """ 8 | Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation 9 | `_ 10 | 11 | Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox 12 | 13 | Implemented by: 14 | 15 | - `Annika Brundyn `_ 16 | - `Akshay Kulkarni `_ 17 | 18 | Args: 19 | num_classes: Number of output classes required 20 | input_channels: Number of channels in input images (default 3) 21 | num_layers: Number of layers in each side of U-net (default 5) 22 | features_start: Number of features in first layer (default 64) 23 | bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | num_classes: int, 29 | input_channels: int = 3, 30 | num_layers: int = 5, 31 | features_start: int = 64, 32 | bilinear: bool = False, 33 | ): 34 | 35 | if num_layers < 1: 36 | raise ValueError(f"num_layers = {num_layers}, expected: num_layers > 0") 37 | 38 | super().__init__() 39 | self.num_layers = num_layers 40 | 41 | layers = [DoubleConv(input_channels, features_start)] 42 | 43 | feats = features_start 44 | for _ in range(num_layers - 1): 45 | layers.append(Down(feats, feats * 2)) 46 | feats *= 2 47 | 48 | for _ in range(num_layers - 1): 49 | layers.append(Up(feats, feats // 2, bilinear)) 50 | feats //= 2 51 | 52 | layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) 53 | 54 | self.layers = nn.ModuleList(layers) 55 | 56 | def forward(self, x): 57 | xi = [self.layers[0](x)] 58 | # Down path 59 | for layer in self.layers[1 : self.num_layers]: 60 | xi.append(layer(xi[-1])) 61 | # Up path 62 | for i, layer in enumerate(self.layers[self.num_layers : -1]): 63 | xi[-1] = layer(xi[-1], xi[-2 - i]) 64 | return self.layers[-1](xi[-1]) 65 | 66 | 67 | class DoubleConv(nn.Module): 68 | """[ Conv2d => BatchNorm (optional) => ReLU ] x 2.""" 69 | 70 | def __init__(self, in_ch: int, out_ch: int): 71 | super().__init__() 72 | self.net = nn.Sequential( 73 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), 74 | nn.BatchNorm2d(out_ch), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(out_ch), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | def forward(self, x): 82 | return self.net(x) 83 | 84 | 85 | class Down(nn.Module): 86 | """Downscale with MaxPool => DoubleConvolution block.""" 87 | 88 | def __init__(self, in_ch: int, out_ch: int): 89 | super().__init__() 90 | self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch)) 91 | 92 | def forward(self, x): 93 | return self.net(x) 94 | 95 | 96 | class Up(nn.Module): 97 | """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature 98 | map from contracting path, followed by DoubleConv.""" 99 | 100 | def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): 101 | super().__init__() 102 | self.upsample = None 103 | if bilinear: 104 | self.upsample = nn.Sequential( 105 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), 106 | nn.Conv2d(in_ch, in_ch // 2, kernel_size=1), 107 | ) 108 | else: 109 | self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2) 110 | 111 | self.conv = DoubleConv(in_ch, out_ch) 112 | 113 | def forward(self, x1, x2): 114 | x1 = self.upsample(x1) 115 | 116 | # Pad x1 to the size of x2 117 | diff_h = x2.shape[2] - x1.shape[2] 118 | diff_w = x2.shape[3] - x1.shape[3] 119 | 120 | x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) 121 | 122 | # Concatenate along the channels axis 123 | x = torch.cat([x2, x1], dim=1) 124 | return self.conv(x) 125 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/datasets/ssl_amdim_datasets.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Callable, Optional 3 | 4 | import numpy as np 5 | 6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 7 | from pl_bolts.utils.warnings import warn_missing_pkg 8 | 9 | if _TORCHVISION_AVAILABLE: 10 | from torchvision.datasets import CIFAR10 11 | else: # pragma: no cover 12 | warn_missing_pkg("torchvision") 13 | CIFAR10 = object 14 | 15 | 16 | class SSLDatasetMixin(ABC): 17 | @classmethod 18 | def generate_train_val_split(cls, examples, labels, pct_val): 19 | """Splits dataset uniformly across classes.""" 20 | nb_classes = len(set(labels)) 21 | 22 | nb_val_images = int(len(examples) * pct_val) // nb_classes 23 | 24 | val_x = [] 25 | val_y = [] 26 | train_x = [] 27 | train_y = [] 28 | 29 | cts = {x: 0 for x in range(nb_classes)} 30 | for img, class_idx in zip(examples, labels): 31 | 32 | # allow labeled 33 | if cts[class_idx] < nb_val_images: 34 | val_x.append(img) 35 | val_y.append(class_idx) 36 | cts[class_idx] += 1 37 | else: 38 | train_x.append(img) 39 | train_y.append(class_idx) 40 | 41 | val_x = np.stack(val_x) 42 | train_x = np.stack(train_x) 43 | return val_x, val_y, train_x, train_y 44 | 45 | @classmethod 46 | def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val): 47 | """Splits a dataset into two parts. 48 | 49 | The labeled split has nb_imgs_in_val per class 50 | """ 51 | nb_classes = len(set(labels)) 52 | 53 | # def partition_train_set(self, imgs, nb_imgs_in_val): 54 | labeled = [] 55 | labeled_y = [] 56 | unlabeled = [] 57 | unlabeled_y = [] 58 | 59 | cts = {x: 0 for x in range(nb_classes)} 60 | for img_name, class_idx in zip(examples, labels): 61 | 62 | # allow labeled 63 | if cts[class_idx] < nb_imgs_in_val: 64 | labeled.append(img_name) 65 | labeled_y.append(class_idx) 66 | cts[class_idx] += 1 67 | else: 68 | unlabeled.append(img_name) 69 | unlabeled_y.append(class_idx) 70 | 71 | labeled = np.stack(labeled) 72 | 73 | return labeled, labeled_y 74 | 75 | @classmethod 76 | def deterministic_shuffle(cls, x, y): 77 | 78 | n = len(x) 79 | idxs = list(range(0, n)) 80 | np.random.seed(1234) 81 | np.random.shuffle(idxs) 82 | 83 | x = x[idxs] 84 | 85 | y = np.asarray(y) 86 | y = y[idxs] 87 | y = list(y) 88 | 89 | return x, y 90 | 91 | 92 | class CIFAR10Mixed(SSLDatasetMixin, CIFAR10): 93 | def __init__( 94 | self, 95 | root: str, 96 | split: str = "val", 97 | transform: Optional[Callable] = None, 98 | target_transform: Optional[Callable] = None, 99 | download: bool = False, 100 | nb_labeled_per_class: Optional[int] = None, 101 | val_pct: float = 0.10, 102 | ): 103 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 104 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") 105 | 106 | if nb_labeled_per_class == -1: 107 | nb_labeled_per_class = None 108 | 109 | # use train for all of these splits 110 | train = split in ("val", "train", "train+unlabeled") 111 | super().__init__(root, train, transform, target_transform, download) 112 | 113 | # modify only for val, train 114 | if split != "test": 115 | # limit nb of examples per class 116 | X_test, y_test, X_train, y_train = self.generate_train_val_split(self.data, self.targets, val_pct) 117 | 118 | # shuffle idxs representing the data 119 | X_train, y_train = self.deterministic_shuffle(X_train, y_train) 120 | X_test, y_test = self.deterministic_shuffle(X_test, y_test) 121 | 122 | if split == "val": 123 | self.data = X_test 124 | self.targets = y_test 125 | 126 | else: 127 | self.data = X_train 128 | self.targets = y_train 129 | 130 | # limit the number of items per class 131 | if nb_labeled_per_class is not None: 132 | self.data, self.targets = self.select_nb_imgs_per_class(self.data, self.targets, nb_labeled_per_class) 133 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/models/vision/segmentation.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from pytorch_lightning import LightningModule, Trainer, seed_everything 5 | from torch.nn import functional as F 6 | 7 | from pl_bolts.models.vision.unet import UNet 8 | 9 | 10 | class SemSegment(LightningModule): 11 | def __init__( 12 | self, 13 | lr: float = 0.01, 14 | num_classes: int = 19, 15 | num_layers: int = 5, 16 | features_start: int = 64, 17 | bilinear: bool = False, 18 | ): 19 | """Basic model for semantic segmentation. Uses UNet architecture by default. 20 | 21 | The default parameters in this model are for the KITTI dataset. Note, if you'd like to use this model as is, 22 | you will first need to download the KITTI dataset yourself. You can download the dataset `here. 23 | `_ 24 | 25 | Implemented by: 26 | 27 | - `Annika Brundyn `_ 28 | 29 | Args: 30 | num_layers: number of layers in each side of U-net (default 5) 31 | features_start: number of features in first layer (default 64) 32 | bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling. 33 | lr: learning (default 0.01) 34 | """ 35 | super().__init__() 36 | 37 | self.num_classes = num_classes 38 | self.num_layers = num_layers 39 | self.features_start = features_start 40 | self.bilinear = bilinear 41 | self.lr = lr 42 | 43 | self.net = UNet( 44 | num_classes=num_classes, 45 | num_layers=self.num_layers, 46 | features_start=self.features_start, 47 | bilinear=self.bilinear, 48 | ) 49 | 50 | def forward(self, x): 51 | return self.net(x) 52 | 53 | def training_step(self, batch, batch_nb): 54 | img, mask = batch 55 | img = img.float() 56 | mask = mask.long() 57 | out = self(img) 58 | loss_val = F.cross_entropy(out, mask, ignore_index=250) 59 | log_dict = {"train_loss": loss_val} 60 | return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict} 61 | 62 | def validation_step(self, batch, batch_idx): 63 | img, mask = batch 64 | img = img.float() 65 | mask = mask.long() 66 | out = self(img) 67 | loss_val = F.cross_entropy(out, mask, ignore_index=250) 68 | return {"val_loss": loss_val} 69 | 70 | def validation_epoch_end(self, outputs): 71 | loss_val = torch.stack([x["val_loss"] for x in outputs]).mean() 72 | log_dict = {"val_loss": loss_val} 73 | return {"log": log_dict, "val_loss": log_dict["val_loss"], "progress_bar": log_dict} 74 | 75 | def configure_optimizers(self): 76 | opt = torch.optim.Adam(self.net.parameters(), lr=self.lr) 77 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) 78 | return [opt], [sch] 79 | 80 | @staticmethod 81 | def add_model_specific_args(parent_parser): 82 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 83 | parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate") 84 | parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") 85 | parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") 86 | parser.add_argument( 87 | "--bilinear", action="store_true", default=False, help="whether to use bilinear interpolation or transposed" 88 | ) 89 | 90 | return parser 91 | 92 | 93 | def cli_main(): 94 | from pl_bolts.datamodules import KittiDataModule 95 | 96 | seed_everything(1234) 97 | 98 | parser = ArgumentParser() 99 | # trainer args 100 | parser = Trainer.add_argparse_args(parser) 101 | # model args 102 | parser = SemSegment.add_model_specific_args(parser) 103 | # datamodule args 104 | parser = KittiDataModule.add_argparse_args(parser) 105 | 106 | args = parser.parse_args() 107 | 108 | # data 109 | dm = KittiDataModule(args.data_dir).from_argparse_args(args) 110 | 111 | # model 112 | model = SemSegment(**args.__dict__) 113 | 114 | # train 115 | trainer = Trainer().from_argparse_args(args) 116 | trainer.fit(model, datamodule=dm) 117 | 118 | 119 | if __name__ == "__main__": 120 | cli_main() 121 | -------------------------------------------------------------------------------- /Pre-Training/Masked_Autoencoder/utils/arg_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) ByteDance, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import sys 10 | 11 | from tap import Tap 12 | 13 | import dist 14 | 15 | 16 | class Args(Tap): 17 | # environment 18 | exp_name: str = 'your_exp_name' 19 | exp_dir: str = 'your_exp_dir' # will be created if not exists 20 | data_path: str = 'imagenet_data_path' 21 | resume_from: str = '' # resume from some checkpoint.pth 22 | 23 | # SparK hyperparameters 24 | mask: float = 0.6 # mask ratio, should be in (0, 1) 25 | 26 | # encoder hyperparameters 27 | model: str = 'resnet50' 28 | input_size: int = 224 29 | sbn: bool = True 30 | 31 | # data hyperparameters 32 | bs: int = 4096 33 | dataloader_workers: int = 8 34 | 35 | # pre-training hyperparameters 36 | dp: float = 0.0 37 | base_lr: float = 2e-4 38 | wd: float = 0.04 39 | wde: float = 0.2 40 | ep: int = 1600 41 | wp_ep: int = 40 42 | clip: int = 5. 43 | opt: str = 'lamb' 44 | ada: float = 0. 45 | 46 | # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically 47 | lr: float = None 48 | batch_size_per_gpu: int = 0 49 | glb_batch_size: int = 0 50 | densify_norm: str = '' 51 | device: str = 'cpu' 52 | local_rank: int = 0 53 | cmd: str = ' '.join(sys.argv[1:]) 54 | commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]' 55 | commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip() 56 | last_loss: float = 0. 57 | cur_ep: str = '' 58 | remain_time: str = '' 59 | finish_time: str = '' 60 | first_logging: bool = True 61 | log_txt_name: str = '{args.exp_dir}/pretrain_log.txt' 62 | tb_lg_dir: str = '' # tensorboard log directory 63 | 64 | @property 65 | def is_convnext(self): 66 | return 'convnext' in self.model or 'cnx' in self.model 67 | 68 | @property 69 | def is_resnet(self): 70 | return 'resnet' in self.model 71 | 72 | def log_epoch(self): 73 | if not dist.is_local_master(): 74 | return 75 | 76 | if self.first_logging: 77 | self.first_logging = False 78 | with open(self.log_txt_name, 'w') as fp: 79 | json.dump({ 80 | 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg, 81 | 'model': self.model, 82 | }, fp) 83 | fp.write('\n\n') 84 | 85 | with open(self.log_txt_name, 'a') as fp: 86 | json.dump({ 87 | 'cur_ep': self.cur_ep, 88 | 'last_L': self.last_loss, 89 | 'rema': self.remain_time, 'fini': self.finish_time, 90 | }, fp) 91 | fp.write('\n') 92 | 93 | 94 | def init_dist_and_get_args(): 95 | from utils import misc 96 | 97 | # initialize 98 | args = Args(explicit_bool=True).parse_args() 99 | e = os.path.abspath(args.exp_dir) 100 | d, e = os.path.dirname(e), os.path.basename(e) 101 | e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e) 102 | args.exp_dir = os.path.join(d, e) 103 | 104 | os.makedirs(args.exp_dir, exist_ok=True) 105 | args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt') 106 | args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log') 107 | try: 108 | os.makedirs(args.tb_lg_dir, exist_ok=True) 109 | except: 110 | pass 111 | 112 | misc.init_distributed_environ(exp_dir=args.exp_dir) 113 | 114 | # update args 115 | if not dist.initialized(): 116 | args.sbn = False 117 | args.first_logging = True 118 | args.device = dist.get_device() 119 | args.batch_size_per_gpu = args.bs // dist.get_world_size() 120 | args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size() 121 | 122 | if args.is_resnet: 123 | args.ada = args.ada or 0.95 124 | args.densify_norm = 'bn' 125 | 126 | if args.is_convnext: 127 | args.ada = args.ada or 0.999 128 | args.densify_norm = 'ln' 129 | 130 | args.opt = args.opt.lower() 131 | args.lr = args.base_lr * args.glb_batch_size / 256 132 | args.wde = args.wde or args.wd 133 | 134 | return args 135 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/utils/semi_supervised.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Sequence, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | 8 | from pl_bolts.utils import _SKLEARN_AVAILABLE 9 | from pl_bolts.utils.warnings import warn_missing_pkg 10 | 11 | if _SKLEARN_AVAILABLE: 12 | from sklearn.utils import shuffle as sk_shuffle 13 | else: # pragma: no cover 14 | warn_missing_pkg("sklearn", pypi_name="scikit-learn") 15 | 16 | 17 | class Identity(torch.nn.Module): 18 | """An identity class to replace arbitrary layers in pretrained models. 19 | 20 | Example:: 21 | 22 | from pl_bolts.utils import Identity 23 | 24 | model = resnet18() 25 | model.fc = Identity() 26 | """ 27 | 28 | def __init__(self) -> None: 29 | super().__init__() 30 | 31 | def forward(self, x: Tensor) -> Tensor: 32 | return x 33 | 34 | 35 | def balance_classes( 36 | X: Union[Tensor, np.ndarray], Y: Union[Tensor, np.ndarray, Sequence[int]], batch_size: int 37 | ) -> Tuple[np.ndarray, np.ndarray]: 38 | """Makes sure each batch has an equal amount of data from each class. Perfect balance. 39 | 40 | Args: 41 | X: input features 42 | Y: mixed labels (ints) 43 | batch_size: the ultimate batch size 44 | """ 45 | if not _SKLEARN_AVAILABLE: # pragma: no cover 46 | raise ModuleNotFoundError("You want to use `shuffle` function from `scikit-learn` which is not installed yet.") 47 | 48 | nb_classes = len(set(Y)) 49 | 50 | nb_batches = math.ceil(len(Y) / batch_size) 51 | 52 | # sort by classes 53 | final_batches_x: List[list] = [[] for i in range(nb_batches)] 54 | final_batches_y: List[list] = [[] for i in range(nb_batches)] 55 | 56 | # Y needs to be np arr 57 | Y = np.asarray(Y) 58 | 59 | # pick chunk size for each class using the largest split 60 | chunk_sizes = [] 61 | for class_i in range(nb_classes): 62 | mask = Y == class_i 63 | y = Y[mask] 64 | chunk_sizes.append(math.ceil(len(y) / nb_batches)) 65 | chunk_size = max(chunk_sizes) 66 | # force chunk size to be even 67 | if chunk_size % 2 != 0: 68 | chunk_size -= 1 69 | 70 | # divide each class into each batch 71 | for class_i in range(nb_classes): 72 | mask = Y == class_i 73 | x = X[mask] 74 | y = Y[mask] 75 | 76 | # shuffle items in the class 77 | x, y = sk_shuffle(x, y, random_state=123) 78 | 79 | # divide the class into the batches 80 | for i_start in range(0, len(y), chunk_size): 81 | batch_i = i_start // chunk_size 82 | i_end = i_start + chunk_size 83 | 84 | if len(final_batches_x) > batch_i: 85 | final_batches_x[batch_i].append(x[i_start:i_end]) 86 | final_batches_y[batch_i].append(y[i_start:i_end]) 87 | 88 | # merge into full dataset 89 | final_batches_x = [np.concatenate(x, axis=0) for x in final_batches_x if len(x) > 0] 90 | final_batches_x = np.concatenate(final_batches_x, axis=0) 91 | 92 | final_batches_y = [np.concatenate(x, axis=0) for x in final_batches_y if len(x) > 0] 93 | final_batches_y = np.concatenate(final_batches_y, axis=0) 94 | 95 | return final_batches_x, final_batches_y 96 | 97 | 98 | def generate_half_labeled_batches( 99 | smaller_set_X: np.ndarray, 100 | smaller_set_Y: np.ndarray, 101 | larger_set_X: np.ndarray, 102 | larger_set_Y: np.ndarray, 103 | batch_size: int, 104 | ) -> Tuple[np.ndarray, np.ndarray]: 105 | """Given a labeled dataset and an unlabeled dataset, this function generates a joint pair where half the 106 | batches are labeled and the other half is not.""" 107 | X = [] 108 | Y = [] 109 | half_batch = batch_size // 2 110 | 111 | n_larger = len(larger_set_X) 112 | n_smaller = len(smaller_set_X) 113 | for i_start in range(0, n_larger, half_batch): 114 | i_end = i_start + half_batch 115 | 116 | X_larger = larger_set_X[i_start:i_end] 117 | Y_larger = larger_set_Y[i_start:i_end] 118 | 119 | # pull out labeled part 120 | smaller_start = i_start % (n_smaller - half_batch) 121 | smaller_end = smaller_start + half_batch 122 | 123 | X_small = smaller_set_X[smaller_start:smaller_end] 124 | Y_small = smaller_set_Y[smaller_start:smaller_end] 125 | 126 | X.extend([X_larger, X_small]) 127 | Y.extend([Y_larger, Y_small]) 128 | 129 | # aggregate reshuffled at end of shuffling 130 | X = np.concatenate(X, axis=0) 131 | Y = np.concatenate(Y, axis=0) 132 | 133 | return X, Y 134 | -------------------------------------------------------------------------------- /Pre-Training/Contrastive_Learning/pl_bolts/losses/rl.py: -------------------------------------------------------------------------------- 1 | """Loss functions for the RL models.""" 2 | 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from torch import Tensor, nn 8 | 9 | 10 | def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module, gamma: float = 0.99) -> Tensor: 11 | """Calculates the mse loss using a mini batch from the replay buffer. 12 | 13 | Args: 14 | batch: current mini batch of replay data 15 | net: main training network 16 | target_net: target network of the main training network 17 | gamma: discount factor 18 | 19 | Returns: 20 | loss 21 | """ 22 | states, actions, rewards, dones, next_states = batch 23 | 24 | actions = actions.long().squeeze(-1) 25 | 26 | state_action_values = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) 27 | 28 | with torch.no_grad(): 29 | next_state_values = target_net(next_states).max(1)[0] 30 | next_state_values[dones] = 0.0 31 | next_state_values = next_state_values.detach() 32 | 33 | expected_state_action_values = next_state_values * gamma + rewards 34 | 35 | return nn.MSELoss()(state_action_values, expected_state_action_values) 36 | 37 | 38 | def double_dqn_loss( 39 | batch: Tuple[Tensor, Tensor], 40 | net: nn.Module, 41 | target_net: nn.Module, 42 | gamma: float = 0.99, 43 | ) -> Tensor: 44 | """Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original 45 | DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the value 46 | from the target network. This code is heavily commented in order to explain the process clearly. 47 | 48 | Args: 49 | batch: current mini batch of replay data 50 | net: main training network 51 | target_net: target network of the main training network 52 | gamma: discount factor 53 | 54 | Returns: 55 | loss 56 | """ 57 | states, actions, rewards, dones, next_states = batch # batch of experiences, batch_size = 16 58 | 59 | actions = actions.long().squeeze(-1) 60 | 61 | state_action_values = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) 62 | 63 | # dont want to mess with gradients when using the target network 64 | with torch.no_grad(): 65 | next_outputs = net(next_states) # [16, 2], [batch, action_space] 66 | 67 | next_state_acts = next_outputs.max(1)[1].unsqueeze(-1) # take action at the index with the highest value 68 | next_tgt_out = target_net(next_states) 69 | 70 | # Take the value of the action chosen by the train network 71 | next_state_values = next_tgt_out.gather(1, next_state_acts).squeeze(-1) 72 | next_state_values[dones] = 0.0 # any steps flagged as done get a 0 value 73 | next_state_values = next_state_values.detach() # remove values from the graph, no grads needed 74 | 75 | # calc expected discounted return of next_state_values 76 | expected_state_action_values = next_state_values * gamma + rewards 77 | 78 | # Standard MSE loss between the state action values of the current state and the 79 | # expected state action values of the next state 80 | return nn.MSELoss()(state_action_values, expected_state_action_values) 81 | 82 | 83 | def per_dqn_loss( 84 | batch: Tuple[Tensor, Tensor], 85 | batch_weights: List, 86 | net: nn.Module, 87 | target_net: nn.Module, 88 | gamma: float = 0.99, 89 | ) -> Tuple[Tensor, np.ndarray]: 90 | """Calculates the mse loss with the priority weights of the batch from the PER buffer. 91 | 92 | Args: 93 | batch: current mini batch of replay data 94 | batch_weights: how each of these samples are weighted in terms of priority 95 | net: main training network 96 | target_net: target network of the main training network 97 | gamma: discount factor 98 | 99 | Returns: 100 | loss and batch_weights 101 | """ 102 | states, actions, rewards, dones, next_states = batch 103 | 104 | actions = actions.long() 105 | 106 | batch_weights = torch.tensor(batch_weights) 107 | 108 | actions_v = actions.unsqueeze(-1) 109 | outputs = net(states) 110 | state_action_vals = outputs.gather(1, actions_v) 111 | state_action_vals = state_action_vals.squeeze(-1) 112 | 113 | with torch.no_grad(): 114 | next_s_vals = target_net(next_states).max(1)[0] 115 | next_s_vals[dones] = 0.0 116 | exp_sa_vals = next_s_vals.detach() * gamma + rewards 117 | loss = (state_action_vals - exp_sa_vals) ** 2 118 | losses_v = batch_weights * loss 119 | return losses_v.mean(), (losses_v + 1e-5).data.cpu().numpy() 120 | --------------------------------------------------------------------------------