`_.
42 |
43 | Args:
44 | preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
45 | target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
46 |
47 | Example:
48 |
49 | >>> import torch
50 | >>> from pl_bolts.metrics.object_detection import giou
51 | >>> preds = torch.tensor([[100, 100, 200, 200]])
52 | >>> target = torch.tensor([[150, 150, 250, 250]])
53 | >>> giou(preds, target)
54 | tensor([[-0.0794]])
55 |
56 | Returns:
57 | GIoU in an NxM tensor containing the pairwise GIoU values for every element in preds and target,
58 | where N is the number of prediction bounding boxes and M is the number of target bounding boxes
59 | """
60 | x_min = torch.max(preds[:, None, 0], target[:, 0])
61 | y_min = torch.max(preds[:, None, 1], target[:, 1])
62 | x_max = torch.min(preds[:, None, 2], target[:, 2])
63 | y_max = torch.min(preds[:, None, 3], target[:, 3])
64 | intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0)
65 | pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
66 | target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
67 | union = pred_area[:, None] + target_area - intersection
68 | C_x_min = torch.min(preds[:, None, 0], target[:, 0])
69 | C_y_min = torch.min(preds[:, None, 1], target[:, 1])
70 | C_x_max = torch.max(preds[:, None, 2], target[:, 2])
71 | C_y_max = torch.max(preds[:, None, 3], target[:, 3])
72 | C_area = (C_x_max - C_x_min).clamp(min=0) * (C_y_max - C_y_min).clamp(min=0)
73 | iou = torch.true_divide(intersection, union)
74 | giou = iou - torch.true_divide((C_area - union), C_area)
75 | return giou
76 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/setup_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright The PyTorch Lightning team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | import re
17 | from typing import List
18 |
19 | _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
20 |
21 |
22 | def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_char: str = "#") -> List[str]:
23 | """Load requirements from a file.
24 |
25 | >>> _load_requirements(_PROJECT_ROOT) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
26 | ['torch...', 'pytorch-lightning...'...]
27 | """
28 | with open(os.path.join(path_dir, file_name)) as file:
29 | lines = [ln.strip() for ln in file.readlines()]
30 | reqs = []
31 | for ln in lines:
32 | # filer all comments
33 | if comment_char in ln:
34 | ln = ln[: ln.index(comment_char)].strip()
35 | # skip directly installed dependencies
36 | if ln.startswith("http"):
37 | continue
38 | if ln: # if requirement is not empty
39 | reqs.append(ln)
40 | return reqs
41 |
42 |
43 | def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
44 | """Load readme as decribtion.
45 |
46 | >>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
47 | '...'
48 | """
49 | path_readme = os.path.join(path_dir, "README.md")
50 | text = open(path_readme, encoding="utf-8").read()
51 |
52 | # drop images from readme
53 | text = text.replace("", "")
54 |
55 | # https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_png
56 | github_source_url = os.path.join(homepage, "raw", ver)
57 | # replace relative repository path to absolute link to the release
58 | # do not replace all "docs" as in the readme we reger some other sources with particular path to docs
59 | text = text.replace("docs/source/_images/", f"{os.path.join(github_source_url, 'docs/source/_images/')}")
60 |
61 | # readthedocs badge
62 | text = text.replace("badge/?version=stable", f"badge/?version={ver}")
63 | text = text.replace("lightning-bolts.readthedocs.io/en/stable/", f"lightning-bolts.readthedocs.io/en/{ver}")
64 | # codecov badge
65 | text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg")
66 | # replace github badges for release ones
67 | text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}")
68 |
69 | skip_begin = r""
70 | skip_end = r""
71 | # todo: wrap content as commented description
72 | text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL)
73 |
74 | # # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
75 | # github_release_url = os.path.join(homepage, "releases", "download", ver)
76 | # # download badge and replace url with local file
77 | # text = _parse_for_badge(text, github_release_url)
78 | return text
79 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/models/gans/dcgan/components.py:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
2 | from torch import Tensor, nn
3 |
4 |
5 | class DCGANGenerator(nn.Module):
6 | def __init__(self, latent_dim: int, feature_maps: int, image_channels: int) -> None:
7 | """
8 | Args:
9 | latent_dim: Dimension of the latent space
10 | feature_maps: Number of feature maps to use
11 | image_channels: Number of channels of the images from the dataset
12 | """
13 | super().__init__()
14 | self.gen = nn.Sequential(
15 | self._make_gen_block(latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0),
16 | self._make_gen_block(feature_maps * 8, feature_maps * 4),
17 | self._make_gen_block(feature_maps * 4, feature_maps * 2),
18 | self._make_gen_block(feature_maps * 2, feature_maps),
19 | self._make_gen_block(feature_maps, image_channels, last_block=True),
20 | )
21 |
22 | @staticmethod
23 | def _make_gen_block(
24 | in_channels: int,
25 | out_channels: int,
26 | kernel_size: int = 4,
27 | stride: int = 2,
28 | padding: int = 1,
29 | bias: bool = False,
30 | last_block: bool = False,
31 | ) -> nn.Sequential:
32 | if not last_block:
33 | gen_block = nn.Sequential(
34 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
35 | nn.BatchNorm2d(out_channels),
36 | nn.ReLU(True),
37 | )
38 | else:
39 | gen_block = nn.Sequential(
40 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
41 | nn.Tanh(),
42 | )
43 |
44 | return gen_block
45 |
46 | def forward(self, noise: Tensor) -> Tensor:
47 | return self.gen(noise)
48 |
49 |
50 | class DCGANDiscriminator(nn.Module):
51 | def __init__(self, feature_maps: int, image_channels: int) -> None:
52 | """
53 | Args:
54 | feature_maps: Number of feature maps to use
55 | image_channels: Number of channels of the images from the dataset
56 | """
57 | super().__init__()
58 | self.disc = nn.Sequential(
59 | self._make_disc_block(image_channels, feature_maps, batch_norm=False),
60 | self._make_disc_block(feature_maps, feature_maps * 2),
61 | self._make_disc_block(feature_maps * 2, feature_maps * 4),
62 | self._make_disc_block(feature_maps * 4, feature_maps * 8),
63 | self._make_disc_block(feature_maps * 8, 1, kernel_size=4, stride=1, padding=0, last_block=True),
64 | )
65 |
66 | @staticmethod
67 | def _make_disc_block(
68 | in_channels: int,
69 | out_channels: int,
70 | kernel_size: int = 4,
71 | stride: int = 2,
72 | padding: int = 1,
73 | bias: bool = False,
74 | batch_norm: bool = True,
75 | last_block: bool = False,
76 | ) -> nn.Sequential:
77 | if not last_block:
78 | disc_block = nn.Sequential(
79 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
80 | nn.BatchNorm2d(out_channels) if batch_norm else nn.Identity(),
81 | nn.LeakyReLU(0.2, inplace=True),
82 | )
83 | else:
84 | disc_block = nn.Sequential(
85 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias),
86 | nn.Sigmoid(),
87 | )
88 |
89 | return disc_block
90 |
91 | def forward(self, x: Tensor) -> Tensor:
92 | return self.disc(x).view(-1, 1).squeeze(1)
93 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/callbacks/variational.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 | from pytorch_lightning import LightningModule, Trainer
6 | from pytorch_lightning.callbacks import Callback
7 | from torch import Tensor
8 |
9 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
10 | from pl_bolts.utils.warnings import warn_missing_pkg
11 |
12 | if _TORCHVISION_AVAILABLE:
13 | import torchvision
14 | else: # pragma: no cover
15 | warn_missing_pkg("torchvision")
16 |
17 |
18 | class LatentDimInterpolator(Callback):
19 | """Interpolates the latent space for a model by setting all dims to zero and stepping through the first two
20 | dims increasing one unit at a time.
21 |
22 | Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5)
23 |
24 | Example::
25 |
26 | from pl_bolts.callbacks import LatentDimInterpolator
27 |
28 | Trainer(callbacks=[LatentDimInterpolator()])
29 | """
30 |
31 | def __init__(
32 | self,
33 | interpolate_epoch_interval: int = 20,
34 | range_start: int = -5,
35 | range_end: int = 5,
36 | steps: int = 11,
37 | num_samples: int = 2,
38 | normalize: bool = True,
39 | ):
40 | """
41 | Args:
42 | interpolate_epoch_interval: default 20
43 | range_start: default -5
44 | range_end: default 5
45 | steps: number of step between start and end
46 | num_samples: default 2
47 | normalize: default True (change image to (0, 1) range)
48 | """
49 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
50 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
51 |
52 | super().__init__()
53 | self.interpolate_epoch_interval = interpolate_epoch_interval
54 | self.range_start = range_start
55 | self.range_end = range_end
56 | self.num_samples = num_samples
57 | self.normalize = normalize
58 | self.steps = steps
59 |
60 | def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
61 | if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
62 | images = self.interpolate_latent_space(pl_module, latent_dim=pl_module.hparams.latent_dim)
63 | images = torch.cat(images, dim=0)
64 |
65 | num_rows = self.steps
66 | grid = torchvision.utils.make_grid(images, nrow=num_rows, normalize=self.normalize)
67 | str_title = f"{pl_module.__class__.__name__}_latent_space"
68 | trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)
69 |
70 | def interpolate_latent_space(self, pl_module: LightningModule, latent_dim: int) -> List[Tensor]:
71 | images = []
72 | with torch.no_grad():
73 | pl_module.eval()
74 | for z1 in np.linspace(self.range_start, self.range_end, self.steps):
75 | for z2 in np.linspace(self.range_start, self.range_end, self.steps):
76 | # set all dims to zero
77 | z = torch.zeros(self.num_samples, latent_dim, device=pl_module.device)
78 |
79 | # set the fist 2 dims to the value
80 | z[:, 0] = torch.tensor(z1)
81 | z[:, 1] = torch.tensor(z2)
82 |
83 | # sample
84 | # generate images
85 | img = pl_module(z)
86 |
87 | if len(img.size()) == 2:
88 | img = img.view(self.num_samples, *pl_module.img_dim)
89 |
90 | img = img[0]
91 | img = img.unsqueeze(0)
92 | images.append(img)
93 |
94 | pl_module.train()
95 | return images
96 |
--------------------------------------------------------------------------------
/Pre-Training/Masked_Autoencoder/models/custom.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from typing import List
10 | from timm.models.registry import register_model
11 |
12 |
13 | class YourConvNet(nn.Module):
14 | """
15 | This is a template for your custom ConvNet.
16 | It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
17 | You can refer to the implementations in `pretrain\models\resnet.py` for an example.
18 | """
19 |
20 | def get_downsample_ratio(self) -> int:
21 | """
22 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
23 |
24 | :return: the TOTAL downsample ratio of the ConvNet.
25 | E.g., for a ResNet-50, this should return 32.
26 | """
27 | raise NotImplementedError
28 |
29 | def get_feature_map_channels(self) -> List[int]:
30 | """
31 | This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
32 |
33 | :return: a list of the number of channels of each feature map.
34 | E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
35 | """
36 | raise NotImplementedError
37 |
38 | def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
39 | """
40 | The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
41 |
42 | :param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
43 | :param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
44 | :return:
45 | - hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
46 | - hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
47 | E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
48 | for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
49 | """
50 | raise NotImplementedError
51 |
52 |
53 | @register_model
54 | def your_convnet_small(pretrained=False, **kwargs):
55 | raise NotImplementedError
56 | return YourConvNet(**kwargs)
57 |
58 |
59 | @torch.no_grad()
60 | def convnet_test():
61 | from timm.models import create_model
62 | cnn = create_model('your_convnet_small')
63 | print('get_downsample_ratio:', cnn.get_downsample_ratio())
64 | print('get_feature_map_channels:', cnn.get_feature_map_channels())
65 |
66 | downsample_ratio = cnn.get_downsample_ratio()
67 | feature_map_channels = cnn.get_feature_map_channels()
68 |
69 | # check the forward function
70 | B, C, H, W = 4, 3, 224, 224
71 | inp = torch.rand(B, C, H, W)
72 | feats = cnn(inp, hierarchical=True)
73 | assert isinstance(feats, list)
74 | assert len(feats) == len(feature_map_channels)
75 | print([tuple(t.shape) for t in feats])
76 |
77 | # check the downsample ratio
78 | feats = cnn(inp, hierarchical=True)
79 | assert feats[-1].shape[-2] == H // downsample_ratio
80 | assert feats[-1].shape[-1] == W // downsample_ratio
81 |
82 | # check the channel number
83 | for feat, ch in zip(feats, feature_map_channels):
84 | assert feat.ndim == 4
85 | assert feat.shape[1] == ch
86 |
87 |
88 | if __name__ == '__main__':
89 | convnet_test()
90 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/datamodules/mnist_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Union
2 |
3 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule
4 | from pl_bolts.datasets import MNIST
5 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
6 | from pl_bolts.utils.warnings import warn_missing_pkg
7 |
8 | if _TORCHVISION_AVAILABLE:
9 | from torchvision import transforms as transform_lib
10 | else: # pragma: no cover
11 | warn_missing_pkg("torchvision")
12 |
13 |
14 | class MNISTDataModule(VisionDataModule):
15 | """
16 | .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
17 | :width: 400
18 | :alt: MNIST
19 |
20 | Specs:
21 | - 10 classes (1 per digit)
22 | - Each image is (1 x 28 x 28)
23 |
24 | Standard MNIST, train, val, test splits and transforms
25 |
26 | Transforms::
27 |
28 | mnist_transforms = transform_lib.Compose([
29 | transform_lib.ToTensor()
30 | ])
31 |
32 | Example::
33 |
34 | from pl_bolts.datamodules import MNISTDataModule
35 |
36 | dm = MNISTDataModule('.')
37 | model = LitModel()
38 |
39 | Trainer().fit(model, datamodule=dm)
40 | """
41 |
42 | name = "mnist"
43 | dataset_cls = MNIST
44 | dims = (1, 28, 28)
45 |
46 | def __init__(
47 | self,
48 | data_dir: Optional[str] = None,
49 | val_split: Union[int, float] = 0.2,
50 | num_workers: int = 0,
51 | normalize: bool = False,
52 | batch_size: int = 32,
53 | seed: int = 42,
54 | shuffle: bool = True,
55 | pin_memory: bool = True,
56 | drop_last: bool = False,
57 | *args: Any,
58 | **kwargs: Any,
59 | ) -> None:
60 | """
61 | Args:
62 | data_dir: Where to save/load the data
63 | val_split: Percent (float) or number (int) of samples to use for the validation split
64 | num_workers: How many workers to use for loading data
65 | normalize: If true applies image normalize
66 | batch_size: How many samples per batch to load
67 | seed: Random seed to be used for train/val/test splits
68 | shuffle: If true shuffles the train data every epoch
69 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
70 | returning them
71 | drop_last: If true drops the last incomplete batch
72 | """
73 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
74 | raise ModuleNotFoundError(
75 | "You want to use MNIST dataset loaded from `torchvision` which is not installed yet."
76 | )
77 |
78 | super().__init__( # type: ignore[misc]
79 | data_dir=data_dir,
80 | val_split=val_split,
81 | num_workers=num_workers,
82 | normalize=normalize,
83 | batch_size=batch_size,
84 | seed=seed,
85 | shuffle=shuffle,
86 | pin_memory=pin_memory,
87 | drop_last=drop_last,
88 | *args,
89 | **kwargs,
90 | )
91 |
92 | @property
93 | def num_classes(self) -> int:
94 | """
95 | Return:
96 | 10
97 | """
98 | return 10
99 |
100 | def default_transforms(self) -> Callable:
101 | if self.normalize:
102 | mnist_transforms = transform_lib.Compose(
103 | [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
104 | )
105 | else:
106 | mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])
107 |
108 | return mnist_transforms
109 |
--------------------------------------------------------------------------------
/Downstream/README.md:
--------------------------------------------------------------------------------
1 | # Downstream Tasks
2 |
3 | We tested our pre-training on three CT classification tasks:
4 | - **COVID-19**: Covid classification on lung CT scans (From Grand Challenge [https://covid-ct.grand-challenge.org/](https://covid-ct.grand-challenge.org/) or
5 | [https://doi.org/10.48550/arXiv.2003.13865](https://doi.org/10.48550/arXiv.2003.13865))
6 | - **OrgMNIST**: Multi-class classification of 11 body organs on patches cropped around organs from abdominal CT scans (From MedMNIST Challenges [https://medmnist.com/](https://medmnist.com/) or [https://doi.org/10.1038/s41597-022-01721-8](https://doi.org/10.1038/s41597-022-01721-8))
7 | - **Brain**: Brain hemorrhage classification on brain CT scans on an internal dataset of the Ulm Univerity Medical Center
8 |
9 | We gradually reduced the training dataset size for all three tasks to evaluate which pre-training method is best when only small annotated datasets are available.
10 |
11 | Here are our results:
12 | 
13 |
14 |
15 |
16 | ### How to Start:
17 | We have jupyther notebooks with PyTorch Lightning and Moani for the three Downstream Tasks. \
18 | If you are using Conda on Linux, here is how to get started:
19 | 1. Open your terminal and follow these steps:
20 | 1. conda create --name SSL_Downstream python==3.10
21 | 2. conda activate SSL_Downstream
22 | 3. *CUDA 10.2:* conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=10.2 -c pytorch\
23 | *CUDA 11.3:* conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch\
24 | *CUDA 11.6:* conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge \
25 | (The newest PyTorch should also work [https://pytorch.org/](https://pytorch.org/))
26 | 4. cd ...SSL-MedicalImagining-CL-MAE/Downstream/
27 | 5. pip install -r requirements.txt
28 | 6. Download Jupyter: conda install -c anaconda jupyter
29 | 3. Login to Wandb (or create an account [https://wandb.ai/](https://wandb.ai/))
30 | 4. Open "OrgMNIST.ipynb" or "COVID_19.ipynb" or "Brain.ipynb" in Jupyter Notebook or Jupyter Lab
31 | 1. Fill out the first cell with your preferences (Here you have to add the path to the downloaded pre-training checkpoints from the main README.md)
32 | 2. Run all cells
33 |
34 |
35 | ### Start Notebooks from Bash:
36 | This is not necessary, you can run everything directly in Jupyter Notebook or Jupyter Lab. However this might be useful
37 | 1. Open the notebook in Jupyter Lab
38 | 2. Click in the first code cell (This cell has all the parameters that needs to be specified)
39 | 1. On the left click on the two gear wheels
40 | 2. Add a cell tag with the name "parameters" \
41 | 
42 | 3. Download papermill conda install -c conda-forge papermill
43 | 4. Creat a bash file (e.g. "file.sh"). All variables from the first code cell are parameters and can be specified in the bash file with -p ...
44 |
45 | ```bash
46 | # COVID-19
47 | papermill COVID-19.ipynb COVID-19.ipynb \
48 | -p root_dir "path/where/results/should/be/saved" \
49 | -p Run "WandB_Name_of_Run" \
50 | -p pretrained_weights "/path/to/the/downloaded/checkpoints/SparK.pth" \
51 | -p pre_train "SparK" \
52 |
53 | # OrgMNIST
54 | papermill OrgMNIST.ipynb OrgMNIST.ipynb \
55 | -p root_dir "path/where/results/should/be/saved" \
56 | -p Run "WandB_Name_of_Run" \
57 | -p pretrained_weights "/path/to/the/downloaded/SwAV.ckpt" \
58 | -p pre_train "SwAV" \
59 |
60 | ```
61 | 5. Run the bash file (this will start the notebook)
62 |
63 |
64 |
65 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/datamodules/binary_mnist_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Union
2 |
3 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule
4 | from pl_bolts.datasets import BinaryMNIST
5 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
6 | from pl_bolts.utils.warnings import warn_missing_pkg
7 |
8 | if _TORCHVISION_AVAILABLE:
9 | from torchvision import transforms as transform_lib
10 | else: # pragma: no cover
11 | warn_missing_pkg("torchvision")
12 |
13 |
14 | class BinaryMNISTDataModule(VisionDataModule):
15 | """
16 | .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
17 | :width: 400
18 | :alt: MNIST
19 |
20 | Specs:
21 | - 10 classes (1 per digit)
22 | - Each image is (1 x 28 x 28)
23 |
24 | Binary MNIST, train, val, test splits and transforms
25 |
26 | Transforms::
27 |
28 | mnist_transforms = transform_lib.Compose([
29 | transform_lib.ToTensor()
30 | ])
31 |
32 | Example::
33 |
34 | from pl_bolts.datamodules import BinaryMNISTDataModule
35 |
36 | dm = BinaryMNISTDataModule('.')
37 | model = LitModel()
38 |
39 | Trainer().fit(model, datamodule=dm)
40 | """
41 |
42 | name = "binary_mnist"
43 | dataset_cls = BinaryMNIST
44 | dims = (1, 28, 28)
45 |
46 | def __init__(
47 | self,
48 | data_dir: Optional[str] = None,
49 | val_split: Union[int, float] = 0.2,
50 | num_workers: int = 0,
51 | normalize: bool = False,
52 | batch_size: int = 32,
53 | seed: int = 42,
54 | shuffle: bool = True,
55 | pin_memory: bool = True,
56 | drop_last: bool = False,
57 | *args: Any,
58 | **kwargs: Any,
59 | ) -> None:
60 | """
61 | Args:
62 | data_dir: Where to save/load the data
63 | val_split: Percent (float) or number (int) of samples to use for the validation split
64 | num_workers: How many workers to use for loading data
65 | normalize: If true applies image normalize
66 | batch_size: How many samples per batch to load
67 | seed: Random seed to be used for train/val/test splits
68 | shuffle: If true shuffles the train data every epoch
69 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
70 | returning them
71 | drop_last: If true drops the last incomplete batch
72 | """
73 |
74 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
75 | raise ModuleNotFoundError(
76 | "You want to use transforms loaded from `torchvision` which is not installed yet."
77 | )
78 |
79 | super().__init__( # type: ignore[misc]
80 | data_dir=data_dir,
81 | val_split=val_split,
82 | num_workers=num_workers,
83 | normalize=normalize,
84 | batch_size=batch_size,
85 | seed=seed,
86 | shuffle=shuffle,
87 | pin_memory=pin_memory,
88 | drop_last=drop_last,
89 | *args,
90 | **kwargs,
91 | )
92 |
93 | @property
94 | def num_classes(self) -> int:
95 | """
96 | Return:
97 | 10
98 | """
99 | return 10
100 |
101 | def default_transforms(self) -> Callable:
102 | if self.normalize:
103 | mnist_transforms = transform_lib.Compose(
104 | [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
105 | )
106 | else:
107 | mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])
108 |
109 | return mnist_transforms
110 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/callbacks/vision/image_generation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from pytorch_lightning import Callback, LightningModule, Trainer
5 |
6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
7 | from pl_bolts.utils.warnings import warn_missing_pkg
8 |
9 | if _TORCHVISION_AVAILABLE:
10 | import torchvision
11 | else: # pragma: no cover
12 | warn_missing_pkg("torchvision")
13 |
14 |
15 | class TensorboardGenerativeModelImageSampler(Callback):
16 | """Generates images and logs to tensorboard. Your model must implement the ``forward`` function for generation.
17 |
18 | Requirements::
19 |
20 | # model must have img_dim arg
21 | model.img_dim = (1, 28, 28)
22 |
23 | # model forward must work for sampling
24 | z = torch.rand(batch_size, latent_dim)
25 | img_samples = your_model(z)
26 |
27 | Example::
28 |
29 | from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler
30 |
31 | trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()])
32 | """
33 |
34 | def __init__(
35 | self,
36 | num_samples: int = 3,
37 | nrow: int = 8,
38 | padding: int = 2,
39 | normalize: bool = False,
40 | norm_range: Optional[Tuple[int, int]] = None,
41 | scale_each: bool = False,
42 | pad_value: int = 0,
43 | ) -> None:
44 | """
45 | Args:
46 | num_samples: Number of images displayed in the grid. Default: ``3``.
47 | nrow: Number of images displayed in each row of the grid.
48 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
49 | padding: Amount of padding. Default: ``2``.
50 | normalize: If ``True``, shift the image to the range (0, 1),
51 | by the min and max values specified by :attr:`range`. Default: ``False``.
52 | norm_range: Tuple (min, max) where min and max are numbers,
53 | then these numbers are used to normalize the image. By default, min and max
54 | are computed from the tensor.
55 | scale_each: If ``True``, scale each image in the batch of
56 | images separately rather than the (min, max) over all images. Default: ``False``.
57 | pad_value: Value for the padded pixels. Default: ``0``.
58 | """
59 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
60 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
61 |
62 | super().__init__()
63 | self.num_samples = num_samples
64 | self.nrow = nrow
65 | self.padding = padding
66 | self.normalize = normalize
67 | self.norm_range = norm_range
68 | self.scale_each = scale_each
69 | self.pad_value = pad_value
70 |
71 | def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
72 | dim = (self.num_samples, pl_module.hparams.latent_dim)
73 | z = torch.normal(mean=0.0, std=1.0, size=dim, device=pl_module.device)
74 |
75 | # generate images
76 | with torch.no_grad():
77 | pl_module.eval()
78 | images = pl_module(z)
79 | pl_module.train()
80 |
81 | if len(images.size()) == 2:
82 | img_dim = pl_module.img_dim
83 | images = images.view(self.num_samples, *img_dim)
84 |
85 | grid = torchvision.utils.make_grid(
86 | tensor=images,
87 | nrow=self.nrow,
88 | padding=self.padding,
89 | normalize=self.normalize,
90 | range=self.norm_range,
91 | scale_each=self.scale_each,
92 | pad_value=self.pad_value,
93 | )
94 | str_title = f"{pl_module.__class__.__name__}_images"
95 | trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)
96 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/datamodules/fashion_mnist_datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional, Union
2 |
3 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule
4 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
5 | from pl_bolts.utils.warnings import warn_missing_pkg
6 |
7 | if _TORCHVISION_AVAILABLE:
8 | from torchvision import transforms as transform_lib
9 | from torchvision.datasets import FashionMNIST
10 | else: # pragma: no cover
11 | warn_missing_pkg("torchvision")
12 | FashionMNIST = None
13 |
14 |
15 | class FashionMNISTDataModule(VisionDataModule):
16 | """
17 | .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/
18 | wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png
19 | :width: 400
20 | :alt: Fashion MNIST
21 |
22 | Specs:
23 | - 10 classes (1 per type)
24 | - Each image is (1 x 28 x 28)
25 |
26 | Standard FashionMNIST, train, val, test splits and transforms
27 |
28 | Transforms::
29 |
30 | mnist_transforms = transform_lib.Compose([
31 | transform_lib.ToTensor()
32 | ])
33 |
34 | Example::
35 |
36 | from pl_bolts.datamodules import FashionMNISTDataModule
37 |
38 | dm = FashionMNISTDataModule('.')
39 | model = LitModel()
40 |
41 | Trainer().fit(model, datamodule=dm)
42 | """
43 |
44 | name = "fashion_mnist"
45 | dataset_cls = FashionMNIST
46 | dims = (1, 28, 28)
47 |
48 | def __init__(
49 | self,
50 | data_dir: Optional[str] = None,
51 | val_split: Union[int, float] = 0.2,
52 | num_workers: int = 0,
53 | normalize: bool = False,
54 | batch_size: int = 32,
55 | seed: int = 42,
56 | shuffle: bool = True,
57 | pin_memory: bool = True,
58 | drop_last: bool = False,
59 | *args: Any,
60 | **kwargs: Any,
61 | ) -> None:
62 | """
63 | Args:
64 | data_dir: Where to save/load the data
65 | val_split: Percent (float) or number (int) of samples to use for the validation split
66 | num_workers: How many workers to use for loading data
67 | normalize: If true applies image normalize
68 | batch_size: How many samples per batch to load
69 | seed: Random seed to be used for train/val/test splits
70 | shuffle: If true shuffles the train data every epoch
71 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
72 | returning them
73 | drop_last: If true drops the last incomplete batch
74 | """
75 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
76 | raise ModuleNotFoundError(
77 | "You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet."
78 | )
79 |
80 | super().__init__( # type: ignore[misc]
81 | data_dir=data_dir,
82 | val_split=val_split,
83 | num_workers=num_workers,
84 | normalize=normalize,
85 | batch_size=batch_size,
86 | seed=seed,
87 | shuffle=shuffle,
88 | pin_memory=pin_memory,
89 | drop_last=drop_last,
90 | *args,
91 | **kwargs,
92 | )
93 |
94 | @property
95 | def num_classes(self) -> int:
96 | """
97 | Return:
98 | 10
99 | """
100 | return 10
101 |
102 | def default_transforms(self) -> Callable:
103 | if self.normalize:
104 | mnist_transforms = transform_lib.Compose(
105 | [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
106 | )
107 | else:
108 | mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])
109 |
110 | return mnist_transforms
111 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/models/rl/noisy_dqn_model.py:
--------------------------------------------------------------------------------
1 | """Noisy DQN."""
2 | import argparse
3 | from typing import Tuple
4 |
5 | import numpy as np
6 | from pytorch_lightning import Trainer
7 | from torch import Tensor
8 |
9 | from pl_bolts.datamodules.experience_source import Experience
10 | from pl_bolts.models.rl.common.networks import NoisyCNN
11 | from pl_bolts.models.rl.dqn_model import DQN
12 |
13 |
14 | class NoisyDQN(DQN):
15 | """PyTorch Lightning implementation of `Noisy DQN `_
16 |
17 | Paper authors: Meire Fortunato, Mohammad Gheshlaghi Azar, Bilal Piot, Jacob Menick, Ian Osband, Alex Graves,
18 | Vlad Mnih, Remi Munos, Demis Hassabis, Olivier Pietquin, Charles Blundell, Shane Legg
19 |
20 | Model implemented by:
21 |
22 | - `Donal Byrne `
23 |
24 | Example:
25 | >>> from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN
26 | ...
27 | >>> model = NoisyDQN("PongNoFrameskip-v4")
28 |
29 | Train::
30 |
31 | trainer = Trainer()
32 | trainer.fit(model)
33 |
34 | .. note:: Currently only supports CPU and single GPU training with `accelerator=dp`
35 | """
36 |
37 | def build_networks(self) -> None:
38 | """Initializes the Noisy DQN train and target networks."""
39 | self.net = NoisyCNN(self.obs_shape, self.n_actions)
40 | self.target_net = NoisyCNN(self.obs_shape, self.n_actions)
41 |
42 | def on_train_start(self) -> None:
43 | """Set the agents epsilon to 0 as the exploration comes from the network."""
44 | self.agent.epsilon = 0.0
45 |
46 | def train_batch(
47 | self,
48 | ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
49 | """Contains the logic for generating a new batch of data to be passed to the DataLoader. This is the same
50 | function as the standard DQN except that we dont update epsilon as it is always 0. The exploration comes
51 | from the noisy network.
52 |
53 | Returns:
54 | yields a Experience tuple containing the state, action, reward, done and next_state.
55 | """
56 | episode_reward = 0
57 | episode_steps = 0
58 |
59 | while True:
60 | self.total_steps += 1
61 | action = self.agent(self.state, self.device)
62 |
63 | next_state, r, is_done, _ = self.env.step(action[0])
64 |
65 | episode_reward += r
66 | episode_steps += 1
67 |
68 | exp = Experience(state=self.state, action=action[0], reward=r, done=is_done, new_state=next_state)
69 |
70 | self.buffer.append(exp)
71 | self.state = next_state
72 |
73 | if is_done:
74 | self.done_episodes += 1
75 | self.total_rewards.append(episode_reward)
76 | self.total_episode_steps.append(episode_steps)
77 | self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :]))
78 | self.state = self.env.reset()
79 | episode_steps = 0
80 | episode_reward = 0
81 |
82 | states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size)
83 |
84 | for idx, _ in enumerate(dones):
85 | yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx]
86 |
87 | # Simulates epochs
88 | if self.total_steps % self.batches_per_epoch == 0:
89 | break
90 |
91 |
92 | def cli_main():
93 | parser = argparse.ArgumentParser(add_help=False)
94 |
95 | # trainer args
96 | parser = Trainer.add_argparse_args(parser)
97 |
98 | # model args
99 | parser = NoisyDQN.add_model_specific_args(parser)
100 | args = parser.parse_args()
101 |
102 | model = NoisyDQN(**args.__dict__)
103 |
104 | trainer = Trainer.from_argparse_args(args)
105 | trainer.fit(model)
106 |
107 |
108 | if __name__ == "__main__":
109 | cli_main()
110 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/models/vision/image_gpt/gpt2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning import LightningModule
3 | from torch import nn
4 |
5 |
6 | class Block(nn.Module):
7 | def __init__(self, embed_dim, heads):
8 | super().__init__()
9 | self.ln_1 = nn.LayerNorm(embed_dim)
10 | self.ln_2 = nn.LayerNorm(embed_dim)
11 | self.attn = nn.MultiheadAttention(embed_dim, heads)
12 | self.mlp = nn.Sequential(
13 | nn.Linear(embed_dim, embed_dim * 4),
14 | nn.GELU(),
15 | nn.Linear(embed_dim * 4, embed_dim),
16 | )
17 |
18 | def forward(self, x):
19 | attn_mask = torch.full((len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype)
20 | attn_mask = torch.triu(attn_mask, diagonal=1)
21 |
22 | x = self.ln_1(x)
23 | a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
24 | x = x + a
25 | m = self.mlp(self.ln_2(x))
26 | x = x + m
27 | return x
28 |
29 |
30 | class GPT2(LightningModule):
31 | """GPT-2 from `language Models are Unsupervised Multitask Learners `_
33 |
34 | Paper by: Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever
35 |
36 | Implementation contributed by:
37 |
38 | - `Teddy Koker `_
39 |
40 | Example::
41 |
42 | from pl_bolts.models.vision import GPT2
43 |
44 | seq_len = 17
45 | batch_size = 32
46 | vocab_size = 16
47 | x = torch.randint(0, vocab_size, (seq_len, batch_size))
48 | model = GPT2(embed_dim=32, heads=2, layers=2, num_positions=seq_len, vocab_size=vocab_size, num_classes=4)
49 | results = model(x)
50 | """
51 |
52 | def __init__(
53 | self,
54 | embed_dim: int,
55 | heads: int,
56 | layers: int,
57 | num_positions: int,
58 | vocab_size: int,
59 | num_classes: int,
60 | ):
61 | super().__init__()
62 | self.save_hyperparameters()
63 |
64 | self._init_sos_token()
65 | self._init_embeddings()
66 | self._init_layers()
67 |
68 | def _init_sos_token(self):
69 | self.sos = torch.nn.Parameter(torch.zeros(self.hparams.embed_dim))
70 | nn.init.normal_(self.sos)
71 |
72 | def _init_embeddings(self):
73 | self.token_embeddings = nn.Embedding(self.hparams.vocab_size, self.hparams.embed_dim)
74 | self.position_embeddings = nn.Embedding(self.hparams.num_positions, self.hparams.embed_dim)
75 |
76 | def _init_layers(self):
77 | self.layers = nn.ModuleList()
78 | for _ in range(self.hparams.layers):
79 | self.layers.append(Block(self.hparams.embed_dim, self.hparams.heads))
80 |
81 | self.ln_f = nn.LayerNorm(self.hparams.embed_dim)
82 | self.head = nn.Linear(self.hparams.embed_dim, self.hparams.vocab_size, bias=False)
83 | self.clf_head = nn.Linear(self.hparams.embed_dim, self.hparams.num_classes)
84 |
85 | def forward(self, x, classify=False):
86 | """Expect input as shape [sequence len, batch] If classify, return classification logits."""
87 | length, batch = x.shape
88 |
89 | h = self.token_embeddings(x.long())
90 |
91 | # prepend sos token
92 | sos = torch.ones(1, batch, self.hparams.embed_dim, device=x.device, dtype=x.dtype) * self.sos
93 | h = torch.cat([sos, h[:-1, :, :]], axis=0)
94 |
95 | # add positional embeddings
96 | positions = torch.arange(length, device=x.device).unsqueeze(-1)
97 | h = h + self.position_embeddings(positions).expand_as(h)
98 |
99 | # transformer
100 | for layer in self.layers:
101 | h = layer(h)
102 |
103 | if not classify:
104 | # return logits
105 | return self.head(h)
106 |
107 | h = torch.mean(h, dim=0) # average pool over sequence
108 | return self.clf_head(h) # return classification logits
109 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/README.md:
--------------------------------------------------------------------------------
1 | # Pre-Training with SwAV, MoCoV2, BYOL
2 |
3 | We used the implementation of PyTorch Lightning Bolds [https://lightning.ai/docs/pytorch/stable/ecosystem/bolts.html](https://lightning.ai/docs/pytorch/stable/ecosystem/bolts.html)
4 |
5 | ### How to Start:
6 | 1. Download the LIDC data and run the preprocessing script as explained here: [https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing](https://github.com/Wolfda95/SSL-MedicalImagining-CL-MAE/tree/main/Pre-Training/Data_Preprocessing)
7 |
8 | #### Option 1: Use the latest PyTorch Lightning Bolts implementation
9 | You can use the implementation of PyTorch Lightning Bolts. You only have to change the data loading.
10 | - SwAV: [https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/swav/swav_module.py](https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/swav/swav_module.py)
11 | - MoCoV2: [https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/moco/moco_module.py](https://github.com/Lightning-Universe/lightning-bolts/blob/master/src/pl_bolts/models/self_supervised/moco/moco_module.py)
12 | - BYOL: [https://github.com/Lightning-Universe/lightning-bolts/tree/master/src/pl_bolts/models/self_supervised/byol](https://github.com/Lightning-Universe/lightning-bolts/tree/master/src/pl_bolts/models/self_supervised/byol)
13 |
14 | #### Option 2: Use our PyTorch Lightning Bolts adapion
15 | 2. Change the folder structure of the preprocessed data to:
16 | ```bash
17 | LIDC-Data
18 | /
19 | train
20 | ```
21 | 2. Open your terminal and follow these steps:
22 | 1. conda create --name SSL_Contrastive python==3.10
23 | 2. conda activate SSL_Contrastive
24 | 3. conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
25 | 4. cd .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/
26 | 5. pip install -r requirements.txt
27 | 4. Start the pre-training with a bash script: \
28 | SwAV:
29 | ```bash
30 | #!/bin/bash
31 |
32 | wandb login your_login_id
33 | python .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/swav/swav_module_lidc.py \
34 | --save_path /path/where/results/should/be/saved \
35 | --data_dir /path/to/the/LIDC-Data \
36 | --model Some_Name_for_WandB \
37 | --test Some_Name_for_WandB \
38 | --project WandB_project_name \
39 | --batch_size 128 \
40 | --group Bs_128 \
41 | --tags ["500Proto_Color2x04-2x02-Blur-Crop"] \
42 | --learning_rate 0.15 \
43 | --final_lr 0.00015 \
44 | --start_lr 0.3 \
45 | --freeze_prototypes_epochs 313 \
46 | --accumulate_grad_batches 1 \
47 | --optimizer lars \
48 | ```
49 |
50 | MoCo V2:
51 | ```bash
52 | #!/bin/bash
53 |
54 | wandb login your_login_id
55 | python .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/moco/moco2_module.py \
56 | --dataset=medical \
57 | --batch_size=128 \
58 | --data_dir=/path/to/the/LIDC-Data \
59 | --savepath=/path/where/results/should/be/saved \
60 | --wandb_group=LIDC \
61 | --wandb_job_type=MoCo \
62 | --lambda_ 0.05 \
63 | --base_encoder=resnet50 \
64 | --max_epochs=800 \
65 | --num_workers=12 \
66 | --tags resnet50 LIDC MoCo \
67 | ```
68 |
69 | BYOL:
70 | ```bash
71 | #!/bin/bash
72 |
73 | wandb login your_login_id
74 | python .../SSL-MedicalImagining-CL-MAE/Pre-Training/Contrastive_Learning/pl_bolts/models/self_supervised/byol/byol_module.py --gpus 1 \
75 | --data_dir /path/to/the/LIDC-Data \
76 | --batch_size 64 \
77 | --savepath /path/where/results/should/be/saved \
78 | --group BYOL \
79 | --name WandB_name \
80 | ```
81 | For further information and other setting please refere to the PyTorch Lightning Bolds github: [https://github.com/Lightning-Universe/lightning-bolts/tree/master](https://github.com/Lightning-Universe/lightning-bolts/tree/master)
82 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/models/mnist_module.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import torch
4 | from pytorch_lightning import LightningModule, Trainer
5 | from torch.nn import functional as F
6 | from torch.utils.data import DataLoader, random_split
7 |
8 | from pl_bolts.datasets import MNIST
9 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
10 | from pl_bolts.utils.warnings import warn_missing_pkg
11 |
12 | if _TORCHVISION_AVAILABLE:
13 | from torchvision import transforms
14 | else: # pragma: no cover
15 | warn_missing_pkg("torchvision")
16 |
17 |
18 | class LitMNIST(LightningModule):
19 | def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir="", **kwargs):
20 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
21 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
22 |
23 | super().__init__()
24 | self.save_hyperparameters()
25 |
26 | self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
27 | self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
28 |
29 | self.mnist_train = None
30 | self.mnist_val = None
31 |
32 | def forward(self, x):
33 | x = x.view(x.size(0), -1)
34 | x = torch.relu(self.l1(x))
35 | x = torch.relu(self.l2(x))
36 | return x
37 |
38 | def training_step(self, batch, batch_idx):
39 | x, y = batch
40 | y_hat = self(x)
41 | loss = F.cross_entropy(y_hat, y)
42 | self.log("train_loss", loss)
43 | return loss
44 |
45 | def validation_step(self, batch, batch_idx):
46 | x, y = batch
47 | y_hat = self(x)
48 | loss = F.cross_entropy(y_hat, y)
49 | self.log("val_loss", loss)
50 |
51 | def test_step(self, batch, batch_idx):
52 | x, y = batch
53 | y_hat = self(x)
54 | loss = F.cross_entropy(y_hat, y)
55 | self.log("test_loss", loss)
56 |
57 | def configure_optimizers(self):
58 | return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
59 |
60 | def prepare_data(self):
61 | MNIST(self.hparams.data_dir, train=True, download=True, transform=transforms.ToTensor())
62 |
63 | def train_dataloader(self):
64 | dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor())
65 | mnist_train, _ = random_split(dataset, [55000, 5000])
66 | loader = DataLoader(mnist_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
67 | return loader
68 |
69 | def val_dataloader(self):
70 | dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor())
71 | _, mnist_val = random_split(dataset, [55000, 5000])
72 | loader = DataLoader(mnist_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
73 | return loader
74 |
75 | def test_dataloader(self):
76 | test_dataset = MNIST(self.hparams.data_dir, train=False, download=True, transform=transforms.ToTensor())
77 | loader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
78 | return loader
79 |
80 | @staticmethod
81 | def add_model_specific_args(parent_parser):
82 | parser = ArgumentParser(parents=[parent_parser], add_help=False)
83 | parser.add_argument("--batch_size", type=int, default=32)
84 | parser.add_argument("--num_workers", type=int, default=4)
85 | parser.add_argument("--hidden_dim", type=int, default=128)
86 | parser.add_argument("--data_dir", type=str, default="")
87 | parser.add_argument("--learning_rate", type=float, default=0.0001)
88 | return parser
89 |
90 |
91 | def cli_main():
92 | # args
93 | parser = ArgumentParser()
94 | parser = Trainer.add_argparse_args(parser)
95 | parser = LitMNIST.add_model_specific_args(parser)
96 | args = parser.parse_args()
97 |
98 | # model
99 | model = LitMNIST(**vars(args))
100 |
101 | # training
102 | trainer = Trainer.from_argparse_args(args)
103 | trainer.fit(model)
104 |
105 |
106 | if __name__ == "__main__": # pragma: no cover
107 | cli_main()
108 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/datasets/kitti_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | from pl_bolts.utils import _PIL_AVAILABLE
7 | from pl_bolts.utils.warnings import warn_missing_pkg
8 |
9 | if _PIL_AVAILABLE:
10 | from PIL import Image
11 | else: # pragma: no cover
12 | warn_missing_pkg("PIL", pypi_name="Pillow")
13 |
14 | DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1)
15 | DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33)
16 |
17 |
18 | class KittiDataset(Dataset):
19 | """
20 | Note:
21 | You need to have downloaded the Kitti dataset first and provide the path to where it is saved.
22 | You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
23 |
24 | There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These
25 | useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored
26 | in `valid_labels`.
27 |
28 | The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index`
29 | (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and
30 | `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by
31 | the loss function when comparing with the output.
32 | """
33 |
34 | IMAGE_PATH = os.path.join("training", "image_2")
35 | MASK_PATH = os.path.join("training", "semantic")
36 |
37 | def __init__(
38 | self,
39 | data_dir: str,
40 | img_size: tuple = (1242, 376),
41 | void_labels: list = DEFAULT_VOID_LABELS,
42 | valid_labels: list = DEFAULT_VALID_LABELS,
43 | transform=None,
44 | ):
45 | """
46 | Args:
47 | data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/'
48 | img_size: image dimensions (width, height)
49 | void_labels: useless classes to be excluded from training
50 | valid_labels: useful classes to include
51 | """
52 | if not _PIL_AVAILABLE: # pragma: no cover
53 | raise ModuleNotFoundError("You want to use `PIL` which is not installed yet.")
54 |
55 | self.img_size = img_size
56 | self.void_labels = void_labels
57 | self.valid_labels = valid_labels
58 | self.ignore_index = 250
59 | self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels))))
60 | self.transform = transform
61 |
62 | self.data_dir = data_dir
63 | self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH)
64 | self.mask_path = os.path.join(self.data_dir, self.MASK_PATH)
65 | self.img_list = self.get_filenames(self.img_path)
66 | self.mask_list = self.get_filenames(self.mask_path)
67 |
68 | def __len__(self):
69 | return len(self.img_list)
70 |
71 | def __getitem__(self, idx):
72 | img = Image.open(self.img_list[idx])
73 | img = img.resize(self.img_size)
74 | img = np.array(img)
75 |
76 | mask = Image.open(self.mask_list[idx]).convert("L")
77 | mask = mask.resize(self.img_size)
78 | mask = np.array(mask)
79 | mask = self.encode_segmap(mask)
80 |
81 | if self.transform:
82 | img = self.transform(img)
83 |
84 | return img, mask
85 |
86 | def encode_segmap(self, mask):
87 | """Sets void classes to zero so they won't be considered for training."""
88 | for voidc in self.void_labels:
89 | mask[mask == voidc] = self.ignore_index
90 | for validc in self.valid_labels:
91 | mask[mask == validc] = self.class_map[validc]
92 | # remove extra idxs from updated dataset
93 | mask[mask > 18] = self.ignore_index
94 | return mask
95 |
96 | def get_filenames(self, path):
97 | """Returns a list of absolute paths to images inside given `path`"""
98 | files_list = list()
99 | for filename in os.listdir(path):
100 | files_list.append(os.path.join(path, filename))
101 | return files_list
102 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/models/vision/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | class UNet(nn.Module):
7 | """
8 | Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
9 | `_
10 |
11 | Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox
12 |
13 | Implemented by:
14 |
15 | - `Annika Brundyn `_
16 | - `Akshay Kulkarni `_
17 |
18 | Args:
19 | num_classes: Number of output classes required
20 | input_channels: Number of channels in input images (default 3)
21 | num_layers: Number of layers in each side of U-net (default 5)
22 | features_start: Number of features in first layer (default 64)
23 | bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling.
24 | """
25 |
26 | def __init__(
27 | self,
28 | num_classes: int,
29 | input_channels: int = 3,
30 | num_layers: int = 5,
31 | features_start: int = 64,
32 | bilinear: bool = False,
33 | ):
34 |
35 | if num_layers < 1:
36 | raise ValueError(f"num_layers = {num_layers}, expected: num_layers > 0")
37 |
38 | super().__init__()
39 | self.num_layers = num_layers
40 |
41 | layers = [DoubleConv(input_channels, features_start)]
42 |
43 | feats = features_start
44 | for _ in range(num_layers - 1):
45 | layers.append(Down(feats, feats * 2))
46 | feats *= 2
47 |
48 | for _ in range(num_layers - 1):
49 | layers.append(Up(feats, feats // 2, bilinear))
50 | feats //= 2
51 |
52 | layers.append(nn.Conv2d(feats, num_classes, kernel_size=1))
53 |
54 | self.layers = nn.ModuleList(layers)
55 |
56 | def forward(self, x):
57 | xi = [self.layers[0](x)]
58 | # Down path
59 | for layer in self.layers[1 : self.num_layers]:
60 | xi.append(layer(xi[-1]))
61 | # Up path
62 | for i, layer in enumerate(self.layers[self.num_layers : -1]):
63 | xi[-1] = layer(xi[-1], xi[-2 - i])
64 | return self.layers[-1](xi[-1])
65 |
66 |
67 | class DoubleConv(nn.Module):
68 | """[ Conv2d => BatchNorm (optional) => ReLU ] x 2."""
69 |
70 | def __init__(self, in_ch: int, out_ch: int):
71 | super().__init__()
72 | self.net = nn.Sequential(
73 | nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
74 | nn.BatchNorm2d(out_ch),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
77 | nn.BatchNorm2d(out_ch),
78 | nn.ReLU(inplace=True),
79 | )
80 |
81 | def forward(self, x):
82 | return self.net(x)
83 |
84 |
85 | class Down(nn.Module):
86 | """Downscale with MaxPool => DoubleConvolution block."""
87 |
88 | def __init__(self, in_ch: int, out_ch: int):
89 | super().__init__()
90 | self.net = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch))
91 |
92 | def forward(self, x):
93 | return self.net(x)
94 |
95 |
96 | class Up(nn.Module):
97 | """Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature
98 | map from contracting path, followed by DoubleConv."""
99 |
100 | def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False):
101 | super().__init__()
102 | self.upsample = None
103 | if bilinear:
104 | self.upsample = nn.Sequential(
105 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
106 | nn.Conv2d(in_ch, in_ch // 2, kernel_size=1),
107 | )
108 | else:
109 | self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
110 |
111 | self.conv = DoubleConv(in_ch, out_ch)
112 |
113 | def forward(self, x1, x2):
114 | x1 = self.upsample(x1)
115 |
116 | # Pad x1 to the size of x2
117 | diff_h = x2.shape[2] - x1.shape[2]
118 | diff_w = x2.shape[3] - x1.shape[3]
119 |
120 | x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2])
121 |
122 | # Concatenate along the channels axis
123 | x = torch.cat([x2, x1], dim=1)
124 | return self.conv(x)
125 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/datasets/ssl_amdim_datasets.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Callable, Optional
3 |
4 | import numpy as np
5 |
6 | from pl_bolts.utils import _TORCHVISION_AVAILABLE
7 | from pl_bolts.utils.warnings import warn_missing_pkg
8 |
9 | if _TORCHVISION_AVAILABLE:
10 | from torchvision.datasets import CIFAR10
11 | else: # pragma: no cover
12 | warn_missing_pkg("torchvision")
13 | CIFAR10 = object
14 |
15 |
16 | class SSLDatasetMixin(ABC):
17 | @classmethod
18 | def generate_train_val_split(cls, examples, labels, pct_val):
19 | """Splits dataset uniformly across classes."""
20 | nb_classes = len(set(labels))
21 |
22 | nb_val_images = int(len(examples) * pct_val) // nb_classes
23 |
24 | val_x = []
25 | val_y = []
26 | train_x = []
27 | train_y = []
28 |
29 | cts = {x: 0 for x in range(nb_classes)}
30 | for img, class_idx in zip(examples, labels):
31 |
32 | # allow labeled
33 | if cts[class_idx] < nb_val_images:
34 | val_x.append(img)
35 | val_y.append(class_idx)
36 | cts[class_idx] += 1
37 | else:
38 | train_x.append(img)
39 | train_y.append(class_idx)
40 |
41 | val_x = np.stack(val_x)
42 | train_x = np.stack(train_x)
43 | return val_x, val_y, train_x, train_y
44 |
45 | @classmethod
46 | def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val):
47 | """Splits a dataset into two parts.
48 |
49 | The labeled split has nb_imgs_in_val per class
50 | """
51 | nb_classes = len(set(labels))
52 |
53 | # def partition_train_set(self, imgs, nb_imgs_in_val):
54 | labeled = []
55 | labeled_y = []
56 | unlabeled = []
57 | unlabeled_y = []
58 |
59 | cts = {x: 0 for x in range(nb_classes)}
60 | for img_name, class_idx in zip(examples, labels):
61 |
62 | # allow labeled
63 | if cts[class_idx] < nb_imgs_in_val:
64 | labeled.append(img_name)
65 | labeled_y.append(class_idx)
66 | cts[class_idx] += 1
67 | else:
68 | unlabeled.append(img_name)
69 | unlabeled_y.append(class_idx)
70 |
71 | labeled = np.stack(labeled)
72 |
73 | return labeled, labeled_y
74 |
75 | @classmethod
76 | def deterministic_shuffle(cls, x, y):
77 |
78 | n = len(x)
79 | idxs = list(range(0, n))
80 | np.random.seed(1234)
81 | np.random.shuffle(idxs)
82 |
83 | x = x[idxs]
84 |
85 | y = np.asarray(y)
86 | y = y[idxs]
87 | y = list(y)
88 |
89 | return x, y
90 |
91 |
92 | class CIFAR10Mixed(SSLDatasetMixin, CIFAR10):
93 | def __init__(
94 | self,
95 | root: str,
96 | split: str = "val",
97 | transform: Optional[Callable] = None,
98 | target_transform: Optional[Callable] = None,
99 | download: bool = False,
100 | nb_labeled_per_class: Optional[int] = None,
101 | val_pct: float = 0.10,
102 | ):
103 | if not _TORCHVISION_AVAILABLE: # pragma: no cover
104 | raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")
105 |
106 | if nb_labeled_per_class == -1:
107 | nb_labeled_per_class = None
108 |
109 | # use train for all of these splits
110 | train = split in ("val", "train", "train+unlabeled")
111 | super().__init__(root, train, transform, target_transform, download)
112 |
113 | # modify only for val, train
114 | if split != "test":
115 | # limit nb of examples per class
116 | X_test, y_test, X_train, y_train = self.generate_train_val_split(self.data, self.targets, val_pct)
117 |
118 | # shuffle idxs representing the data
119 | X_train, y_train = self.deterministic_shuffle(X_train, y_train)
120 | X_test, y_test = self.deterministic_shuffle(X_test, y_test)
121 |
122 | if split == "val":
123 | self.data = X_test
124 | self.targets = y_test
125 |
126 | else:
127 | self.data = X_train
128 | self.targets = y_train
129 |
130 | # limit the number of items per class
131 | if nb_labeled_per_class is not None:
132 | self.data, self.targets = self.select_nb_imgs_per_class(self.data, self.targets, nb_labeled_per_class)
133 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/models/vision/segmentation.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import torch
4 | from pytorch_lightning import LightningModule, Trainer, seed_everything
5 | from torch.nn import functional as F
6 |
7 | from pl_bolts.models.vision.unet import UNet
8 |
9 |
10 | class SemSegment(LightningModule):
11 | def __init__(
12 | self,
13 | lr: float = 0.01,
14 | num_classes: int = 19,
15 | num_layers: int = 5,
16 | features_start: int = 64,
17 | bilinear: bool = False,
18 | ):
19 | """Basic model for semantic segmentation. Uses UNet architecture by default.
20 |
21 | The default parameters in this model are for the KITTI dataset. Note, if you'd like to use this model as is,
22 | you will first need to download the KITTI dataset yourself. You can download the dataset `here.
23 | `_
24 |
25 | Implemented by:
26 |
27 | - `Annika Brundyn `_
28 |
29 | Args:
30 | num_layers: number of layers in each side of U-net (default 5)
31 | features_start: number of features in first layer (default 64)
32 | bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
33 | lr: learning (default 0.01)
34 | """
35 | super().__init__()
36 |
37 | self.num_classes = num_classes
38 | self.num_layers = num_layers
39 | self.features_start = features_start
40 | self.bilinear = bilinear
41 | self.lr = lr
42 |
43 | self.net = UNet(
44 | num_classes=num_classes,
45 | num_layers=self.num_layers,
46 | features_start=self.features_start,
47 | bilinear=self.bilinear,
48 | )
49 |
50 | def forward(self, x):
51 | return self.net(x)
52 |
53 | def training_step(self, batch, batch_nb):
54 | img, mask = batch
55 | img = img.float()
56 | mask = mask.long()
57 | out = self(img)
58 | loss_val = F.cross_entropy(out, mask, ignore_index=250)
59 | log_dict = {"train_loss": loss_val}
60 | return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict}
61 |
62 | def validation_step(self, batch, batch_idx):
63 | img, mask = batch
64 | img = img.float()
65 | mask = mask.long()
66 | out = self(img)
67 | loss_val = F.cross_entropy(out, mask, ignore_index=250)
68 | return {"val_loss": loss_val}
69 |
70 | def validation_epoch_end(self, outputs):
71 | loss_val = torch.stack([x["val_loss"] for x in outputs]).mean()
72 | log_dict = {"val_loss": loss_val}
73 | return {"log": log_dict, "val_loss": log_dict["val_loss"], "progress_bar": log_dict}
74 |
75 | def configure_optimizers(self):
76 | opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
77 | sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
78 | return [opt], [sch]
79 |
80 | @staticmethod
81 | def add_model_specific_args(parent_parser):
82 | parser = ArgumentParser(parents=[parent_parser], add_help=False)
83 | parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate")
84 | parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
85 | parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
86 | parser.add_argument(
87 | "--bilinear", action="store_true", default=False, help="whether to use bilinear interpolation or transposed"
88 | )
89 |
90 | return parser
91 |
92 |
93 | def cli_main():
94 | from pl_bolts.datamodules import KittiDataModule
95 |
96 | seed_everything(1234)
97 |
98 | parser = ArgumentParser()
99 | # trainer args
100 | parser = Trainer.add_argparse_args(parser)
101 | # model args
102 | parser = SemSegment.add_model_specific_args(parser)
103 | # datamodule args
104 | parser = KittiDataModule.add_argparse_args(parser)
105 |
106 | args = parser.parse_args()
107 |
108 | # data
109 | dm = KittiDataModule(args.data_dir).from_argparse_args(args)
110 |
111 | # model
112 | model = SemSegment(**args.__dict__)
113 |
114 | # train
115 | trainer = Trainer().from_argparse_args(args)
116 | trainer.fit(model, datamodule=dm)
117 |
118 |
119 | if __name__ == "__main__":
120 | cli_main()
121 |
--------------------------------------------------------------------------------
/Pre-Training/Masked_Autoencoder/utils/arg_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) ByteDance, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import json
8 | import os
9 | import sys
10 |
11 | from tap import Tap
12 |
13 | import dist
14 |
15 |
16 | class Args(Tap):
17 | # environment
18 | exp_name: str = 'your_exp_name'
19 | exp_dir: str = 'your_exp_dir' # will be created if not exists
20 | data_path: str = 'imagenet_data_path'
21 | resume_from: str = '' # resume from some checkpoint.pth
22 |
23 | # SparK hyperparameters
24 | mask: float = 0.6 # mask ratio, should be in (0, 1)
25 |
26 | # encoder hyperparameters
27 | model: str = 'resnet50'
28 | input_size: int = 224
29 | sbn: bool = True
30 |
31 | # data hyperparameters
32 | bs: int = 4096
33 | dataloader_workers: int = 8
34 |
35 | # pre-training hyperparameters
36 | dp: float = 0.0
37 | base_lr: float = 2e-4
38 | wd: float = 0.04
39 | wde: float = 0.2
40 | ep: int = 1600
41 | wp_ep: int = 40
42 | clip: int = 5.
43 | opt: str = 'lamb'
44 | ada: float = 0.
45 |
46 | # NO NEED TO SPECIFIED; each of these args would be updated in runtime automatically
47 | lr: float = None
48 | batch_size_per_gpu: int = 0
49 | glb_batch_size: int = 0
50 | densify_norm: str = ''
51 | device: str = 'cpu'
52 | local_rank: int = 0
53 | cmd: str = ' '.join(sys.argv[1:])
54 | commit_id: str = os.popen(f'git rev-parse HEAD').read().strip() or '[unknown]'
55 | commit_msg: str = (os.popen(f'git log -1').read().strip().splitlines() or ['[unknown]'])[-1].strip()
56 | last_loss: float = 0.
57 | cur_ep: str = ''
58 | remain_time: str = ''
59 | finish_time: str = ''
60 | first_logging: bool = True
61 | log_txt_name: str = '{args.exp_dir}/pretrain_log.txt'
62 | tb_lg_dir: str = '' # tensorboard log directory
63 |
64 | @property
65 | def is_convnext(self):
66 | return 'convnext' in self.model or 'cnx' in self.model
67 |
68 | @property
69 | def is_resnet(self):
70 | return 'resnet' in self.model
71 |
72 | def log_epoch(self):
73 | if not dist.is_local_master():
74 | return
75 |
76 | if self.first_logging:
77 | self.first_logging = False
78 | with open(self.log_txt_name, 'w') as fp:
79 | json.dump({
80 | 'name': self.exp_name, 'cmd': self.cmd, 'git_commit_id': self.commit_id, 'git_commit_msg': self.commit_msg,
81 | 'model': self.model,
82 | }, fp)
83 | fp.write('\n\n')
84 |
85 | with open(self.log_txt_name, 'a') as fp:
86 | json.dump({
87 | 'cur_ep': self.cur_ep,
88 | 'last_L': self.last_loss,
89 | 'rema': self.remain_time, 'fini': self.finish_time,
90 | }, fp)
91 | fp.write('\n')
92 |
93 |
94 | def init_dist_and_get_args():
95 | from utils import misc
96 |
97 | # initialize
98 | args = Args(explicit_bool=True).parse_args()
99 | e = os.path.abspath(args.exp_dir)
100 | d, e = os.path.dirname(e), os.path.basename(e)
101 | e = ''.join(ch if (ch.isalnum() or ch == '-') else '_' for ch in e)
102 | args.exp_dir = os.path.join(d, e)
103 |
104 | os.makedirs(args.exp_dir, exist_ok=True)
105 | args.log_txt_name = os.path.join(args.exp_dir, 'pretrain_log.txt')
106 | args.tb_lg_dir = args.tb_lg_dir or os.path.join(args.exp_dir, 'tensorboard_log')
107 | try:
108 | os.makedirs(args.tb_lg_dir, exist_ok=True)
109 | except:
110 | pass
111 |
112 | misc.init_distributed_environ(exp_dir=args.exp_dir)
113 |
114 | # update args
115 | if not dist.initialized():
116 | args.sbn = False
117 | args.first_logging = True
118 | args.device = dist.get_device()
119 | args.batch_size_per_gpu = args.bs // dist.get_world_size()
120 | args.glb_batch_size = args.batch_size_per_gpu * dist.get_world_size()
121 |
122 | if args.is_resnet:
123 | args.ada = args.ada or 0.95
124 | args.densify_norm = 'bn'
125 |
126 | if args.is_convnext:
127 | args.ada = args.ada or 0.999
128 | args.densify_norm = 'ln'
129 |
130 | args.opt = args.opt.lower()
131 | args.lr = args.base_lr * args.glb_batch_size / 256
132 | args.wde = args.wde or args.wd
133 |
134 | return args
135 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/utils/semi_supervised.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Sequence, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | from torch import Tensor
7 |
8 | from pl_bolts.utils import _SKLEARN_AVAILABLE
9 | from pl_bolts.utils.warnings import warn_missing_pkg
10 |
11 | if _SKLEARN_AVAILABLE:
12 | from sklearn.utils import shuffle as sk_shuffle
13 | else: # pragma: no cover
14 | warn_missing_pkg("sklearn", pypi_name="scikit-learn")
15 |
16 |
17 | class Identity(torch.nn.Module):
18 | """An identity class to replace arbitrary layers in pretrained models.
19 |
20 | Example::
21 |
22 | from pl_bolts.utils import Identity
23 |
24 | model = resnet18()
25 | model.fc = Identity()
26 | """
27 |
28 | def __init__(self) -> None:
29 | super().__init__()
30 |
31 | def forward(self, x: Tensor) -> Tensor:
32 | return x
33 |
34 |
35 | def balance_classes(
36 | X: Union[Tensor, np.ndarray], Y: Union[Tensor, np.ndarray, Sequence[int]], batch_size: int
37 | ) -> Tuple[np.ndarray, np.ndarray]:
38 | """Makes sure each batch has an equal amount of data from each class. Perfect balance.
39 |
40 | Args:
41 | X: input features
42 | Y: mixed labels (ints)
43 | batch_size: the ultimate batch size
44 | """
45 | if not _SKLEARN_AVAILABLE: # pragma: no cover
46 | raise ModuleNotFoundError("You want to use `shuffle` function from `scikit-learn` which is not installed yet.")
47 |
48 | nb_classes = len(set(Y))
49 |
50 | nb_batches = math.ceil(len(Y) / batch_size)
51 |
52 | # sort by classes
53 | final_batches_x: List[list] = [[] for i in range(nb_batches)]
54 | final_batches_y: List[list] = [[] for i in range(nb_batches)]
55 |
56 | # Y needs to be np arr
57 | Y = np.asarray(Y)
58 |
59 | # pick chunk size for each class using the largest split
60 | chunk_sizes = []
61 | for class_i in range(nb_classes):
62 | mask = Y == class_i
63 | y = Y[mask]
64 | chunk_sizes.append(math.ceil(len(y) / nb_batches))
65 | chunk_size = max(chunk_sizes)
66 | # force chunk size to be even
67 | if chunk_size % 2 != 0:
68 | chunk_size -= 1
69 |
70 | # divide each class into each batch
71 | for class_i in range(nb_classes):
72 | mask = Y == class_i
73 | x = X[mask]
74 | y = Y[mask]
75 |
76 | # shuffle items in the class
77 | x, y = sk_shuffle(x, y, random_state=123)
78 |
79 | # divide the class into the batches
80 | for i_start in range(0, len(y), chunk_size):
81 | batch_i = i_start // chunk_size
82 | i_end = i_start + chunk_size
83 |
84 | if len(final_batches_x) > batch_i:
85 | final_batches_x[batch_i].append(x[i_start:i_end])
86 | final_batches_y[batch_i].append(y[i_start:i_end])
87 |
88 | # merge into full dataset
89 | final_batches_x = [np.concatenate(x, axis=0) for x in final_batches_x if len(x) > 0]
90 | final_batches_x = np.concatenate(final_batches_x, axis=0)
91 |
92 | final_batches_y = [np.concatenate(x, axis=0) for x in final_batches_y if len(x) > 0]
93 | final_batches_y = np.concatenate(final_batches_y, axis=0)
94 |
95 | return final_batches_x, final_batches_y
96 |
97 |
98 | def generate_half_labeled_batches(
99 | smaller_set_X: np.ndarray,
100 | smaller_set_Y: np.ndarray,
101 | larger_set_X: np.ndarray,
102 | larger_set_Y: np.ndarray,
103 | batch_size: int,
104 | ) -> Tuple[np.ndarray, np.ndarray]:
105 | """Given a labeled dataset and an unlabeled dataset, this function generates a joint pair where half the
106 | batches are labeled and the other half is not."""
107 | X = []
108 | Y = []
109 | half_batch = batch_size // 2
110 |
111 | n_larger = len(larger_set_X)
112 | n_smaller = len(smaller_set_X)
113 | for i_start in range(0, n_larger, half_batch):
114 | i_end = i_start + half_batch
115 |
116 | X_larger = larger_set_X[i_start:i_end]
117 | Y_larger = larger_set_Y[i_start:i_end]
118 |
119 | # pull out labeled part
120 | smaller_start = i_start % (n_smaller - half_batch)
121 | smaller_end = smaller_start + half_batch
122 |
123 | X_small = smaller_set_X[smaller_start:smaller_end]
124 | Y_small = smaller_set_Y[smaller_start:smaller_end]
125 |
126 | X.extend([X_larger, X_small])
127 | Y.extend([Y_larger, Y_small])
128 |
129 | # aggregate reshuffled at end of shuffling
130 | X = np.concatenate(X, axis=0)
131 | Y = np.concatenate(Y, axis=0)
132 |
133 | return X, Y
134 |
--------------------------------------------------------------------------------
/Pre-Training/Contrastive_Learning/pl_bolts/losses/rl.py:
--------------------------------------------------------------------------------
1 | """Loss functions for the RL models."""
2 |
3 | from typing import List, Tuple
4 |
5 | import numpy as np
6 | import torch
7 | from torch import Tensor, nn
8 |
9 |
10 | def dqn_loss(batch: Tuple[Tensor, Tensor], net: nn.Module, target_net: nn.Module, gamma: float = 0.99) -> Tensor:
11 | """Calculates the mse loss using a mini batch from the replay buffer.
12 |
13 | Args:
14 | batch: current mini batch of replay data
15 | net: main training network
16 | target_net: target network of the main training network
17 | gamma: discount factor
18 |
19 | Returns:
20 | loss
21 | """
22 | states, actions, rewards, dones, next_states = batch
23 |
24 | actions = actions.long().squeeze(-1)
25 |
26 | state_action_values = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
27 |
28 | with torch.no_grad():
29 | next_state_values = target_net(next_states).max(1)[0]
30 | next_state_values[dones] = 0.0
31 | next_state_values = next_state_values.detach()
32 |
33 | expected_state_action_values = next_state_values * gamma + rewards
34 |
35 | return nn.MSELoss()(state_action_values, expected_state_action_values)
36 |
37 |
38 | def double_dqn_loss(
39 | batch: Tuple[Tensor, Tensor],
40 | net: nn.Module,
41 | target_net: nn.Module,
42 | gamma: float = 0.99,
43 | ) -> Tensor:
44 | """Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original
45 | DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the value
46 | from the target network. This code is heavily commented in order to explain the process clearly.
47 |
48 | Args:
49 | batch: current mini batch of replay data
50 | net: main training network
51 | target_net: target network of the main training network
52 | gamma: discount factor
53 |
54 | Returns:
55 | loss
56 | """
57 | states, actions, rewards, dones, next_states = batch # batch of experiences, batch_size = 16
58 |
59 | actions = actions.long().squeeze(-1)
60 |
61 | state_action_values = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)
62 |
63 | # dont want to mess with gradients when using the target network
64 | with torch.no_grad():
65 | next_outputs = net(next_states) # [16, 2], [batch, action_space]
66 |
67 | next_state_acts = next_outputs.max(1)[1].unsqueeze(-1) # take action at the index with the highest value
68 | next_tgt_out = target_net(next_states)
69 |
70 | # Take the value of the action chosen by the train network
71 | next_state_values = next_tgt_out.gather(1, next_state_acts).squeeze(-1)
72 | next_state_values[dones] = 0.0 # any steps flagged as done get a 0 value
73 | next_state_values = next_state_values.detach() # remove values from the graph, no grads needed
74 |
75 | # calc expected discounted return of next_state_values
76 | expected_state_action_values = next_state_values * gamma + rewards
77 |
78 | # Standard MSE loss between the state action values of the current state and the
79 | # expected state action values of the next state
80 | return nn.MSELoss()(state_action_values, expected_state_action_values)
81 |
82 |
83 | def per_dqn_loss(
84 | batch: Tuple[Tensor, Tensor],
85 | batch_weights: List,
86 | net: nn.Module,
87 | target_net: nn.Module,
88 | gamma: float = 0.99,
89 | ) -> Tuple[Tensor, np.ndarray]:
90 | """Calculates the mse loss with the priority weights of the batch from the PER buffer.
91 |
92 | Args:
93 | batch: current mini batch of replay data
94 | batch_weights: how each of these samples are weighted in terms of priority
95 | net: main training network
96 | target_net: target network of the main training network
97 | gamma: discount factor
98 |
99 | Returns:
100 | loss and batch_weights
101 | """
102 | states, actions, rewards, dones, next_states = batch
103 |
104 | actions = actions.long()
105 |
106 | batch_weights = torch.tensor(batch_weights)
107 |
108 | actions_v = actions.unsqueeze(-1)
109 | outputs = net(states)
110 | state_action_vals = outputs.gather(1, actions_v)
111 | state_action_vals = state_action_vals.squeeze(-1)
112 |
113 | with torch.no_grad():
114 | next_s_vals = target_net(next_states).max(1)[0]
115 | next_s_vals[dones] = 0.0
116 | exp_sa_vals = next_s_vals.detach() * gamma + rewards
117 | loss = (state_action_vals - exp_sa_vals) ** 2
118 | losses_v = batch_weights * loss
119 | return losses_v.mean(), (losses_v + 1e-5).data.cpu().numpy()
120 |
--------------------------------------------------------------------------------