├── utils
├── transforms
│ ├── __init__.py
│ └── transforms.py
├── dataset
│ ├── types.py
│ ├── cifar100_datamodule.py
│ ├── miniimagenet_datamodule.py
│ ├── datasets.py
│ ├── datamodule.py
│ ├── domainnet_real_datamodule.py
│ ├── svhn_datamodule.py
│ ├── miniimagenet.py
│ ├── cifar10_datamodule.py
│ ├── __init__.py
│ ├── domainnet_real.py
│ ├── ssl_datamodule.py
│ └── utils.py
├── types.py
├── loggers
│ ├── __init__.py
│ ├── print_logger.py
│ ├── log_aggregator.py
│ ├── wandb_logger.py
│ └── logger.py
├── timing.py
├── file_io.py
├── utils.py
├── __init__.py
└── metrics.py
├── media
└── pairloss_anim.png
├── requirements.txt
├── loss
├── pair_loss
│ ├── __init__.py
│ ├── utils.py
│ └── pair_loss.py
├── utils
│ ├── __init__.py
│ └── utils.py
├── types.py
├── __init__.py
├── loss.py
└── visualization.py
├── models
├── mixmatch
│ ├── types.py
│ ├── __init__.py
│ ├── mixmatch.py
│ ├── utils.py
│ ├── simple_mixmatch.py
│ └── mixmatch_base.py
├── types.py
├── models
│ ├── utils.py
│ ├── __init__.py
│ ├── ema.py
│ ├── wide_resnet.py
│ └── resnet.py
├── optimization
│ ├── types.py
│ ├── __init__.py
│ └── lr_scheduler.py
├── rampup.py
├── augmentation
│ ├── augmenter.py
│ ├── __init__.py
│ └── randaugment.py
├── utils.py
└── __init__.py
├── runs
├── miniimagenet_args.txt
├── cifar10_args.txt
└── cifar100_args.txt
├── .gitignore
├── environment.yaml
├── .gitattributes
├── example_logger_config.yaml
├── main_ddp.py
├── ablation_estimator.py
├── main.py
├── README.md
├── LICENSE
└── checkpoint_saver.py
/utils/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .transforms import *
2 |
--------------------------------------------------------------------------------
/media/pairloss_anim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zijian-hu/SimPLE/HEAD/media/pairloss_anim.png
--------------------------------------------------------------------------------
/utils/dataset/types.py:
--------------------------------------------------------------------------------
1 | from .utils import BatchType, LoaderType, BatchGeneratorType
2 |
3 | __all__ = [
4 | # types
5 | "BatchType",
6 | "LoaderType",
7 | "BatchGeneratorType",
8 | ]
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | future
2 | h5py
3 | matplotlib
4 | numpy
5 | pillow
6 | plotly>=4.0.0
7 | pyyaml
8 | pandas
9 | scikit-learn
10 | scipy
11 | torch>=1.6.0,<=1.9.0
12 | torchvision>=0.7.0,<=0.10.0
13 | kornia==0.5.0
14 | wandb
15 | tqdm
16 |
--------------------------------------------------------------------------------
/loss/pair_loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .pair_loss import PairLoss
2 |
3 | # modules
4 | from . import utils
5 |
6 |
7 | __all__ = [
8 | # modules
9 | "utils",
10 |
11 | # classes
12 | "PairLoss",
13 |
14 | # functions
15 | ]
16 |
--------------------------------------------------------------------------------
/utils/types.py:
--------------------------------------------------------------------------------
1 | from .dataset.types import BatchType, LoaderType, BatchGeneratorType
2 | from .metrics import MetricDictType
3 |
4 | __all__ = [
5 | # types
6 | "BatchType",
7 | "LoaderType",
8 | "BatchGeneratorType",
9 | "MetricDictType",
10 | ]
11 |
--------------------------------------------------------------------------------
/models/mixmatch/types.py:
--------------------------------------------------------------------------------
1 | from .mixmatch import MixMatch
2 | from .simple_mixmatch import SimPLE
3 | from .mixmatch_base import MixMatchBase as MixMatchEnhanced
4 |
5 | from typing import Union
6 |
7 | MixMatchFunctionType = Union[MixMatch, MixMatchEnhanced, SimPLE]
8 |
9 | __all__ = [
10 | "MixMatchFunctionType",
11 | ]
12 |
--------------------------------------------------------------------------------
/utils/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import Logger
2 | from .wandb_logger import WandbLogger
3 | from .print_logger import PrintLogger
4 |
5 | from .log_aggregator import LogAggregator
6 |
7 | __all__ = [
8 | # classes
9 | "Logger",
10 | "WandbLogger",
11 | "PrintLogger",
12 | "LogAggregator",
13 | ]
14 |
--------------------------------------------------------------------------------
/models/types.py:
--------------------------------------------------------------------------------
1 | from .optimization.types import LRSchedulerType, ParametersType, ParametersGroupType, OptimizerParametersType
2 | from .mixmatch.types import MixMatchFunctionType
3 |
4 | __all__ = [
5 | "LRSchedulerType",
6 | "ParametersType",
7 | "ParametersGroupType",
8 | "OptimizerParametersType",
9 | "MixMatchFunctionType",
10 | ]
11 |
--------------------------------------------------------------------------------
/runs/miniimagenet_args.txt:
--------------------------------------------------------------------------------
1 | -e
2 | simple
3 | --num-epochs
4 | 2048
5 | --dataset
6 | miniimagenet
7 | --labeled-train-size
8 | 4000
9 | --validation-size
10 | 7200
11 | --batch-size
12 | 16
13 | --K-strong
14 | 7
15 | --lambda-u
16 | 300
17 | --lambda-pair
18 | 300
19 | --conf-threshold
20 | 0.95
21 | --sim-threshold
22 | 0.9
23 | --ema-type
24 | full
25 | --max-grad-norm
26 | 5
--------------------------------------------------------------------------------
/runs/cifar10_args.txt:
--------------------------------------------------------------------------------
1 | -e
2 | simple
3 | --num-epochs
4 | 1024
5 | --dataset
6 | cifar10
7 | --K-strong
8 | 7
9 | --labeled-train-size
10 | 4000
11 | --lr
12 | 0.03
13 | --weight-decay
14 | 5e-4
15 | --lambda-u
16 | 75
17 | --lambda-pair
18 | 75
19 | --conf-threshold
20 | 0.95
21 | --sim-threshold
22 | 0.9
23 | --max-grad-norm
24 | 5
25 | --optimizer-type
26 | sgd
27 | --lr-scheduler-type
28 | cosine_decay
--------------------------------------------------------------------------------
/runs/cifar100_args.txt:
--------------------------------------------------------------------------------
1 | -e
2 | simple
3 | --dataset
4 | cifar100
5 | --K-strong
6 | 4
7 | --labeled-train-size
8 | 10000
9 | --lr
10 | 0.03
11 | --weight-decay
12 | 0.001
13 | --lambda-u
14 | 150
15 | --lambda-pair
16 | 150
17 | --conf-threshold
18 | 0.95
19 | --sim-threshold
20 | 0.9
21 | --max-grad-norm
22 | 5
23 | --model-type
24 | wrn28-8
25 | --optimizer-type
26 | sgd
27 | --lr-scheduler-type
28 | cosine_decay
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # CMake
2 | cmake-build-*/
3 |
4 | # JetBrains IDE
5 | .idea/
6 |
7 | # VS Code
8 | .vscode/
9 |
10 | # Mac
11 | *.DS_Store
12 |
13 | # Node.js
14 | node_modules/
15 | package-lock.json
16 |
17 | # Python
18 | __pycache__/
19 | *.pyc
20 |
21 | # Jupyter notebook
22 | .ipynb_checkpoints
23 |
24 | # pyenv files
25 | .python-version
26 |
27 | # data files
28 | data/
29 | logs/
30 | wandb/
31 | lightning_logs/
32 |
--------------------------------------------------------------------------------
/models/mixmatch/__init__.py:
--------------------------------------------------------------------------------
1 | from .mixmatch import MixMatch
2 | from .mixmatch_base import MixMatchBase as MixMatchEnhanced
3 | from .simple_mixmatch import SimPLE
4 |
5 | # modules
6 | from . import utils
7 | from . import types
8 |
9 | __all__ = [
10 | # modules
11 | "utils",
12 | "types",
13 |
14 | # classes
15 | "MixMatch",
16 | "MixMatchEnhanced",
17 | "SimPLE",
18 |
19 | # functions
20 | ]
21 |
--------------------------------------------------------------------------------
/loss/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import (reduce_tensor, to_tensor, bha_coeff_log_prob, bha_coeff, bha_coeff_distance,
2 | l2_distance, pairwise_apply)
3 |
4 | __all__ = [
5 | # modules
6 | # classes
7 | # functions
8 | "reduce_tensor",
9 | "to_tensor",
10 | "bha_coeff_log_prob",
11 | "bha_coeff",
12 | "bha_coeff_distance",
13 | "l2_distance",
14 | "pairwise_apply",
15 | ]
16 |
--------------------------------------------------------------------------------
/models/models/utils.py:
--------------------------------------------------------------------------------
1 | # for type hint
2 | from typing import Union
3 |
4 | from torch.nn import Module, DataParallel
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 | ModelType = Union[Module, DataParallel, DistributedDataParallel]
8 |
9 |
10 | def unwrap_model(model: ModelType) -> Module:
11 | if hasattr(model, "module"):
12 | return model.module
13 | else:
14 | return model
15 |
16 |
17 | __all__ = [
18 | "unwrap_model",
19 | ]
20 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: simple
2 | channels:
3 | - defaults
4 | - pytorch
5 | - conda-forge
6 | dependencies:
7 | - python=3.8
8 | - pip
9 | - pillow
10 | - tqdm
11 | - pyyaml
12 | - matplotlib
13 | - future
14 | - h5py
15 | - numpy
16 | - matplotlib
17 | - plotly
18 | - pandas
19 | - scipy
20 | - scikit-learn
21 | - cudatoolkit=10.2
22 | - pytorch::torchvision
23 | - pytorch::pytorch=1.6.0
24 | - pip:
25 | - kornia==0.5.0
26 | - wandb
27 |
--------------------------------------------------------------------------------
/models/optimization/types.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Iterable, Dict
2 | from torch.optim.lr_scheduler import LambdaLR, StepLR
3 | from torch.nn import Parameter
4 |
5 | LRSchedulerType = Union[LambdaLR, StepLR]
6 | ParametersType = Iterable[Parameter]
7 | ParametersGroupType = Iterable[Dict[str, Union[Parameter, float, int]]]
8 | OptimizerParametersType = Union[ParametersType, ParametersGroupType]
9 |
10 | __all__ = [
11 | "LRSchedulerType",
12 | "ParametersType",
13 | "ParametersGroupType",
14 | "OptimizerParametersType",
15 | ]
16 |
--------------------------------------------------------------------------------
/utils/timing.py:
--------------------------------------------------------------------------------
1 | from timeit import default_timer as timer
2 | from functools import wraps
3 | from datetime import timedelta
4 |
5 |
6 | def timing(func):
7 | # see https://stackoverflow.com/a/27737385/5838091
8 | @wraps(func)
9 | def wrap(*args, **kwargs):
10 | start_time = timer()
11 |
12 | result = func(*args, **kwargs)
13 |
14 | time_elapsed = timer() - start_time
15 | print(f"Total time for {func.__name__}: {str(timedelta(seconds=time_elapsed))}", flush=True)
16 |
17 | return result
18 |
19 | return wrap
20 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # This is an example git attribute file
2 |
3 | # Set the default behavior, in case people don't have core.autocrlf set.
4 | * text=auto
5 |
6 | # Explicitly declare text files you want to always be normalized and converted
7 | # to native line endings on checkout.
8 | *.c text
9 | *.h text
10 | *.py text
11 |
12 | # Declare files that will always have CRLF line endings on checkout.
13 | *.sln text eol=crlf
14 |
15 | # Declare files that will always have LF line endings on checkout.
16 | *.sh text eol=lf
17 |
18 | # Denote all files that are truly binary and should not be modified.
19 | *.png binary
20 | *.jpg binary
21 |
--------------------------------------------------------------------------------
/example_logger_config.yaml:
--------------------------------------------------------------------------------
1 | # A short display name for this run, which is how you'll identify this run in the UI
2 | name: null
3 |
4 | # An entity is a username or team name where you're sending runs
5 | entity: null
6 |
7 | # The name of the project where you're sending the new run
8 | project: null
9 |
10 | # A list of strings, which will populate the list of tags on this run in the UI
11 | tags: [ ]
12 |
13 | # A longer description of the run, like a -m commit message in git
14 | notes: null
15 |
16 | # if set to True, the run auto resumes; can also be a unique string for manual resuming
17 | resume: False
18 |
19 | # Can be "online", "offline" or "disabled". Defaults to online
20 | mode: online
21 |
--------------------------------------------------------------------------------
/loss/types.py:
--------------------------------------------------------------------------------
1 | # for type hint
2 | from typing import Union, Dict
3 | from torch import Tensor
4 | from plotly.graph_objects import Figure
5 | from wandb import Histogram
6 |
7 | from .utils import (bha_coeff, bha_coeff_distance, l2_distance)
8 | from .loss import softmax_cross_entropy_loss, bha_coeff_loss, l2_dist_loss
9 |
10 | LogDictType = Dict[str, Tensor]
11 | PlotDictType = Dict[str, Union[Figure, Histogram]]
12 | LossInfoType = Union[Dict[str, Union[LogDictType, PlotDictType]], LogDictType]
13 |
14 | SimilarityType = Union[bha_coeff]
15 | DistanceType = Union[bha_coeff_distance, l2_distance]
16 | DistanceLossType = Union[softmax_cross_entropy_loss, l2_dist_loss, bha_coeff_loss]
17 |
18 | __all__ = [
19 | "LogDictType",
20 | "PlotDictType",
21 | "LossInfoType",
22 | "SimilarityType",
23 | "DistanceType",
24 | "DistanceLossType",
25 | ]
26 |
--------------------------------------------------------------------------------
/loss/pair_loss/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # for type hint
4 | from torch import Tensor
5 |
6 |
7 | def get_pair_indices(inputs: Tensor, ordered_pair: bool = False) -> Tensor:
8 | """
9 | Get pair indices between each element in input tensor
10 |
11 | Args:
12 | inputs: input tensor
13 | ordered_pair: if True, will return ordered pairs. (e.g. both inputs[i,j] and inputs[j,i] are included)
14 |
15 | Returns: a tensor of shape (K, 2) where K = choose(len(inputs),2) if ordered_pair is False.
16 | Else K = 2 * choose(len(inputs),2). Each row corresponds to two indices in inputs.
17 |
18 | """
19 | indices = torch.combinations(torch.tensor(range(len(inputs))), r=2)
20 |
21 | if ordered_pair:
22 | # make pairs ordered (e.g. both (0,1) and (1,0) are included)
23 | indices = torch.cat((indices, indices[:, [1, 0]]), dim=0)
24 |
25 | return indices
26 |
--------------------------------------------------------------------------------
/models/mixmatch/mixmatch.py:
--------------------------------------------------------------------------------
1 | from .mixmatch_base import MixMatchBase
2 |
3 | # for type hint
4 | from torch.nn import Module
5 |
6 |
7 | class MixMatch(MixMatchBase):
8 | def __init__(self,
9 | augmenter: Module,
10 | num_classes: int,
11 | temperature: float,
12 | num_augmentations: int,
13 | alpha: float,
14 | train_label_guessing: bool):
15 | super(MixMatch, self).__init__(augmenter=augmenter,
16 | strong_augmenter=None,
17 | num_classes=num_classes,
18 | temperature=temperature,
19 | num_augmentations=num_augmentations,
20 | num_strong_augmentations=0,
21 | alpha=alpha,
22 | is_strong_augment_x=False,
23 | train_label_guessing=train_label_guessing)
24 |
--------------------------------------------------------------------------------
/utils/dataset/cifar100_datamodule.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR100
2 |
3 | from .cifar10_datamodule import CIFAR10DataModule
4 |
5 | # for type hint
6 | from typing import Optional
7 |
8 |
9 | class CIFAR100DataModule(CIFAR10DataModule):
10 | num_classes: int = 100
11 |
12 | total_train_size: int = 50_000
13 | total_test_size: int = 10_000
14 |
15 | DATASET = CIFAR100
16 |
17 | def __init__(self,
18 | data_dir: str,
19 | labeled_train_size: int,
20 | validation_size: int,
21 | unlabeled_train_size: Optional[int] = None,
22 | **kwargs):
23 | super(CIFAR100DataModule, self).__init__(
24 | data_dir=data_dir,
25 | labeled_train_size=labeled_train_size,
26 | validation_size=validation_size,
27 | unlabeled_train_size=unlabeled_train_size,
28 | **kwargs)
29 |
30 | # dataset stats
31 | # CIFAR-100 mean, std values in CHW
32 | self.dataset_mean = [0.44091784, 0.50707516, 0.48654887]
33 | self.dataset_std = [0.27615047, 0.26733429, 0.25643846]
34 |
--------------------------------------------------------------------------------
/main_ddp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import distributed
3 |
4 | import warnings
5 |
6 | from main import main as main_single_thread
7 | from utils import get_args, timing, set_random_seed
8 |
9 | # for type hint
10 | from typing import Optional
11 | from argparse import Namespace
12 | from utils.dataset import SSLDataModule
13 |
14 | IS_DISTRIBUTED_AVAILABLE = distributed.is_available()
15 |
16 |
17 | @timing
18 | def main(args: Namespace, datamodule: Optional[SSLDataModule] = None):
19 | if IS_DISTRIBUTED_AVAILABLE and torch.cuda.is_available() and torch.cuda.device_count() > 1:
20 | distributed.init_process_group(backend='nccl')
21 |
22 | torch.cuda.set_device(args.local_rank)
23 | device = torch.device("cuda", args.local_rank)
24 | else:
25 | warnings.warn("Cannot initializePyTorch distributed training, fallback to single GPU training")
26 | device = None
27 |
28 | main_single_thread(args=args, datamodule=datamodule, device=device)
29 |
30 |
31 | if __name__ == '__main__':
32 | parsed_args = get_args()
33 |
34 | # fix random seed
35 | set_random_seed(parsed_args.seed, is_cudnn_deterministic=parsed_args.debug_mode)
36 |
37 | main(parsed_args)
38 |
--------------------------------------------------------------------------------
/models/mixmatch/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.nn import functional as F
4 |
5 | from functools import reduce
6 |
7 | from ..utils import set_model_mode
8 |
9 | # for type hint
10 | from typing import Sequence, Tuple
11 | from torch import Tensor
12 | from torch.nn import Module
13 |
14 |
15 | def label_guessing(batches: Sequence[Tensor], model: Module, is_train_mode: bool = True) -> Tensor:
16 | with set_model_mode(model, is_train_mode):
17 | with torch.no_grad():
18 | probs = [F.softmax(model(batch), dim=1) for batch in batches]
19 | mean_prob = reduce(lambda x, y: x + y, probs) / len(batches)
20 |
21 | return mean_prob
22 |
23 |
24 | def sharpen(x: Tensor, temperature: float) -> Tensor:
25 | sharpened_x = x ** (1 / temperature)
26 | return sharpened_x / sharpened_x.sum(dim=1, keepdim=True)
27 |
28 |
29 | def mixup(x1: Tensor, x2: Tensor, y1: Tensor, y2: Tensor, alpha: float) -> Tuple[Tensor, Tensor]:
30 | # lambda is a reserved word in python, substituting by beta
31 | lam = np.random.beta(alpha, alpha)
32 | lam = max(lam, 1 - lam)
33 | x = lam * x1 + (1 - lam) * x2
34 | y = lam * y1 + (1 - lam) * y2
35 | return x, y
36 |
--------------------------------------------------------------------------------
/models/optimization/__init__.py:
--------------------------------------------------------------------------------
1 | from torch.optim import SGD, AdamW
2 |
3 | from .lr_scheduler import build_lr_scheduler
4 |
5 | from . import lr_scheduler
6 | from . import types
7 |
8 | # for type hint
9 | from torch.optim.optimizer import Optimizer
10 |
11 | from .types import OptimizerParametersType
12 |
13 |
14 | def build_optimizer(optimizer_type: str,
15 | params: OptimizerParametersType,
16 | learning_rate: float,
17 | weight_decay: float,
18 | momentum: float) -> Optimizer:
19 | if optimizer_type == "sgd":
20 | optimizer = SGD(params,
21 | lr=learning_rate,
22 | weight_decay=weight_decay,
23 | momentum=momentum,
24 | nesterov=True)
25 |
26 | elif optimizer_type == "adamw":
27 | optimizer = AdamW(params, lr=learning_rate, weight_decay=weight_decay)
28 |
29 | else:
30 | raise NotImplementedError(f"\"{optimizer_type}\" is not a supported optimizer type")
31 |
32 | return optimizer
33 |
34 |
35 | __all__ = [
36 | # classes
37 | # modules
38 | "lr_scheduler",
39 | "types",
40 |
41 | # functions
42 | "build_optimizer",
43 | "build_lr_scheduler",
44 | ]
45 |
--------------------------------------------------------------------------------
/utils/dataset/miniimagenet_datamodule.py:
--------------------------------------------------------------------------------
1 | from .miniimagenet import MiniImageNet
2 | from .cifar10_datamodule import CIFAR10DataModule
3 |
4 | # for type hint
5 | from typing import Optional, Tuple
6 |
7 |
8 | class MiniImageNetDataModule(CIFAR10DataModule):
9 | num_classes: int = 100
10 |
11 | total_train_size: int = 50_000
12 | total_test_size: int = 10_000
13 |
14 | DATASET = MiniImageNet
15 |
16 | def __init__(self,
17 | data_dir: str,
18 | labeled_train_size: int,
19 | validation_size: int,
20 | unlabeled_train_size: Optional[int] = None,
21 | dims: Optional[Tuple[int, ...]] = None,
22 | **kwargs):
23 | if dims is None:
24 | dims = (3, 84, 84)
25 |
26 | super(MiniImageNetDataModule, self).__init__(
27 | data_dir=data_dir,
28 | labeled_train_size=labeled_train_size,
29 | validation_size=validation_size,
30 | unlabeled_train_size=unlabeled_train_size,
31 | dims=dims,
32 | **kwargs)
33 |
34 | # dataset stats
35 | # Mini-ImageNet mean, std values in CHW
36 | self.dataset_mean = [0.40233998, 0.47269102, 0.44823737]
37 | self.dataset_std = [0.2884859, 0.28327602, 0.27511246]
38 |
--------------------------------------------------------------------------------
/utils/file_io.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import yaml
3 |
4 | from pathlib import Path
5 | import re
6 |
7 | # for type hint
8 | from typing import Union, Pattern, Optional, Dict, Any, List
9 |
10 |
11 | def find_checkpoint_path(checkpoint_dir: Union[str, Path], step_filter: Union[Pattern, str]) -> Optional[Path]:
12 | checkpoint_dir_path = Path(checkpoint_dir)
13 | output_file = None
14 | max_step_num = -np.inf
15 |
16 | for file_item in checkpoint_dir_path.iterdir():
17 | if not file_item.is_file():
18 | continue
19 |
20 | search_result = re.search(step_filter, file_item.name)
21 | if search_result is None:
22 | continue
23 |
24 | step_num = int(search_result.group(1))
25 | if step_num > max_step_num:
26 | max_step_num = step_num
27 | output_file = file_item
28 |
29 | return output_file
30 |
31 |
32 | def read_yaml(path: Union[str, Path]) -> Dict[str, Any]:
33 | with open(path, "r") as f:
34 | return yaml.safe_load(f)
35 |
36 |
37 | def find_all_files(checkpoint_dir: Union[str, Path], search_pattern: Union[Pattern, str]) -> List[Path]:
38 | checkpoint_dir_path = Path(checkpoint_dir)
39 |
40 | return [file_item for file_item in checkpoint_dir_path.iterdir()
41 | if file_item.is_file() and re.search(pattern=search_pattern, string=file_item.name) is not None]
42 |
--------------------------------------------------------------------------------
/utils/dataset/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, Subset
2 |
3 | # for type hint
4 | from torchvision.datasets import VisionDataset
5 | from typing import Union, Optional, Callable
6 |
7 |
8 | class LabeledDataset(VisionDataset):
9 | def __init__(self,
10 | dataset: Union[Dataset, Subset],
11 | root: str,
12 | min_size: int = 0,
13 | transforms: Optional[Callable] = None,
14 | transform: Optional[Callable] = None,
15 | target_transform: Optional[Callable] = None):
16 | super().__init__(root, transforms, transform, target_transform)
17 |
18 | self.dataset = dataset
19 |
20 | self.min_size = min_size
21 |
22 | @property
23 | def min_size(self) -> int:
24 | return self._min_size if len(self.dataset) > 0 else 0
25 |
26 | @min_size.setter
27 | def min_size(self, min_size: int) -> None:
28 | if min_size < 0:
29 | raise ValueError(f"only non-negative min_size is allowed")
30 |
31 | self._min_size = min_size
32 |
33 | def __getitem__(self, index: int):
34 | img, target = self.dataset[index % len(self.dataset)]
35 |
36 | if self.transform is not None:
37 | img = self.transform(img)
38 |
39 | if self.target_transform is not None:
40 | target = self.target_transform(target)
41 |
42 | return img, target
43 |
44 | def __len__(self):
45 | return max(len(self.dataset), self.min_size)
46 |
--------------------------------------------------------------------------------
/utils/dataset/datamodule.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | # for type hint
4 | from typing import Optional, Tuple, Callable, Union, List, Generator, Any
5 | from torch.utils.data import DataLoader
6 |
7 |
8 | class DataModule(ABC):
9 | def __init__(self,
10 | train_transform: Optional[Callable] = None,
11 | val_transform: Optional[Callable] = None,
12 | test_transform: Optional[Callable] = None,
13 | dims: Optional[Tuple[int, ...]] = None):
14 | self.train_transform = train_transform
15 | self.val_transform = val_transform
16 | self.test_transform = test_transform
17 | self.dims = dims
18 |
19 | @abstractmethod
20 | def prepare_data(self, *args, **kwargs):
21 | # download, split, etc...
22 | # only called on 1 GPU/TPU in distributed
23 | pass
24 |
25 | @abstractmethod
26 | def setup(self, stage: Optional[str] = None):
27 | # make assignments here (val/train/test split)
28 | # called on every process in DDP
29 | pass
30 |
31 | @abstractmethod
32 | def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
33 | pass
34 |
35 | @abstractmethod
36 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
37 | pass
38 |
39 | @abstractmethod
40 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
41 | pass
42 |
43 | def get_train_batch(self, *args, **kwargs) -> Generator[Any, Any, Any]:
44 | pass
45 |
--------------------------------------------------------------------------------
/utils/dataset/domainnet_real_datamodule.py:
--------------------------------------------------------------------------------
1 | from .cifar10_datamodule import CIFAR10DataModule
2 | from .domainnet_real import DomainNetReal
3 | from .utils import per_class_random_split
4 |
5 | # for type hint
6 | from typing import Optional, Tuple, List
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class DomainNetRealDataModule(CIFAR10DataModule):
11 | num_classes: int = 345
12 |
13 | total_train_size: int = 120_906
14 | total_test_size: int = 52_041
15 |
16 | DATASET = DomainNetReal
17 |
18 | def __init__(self,
19 | data_dir: str,
20 | labeled_train_size: int,
21 | validation_size: int,
22 | unlabeled_train_size: Optional[int] = None,
23 | dims: Optional[Tuple[int, ...]] = None,
24 | **kwargs):
25 | if dims is None:
26 | dims = (3, 224, 224)
27 |
28 | super(DomainNetRealDataModule, self).__init__(
29 | data_dir=data_dir,
30 | labeled_train_size=labeled_train_size,
31 | validation_size=validation_size,
32 | unlabeled_train_size=unlabeled_train_size,
33 | dims=dims,
34 | **kwargs)
35 |
36 | # dataset stats
37 | # DomainNet-Real mean, std values in CHW
38 | self.dataset_mean = [0.54873651, 0.60511086, 0.5840634]
39 | self.dataset_std = [0.33955591, 0.32637834, 0.31887854]
40 |
41 | def split_dataset(self, dataset: Dataset, **kwargs) -> List[Dataset]:
42 | split_kwargs = dict(lengths=[self.validation_size, self.labeled_train_size],
43 | num_classes=self.num_classes,
44 | uneven_split=True)
45 |
46 | # update split arguments
47 | split_kwargs.update(kwargs)
48 |
49 | return per_class_random_split(dataset, **split_kwargs)
50 |
--------------------------------------------------------------------------------
/ablation_estimator.py:
--------------------------------------------------------------------------------
1 | from torch.nn import functional as F
2 |
3 | from simple_estimator import SimPLEEstimator
4 |
5 | # for type hint
6 | from typing import Tuple, Dict
7 | from torch import Tensor
8 | from loss.types import LossInfoType
9 |
10 |
11 | class AblationEstimator(SimPLEEstimator):
12 | def training_step(self, batch: Tuple[Tuple[Tensor, Tensor], ...], batch_idx: int) -> Tuple[Tensor, LossInfoType]:
13 | outputs = self.preprocess_batch(batch, batch_idx)
14 |
15 | model_outputs = self.compute_train_logits(x_inputs=outputs["x_inputs"])
16 |
17 | outputs.update(model_outputs)
18 |
19 | # calculate loss
20 | return self.compute_train_loss(
21 | x_logits=outputs["x_logits"],
22 | x_targets=outputs["x_targets"]
23 | )
24 |
25 | def preprocess_batch(self, batch: Tuple[Tuple[Tensor, Tensor], ...], batch_idx: int) -> Dict[str, Tensor]:
26 | # unpack batch
27 | (x_inputs, x_targets), (_, _) = batch
28 |
29 | # load data to device
30 | x_inputs = x_inputs.to(self.device)
31 | x_targets = x_targets.to(self.device)
32 |
33 | # apply augmentations
34 | x_inputs = self.augmenter(x_inputs)
35 |
36 | return dict(
37 | x_inputs=x_inputs,
38 | x_targets=x_targets,
39 | )
40 |
41 | def compute_train_logits(self, x_inputs: Tensor, *args: Tensor) -> Dict[str, Tensor]:
42 | return dict(x_logits=self.model(x_inputs))
43 |
44 | def compute_train_loss(self,
45 | x_logits: Tensor,
46 | x_targets: Tensor,
47 | *args: Tensor,
48 | **kwargs) -> Tuple[Tensor, LossInfoType]:
49 | loss = F.cross_entropy(x_logits, x_targets, reduction="mean")
50 |
51 | log_info = {
52 | "loss": loss.detach().clone(),
53 | "loss_x": loss.detach().clone(),
54 | }
55 |
56 | return loss, {"log": log_info, "plot": {}}
57 |
--------------------------------------------------------------------------------
/utils/loggers/print_logger.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | from .logger import Logger
4 |
5 | # for type hint
6 | from typing import Any, Dict, Optional, Union
7 | from argparse import Namespace
8 |
9 | from torch.nn import Module
10 |
11 |
12 | class PrintLogger(Logger):
13 | def __init__(self,
14 | log_dir: str,
15 | config: Union[Namespace, Dict[str, Any]],
16 | is_display_plots: bool = False):
17 | super().__init__(
18 | log_dir=log_dir,
19 | config=config,
20 | log_info_key_map={
21 | "top1_acc": "mean_acc",
22 | "top5_acc": "mean_top5_acc",
23 | })
24 |
25 | self.is_display_plots = is_display_plots
26 |
27 | def log(self,
28 | log_info: Dict[str, Any],
29 | step: Optional[int] = None,
30 | file=sys.stdout,
31 | flush: bool = False,
32 | sep: str = "\n\t",
33 | prefix: Optional[str] = None,
34 | log_info_override: Optional[Dict[str, Any]] = None,
35 | **kwargs):
36 | # process log_info
37 | log_info = self.process_log_info(log_info, prefix=prefix, log_info_override=log_info_override)
38 |
39 | if len(log_info) == 0:
40 | return
41 |
42 | log_info, plot_info = self.separate_plot(log_info)
43 |
44 | if self.is_display_plots:
45 | # display plots
46 | for key, fig in plot_info.items():
47 | if hasattr(fig, "show"):
48 | fig.show()
49 |
50 | output_str = sep.join(f"{str(key)}: {str(val)}" for key, val in log_info.items())
51 | if step is not None:
52 | print(f"Step {step}:\n\t{output_str}", file=file, flush=flush)
53 | else:
54 | print(output_str, file=file, flush=flush)
55 |
56 | # invoke all log hook functions
57 | self.call_log_hooks(log_info)
58 |
59 | def watch(self, model: Module, **kwargs):
60 | # TODO: implement
61 | pass
62 |
63 | def save(self, output_path: str):
64 | # TODO: implement
65 | pass
66 |
--------------------------------------------------------------------------------
/utils/dataset/svhn_datamodule.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import SVHN
2 |
3 | from .cifar10_datamodule import CIFAR10DataModule
4 | from .utils import per_class_random_split
5 |
6 | # for type hint
7 | from typing import Optional, List
8 | from torch.utils.data import Dataset
9 |
10 |
11 | class SVHNDataModule(CIFAR10DataModule):
12 | num_classes: int = 10
13 |
14 | total_train_size: int = 73_257
15 | total_test_size: int = 26_032
16 |
17 | DATASET = SVHN
18 |
19 | def __init__(self,
20 | data_dir: str,
21 | labeled_train_size: int,
22 | validation_size: int,
23 | unlabeled_train_size: Optional[int] = None,
24 | **kwargs):
25 | super(SVHNDataModule, self).__init__(
26 | data_dir=data_dir,
27 | labeled_train_size=labeled_train_size,
28 | validation_size=validation_size,
29 | unlabeled_train_size=unlabeled_train_size,
30 | **kwargs)
31 |
32 | # dataset stats
33 | # SVHN mean, std values in CHW
34 | self.dataset_mean = [0.4376821, 0.4437697, 0.47280442]
35 | self.dataset_std = [0.19803012, 0.20101562, 0.19703614]
36 |
37 | def prepare_data(self, *args, **kwargs):
38 | self.DATASET(root=self.data_dir, split="train", download=True)
39 | self.DATASET(root=self.data_dir, split="test", download=True)
40 |
41 | def setup(self, stage: Optional[str] = None):
42 | full_train_set = self.DATASET(root=self.data_dir, split="train")
43 | full_test_set = self.DATASET(root=self.data_dir, split="test")
44 |
45 | self.setup_helper(full_train_set=full_train_set, full_test_set=full_test_set, stage=stage)
46 |
47 | def split_dataset(self, dataset: Dataset, **kwargs) -> List[Dataset]:
48 | split_kwargs = dict(lengths=[self.validation_size, self.labeled_train_size],
49 | num_classes=self.num_classes,
50 | uneven_split=True)
51 |
52 | # update split arguments
53 | split_kwargs.update(kwargs)
54 |
55 | return per_class_random_split(dataset, **split_kwargs)
56 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import logging
5 |
6 | # for type hint
7 | from typing import Union, Dict, Any, Set, Optional
8 | from torch import Tensor
9 |
10 |
11 | def str_to_bool(input_str: Union[str]) -> Union[str, bool]:
12 | """
13 | If input_str is "True" or "False" (case insensitive and allows spaces on the side). Else, return the input_str
14 |
15 | Args:
16 | input_str:
17 |
18 | Returns: a boolean value if input_str is "True" or "False"; else, return the input string
19 |
20 | """
21 | comp_str = input_str.lower().strip()
22 |
23 | if comp_str == "true":
24 | return True
25 | elif comp_str == "false":
26 | return False
27 | else:
28 | return input_str
29 |
30 |
31 | def get_device(device_id: str) -> torch.device:
32 | # update device
33 | if device_id != "cpu":
34 | if torch.cuda.is_available():
35 | device = torch.device("cuda")
36 | else:
37 | logging.warning(f"device \"{device_id}\" is not available")
38 | device = torch.device("cpu")
39 | else:
40 | device = torch.device("cpu")
41 |
42 | return device
43 |
44 |
45 | def dict_add_prefix(input_dict: Dict[str, Any], prefix: str, separator: str = "/") -> Dict[str, Any]:
46 | return {f"{prefix}{separator}{key}": val for key, val in input_dict.items()}
47 |
48 |
49 | def filter_dict(input_dict: Dict[str, Any], excluded_keys: Optional[Set[str]] = None,
50 | included_keys: Optional[Set[str]] = None) -> Dict[str, Any]:
51 | if excluded_keys is None:
52 | excluded_keys = set()
53 |
54 | if included_keys is not None:
55 | input_dict = {k: v for k, v in input_dict.items() if k in included_keys}
56 |
57 | return {k: v for k, v in input_dict.items() if k not in excluded_keys}
58 |
59 |
60 | def detorch(inputs: Union[Tensor, np.ndarray, float, int, bool]) -> Union[np.ndarray, float, int, bool]:
61 | if isinstance(inputs, Tensor):
62 | outputs = inputs.detach().cpu().clone().numpy()
63 | else:
64 | outputs = inputs
65 |
66 | if isinstance(outputs, np.ndarray) and outputs.size == 1:
67 | outputs = outputs.item()
68 |
69 | return outputs
70 |
--------------------------------------------------------------------------------
/models/rampup.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | # for type hint
4 | from typing import Dict, Any, Set, Tuple, Optional
5 |
6 |
7 | class RampUp(ABC):
8 | def __init__(self, length: int, current: int = 0):
9 | self.current = current
10 | self.length = length
11 |
12 | @abstractmethod
13 | def __call__(self, current: Optional[int] = None, is_step: bool = True) -> float:
14 | pass
15 |
16 | def state_dict(self) -> Dict[str, Any]:
17 | return dict(current=self.current, length=self.length)
18 |
19 | def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
20 | if strict:
21 | is_equal, incompatible_keys = self._verify_state_dict(state_dict)
22 | assert is_equal, f"loaded state dict contains incompatible keys: {incompatible_keys}"
23 |
24 | # for attr_name, attr_value in state_dict.items():
25 | # if attr_name in self.__dict__:
26 | # self.__dict__[attr_name] = attr_value
27 |
28 | self.current = state_dict["current"]
29 | self.length = state_dict["length"]
30 |
31 | def _verify_state_dict(self, state_dict: Dict[str, Any]) -> Tuple[bool, Set[str]]:
32 | self_keys = set(self.__dict__.keys())
33 | state_dict_keys = set(state_dict.keys())
34 | incompatible_keys = self_keys.union(state_dict_keys) - self_keys.intersection(state_dict_keys)
35 | is_equal = (len(incompatible_keys) == 0)
36 |
37 | return is_equal, incompatible_keys
38 |
39 | def _update_step(self, is_step: bool):
40 | if is_step:
41 | self.current += 1
42 |
43 |
44 | class LinearRampUp(RampUp):
45 | def __call__(self, current: Optional[int] = None, is_step: bool = True) -> float:
46 | if current is not None:
47 | self.current = current
48 |
49 | if self.current >= self.length:
50 | ramp_up = 1.0
51 | else:
52 | ramp_up = self.current / self.length
53 |
54 | self._update_step(is_step)
55 |
56 | return ramp_up
57 |
58 |
59 | def get_ramp_up(ramp_up_type: str, length: int, current: int = 0) -> RampUp:
60 | return {
61 | "linear": lambda: LinearRampUp(length, current),
62 | }[ramp_up_type]()
63 |
--------------------------------------------------------------------------------
/utils/loggers/log_aggregator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ..utils import detorch
4 |
5 | # for type hint
6 | from typing import List, Any, Dict, Optional, Union, FrozenSet
7 |
8 |
9 | class LogAggregator:
10 | supported_reductions = {
11 | "mean": lambda x: np.mean(x).item(),
12 | "sum": lambda x: np.sum(x).item(),
13 | }
14 |
15 | def __init__(self):
16 | self.log_dict: Dict[str, Union[List[Any], Any]] = dict()
17 | self.plot_dict: Dict[str, Any] = dict()
18 |
19 | def __getitem__(self, k):
20 | return self.log_dict[k]
21 |
22 | @property
23 | def log_keys(self) -> FrozenSet[str]:
24 | return frozenset(self.log_dict.keys())
25 |
26 | @property
27 | def plot_keys(self) -> FrozenSet[str]:
28 | return frozenset(self.plot_dict.keys())
29 |
30 | def clear(self) -> None:
31 | self.log_dict.clear()
32 | self.plot_dict.clear()
33 |
34 | def add_log(self, log_info: Dict[str, Any]) -> None:
35 | conflict_keys = self.plot_keys.intersection(log_info.keys())
36 | assert len(conflict_keys) == 0, f"conflicting keys for plot and log data: {conflict_keys}"
37 |
38 | for k, v in log_info.items():
39 | if k not in self.log_dict:
40 | self.log_dict[k] = list()
41 |
42 | self.log_dict[k].append(detorch(v))
43 |
44 | def add_plot(self, plot_info: Dict[str, Any]) -> None:
45 | conflict_keys = self.log_keys.intersection(plot_info.keys())
46 | assert len(conflict_keys) == 0, f"conflicting keys for plot and log data: {conflict_keys}"
47 |
48 | for k, v in plot_info.items():
49 | self.plot_dict[k] = v
50 |
51 | def aggregate(self, reduction: str = "mean", key_mapping: Optional[Dict[Any, Any]] = None) -> Dict[str, Any]:
52 | assert reduction in self.supported_reductions, f"unsupported reduction method: {reduction}"
53 |
54 | if key_mapping is None:
55 | key_mapping = {}
56 |
57 | reduce_func = self.supported_reductions[reduction]
58 |
59 | output_log_dict = {key_mapping.get(k, k): reduce_func(v) for k, v in self.log_dict.items()}
60 | output_plot_dict = {key_mapping.get(k, k): v for k, v in self.plot_dict.items()}
61 |
62 | output_dict = output_log_dict
63 | output_plot_dict.update(output_plot_dict)
64 |
65 | return output_dict
66 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.backends import cudnn
5 |
6 | from .cli import get_arg_parser, get_args, update_args, args_to_logger_config
7 | from .dataset import get_dataset, repeater
8 | from .file_io import find_checkpoint_path, read_yaml, find_all_files
9 |
10 | from .timing import timing
11 | from .loggers import Logger, LogAggregator
12 | from .utils import str_to_bool, get_device, dict_add_prefix, filter_dict, detorch
13 |
14 | from . import dataset
15 | from . import cli
16 | from . import loggers
17 | from . import metrics
18 | from . import types
19 |
20 | # for type hint
21 | from typing import Optional
22 | from argparse import Namespace
23 |
24 |
25 | def set_random_seed(seed: Optional[int], is_cudnn_deterministic: bool) -> None:
26 | if seed is not None:
27 | random.seed(seed)
28 | np.random.seed(seed)
29 | torch.manual_seed(seed)
30 |
31 | if is_cudnn_deterministic:
32 | cudnn.deterministic = True
33 | cudnn.benchmark = False
34 |
35 |
36 | def get_logger(args: Namespace) -> Logger:
37 | logger_type = args.logger
38 | config_dict = args_to_logger_config(args)
39 |
40 | if logger_type == "wandb":
41 | from .loggers import WandbLogger
42 | return WandbLogger(
43 | log_dir=args.log_dir,
44 | config=config_dict,
45 | **args.logger_config_dict)
46 | elif logger_type == "nop":
47 | return Logger(
48 | log_dir=args.log_dir,
49 | config=config_dict)
50 | else:
51 | from .loggers import PrintLogger
52 | return PrintLogger(
53 | log_dir=args.log_dir,
54 | config=config_dict,
55 | is_display_plots=args.is_display_plots)
56 |
57 |
58 | __all__ = [
59 | # modules
60 | "dataset",
61 | "cli",
62 | "loggers",
63 | "metrics",
64 | "types",
65 |
66 | # classes
67 | "LogAggregator",
68 | "Logger",
69 |
70 | # functions
71 | "get_arg_parser",
72 | "get_args",
73 | "update_args",
74 | "args_to_logger_config",
75 | "get_dataset",
76 | "repeater",
77 | "find_checkpoint_path",
78 | "read_yaml",
79 | "find_all_files",
80 | "timing",
81 | "get_logger",
82 | "str_to_bool",
83 | "get_device",
84 | "dict_add_prefix",
85 | "filter_dict",
86 | "detorch",
87 | "set_random_seed",
88 | ]
89 |
--------------------------------------------------------------------------------
/models/augmentation/augmenter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from kornia import filters as F
3 | import numpy as np
4 |
5 | from random import random
6 |
7 | # for type hint
8 | from typing import Optional, Tuple, Sequence
9 | from torch import Tensor
10 | from torch.nn import Module
11 |
12 |
13 | class RandomAugmentation(Module):
14 | def __init__(self, augmentation: Module, p: float = 0.5, same_on_batch: bool = False):
15 | super().__init__()
16 |
17 | self.prob = p
18 | self.augmentation = augmentation
19 | self.same_on_batch = same_on_batch
20 |
21 | def forward(self, images: Tensor) -> Tensor:
22 | is_batch = len(images) < 4
23 |
24 | if not is_batch or self.same_on_batch:
25 | if random() <= self.prob:
26 | out = self.augmentation(images)
27 | else:
28 | out = images
29 | else:
30 | out = self.augmentation(images)
31 | batch_size = len(images)
32 |
33 | # get the indices of data which shouldn't apply augmentation
34 | indices = torch.where(torch.rand(batch_size) > self.prob)
35 | out[indices] = images[indices]
36 |
37 | return out
38 |
39 |
40 | class RandomGaussianBlur(Module):
41 | def __init__(self, kernel_size: Tuple[int, int], min_sigma=0.1, max_sigma=2.0, p=0.5) -> None:
42 | super().__init__()
43 | self.kernel_size = tuple(s + 1 if s % 2 == 0 else s for s in kernel_size) # kernel size must be odd
44 | self.min_sigma = min_sigma
45 | self.max_sigma = max_sigma
46 | self.p = p
47 |
48 | def forward(self, img):
49 | if self.p > random():
50 | sigma = (self.max_sigma - self.min_sigma) * random() + self.min_sigma
51 | return F.gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=(sigma, sigma))
52 | else:
53 | return img
54 |
55 |
56 | class RandomChoice(Module):
57 | def __init__(self, augmentations: Sequence[Module], size: int = 2, p: Optional[Sequence[float]] = None):
58 | super().__init__()
59 |
60 | assert size <= len(augmentations), f"size = {size} should be <= # aug. = {len(augmentations)}"
61 |
62 | self.augmentations = augmentations
63 | self.size = size
64 | self.p = p
65 |
66 | def forward(self, inputs: Tensor) -> Tensor:
67 | indices = np.random.choice(range(len(self.augmentations)), size=self.size, replace=False, p=self.p)
68 |
69 | outputs = inputs
70 | for i in indices:
71 | outputs = self.augmentations[i](outputs)
72 |
73 | return outputs
74 |
--------------------------------------------------------------------------------
/utils/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import functional as F
2 | from PIL import Image
3 |
4 | # for type hint
5 | from typing import Union, Tuple, Callable, List
6 | from collections.abc import Iterable
7 | from torch import Tensor
8 | from PIL.Image import Image as PILImage
9 |
10 | _pil_interpolation_to_str = {
11 | Image.NEAREST: 'PIL.Image.NEAREST',
12 | Image.BILINEAR: 'PIL.Image.BILINEAR',
13 | Image.BICUBIC: 'PIL.Image.BICUBIC',
14 | Image.LANCZOS: 'PIL.Image.LANCZOS',
15 | Image.HAMMING: 'PIL.Image.HAMMING',
16 | Image.BOX: 'PIL.Image.BOX',
17 | }
18 |
19 |
20 | class CenterResizedCrop(object):
21 | """Crops the given PIL Image at the center. Then resize to desired shape.
22 |
23 | First, a largest possible center crop is performed at the center. Then,
24 | the cropped image is resized to the desired output size
25 |
26 | Args:
27 | size (sequence or int): Desired output size of the crop. If size is an
28 | int instead of sequence like (h, w), a square crop (size, size) is
29 | made.
30 | interpolation: Default: PIL.Image.BILINEAR
31 | """
32 |
33 | def __init__(self, size: Union[int, Tuple[int, int]], interpolation: int = Image.BILINEAR):
34 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
35 | self.size = size
36 | self.interpolation = interpolation
37 |
38 | def __call__(self, img: PILImage):
39 | """
40 | Args:
41 | img (PIL Image): Image to be cropped.
42 |
43 | Returns:
44 | PIL Image: Cropped image.
45 | """
46 | image_width, image_height = img.size
47 | output_height, output_width = self.size
48 |
49 | image_ratio = image_width / image_height
50 | output_ratio = output_width / output_height
51 |
52 | if image_ratio >= output_ratio:
53 | crop_height = int(image_height)
54 | crop_width = int(image_height * output_ratio)
55 | else:
56 | crop_height = int(image_width / output_ratio)
57 | crop_width = int(image_width)
58 |
59 | cropped_img = F.center_crop(img, (crop_height, crop_width))
60 |
61 | return F.resize(cropped_img, size=self.size, interpolation=self.interpolation)
62 |
63 | def __repr__(self):
64 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
65 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
66 | format_string += ', interpolation={0})'.format(interpolate_str)
67 | return format_string
68 |
69 |
70 | __all__ = [
71 | "CenterResizedCrop",
72 | ]
73 |
--------------------------------------------------------------------------------
/utils/dataset/miniimagenet.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | from PIL import Image
3 | from torchvision.datasets import VisionDataset
4 | from torchvision.datasets.utils import check_integrity, download_file_from_google_drive
5 |
6 | from pathlib import Path
7 |
8 | # for type hint
9 | from typing import Optional, Callable
10 |
11 |
12 | class MiniImageNet(VisionDataset):
13 | base_folder = 'mini-imagenet'
14 | gdrive_id = '1EKmnUcpipszzBHBRcXxmejuO4pceD4ht'
15 | file_md5 = '3bda5120eb7353dd88e06de46e680146'
16 | filename = 'mini-imagenet.hdf5'
17 |
18 | def __init__(self,
19 | root: str,
20 | train: bool = True,
21 | transform: Optional[Callable] = None,
22 | target_transform: Optional[Callable] = None,
23 | download: bool = False):
24 | super().__init__(root, transform=transform, target_transform=target_transform)
25 | self.train = train
26 | self.root = root
27 |
28 | if download:
29 | self.download()
30 |
31 | if not self._check_integrity():
32 | raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')
33 |
34 | img_key = 'train_img' if self.train else 'test_img'
35 | target_key = 'train_target' if self.train else 'test_target'
36 | with h5py.File(self.data_root / self.filename, "r", swmr=True) as h5_f:
37 | self.data = h5_f[img_key][...]
38 | self.target = h5_f[target_key][...]
39 |
40 | @property
41 | def data_root(self) -> Path:
42 | return Path(self.root) / self.base_folder
43 |
44 | @property
45 | def download_root(self) -> Path:
46 | return self.data_root
47 |
48 | def __len__(self):
49 | return len(self.target)
50 |
51 | def __getitem__(self, idx):
52 | img, target = Image.fromarray(self.data[idx]), self.target[idx]
53 |
54 | if self.transform is not None:
55 | img = self.transform(img)
56 |
57 | if self.target_transform is not None:
58 | target = self.target_transform(target)
59 |
60 | return img, target
61 |
62 | def download(self):
63 | if self._check_integrity():
64 | print('Files already downloaded and verified')
65 | return
66 | download_file_from_google_drive(file_id=self.gdrive_id,
67 | root=str(self.download_root),
68 | filename=self.filename,
69 | md5=self.file_md5)
70 |
71 | def _check_integrity(self):
72 | return check_integrity(fpath=str(self.download_root / self.filename), md5=self.file_md5)
73 |
--------------------------------------------------------------------------------
/models/models/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from .wide_resnet import WideResNet
4 | from .resnet import *
5 | from .ema import EMA, FullEMA
6 | from .utils import unwrap_model
7 |
8 | from . import utils
9 |
10 | # for type hint
11 | from typing import Union
12 | from torch.nn import Module
13 |
14 |
15 | def get_model_size(model: Module) -> int:
16 | return sum(p.numel() for p in model.parameters())
17 |
18 |
19 | def build_model(model_type: str, in_channels: int, out_channels: int, **kwargs) -> Union[Module, EMA]:
20 | if model_type == "wrn28-8":
21 | model = WideResNet(in_channels=in_channels,
22 | out_channels=out_channels,
23 | depth=28,
24 | widening_factor=8,
25 | base_channels=16,
26 | **kwargs)
27 |
28 | elif model_type == "resnet18":
29 | model = resnet18(pretrained=False, progress=True, num_classes=out_channels, **kwargs)
30 |
31 | elif model_type == "resnet50":
32 | model = resnet50(pretrained=False, progress=True, num_classes=out_channels, **kwargs)
33 |
34 | elif model_type == "wrn28-2":
35 | # default model is WRN 28-2
36 | model = WideResNet(in_channels=in_channels,
37 | out_channels=out_channels,
38 | depth=28,
39 | widening_factor=2,
40 | base_channels=16,
41 | **kwargs)
42 |
43 | else:
44 | raise NotImplementedError(f"\"{model_type}\" is not a supported model type")
45 |
46 | print(f'{model_type} Total params: {(get_model_size(model) / 1e6):.2f}M')
47 | return model
48 |
49 |
50 | def build_ema_model(model: Module, ema_type: str, ema_decay: float) -> Union[Module, EMA]:
51 | if ema_decay == 0:
52 | warnings.warn("EMA decay is 0, turn off EMA")
53 | return model
54 |
55 | elif ema_type == "full":
56 | return FullEMA(model, decay=ema_decay)
57 |
58 | elif ema_type == "default":
59 | return EMA(model, decay=ema_decay)
60 |
61 | else:
62 | raise NotImplementedError(f"\"{ema_type}\" is not a supported EMA type ")
63 |
64 |
65 | __all__ = [
66 | # modules,
67 | "utils",
68 |
69 | # classes
70 | "WideResNet",
71 | "ResNet",
72 | "EMA",
73 | "FullEMA",
74 |
75 | # functions
76 | "get_model_size",
77 | "build_model",
78 | "build_ema_model",
79 |
80 | # model functions
81 | "resnet18",
82 | "resnet34",
83 | "resnet50",
84 | "resnet101",
85 | "resnet152",
86 | "resnext50_32x4d",
87 | "resnext101_32x8d",
88 | "wide_resnet50_2",
89 | "wide_resnet101_2",
90 | ]
91 |
--------------------------------------------------------------------------------
/models/mixmatch/simple_mixmatch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .mixmatch_base import MixMatchBase
4 |
5 | # for type hint
6 | from typing import List, Dict
7 | from torch import Tensor
8 | from torch.nn import Module
9 |
10 |
11 | class SimPLE(MixMatchBase):
12 | def __init__(self,
13 | augmenter: Module,
14 | strong_augmenter: Module,
15 | num_classes: int,
16 | temperature: float,
17 | num_augmentations: int,
18 | num_strong_augmentations: int,
19 | is_strong_augment_x: bool,
20 | train_label_guessing: bool):
21 | super(SimPLE, self).__init__(augmenter=augmenter,
22 | strong_augmenter=strong_augmenter,
23 | num_classes=num_classes,
24 | temperature=temperature,
25 | num_augmentations=num_augmentations,
26 | num_strong_augmentations=num_strong_augmentations,
27 | alpha=0.,
28 | is_strong_augment_x=is_strong_augment_x,
29 | train_label_guessing=train_label_guessing)
30 |
31 | @property
32 | def total_num_augmentations(self) -> int:
33 | return self.num_strong_augmentations
34 |
35 | @torch.no_grad()
36 | def __call__(self,
37 | x_augmented: Tensor,
38 | x_strong_augmented: Tensor,
39 | x_targets_one_hot: Tensor,
40 | u_augmented: List[Tensor],
41 | u_strong_augmented: List[Tensor],
42 | u_true_targets_one_hot: Tensor,
43 | model: Module,
44 | *args,
45 | **kwargs) -> Dict[str, Tensor]:
46 | if self.is_strong_augment_x:
47 | x_inputs = x_strong_augmented
48 | else:
49 | x_inputs = x_augmented
50 | u_inputs = u_strong_augmented
51 |
52 | # label guessing with weakly augmented data
53 | pseudo_label_dict = self.guess_label(u_inputs=u_augmented, model=model)
54 |
55 | return self.postprocess(x_augmented=x_inputs,
56 | x_targets_one_hot=x_targets_one_hot,
57 | u_augmented=u_inputs,
58 | q_guess=pseudo_label_dict["q_guess"],
59 | u_true_targets_one_hot=u_true_targets_one_hot)
60 |
61 | def mixup(self,
62 | x_augmented: Tensor,
63 | x_targets_one_hot: Tensor,
64 | u_augmented: Tensor,
65 | q_guess: Tensor,
66 | q_true: Tensor) -> Dict[str, Tensor]:
67 | # SimPLE do not use mixup
68 | return dict(x_mixed=x_augmented,
69 | p_mixed=x_targets_one_hot,
70 | u_mixed=u_augmented,
71 | q_mixed=q_guess,
72 | q_true_mixed=q_true)
73 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | # for type hint
4 | from typing import Dict, Union, Optional, KeysView, Set, Any
5 |
6 |
7 | class MetricMode(Enum):
8 | MAX = 0
9 | MIN = 1
10 |
11 |
12 | MetricDictType = Dict[str, Union[str, MetricMode, int, float]]
13 |
14 |
15 | class MetricMonitor:
16 | def __init__(self, metrics: Optional[Dict[str, MetricDictType]] = None):
17 | self.metrics: Dict[str, MetricDictType] = dict()
18 |
19 | if metrics is not None:
20 | self.update(metrics)
21 |
22 | def __getitem__(self, key: str) -> MetricDictType:
23 | return self.metrics[key]
24 |
25 | def __setitem__(self, key: str, value: MetricDictType):
26 | self.track(key=key,
27 | log_key=value["key"],
28 | mode=value["mode"],
29 | best_value=value["best_value"])
30 |
31 | def __contains__(self, key: str) -> bool:
32 | return key in self.metrics
33 |
34 | def track(self,
35 | key: str,
36 | best_value: Union[float, int],
37 | mode: MetricMode,
38 | log_key: Optional[str] = None,
39 | prefix: Optional[str] = None):
40 | if log_key is None:
41 | log_key = f"best_{key}"
42 |
43 | if prefix is not None:
44 | key = f"{prefix}/{key}"
45 | log_key = f"{prefix}/{log_key}"
46 |
47 | if key in self.metrics:
48 | # if key exist, update best_value
49 | curr_best_value = self.metrics[key]["best_value"]
50 |
51 | if mode == MetricMode.MIN:
52 | best_value = min(best_value, curr_best_value)
53 | else:
54 | best_value = max(best_value, curr_best_value)
55 |
56 | self.metrics[key] = {
57 | "key": log_key,
58 | "best_value": best_value,
59 | "mode": mode,
60 | }
61 |
62 | def keys(self) -> KeysView[str]:
63 | return self.metrics.keys()
64 |
65 | def update(self, metrics: Dict[str, MetricDictType]):
66 | for key, metric in metrics.items():
67 | self[key] = metric
68 |
69 | def mutual_keys(self, keys: Union[Set[str], KeysView[str]]) -> Set[str]:
70 | return set(keys).intersection(set(self.metrics.keys()))
71 |
72 | def update_metrics(self, log_info: Dict[str, Any]) -> Dict[str, Union[int, float]]:
73 | updated_dict = {}
74 |
75 | for mutual_key in self.mutual_keys(log_info.keys()):
76 | metric_dict = self[mutual_key]
77 | new_value = log_info[mutual_key]
78 |
79 | mode: MetricMode = metric_dict["mode"]
80 | best_value = metric_dict["best_value"]
81 |
82 | if (mode == MetricMode.MAX and new_value > best_value) or \
83 | (mode == MetricMode.MIN and new_value < best_value):
84 | metric_dict["best_value"] = new_value
85 |
86 | # save updated key and value
87 | updated_dict[mutual_key] = new_value
88 |
89 | return updated_dict
90 |
91 | def state_dict(self) -> Dict[str, Any]:
92 | return {k: v["best_value"] for k, v in self.metrics.items()}
93 |
94 | def load_state_dict(self, state_dict: Dict[str, Any]):
95 | self.update_metrics(state_dict)
96 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .pair_loss import PairLoss
2 | from .utils import bha_coeff, bha_coeff_distance, l2_distance
3 | from .loss import (SupervisedLoss, UnsupervisedLoss, softmax_cross_entropy_loss, bha_coeff_loss,
4 | l2_dist_loss)
5 |
6 | # modules
7 | from . import utils
8 | from . import types
9 | from . import pair_loss
10 | from . import visualization
11 |
12 | # for type hint
13 | from typing import Tuple
14 | from argparse import Namespace
15 |
16 | from .types import SimilarityType, DistanceType, DistanceLossType
17 |
18 |
19 | def get_similarity_metric(similarity_type: str) -> Tuple[SimilarityType, str]:
20 | """
21 |
22 | Args:
23 | similarity_type: the type of the similarity function
24 |
25 | Returns: similarity function, string indicating the type of the similarity (from [logit, prob, feature])
26 |
27 | """
28 | if similarity_type == "bhc":
29 | return bha_coeff, "prob"
30 |
31 | else:
32 | raise NotImplementedError(f"\"{similarity_type}\" is not a supported similarity type")
33 |
34 |
35 | def get_distance_loss_metric(distance_loss_type: str) -> Tuple[DistanceLossType, str]:
36 | """
37 |
38 |
39 | Args:
40 | distance_loss_type: the type of the distance loss function
41 |
42 | Returns: distance loss function, string indicating the type of the loss (from [logit, prob])
43 |
44 | """
45 | if distance_loss_type == "bhc":
46 | return bha_coeff_loss, "logit"
47 |
48 | elif distance_loss_type == "l2":
49 | return l2_dist_loss, "prob"
50 |
51 | elif distance_loss_type == "entropy":
52 | return softmax_cross_entropy_loss, "logit"
53 |
54 | else:
55 | raise NotImplementedError(f"\"{distance_loss_type}\" is not a supported distance loss type")
56 |
57 |
58 | def build_supervised_loss(args: Namespace) -> SupervisedLoss:
59 | return SupervisedLoss(reduction="mean")
60 |
61 |
62 | def build_unsupervised_loss(args: Namespace) -> UnsupervisedLoss:
63 | return UnsupervisedLoss(
64 | loss_type=args.u_loss_type,
65 | loss_thresholded=args.u_loss_thresholded,
66 | confidence_threshold=args.confidence_threshold,
67 | reduction="mean")
68 |
69 |
70 | def build_pair_loss(args: Namespace, reduction: str = "mean") -> PairLoss:
71 | similarity_metric, similarity_type = get_similarity_metric(args.similarity_type)
72 | distance_loss_metric, distance_loss_type = get_distance_loss_metric(args.distance_loss_type)
73 |
74 | return PairLoss(
75 | similarity_metric=similarity_metric,
76 | distance_loss_metric=distance_loss_metric,
77 | confidence_threshold=args.confidence_threshold,
78 | similarity_threshold=args.similarity_threshold,
79 | similarity_type=similarity_type,
80 | distance_loss_type=distance_loss_type,
81 | reduction=reduction)
82 |
83 |
84 | __all__ = [
85 | # modules
86 | "utils",
87 | "types",
88 | "pair_loss",
89 | "visualization",
90 |
91 | # classes
92 | "SupervisedLoss",
93 | "UnsupervisedLoss",
94 |
95 | # functions
96 | "get_similarity_metric",
97 | "get_distance_loss_metric",
98 |
99 | # loss functions
100 | "build_supervised_loss",
101 | "build_unsupervised_loss",
102 | "build_pair_loss",
103 | ]
104 |
--------------------------------------------------------------------------------
/utils/loggers/wandb_logger.py:
--------------------------------------------------------------------------------
1 | import wandb
2 |
3 | import sys
4 |
5 | from .logger import Logger
6 |
7 | # for type hint
8 | from typing import Any, Dict, Optional, Union, Sequence
9 | from argparse import Namespace
10 |
11 | from torch.nn import Module
12 |
13 |
14 | class WandbLogger(Logger):
15 | def __init__(self,
16 | log_dir: str,
17 | config: Union[Namespace, Dict[str, Any]],
18 | name: str,
19 | tags: Sequence[str],
20 | notes: str,
21 | entity: str,
22 | project: str,
23 | mode: str = "offline",
24 | resume: Union[bool, str] = False):
25 | super().__init__(
26 | log_dir=log_dir,
27 | config=config,
28 | log_info_key_map={
29 | "top1_acc": "mean_acc",
30 | "top5_acc": "mean_top5_acc",
31 | })
32 |
33 | self.name = name
34 | self.config = config
35 | self.tags = tags
36 | self.notes = notes
37 | self.entity = entity
38 | self.project = project
39 | self.mode = mode
40 | self.resume = resume
41 |
42 | self.is_init = False
43 |
44 | def _init_wandb(self):
45 | if not self.is_init:
46 | wandb.init(
47 | name=self.name,
48 | config=self.config,
49 | project=self.project,
50 | entity=self.entity,
51 | dir=self.log_dir,
52 | tags=self.tags,
53 | notes=self.notes,
54 | mode=self.mode,
55 | resume=self.resume)
56 | # update is_init flag
57 | self.is_init = True
58 |
59 | # update config if resumed
60 | if bool(self.resume):
61 | self.log(self.config, is_config=True)
62 |
63 | def log(self,
64 | log_info: Dict[str, Any],
65 | step: Optional[int] = None,
66 | is_summary: bool = False,
67 | is_config: bool = False,
68 | is_commit: Optional[bool] = None,
69 | prefix: Optional[str] = None,
70 | log_info_override: Optional[Dict[str, Any]] = None,
71 | **kwargs):
72 | if not self.is_init:
73 | self._init_wandb()
74 |
75 | # process log_info
76 | log_info = self.process_log_info(log_info, prefix=prefix, log_info_override=log_info_override)
77 |
78 | if len(log_info) == 0:
79 | return
80 |
81 | if is_summary:
82 | # for log_info_key, log_info_value in log_info.items():
83 | # wandb.run.summary[log_info_key] = log_info_value
84 | wandb.run.summary.update(log_info)
85 | elif is_config:
86 | wandb.run.config.update(log_info, allow_val_change=True)
87 | else:
88 | log_info, plot_info = self.separate_plot(log_info)
89 | wandb.log(plot_info, commit=False, step=step)
90 | wandb.log(log_info, commit=is_commit, step=step)
91 |
92 | # invoke all log hook functions
93 | self.call_log_hooks(log_info)
94 |
95 | def watch(self, model: Module, **kwargs):
96 | if not self.is_init:
97 | self._init_wandb()
98 |
99 | wandb.watch(model, **kwargs)
100 |
101 | def save(self, output_path: str):
102 | if not self.is_init:
103 | self._init_wandb()
104 |
105 | if sys.platform != "win32":
106 | # TODO: remove the if condition once found a solution
107 | # Windows requires elevated access to
108 | wandb.save(output_path)
109 |
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | from .utils import reduce_tensor, bha_coeff_log_prob, l2_distance
5 |
6 | # for type hint
7 | from torch import Tensor
8 |
9 |
10 | def softmax_cross_entropy_loss(logits: Tensor, targets: Tensor, dim: int = 1, reduction: str = 'mean') -> Tensor:
11 | """
12 | :param logits: (labeled_batch_size, num_classes) model output of the labeled data
13 | :param targets: (labeled_batch_size, num_classes) labels distribution for the data
14 | :param dim: the dimension or dimensions to reduce
15 | :param reduction: choose from 'mean', 'sum', and 'none'
16 | :return:
17 | """
18 | loss = -torch.sum(F.log_softmax(logits, dim=dim) * targets, dim=dim)
19 |
20 | return reduce_tensor(loss, reduction)
21 |
22 |
23 | def mse_loss(prob: Tensor, targets: Tensor, reduction: str = 'mean', **kwargs) -> Tensor:
24 | return F.mse_loss(prob, targets, reduction=reduction)
25 |
26 |
27 | def bha_coeff_loss(logits: Tensor, targets: Tensor, dim: int = 1, reduction: str = "none") -> Tensor:
28 | """
29 | Bhattacharyya coefficient of p and q; the more similar the larger the coefficient
30 | :param logits: (batch_size, num_classes) model predictions of the data
31 | :param targets: (batch_size, num_classes) label prob distribution
32 | :param dim: the dimension or dimensions to reduce
33 | :param reduction: reduction method, choose from "sum", "mean", "none
34 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance
35 | """
36 | log_probs = F.log_softmax(logits, dim=dim)
37 | log_targets = torch.log(targets)
38 |
39 | # since BC(P,Q) is maximized when P and Q are the same, we minimize 1 - B(P,Q)
40 | return 1. - bha_coeff_log_prob(log_probs, log_targets, dim=dim, reduction=reduction)
41 |
42 |
43 | def l2_dist_loss(probs: Tensor, targets: Tensor, dim: int = 1, reduction: str = "none") -> Tensor:
44 | loss = l2_distance(probs, targets, dim=dim)
45 |
46 | return reduce_tensor(loss, reduction)
47 |
48 |
49 | class SupervisedLoss:
50 | def __init__(self, reduction: str = 'mean'):
51 | self.loss_use_prob = False
52 | self.loss_fn = softmax_cross_entropy_loss
53 |
54 | self.reduction = reduction
55 |
56 | def __call__(self, logits: Tensor, probs: Tensor, targets: Tensor) -> Tensor:
57 | loss_input = probs if self.loss_use_prob else logits
58 | loss = self.loss_fn(loss_input, targets, dim=1, reduction=self.reduction)
59 |
60 | return loss
61 |
62 |
63 | class UnsupervisedLoss:
64 | def __init__(self,
65 | loss_type: str,
66 | loss_thresholded: bool = False,
67 | confidence_threshold: float = 0.,
68 | reduction: str = "mean"):
69 | if loss_type in ["entropy", "cross entropy"]:
70 | self.loss_use_prob = False
71 | self.loss_fn = softmax_cross_entropy_loss
72 | else:
73 | self.loss_use_prob = True
74 | self.loss_fn = mse_loss
75 |
76 | self.loss_thresholded = loss_thresholded
77 | self.confidence_threshold = confidence_threshold
78 | self.reduction = reduction
79 |
80 | def __call__(self, logits: Tensor, probs: Tensor, targets: Tensor) -> Tensor:
81 | loss_input = probs if self.loss_use_prob else logits
82 | loss = self.loss_fn(loss_input, targets, dim=1, reduction="none")
83 |
84 | if self.loss_thresholded:
85 | targets_mask = (targets.max(dim=1).values > self.confidence_threshold)
86 |
87 | if len(loss.shape) > 1:
88 | # mse_loss returns a matrix, need to reshape mask
89 | targets_mask = targets_mask.view(-1, 1)
90 |
91 | loss *= targets_mask.float()
92 |
93 | return reduce_tensor(loss, reduction=self.reduction)
94 |
--------------------------------------------------------------------------------
/loss/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # for type hint
4 | from typing import Union, Optional, Sequence, Callable
5 | from torch import Tensor
6 |
7 | ScalarType = Union[int, float, bool]
8 |
9 |
10 | def reduce_tensor(inputs: Tensor, reduction: str) -> Tensor:
11 | if reduction == 'mean':
12 | return torch.mean(inputs)
13 |
14 | elif reduction == 'sum':
15 | return torch.sum(inputs)
16 |
17 | return inputs
18 |
19 |
20 | def to_tensor(data: Union[ScalarType, Sequence[ScalarType]],
21 | dtype: Optional[torch.dtype] = None,
22 | device: Optional[Union[torch.device, str]] = None,
23 | tensor_like: Optional[Tensor] = None) -> Tensor:
24 | if tensor_like is not None:
25 | dtype = tensor_like.dtype if dtype is None else dtype
26 | device = tensor_like.device if device is None else device
27 |
28 | return torch.tensor(data, dtype=dtype, device=device)
29 |
30 |
31 | def bha_coeff_log_prob(log_p: Tensor, log_q: Tensor, dim: int = 1, reduction: str = "none") -> Tensor:
32 | """
33 | Bhattacharyya coefficient of log(p) and log(q); the more similar the larger the coefficient
34 | :param log_p: (batch_size, num_classes) first log prob distribution
35 | :param log_q: (batch_size, num_classes) second log prob distribution
36 | :param dim: the dimension or dimensions to reduce
37 | :param reduction: reduction method, choose from "sum", "mean", "none"
38 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance
39 | """
40 | # numerical unstable version
41 | # coefficient = torch.sum(torch.sqrt(p * q), dim=dim)
42 | # numerical stable version
43 | coefficient = torch.sum(torch.exp((log_p + log_q) / 2), dim=dim)
44 |
45 | return reduce_tensor(coefficient, reduction)
46 |
47 |
48 | def bha_coeff(p: Tensor, q: Tensor, dim: int = 1, reduction: str = "none") -> Tensor:
49 | """
50 | Bhattacharyya coefficient of p and q; the more similar the larger the coefficient
51 | :param p: (batch_size, num_classes) first prob distribution
52 | :param q: (batch_size, num_classes) second prob distribution
53 | :param dim: the dimension or dimensions to reduce
54 | :param reduction: reduction method, choose from "sum", "mean", "none"
55 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance
56 | """
57 | log_p = torch.log(p)
58 | log_q = torch.log(q)
59 |
60 | return bha_coeff_log_prob(log_p, log_q, dim=dim, reduction=reduction)
61 |
62 |
63 | def bha_coeff_distance(p: Tensor, q: Tensor, dim: int = 1, reduction: str = "none") -> Tensor:
64 | """
65 | Bhattacharyya coefficient of p and q; the more similar the larger the coefficient
66 | :param p: (batch_size, num_classes) model predictions of the data
67 | :param q: (batch_size, num_classes) label prob distribution
68 | :param dim: the dimension or dimensions to reduce
69 | :param reduction: reduction method, choose from "sum", "mean", "none"
70 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance
71 | """
72 | return 1. - bha_coeff(p, q, dim=dim, reduction=reduction)
73 |
74 |
75 | def l2_distance(x: Tensor, y: Tensor, dim: int, **kwargs) -> Tensor:
76 | return torch.norm(x - y, p=2, dim=dim)
77 |
78 |
79 | def pairwise_apply(p: Tensor, q: Tensor, func: Callable, *args, **kwargs) -> Tensor:
80 | """
81 |
82 | Args:
83 | p: (batch_size, num_classes) first prob distribution
84 | q: (batch_size, num_classes) second prob distribution
85 | func: function to be applied on p and q
86 |
87 | Returns: a matrix of pair-wise result between each element of p and q
88 |
89 | """
90 | p = p.unsqueeze(-1)
91 | q = q.T.unsqueeze(0)
92 | return func(p, q, *args, **kwargs)
93 |
--------------------------------------------------------------------------------
/utils/dataset/cifar10_datamodule.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import CIFAR10
2 |
3 | from .ssl_datamodule import SSLDataModule
4 | from .datasets import LabeledDataset
5 | from .utils import per_class_random_split
6 |
7 | # for type hint
8 | from typing import Optional, Tuple, List
9 | from torch.utils.data import Dataset
10 | from torchvision.datasets import VisionDataset
11 |
12 |
13 | class CIFAR10DataModule(SSLDataModule):
14 | num_classes: int = 10
15 |
16 | total_train_size: int = 50_000
17 | total_test_size: int = 10_000
18 |
19 | DATASET = CIFAR10
20 |
21 | def __init__(self,
22 | data_dir: str,
23 | labeled_train_size: int,
24 | validation_size: int,
25 | unlabeled_train_size: Optional[int] = None,
26 | dims: Optional[Tuple[int, ...]] = None,
27 | **kwargs):
28 | if dims is None:
29 | dims = (3, 32, 32)
30 |
31 | super(CIFAR10DataModule, self).__init__(dims=dims, **kwargs)
32 |
33 | self.data_dir = data_dir
34 |
35 | self.labeled_train_size = labeled_train_size
36 | self.validation_size = validation_size
37 | if unlabeled_train_size is None:
38 | self.unlabeled_train_size = self.total_train_size - self.validation_size - self.labeled_train_size
39 | else:
40 | self.unlabeled_train_size = unlabeled_train_size
41 |
42 | # dataset stats
43 | # CIFAR-10 mean, std values in CHW
44 | self.dataset_mean = [0.44653091, 0.49139968, 0.48215841]
45 | self.dataset_std = [0.26158784, 0.24703223, 0.24348513]
46 |
47 | def prepare_data(self, *args, **kwargs):
48 | self.DATASET(root=self.data_dir, train=True, download=True)
49 | self.DATASET(root=self.data_dir, train=False, download=True)
50 |
51 | def setup(self, stage: Optional[str] = None):
52 | full_train_set = self.DATASET(root=self.data_dir, train=True)
53 | full_test_set = self.DATASET(root=self.data_dir, train=False)
54 |
55 | self.setup_helper(full_train_set=full_train_set, full_test_set=full_test_set, stage=stage)
56 |
57 | def setup_helper(self, full_train_set: VisionDataset, full_test_set: VisionDataset, stage: Optional[str] = None):
58 | self.test_set = LabeledDataset(
59 | full_test_set,
60 | root=full_test_set.root,
61 | min_size=self.test_min_size,
62 | transform=self.test_transform)
63 |
64 | # get subsets
65 | validation_subset, labeled_train_subset, unlabeled_train_subset = self.split_dataset(full_train_set)
66 |
67 | # convert to dataset
68 | self.validation_set = LabeledDataset(
69 | validation_subset,
70 | root=full_train_set.root,
71 | min_size=self.test_min_size,
72 | transform=self.val_transform)
73 | self.labeled_train_set = LabeledDataset(
74 | labeled_train_subset,
75 | root=full_train_set.root,
76 | min_size=self.train_min_size,
77 | transform=self.train_transform)
78 | self.unlabeled_train_set = LabeledDataset(
79 | unlabeled_train_subset,
80 | root=full_train_set.root,
81 | min_size=self.unlabeled_train_min_size,
82 | transform=self.train_transform)
83 |
84 | assert len(self.validation_set.dataset) == len(validation_subset) == self.validation_size
85 | assert len(self.labeled_train_set.dataset) == len(labeled_train_subset) == self.labeled_train_size
86 | assert len(self.unlabeled_train_set.dataset) == len(unlabeled_train_subset) == self.unlabeled_train_size
87 | assert len(self.test_set.dataset) == self.total_test_size
88 |
89 | def split_dataset(self, dataset: Dataset, **kwargs) -> List[Dataset]:
90 | split_kwargs = dict(lengths=[self.validation_size, self.labeled_train_size, self.unlabeled_train_size],
91 | num_classes=self.num_classes,
92 | uneven_split=False)
93 |
94 | # update split arguments
95 | split_kwargs.update(kwargs)
96 |
97 | return per_class_random_split(dataset, **split_kwargs)
98 |
--------------------------------------------------------------------------------
/models/optimization/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.optim.lr_scheduler import LambdaLR, StepLR
3 |
4 | # for type hint
5 | from torch.optim.optimizer import Optimizer
6 |
7 | from .types import LRSchedulerType
8 |
9 |
10 | class CosineDecay:
11 | def __init__(self, max_iter: int, factor: float, min_value: float = 0., max_value: float = 1.):
12 | self.max_iter = max_iter
13 | self.factor = factor
14 |
15 | self.min_value = min_value
16 | self.max_value = max_value
17 |
18 | def __call__(self, curr_step: int) -> float:
19 | max_iter = max(self.max_iter, 1)
20 | curr_step = np.clip(curr_step, 0, max_iter).item()
21 |
22 | return self.compute_cosine_decay(curr_step=curr_step,
23 | max_iter=max_iter,
24 | factor=self.factor,
25 | min_value=self.min_value,
26 | max_value=self.max_value)
27 |
28 | @staticmethod
29 | def compute_cosine_decay(curr_step: int,
30 | max_iter: int,
31 | factor: float,
32 | min_value: float,
33 | max_value: float) -> float:
34 | output = np.cos(factor * np.pi * float(curr_step) / float(max(max_iter, 1)))
35 |
36 | return np.clip(output, min_value, max_value).item()
37 |
38 |
39 | class CosineWarmupDecay(CosineDecay):
40 | def __init__(self,
41 | max_iter: int,
42 | factor: float,
43 | num_warmup_steps: int,
44 | min_value: float = 0.,
45 | max_value: float = 1.):
46 | super(CosineWarmupDecay, self).__init__(max_iter=max_iter,
47 | factor=factor,
48 | min_value=min_value,
49 | max_value=max_value)
50 |
51 | self.num_warmup_steps = max(num_warmup_steps, 0)
52 |
53 | def __call__(self, curr_step: int) -> float:
54 | if curr_step < self.num_warmup_steps:
55 | return float(curr_step) / float(max(self.num_warmup_steps, 1))
56 | else:
57 | max_iter = max(self.max_iter - self.num_warmup_steps, 1)
58 | curr_step = np.clip(curr_step - self.num_warmup_steps, 0, max_iter).item()
59 |
60 | return self.compute_cosine_decay(curr_step=curr_step,
61 | max_iter=max_iter,
62 | factor=self.factor,
63 | min_value=self.min_value,
64 | max_value=self.max_value)
65 |
66 |
67 | def build_lr_scheduler(scheduler_type: str,
68 | optimizer: Optimizer,
69 | max_iter: int,
70 | cosine_factor: float,
71 | step_size: int,
72 | gamma: float,
73 | num_warmup_steps: int,
74 | **kwargs) -> LRSchedulerType:
75 | if scheduler_type == "cosine_decay":
76 | return LambdaLR(optimizer=optimizer,
77 | lr_lambda=CosineDecay(max_iter=max_iter, factor=cosine_factor))
78 |
79 | elif scheduler_type == "cosine_warmup_decay":
80 | return LambdaLR(optimizer=optimizer,
81 | lr_lambda=CosineWarmupDecay(max_iter=max_iter,
82 | factor=cosine_factor,
83 | num_warmup_steps=num_warmup_steps))
84 |
85 | elif scheduler_type == "step_decay":
86 | return StepLR(optimizer=optimizer,
87 | step_size=step_size,
88 | gamma=gamma)
89 |
90 | else:
91 | # dummy scheduler
92 | return LambdaLR(optimizer=optimizer, lr_lambda=lambda curr_iter: 1.0)
93 |
94 |
95 | __all__ = [
96 | # functions
97 | "build_lr_scheduler",
98 |
99 | # classes
100 | "CosineDecay",
101 | ]
102 |
--------------------------------------------------------------------------------
/utils/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from torch import distributed
3 | from PIL import Image
4 |
5 | from .utils import repeater, get_batch
6 |
7 | from .miniimagenet import MiniImageNet
8 | from .domainnet_real import DomainNetReal
9 | from .datasets import LabeledDataset
10 |
11 | from .datamodule import DataModule
12 | from .ssl_datamodule import SSLDataModule
13 |
14 | from .cifar10_datamodule import CIFAR10DataModule
15 | from .cifar100_datamodule import CIFAR100DataModule
16 | from .svhn_datamodule import SVHNDataModule
17 | from .miniimagenet_datamodule import MiniImageNetDataModule
18 | from .domainnet_real_datamodule import DomainNetRealDataModule
19 |
20 | from ..transforms import CenterResizedCrop
21 |
22 | # modules
23 | from . import utils
24 |
25 | # for type hint
26 | from typing import Dict, Callable
27 | from argparse import Namespace
28 |
29 |
30 | def _get_world_size() -> int:
31 | if distributed.is_available() and distributed.is_initialized():
32 | return distributed.get_world_size()
33 | else:
34 | return 1
35 |
36 |
37 | def _get_dataset_transforms(args: Namespace) -> Dict[str, Callable]:
38 | if args.data_dims is None and args.dataset == "domainnet-real":
39 | dims = (3, 224, 224)
40 | else:
41 | dims = args.data_dims
42 |
43 | if dims is not None:
44 | dims = tuple(dims)
45 |
46 | return dict(
47 | dims=dims,
48 | train_transform=transforms.Compose([
49 | transforms.RandomResizedCrop(
50 | size=dims[-2:],
51 | scale=(0.08, 1.0),
52 | ratio=(3. / 4, 4. / 3.),
53 | interpolation=Image.BICUBIC),
54 | transforms.ToTensor()]),
55 | val_transform=transforms.Compose([
56 | CenterResizedCrop(dims[-2:], interpolation=Image.BICUBIC),
57 | transforms.ToTensor()]),
58 | test_transform=transforms.Compose([
59 | CenterResizedCrop(dims[-2:], interpolation=Image.BICUBIC),
60 | transforms.ToTensor()]),
61 | )
62 |
63 | else:
64 | return dict(
65 | train_transform=transforms.Compose([transforms.ToTensor()]),
66 | val_transform=transforms.Compose([transforms.ToTensor()]),
67 | test_transform=transforms.Compose([transforms.ToTensor()]),
68 | )
69 |
70 |
71 | def get_dataset(args: Namespace) -> SSLDataModule:
72 | world_size = _get_world_size()
73 |
74 | kwargs = dict(
75 | data_dir=args.data_dir,
76 | labeled_train_size=args.labeled_train_size,
77 | validation_size=args.validation_size,
78 | train_batch_size=args.train_batch_size,
79 | unlabeled_batch_size=args.unlabeled_train_batch_size,
80 | test_batch_size=args.test_batch_size,
81 | num_workers=args.num_workers,
82 | train_min_size=world_size * args.train_batch_size,
83 | unlabeled_train_min_size=world_size * args.unlabeled_train_batch_size,
84 | test_min_size=world_size * args.test_batch_size,
85 | **_get_dataset_transforms(args),
86 | )
87 |
88 | if args.dataset == "cifar10":
89 | return CIFAR10DataModule(**kwargs)
90 |
91 | elif args.dataset == "cifar100":
92 | return CIFAR100DataModule(**kwargs)
93 |
94 | elif args.dataset == "svhn":
95 | return SVHNDataModule(**kwargs)
96 |
97 | elif args.dataset == "miniimagenet":
98 | return MiniImageNetDataModule(**kwargs)
99 |
100 | elif args.dataset == "domainnet-real":
101 | return DomainNetRealDataModule(**kwargs)
102 |
103 | else:
104 | raise NotImplementedError(f"\"{args.dataset}\" is not a supported dataset")
105 |
106 |
107 | __all__ = [
108 | # modules
109 | "utils",
110 | "types",
111 |
112 | # classes
113 | "MiniImageNet",
114 | "DomainNetReal",
115 | "LabeledDataset",
116 |
117 | "DataModule",
118 | "SSLDataModule",
119 | "CIFAR10DataModule",
120 | "CIFAR100DataModule",
121 | "DomainNetRealDataModule",
122 | "MiniImageNetDataModule",
123 | "SVHNDataModule",
124 |
125 | # functions
126 | "repeater",
127 | "get_batch",
128 | "get_dataset",
129 | ]
130 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils import get_args, timing, set_random_seed, get_device, get_dataset
4 | from models import get_augmenter
5 | from simple_estimator import SimPLEEstimator
6 | from ablation_estimator import AblationEstimator
7 | from trainer import Trainer
8 |
9 | # for type hint
10 | from typing import Optional, Type
11 | from argparse import Namespace
12 | from torch.nn import Module
13 |
14 | from utils.dataset import SSLDataModule
15 |
16 |
17 | def get_estimator_type(estimator_type: str) -> Type[SimPLEEstimator]:
18 | if estimator_type == "ablation":
19 | return AblationEstimator
20 | else:
21 | return SimPLEEstimator
22 |
23 |
24 | def get_estimator(args: Namespace,
25 | augmenter: Module,
26 | strong_augmenter: Module,
27 | val_augmenter: Module,
28 | num_classes: int,
29 | in_channels: int,
30 | device: Optional[torch.device],
31 | args_override: Optional[Namespace] = None,
32 | estimator_type: Optional[type(SimPLEEstimator)] = None):
33 | if estimator_type is None:
34 | estimator_type = get_estimator_type(args.estimator)
35 |
36 | if args.checkpoint_path is not None and \
37 | (bool(args.logger_config_dict.get("resume", False)) or not args.use_pretrain):
38 | estimator = estimator_type.from_checkpoint(
39 | augmenter=augmenter,
40 | strong_augmenter=strong_augmenter,
41 | val_augmenter=val_augmenter,
42 | checkpoint_path=args.checkpoint_path,
43 | num_classes=num_classes,
44 | in_channels=in_channels,
45 | device=device,
46 | args_override=args_override,
47 | recover_train_progress=True,
48 | recover_random_state=True)
49 | else:
50 | estimator = estimator_type(
51 | args,
52 | augmenter=augmenter,
53 | strong_augmenter=strong_augmenter,
54 | val_augmenter=val_augmenter,
55 | num_classes=num_classes,
56 | in_channels=in_channels,
57 | device=device)
58 |
59 | return estimator
60 |
61 |
62 | @timing
63 | def main(args: Namespace, datamodule: Optional[SSLDataModule] = None, device: Optional[torch.device] = None):
64 | if device is None:
65 | device = get_device(args.device)
66 |
67 | if datamodule is None:
68 | datamodule = get_dataset(args)
69 |
70 | # dataset stats
71 | dataset_mean = datamodule.dataset_mean
72 | dataset_std = datamodule.dataset_std
73 | image_size = datamodule.dims[1:]
74 |
75 | # build augmenters
76 | augmenter = get_augmenter(args.augmenter_type, image_size=image_size,
77 | dataset_mean=dataset_mean, dataset_std=dataset_std)
78 | strong_augmenter = get_augmenter(args.strong_augmenter_type, image_size=image_size,
79 | dataset_mean=dataset_mean, dataset_std=dataset_std)
80 | val_augmenter = get_augmenter("validation", image_size=image_size, dataset_mean=dataset_mean,
81 | dataset_std=dataset_std)
82 |
83 | estimator = get_estimator(args,
84 | augmenter=augmenter,
85 | strong_augmenter=strong_augmenter,
86 | val_augmenter=val_augmenter,
87 | num_classes=datamodule.num_classes,
88 | in_channels=datamodule.dims[0],
89 | device=device,
90 | args_override=args)
91 | trainer = Trainer(estimator, datamodule=datamodule)
92 |
93 | # training
94 | trainer.fit()
95 |
96 | # load the best state
97 | best_checkpoint_path = trainer.saver.find_best_checkpoint_path(ignore_absolute_best=False)
98 |
99 | if best_checkpoint_path is not None:
100 | best_checkpoint = torch.load(str(best_checkpoint_path), map_location=device)
101 | trainer.load_checkpoint(best_checkpoint)
102 |
103 | # evaluation
104 | trainer.test()
105 |
106 |
107 | if __name__ == '__main__':
108 | parsed_args = get_args()
109 |
110 | # fix random seed
111 | set_random_seed(parsed_args.seed, is_cudnn_deterministic=parsed_args.debug_mode)
112 |
113 | main(parsed_args)
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimPLE
2 |
3 | The code for the
4 | paper: "[SimPLE: Similar Pseudo Label Exploitation for Semi-Supervised Classification](https://arxiv.org/abs/2103.16725)"
5 | by [Zijian Hu*](https://www.zijianhu.com/),
6 | [Zhengyu Yang*](https://zhengyuyang.com/),
7 | [Xuefeng Hu](https://xuefenghu.me/), and [Ram Nevatia](https://sites.usc.edu/iris-cvlab/professor-ram-nevatia/).
8 |
9 | ## Abstract
10 |
11 |
12 |
13 | A common classification task situation is where one has a large amount of data available for training, but only a small
14 | portion is annotated with class labels. The goal of semi-supervised training, in this context, is to improve
15 | classification accuracy by leverage information not only from labeled data but also from a large amount of unlabeled
16 | data. Recent works have developed significant improvements by exploring the consistency constrain between differently
17 | augmented labeled and unlabeled data. Following this path, we propose a novel unsupervised objective that focuses on the
18 | less studied relationship between the high confidence unlabeled data that are similar to each other. The new proposed
19 | Pair Loss minimizes the statistical distance between high confidence pseudo labels with similarity above a certain
20 | threshold. Combining the Pair Loss with the techniques developed by the MixMatch family, our proposed SimPLE algorithm
21 | shows significant performance gains over previous algorithms on CIFAR-100 and Mini-ImageNet, and is on par with the
22 | state-of-the-art methods on CIFAR-10 and SVHN. Furthermore, SimPLE also outperforms the state-of-the-art methods in the
23 | transfer learning setting, where models are initialized by the weights pre-trained on ImageNet or DomainNet-Real.
24 |
25 | ## News
26 |
27 | [11/8/2021]: bugfix that could result in incorrect data loading in distributed training
28 | [9/20/2021]: add data_dims option for data resizing
29 | [8/31/2021]: update the code base for easier extension
30 | [6/22/2021]: add Mini-ImageNet example
31 | [6/2/2021]: add animations and fix broken link in README
32 | [5/30/2021]: initial release
33 |
34 | ## Requirements
35 |
36 | *see [requirements.txt](requirements.txt) for detail*
37 |
38 | - Python 3.6 or newer
39 | - [PyTorch](https://pytorch.org/) 1.6.0 or newer
40 | - [torchvision](https://pytorch.org/docs/stable/torchvision/index.html) 0.7.0 or newer
41 | - [kornia](https://kornia.readthedocs.io/en/latest/augmentation.html) 0.5.0 or newer
42 | - numpy
43 | - scikit-learn
44 | - [plotly](https://plotly.com/python/) 4.0.0 or newer
45 | - wandb 0.9.0 or newer (**optional**, required for logging to [Weights & Biases](https://www.wandb.com/)
46 | see [utils.loggers.WandbLogger](utils/loggers.py) for detail)
47 |
48 | ### Recommended versions
49 |
50 | |Python|PyTorch|torchvision|kornia|
51 | | --- | --- | --- | --- |
52 | |3.8.5|1.6.0|0.7.0|0.5.0|
53 |
54 | ## Setup
55 |
56 | ### Install dependencies
57 |
58 | using pip:
59 |
60 | ```shell
61 | pip install -r requirements.txt
62 | ```
63 |
64 | or using conda:
65 |
66 | ```shell
67 | conda env create -f environment.yaml
68 | ```
69 |
70 | ## Running
71 |
72 | ### Example
73 |
74 | To replicate Mini-ImageNet results
75 |
76 | ```shell
77 | CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES="0" \
78 | python main.py \
79 | @runs/miniimagenet_args.txt
80 | ```
81 |
82 | To replicate CIFAR-10 results
83 |
84 | ```shell
85 | CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES="0" \
86 | python main.py \
87 | @runs/cifar10_args.txt
88 | ```
89 |
90 | To replicate CIFAR-100 result (with distributed training)
91 |
92 | ```shell
93 | CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES="0,1" \
94 | python -m torch.distributed.launch \
95 | --nproc_per_node=2 main_ddp.py \
96 | @runs/cifar100_args.txt \
97 | --num-epochs 2048 \
98 | --num-step-per-epoch 512
99 | ```
100 |
101 | ## Citation
102 |
103 | ```bibtex
104 | @InProceedings{Hu-2020-SimPLE,
105 | author = {{Hu*}, Zijian and {Yang*}, Zhengyu and Hu, Xuefeng and Nevaita, Ram},
106 | title = {{SimPLE}: {S}imilar {P}seudo {L}abel {E}xploitation for {S}emi-{S}upervised {C}lassification},
107 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
108 | month = {June},
109 | year = {2021},
110 | url = {https://arxiv.org/abs/2103.16725},
111 | }
112 | ```
113 |
--------------------------------------------------------------------------------
/models/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from copy import deepcopy
5 | from collections import OrderedDict
6 | import warnings
7 | from contextlib import contextmanager
8 |
9 | from .utils import unwrap_model
10 |
11 | # for type hint
12 | from torch import Tensor
13 | from typing import Dict
14 |
15 |
16 | class EMA(nn.Module):
17 | def __init__(self, model: nn.Module, decay: float):
18 | # adapted from https://fyubang.com/2019/06/01/ema/
19 | super().__init__()
20 | self.decay = decay
21 |
22 | self.model = model
23 | self.shadow = deepcopy(self.model)
24 |
25 | for param in self.shadow.parameters():
26 | param.detach_()
27 |
28 | @torch.no_grad()
29 | def update(self):
30 | if not self.training:
31 | warnings.warn("EMA update should only be called during training")
32 | return
33 |
34 | model_params = OrderedDict(self.model.named_parameters())
35 | shadow_params = OrderedDict(self.shadow.named_parameters())
36 |
37 | # check if both model contains the same set of keys
38 | assert model_params.keys() == shadow_params.keys()
39 |
40 | for name, param in model_params.items():
41 | # see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
42 | # shadow_variable -= (1 - decay) * (shadow_variable - variable)
43 | shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
44 |
45 | model_buffers = OrderedDict(self.model.named_buffers())
46 | shadow_buffers = OrderedDict(self.shadow.named_buffers())
47 |
48 | # check if both model contains the same set of keys
49 | assert model_buffers.keys() == shadow_buffers.keys()
50 |
51 | for name, buffer in model_buffers.items():
52 | # buffers are copied
53 | shadow_buffers[name].copy_(buffer)
54 |
55 | def forward(self, inputs: Tensor, return_feature: bool = False) -> Tensor:
56 | if self.training:
57 | return self.model(inputs, return_feature)
58 | else:
59 | return self.shadow(inputs, return_feature)
60 |
61 | @contextmanager
62 | def data_parallel_switch(self):
63 | model = self.model
64 | shadow = self.shadow
65 |
66 | self.model = unwrap_model(self.model)
67 | self.shadow = unwrap_model(self.shadow)
68 |
69 | try:
70 | yield
71 | finally:
72 | self.model = model
73 | self.shadow = shadow
74 |
75 | def state_dict(self, destination=None, prefix='', keep_vars=False):
76 | with self.data_parallel_switch():
77 | return super().state_dict(destination, prefix, keep_vars)
78 |
79 | def load_state_dict(self, state_dict: Dict[str, Tensor], strict=True):
80 | with self.data_parallel_switch():
81 | super().load_state_dict(state_dict, strict)
82 |
83 |
84 | class FullEMA(EMA):
85 | excluded_param_suffix = ["num_batches_tracked"]
86 |
87 | def __init__(self, model: nn.Module, decay: float):
88 | super().__init__(model, decay)
89 |
90 | self._excluded_param_names = set()
91 |
92 | @torch.no_grad()
93 | def update(self):
94 | if not self.training:
95 | warnings.warn("EMA update should only be called during training")
96 | return
97 |
98 | model_state_dict = self.model.state_dict()
99 | shadow_state_dict = self.shadow.state_dict()
100 |
101 | # check if both model contains the same set of keys
102 | assert model_state_dict.keys() == shadow_state_dict.keys()
103 |
104 | for name, param in model_state_dict.items():
105 | if name not in self.excluded_param_names:
106 | shadow_state_dict[name].sub_((1. - self.decay) * (shadow_state_dict[name] - param))
107 | else:
108 | shadow_state_dict[name].copy_(param)
109 |
110 | @staticmethod
111 | def is_ema_exclude(param_name: str) -> bool:
112 | return any([param_name.endswith(suffix) for suffix in FullEMA.excluded_param_suffix])
113 |
114 | @property
115 | def excluded_param_names(self):
116 | if len(self._excluded_param_names) == 0:
117 | for name, param in self.model.state_dict().items():
118 | if self.is_ema_exclude(name):
119 | self._excluded_param_names.add(name)
120 |
121 | return self._excluded_param_names
122 |
--------------------------------------------------------------------------------
/loss/pair_loss/pair_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | from .utils import get_pair_indices
5 |
6 | # for type hint
7 | from typing import Optional
8 | from torch import Tensor
9 |
10 | from ..types import SimilarityType, DistanceLossType
11 |
12 |
13 | class PairLoss:
14 | def __init__(self,
15 | similarity_metric: SimilarityType,
16 | distance_loss_metric: DistanceLossType,
17 | confidence_threshold: float,
18 | similarity_threshold: float,
19 | similarity_type: str,
20 | distance_loss_type: str,
21 | reduction: str = "mean"):
22 | self.confidence_threshold = confidence_threshold
23 | self.similarity_threshold = similarity_threshold
24 |
25 | self.similarity_type = similarity_type
26 | self.distance_loss_type = distance_loss_type
27 |
28 | self.reduction = reduction
29 |
30 | self.similarity_metric = similarity_metric
31 | self.distance_loss_metric = distance_loss_metric
32 |
33 | def __call__(self,
34 | logits: Tensor,
35 | probs: Tensor,
36 | targets: Tensor,
37 | *args,
38 | indices: Optional[Tensor] = None,
39 | **kwargs) -> Tensor:
40 | """
41 |
42 | Args:
43 | logits: (batch_size, num_classes) predictions of batch data
44 | probs: (batch_size, num_classes) softmax probs logits
45 | targets: (batch_size, num_classes) one-hot labels
46 |
47 | Returns: Pair loss value as a Tensor.
48 |
49 | """
50 | if indices is None:
51 | indices = get_pair_indices(targets, ordered_pair=True)
52 | total_size = len(indices) // 2
53 |
54 | i_indices, j_indices = indices[:, 0], indices[:, 1]
55 | targets_max_prob = targets.max(dim=1).values
56 |
57 | return self.compute_loss(logits_j=logits[j_indices],
58 | probs_j=probs[j_indices],
59 | targets_i=targets[i_indices],
60 | targets_j=targets[j_indices],
61 | targets_i_max_prob=targets_max_prob[i_indices],
62 | total_size=total_size)
63 |
64 | def compute_loss(self,
65 | logits_j: Tensor,
66 | probs_j: Tensor,
67 | targets_i: Tensor,
68 | targets_j: Tensor,
69 | targets_i_max_prob: Tensor,
70 | total_size: int):
71 | # conf_mask should not track gradient
72 | conf_mask = (targets_i_max_prob > self.confidence_threshold).detach().float()
73 |
74 | similarities: Tensor = self.get_similarity(targets_i=targets_i,
75 | targets_j=targets_j,
76 | dim=1)
77 | # sim_mask should not track gradient
78 | sim_mask = F.threshold(similarities, self.similarity_threshold, 0).detach()
79 |
80 | distance = self.get_distance_loss(logits=logits_j,
81 | probs=probs_j,
82 | targets=targets_i,
83 | dim=1,
84 | reduction='none')
85 |
86 | loss = conf_mask * sim_mask * distance
87 |
88 | if self.reduction == "mean":
89 | loss = torch.sum(loss) / total_size
90 | elif self.reduction == "sum":
91 | loss = torch.sum(loss)
92 |
93 | return loss
94 |
95 | def get_similarity(self,
96 | targets_i: Tensor,
97 | targets_j: Tensor,
98 | *args,
99 | **kwargs) -> Tensor:
100 | x, y = targets_i, targets_j
101 |
102 | return self.similarity_metric(x, y, *args, **kwargs)
103 |
104 | def get_distance_loss(self,
105 | logits: Tensor,
106 | probs: Tensor,
107 | targets: Tensor,
108 | *args,
109 | **kwargs) -> Tensor:
110 | if self.distance_loss_type == "prob":
111 | x, y = probs, targets
112 | else:
113 | x, y = logits, targets
114 |
115 | return self.distance_loss_metric(x, y, *args, **kwargs)
116 |
--------------------------------------------------------------------------------
/models/augmentation/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from kornia import augmentation as K
4 | from kornia import filters as F
5 | from torchvision import transforms
6 |
7 | from .augmenter import RandomAugmentation
8 | from .randaugment import RandAugmentNS
9 |
10 | # for type hint
11 | from typing import List, Tuple, Union, Callable
12 | from torch import Tensor
13 | from torch.nn import Module
14 | from PIL.Image import Image as PILImage
15 |
16 | DatasetStatType = List[float]
17 | ImageSizeType = Tuple[int, int]
18 | PaddingInputType = Union[float, Tuple[float, float], Tuple[float, float, float, float]]
19 | ImageType = Union[Tensor, PILImage]
20 |
21 |
22 | def get_augmenter(augmenter_type: str,
23 | image_size: ImageSizeType,
24 | dataset_mean: DatasetStatType,
25 | dataset_std: DatasetStatType,
26 | padding: PaddingInputType = 1. / 8.,
27 | pad_if_needed: bool = False,
28 | subset_size: int = 2) -> Union[Module, Callable]:
29 | """
30 |
31 | Args:
32 | augmenter_type: augmenter type
33 | image_size: (height, width) image size
34 | dataset_mean: dataset mean value in CHW
35 | dataset_std: dataset standard deviation in CHW
36 | padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided,
37 | it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is
38 | used to pad left/right, top/bottom borders, respectively.
39 | pad_if_needed: bool flag for RandomCrop "pad_if_needed" option
40 | subset_size: number of augmentations used in subset
41 |
42 | Returns: nn.Module for Kornia augmentation or Callable for torchvision transform
43 |
44 | """
45 | if not isinstance(padding, tuple):
46 | assert isinstance(padding, float)
47 | padding = (padding, padding, padding, padding)
48 |
49 | assert len(padding) == 2 or len(padding) == 4
50 | if len(padding) == 2:
51 | # padding of length 2 is used to pad left/right, top/bottom borders, respectively
52 | # padding of length 4 is used to pad left, top, right, bottom borders respectively
53 | padding = (padding[0], padding[1], padding[0], padding[1])
54 |
55 | # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders
56 | padding = (
57 | int(image_size[1] * padding[0]),
58 | int(image_size[0] * padding[1]),
59 | int(image_size[1] * padding[2]),
60 | int(image_size[0] * padding[3])
61 | )
62 |
63 | augmenter_type = augmenter_type.strip().lower()
64 |
65 | if augmenter_type == "simple":
66 | return nn.Sequential(
67 | K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed,
68 | padding_mode='reflect'),
69 | K.RandomHorizontalFlip(p=0.5),
70 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
71 | std=torch.tensor(dataset_std, dtype=torch.float32)),
72 | )
73 |
74 | elif augmenter_type == "fixed":
75 | return nn.Sequential(
76 | K.RandomHorizontalFlip(p=0.5),
77 | # K.RandomVerticalFlip(p=0.2),
78 | K.RandomResizedCrop(size=image_size, scale=(0.8, 1.0), ratio=(1., 1.)),
79 | RandomAugmentation(
80 | p=0.5,
81 | augmentation=F.GaussianBlur2d(
82 | kernel_size=(3, 3),
83 | sigma=(1.5, 1.5),
84 | border_type='constant'
85 | )
86 | ),
87 | K.ColorJitter(contrast=(0.75, 1.5)),
88 | # additive Gaussian noise
89 | K.RandomErasing(p=0.1),
90 | # Multiply
91 | K.RandomAffine(
92 | degrees=(-25., 25.),
93 | translate=(0.2, 0.2),
94 | scale=(0.8, 1.2),
95 | shear=(-8., 8.)
96 | ),
97 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
98 | std=torch.tensor(dataset_std, dtype=torch.float32)),
99 | )
100 |
101 | elif augmenter_type in ["validation", "test"]:
102 | return nn.Sequential(
103 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
104 | std=torch.tensor(dataset_std, dtype=torch.float32)),
105 | )
106 |
107 | elif augmenter_type == "randaugment":
108 | return nn.Sequential(
109 | K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed,
110 | padding_mode='reflect'),
111 | K.RandomHorizontalFlip(p=0.5),
112 | RandAugmentNS(n=subset_size, m=10),
113 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32),
114 | std=torch.tensor(dataset_std, dtype=torch.float32)),
115 | )
116 |
117 | else:
118 | raise NotImplementedError(f"\"{augmenter_type}\" is not a supported augmenter type")
119 |
120 |
121 | __all__ = [
122 | # modules
123 | # classes
124 | # functions
125 | "get_augmenter",
126 | ]
127 |
--------------------------------------------------------------------------------
/utils/loggers/logger.py:
--------------------------------------------------------------------------------
1 | from .log_aggregator import LogAggregator
2 | from ..utils import dict_add_prefix
3 |
4 | # for type hint
5 | from typing import Any, Dict, Optional, Union, Tuple
6 | from argparse import Namespace
7 |
8 | from torch.nn import Module
9 | from plotly.graph_objects import Figure
10 | from wandb import Histogram
11 |
12 |
13 | class Logger:
14 | def __init__(self,
15 | log_dir: str,
16 | config: Union[Namespace, Dict[str, Any]],
17 | *args,
18 | log_info_key_map: Optional[Dict[str, str]] = None,
19 | **kwargs):
20 | self.log_dir = log_dir
21 | self.config = config
22 |
23 | # hook functions
24 | self.log_hooks = []
25 |
26 | self.metric_smooth_record: Dict[str, Dict[str, Union[str, Optional[float]]]] = {
27 | "train/mean_acc": {
28 | "key": "train/smoothed_acc",
29 | "value": None,
30 | },
31 | "unlabeled/mean_acc": {
32 | "key": "unlabeled/smoothed_acc",
33 | "value": None,
34 | },
35 | "validation/mean_acc": {
36 | "key": "validation/smoothed_acc",
37 | "value": None,
38 | },
39 | "test/mean_acc": {
40 | "key": "test/smoothed_acc",
41 | "value": None,
42 | },
43 | }
44 |
45 | self.smoothing_weight = 0.9
46 |
47 | # init key map
48 | if log_info_key_map is None:
49 | log_info_key_map = dict()
50 |
51 | self.log_info_key_map = log_info_key_map
52 |
53 | # init log accumulator
54 | self.log_aggregator = LogAggregator()
55 |
56 | def log(self, log_info: Dict[str, Any], *args, **kwargs):
57 | pass
58 |
59 | def watch(self, model: Module, *args, **kwargs):
60 | pass
61 |
62 | def save(self, output_path: str):
63 | pass
64 |
65 | def process_log_info(self,
66 | log_info: Dict[str, Any],
67 | *args,
68 | prefix: Optional[str] = None,
69 | log_info_override: Optional[Dict[str, Any]] = None,
70 | **kwargs) -> Dict[str, Any]:
71 | if log_info_override is None:
72 | log_info_override = {}
73 |
74 | # create shallow copy
75 | log_info = dict(log_info)
76 |
77 | # apply override
78 | log_info.update(log_info_override)
79 |
80 | # update keys based on key map
81 | for key in list(log_info.keys()):
82 | if key in self.log_info_key_map:
83 | new_key = self.log_info_key_map[key]
84 | log_info[new_key] = log_info.pop(key)
85 |
86 | if bool(prefix):
87 | # prepend prefix to info_dict keys
88 | log_info = dict_add_prefix(log_info, prefix)
89 |
90 | # apply smoothing
91 | smoothed_metrics = self.smooth_metrics(log_info)
92 | log_info.update(smoothed_metrics)
93 |
94 | return log_info
95 |
96 | def register_log_hook(self, func, *args, **kwargs):
97 | self.log_hooks.append([func, args, kwargs])
98 |
99 | def call_log_hooks(self, log_info: Dict[str, Any]):
100 | for (func, args, kwargs) in self.log_hooks:
101 | func(log_info=log_info, *args, **kwargs)
102 |
103 | @staticmethod
104 | def separate_plot(input_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
105 | log_info: Dict[str, Any] = dict()
106 | plot_info: Dict[str, Any] = dict()
107 |
108 | for k, v in input_dict.items():
109 | if isinstance(v, Figure) or isinstance(v, Histogram):
110 | plot_info[k] = v
111 | else:
112 | log_info[k] = v
113 |
114 | return log_info, plot_info
115 |
116 | def smooth_metrics(self, log_info: Dict[str, Any]) -> Dict[str, Any]:
117 | # see https://stackoverflow.com/a/49357445/5838091
118 | smoothed_metrics = dict()
119 |
120 | for key, value in log_info.items():
121 | if key not in self.metric_smooth_record:
122 | continue
123 |
124 | smoothed_metric_dict = self.metric_smooth_record[key]
125 | if smoothed_metric_dict["value"] is None:
126 | smoothed_metric_dict["value"] = value
127 | else:
128 | smoothed_metric_dict["value"] = smoothed_metric_dict["value"] * self.smoothing_weight + \
129 | (1 - self.smoothing_weight) * value
130 |
131 | smoothed_metrics[smoothed_metric_dict["key"]] = smoothed_metric_dict["value"]
132 |
133 | return smoothed_metrics
134 |
135 | def accumulate_log(self, log_info: Optional[Dict[str, Any]] = None, plot_info: Optional[Dict[str, Any]] = None):
136 | if log_info is not None:
137 | self.log_aggregator.add_log(log_info)
138 |
139 | if plot_info is not None:
140 | self.log_aggregator.add_plot(plot_info)
141 |
142 | def aggregate_log(self, reduction: str = "mean") -> Dict[str, Any]:
143 | return self.log_aggregator.aggregate(reduction=reduction)
144 |
145 | def reset_aggregator(self):
146 | self.log_aggregator.clear()
147 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import onnx
3 | import numpy as np
4 | from .models.utils import *
5 |
6 | # for type hint
7 | from typing import Optional, Sequence, List, Generator, Any, Union, Set, Tuple, Dict
8 | from torch import Tensor
9 | from torch.nn import Module, Parameter
10 |
11 |
12 | @torch.no_grad()
13 | def get_accuracy(output: Tensor, target: Tensor, top_k: Sequence[int] = (1,)) -> List[Tensor]:
14 | # see https://discuss.pytorch.org/t/imagenet-example-accuracy-calculation/7840
15 | max_k = max(top_k)
16 | batch_size = target.size(0)
17 |
18 | _, pred = output.topk(max_k, dim=1, largest=True, sorted=True)
19 | correct = pred.eq(target.view(-1, 1).expand_as(pred))
20 |
21 | res = []
22 | for k in top_k:
23 | correct_k = (correct[:, :k].sum(dim=1, keepdim=False) > 0).float()
24 | res.append(correct_k.sum() / batch_size)
25 | return res
26 |
27 |
28 | def interleave_offsets(batch_size: int, num_unlabeled: int) -> List[int]:
29 | # TODO: scrutiny
30 | groups = [batch_size // (num_unlabeled + 1)] * (num_unlabeled + 1)
31 | for x in range(batch_size - sum(groups)):
32 | groups[-x - 1] += 1
33 | offsets = [0] + np.cumsum(groups).tolist()
34 | assert offsets[-1] == batch_size
35 | return offsets
36 |
37 |
38 | def interleave(xy: Sequence[Tensor], batch_size: int) -> List[Tensor]:
39 | # TODO: scrutiny
40 | num_unlabeled = len(xy) - 1
41 | offsets = interleave_offsets(batch_size, num_unlabeled)
42 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(num_unlabeled + 1)] for v in xy]
43 | for i in range(1, num_unlabeled + 1):
44 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
45 | return [torch.cat(v, dim=0) for v in xy]
46 |
47 |
48 | def get_gradient_norm(model: Module, grad_enabled: bool = False) -> Tensor:
49 | with torch.set_grad_enabled(grad_enabled):
50 | return sum([(p.grad.detach() ** 2).sum() for p in model.parameters() if p.grad is not None])
51 |
52 |
53 | def get_weight_norm(model: Module, grad_enabled: bool = False) -> Tensor:
54 | with torch.set_grad_enabled(grad_enabled):
55 | return sum([(p.detach() ** 2).sum() for p in model.parameters() if p.data is not None])
56 |
57 |
58 | def split_classifier_params(model: Module, classifier_prefix: Union[str, Set[str]]) \
59 | -> Tuple[List[Parameter], List[Parameter]]:
60 | if not isinstance(classifier_prefix, Set):
61 | classifier_prefix = {classifier_prefix}
62 |
63 | # build tuple for multiple prefix matching
64 | classifier_prefix = tuple(sorted(f"{prefix}." for prefix in classifier_prefix))
65 |
66 | embedder_weights = []
67 | classifier_weights = []
68 |
69 | for k, v in model.named_parameters():
70 | if k.startswith(classifier_prefix):
71 | classifier_weights.append(v)
72 | else:
73 | embedder_weights.append(v)
74 |
75 | return embedder_weights, classifier_weights
76 |
77 |
78 | def set_model_mode(model: Module, mode: Optional[bool]) -> Generator[Any, Any, None]:
79 | """
80 | A context manager to temporarily set the training mode of ‘model’ to ‘mode’, resetting it when we exit the
81 | with-block. A no-op if mode is None
82 |
83 | Args:
84 | model: the model
85 | mode: a bool or None
86 |
87 | Returns:
88 |
89 | """
90 | if hasattr(onnx, "select_model_mode_for_export"):
91 | # In PyTorch 1.6+, set_training is changed to select_model_mode_for_export
92 | return onnx.select_model_mode_for_export(model, mode)
93 | else:
94 | return onnx.set_training(model, mode)
95 |
96 |
97 | def consume_prefix_in_state_dict_if_present(state_dict: Dict[str, Any], prefix: str):
98 | r"""copied from https://github.com/pytorch/pytorch/blob/255494c2aa1fcee7e605a6905be72e5b8ccf4646/torch/nn/modules/utils.py#L37-L67
99 |
100 | Strip the prefix in state_dict, if any.
101 | ..note::
102 | Given a `state_dict` from a DP/DDP model, a local model can load it by applying
103 | `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling
104 | :meth:`torch.nn.Module.load_state_dict`.
105 | Args:
106 | state_dict (OrderedDict): a state-dict to be loaded to the model.
107 | prefix (str): prefix.
108 | """
109 | keys = sorted(state_dict.keys())
110 | for key in keys:
111 | if key.startswith(prefix):
112 | newkey = key[len(prefix):]
113 | state_dict[newkey] = state_dict.pop(key)
114 |
115 | # also strip the prefix in metadata if any.
116 | if "_metadata" in state_dict:
117 | metadata = state_dict["_metadata"]
118 | for key in list(metadata.keys()):
119 | # for the metadata dict, the key can be:
120 | # '': for the DDP module, which we want to remove.
121 | # 'module': for the actual model.
122 | # 'module.xx.xx': for the rest.
123 |
124 | if len(key) == 0:
125 | continue
126 | newkey = key[len(prefix):]
127 | metadata[newkey] = metadata.pop(key)
128 |
129 |
130 | __all__ = [
131 | "get_accuracy",
132 | "interleave",
133 | "get_gradient_norm",
134 | "get_weight_norm",
135 | "split_classifier_params",
136 | "set_model_mode",
137 | "unwrap_model",
138 | "consume_prefix_in_state_dict_if_present",
139 | ]
140 |
--------------------------------------------------------------------------------
/utils/dataset/domainnet_real.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import ImageFolder
2 | from torchvision.datasets.utils import check_integrity, download_url, extract_archive
3 |
4 | from pathlib import Path
5 | from typing import NamedTuple
6 | import re
7 | import os
8 | import shutil
9 |
10 | from .utils import get_directory_size
11 |
12 | # for type hint
13 | from typing import Optional
14 |
15 | FileMeta = NamedTuple("FileMeta", [("filename", str), ("url", str), ("md5", Optional[str])])
16 |
17 |
18 | class DomainNetReal(ImageFolder):
19 | base_folder = 'domainnet-real'
20 |
21 | data_file_meta = FileMeta(
22 | filename="domainnet-real.zip",
23 | url="http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip",
24 | md5="dcc47055e8935767784b7162e7c7cca6")
25 |
26 | train_label_file_meta = FileMeta(
27 | filename="domainnet-real_train.txt",
28 | url="http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/real_train.txt",
29 | md5="8ebf02c2075fadd564705f0dc7cd6291")
30 |
31 | test_label_file_meta = FileMeta(
32 | filename="domainnet-real_test.txt",
33 | url="http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/real_test.txt",
34 | md5="6098816791c3ebed543c71ffa11b9054")
35 |
36 | label_file_pattern = re.compile(r"(.*) (\d+)")
37 |
38 | # extracted file size in Bytes
39 | DATA_FOLDER_SIZE = 6_234_186_058
40 | TRAIN_FOLDER_SIZE = 4_301_431_405
41 | TEST_FOLDER_SIZE = 1_860_180_803
42 |
43 | def __init__(self,
44 | root: str,
45 | train: bool = True,
46 | download: bool = False,
47 | **kwargs):
48 | self.root = root
49 | self.train = train
50 |
51 | if download:
52 | self.download()
53 |
54 | if not self._check_integrity():
55 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
56 |
57 | # parse dataset
58 | self.parse_archives()
59 |
60 | super(DomainNetReal, self).__init__(root=str(self.split_folder), **kwargs)
61 | self.root = root
62 |
63 | @property
64 | def data_root(self) -> Path:
65 | return Path(self.root) / self.base_folder
66 |
67 | @property
68 | def download_root(self) -> Path:
69 | return Path(self.root)
70 |
71 | @property
72 | def data_folder(self) -> Path:
73 | return self.data_root / "real"
74 |
75 | @property
76 | def split_folder(self) -> Path:
77 | return self.data_root / ("train" if self.train else "test")
78 |
79 | @property
80 | def SPLIT_FOLDER_SIZE(self) -> int:
81 | return self.TRAIN_FOLDER_SIZE if self.train else self.TEST_FOLDER_SIZE
82 |
83 | @property
84 | def label_file_meta(self) -> FileMeta:
85 | return self.train_label_file_meta if self.train else self.test_label_file_meta
86 |
87 | def download(self) -> None:
88 | if self._check_integrity():
89 | print('Files already downloaded and verified')
90 | return
91 |
92 | # remove old files
93 | shutil.rmtree(self.split_folder, ignore_errors=True)
94 | shutil.rmtree(self.data_folder, ignore_errors=True)
95 |
96 | for file_meta in (self.data_file_meta, self.label_file_meta):
97 | download_url(url=file_meta.url,
98 | root=str(self.download_root),
99 | filename=file_meta.filename,
100 | md5=file_meta.md5)
101 |
102 | def _check_integrity(self) -> bool:
103 | for file_meta in (self.data_file_meta, self.label_file_meta):
104 | if not check_integrity(fpath=str(self.download_root / file_meta.filename), md5=file_meta.md5):
105 | return False
106 |
107 | return True
108 |
109 | def parse_archives(self) -> None:
110 | if not self.split_folder.is_dir() or get_directory_size(self.split_folder) < self.SPLIT_FOLDER_SIZE:
111 | # if split_folder do not exist or not large enough
112 | self.parse_data_archive()
113 |
114 | # remove old files
115 | shutil.rmtree(self.split_folder, ignore_errors=True)
116 |
117 | with open(self.download_root / self.label_file_meta.filename, "r") as f:
118 | file_content = [line.strip() for line in f.readlines()]
119 |
120 | for line in file_content:
121 | search_result = self.label_file_pattern.search(line)
122 | assert search_result is not None, f"{self.label_file_meta.filename} contains invalid line \"{line}\""
123 |
124 | image_path = Path(search_result.group(1))
125 | image_relative_path = Path(*image_path.parts[1:])
126 |
127 | source_path = self.data_root / image_path
128 | target_path = self.split_folder / image_relative_path
129 |
130 | if not target_path.is_file():
131 | target_path.parent.mkdir(parents=True, exist_ok=True)
132 | os.link(src=source_path.absolute(), dst=target_path.absolute())
133 |
134 | def parse_data_archive(self) -> None:
135 | if not self.data_folder.is_dir() or get_directory_size(self.data_folder) < self.DATA_FOLDER_SIZE:
136 | # if data_folder do not exist or not large enough
137 | # remove old files
138 | shutil.rmtree(self.data_folder, ignore_errors=True)
139 |
140 | print(f"extracting {self.data_file_meta.filename}...")
141 | extract_archive(from_path=str(self.download_root / self.data_file_meta.filename),
142 | to_path=str(self.data_root))
143 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from functools import partial
4 |
5 | from .augmentation import get_augmenter
6 | from .rampup import RampUp, LinearRampUp, get_ramp_up
7 | from .utils import (get_accuracy, interleave, unwrap_model, split_classifier_params)
8 | from .models import EMA, build_model, build_ema_model
9 | from .optimization import build_optimizer, build_lr_scheduler
10 |
11 | from . import augmentation
12 | from . import rampup
13 | from . import utils
14 | from . import mixmatch
15 | from . import optimization
16 | from . import types
17 | from . import models
18 |
19 | # for type hint
20 | from typing import Optional, Union, Set
21 | from argparse import Namespace
22 | from torch.nn import Module
23 |
24 | from .types import OptimizerParametersType, MixMatchFunctionType
25 |
26 |
27 | def load_pretrain(model: Module,
28 | checkpoint_path: str,
29 | allowed_prefix: str,
30 | ignored_prefix: str,
31 | device: torch.device,
32 | checkpoint_key: Optional[str] = None):
33 | full_allowed_prefix = f"{allowed_prefix}." if bool(allowed_prefix) else allowed_prefix
34 | full_ignored_prefix = f"{ignored_prefix}." if bool(ignored_prefix) else ignored_prefix
35 |
36 | checkpoint = torch.load(checkpoint_path, map_location=device)
37 | if checkpoint_key is not None:
38 | pretrain_state_dict = checkpoint[checkpoint_key]
39 | else:
40 | pretrain_state_dict = checkpoint
41 |
42 | shadow = None
43 | if isinstance(model, EMA):
44 | shadow = unwrap_model(model.shadow)
45 | model = unwrap_model(model.model)
46 |
47 | state_dict = model.state_dict()
48 |
49 | for name, param in pretrain_state_dict.items():
50 | if name.startswith(full_allowed_prefix) and not name.startswith(full_ignored_prefix):
51 | name = name[len(full_allowed_prefix):]
52 |
53 | assert name in state_dict.keys()
54 | state_dict[name] = param
55 |
56 | # load pretrain model
57 | model.load_state_dict(state_dict)
58 | if shadow is not None:
59 | shadow.load_state_dict(state_dict)
60 |
61 |
62 | def get_trainable_params(model: Module,
63 | learning_rate: float,
64 | feature_learning_rate: Optional[float],
65 | classifier_prefix: Union[str, Set[str]] = 'fc',
66 | requires_grad_only: bool = True) -> OptimizerParametersType:
67 | if feature_learning_rate is not None:
68 | embedder_weights, classifier_weights = split_classifier_params(model, classifier_prefix)
69 |
70 | if requires_grad_only:
71 | # keep only the parameters that requires grad
72 | embedder_weights = [param for param in embedder_weights if param.requires_grad]
73 | classifier_weights = [param for param in classifier_weights if param.requires_grad]
74 |
75 | params = [dict(params=embedder_weights, lr=feature_learning_rate),
76 | dict(params=classifier_weights, lr=learning_rate)]
77 | else:
78 | params = model.parameters()
79 |
80 | if requires_grad_only:
81 | # keep only the parameters that requires grad
82 | params = [param for param in params if param.requires_grad]
83 |
84 | return params
85 |
86 |
87 | def get_mixmatch_function(args: Namespace,
88 | num_classes: int,
89 | augmenter: Module,
90 | strong_augmenter: Module) -> MixMatchFunctionType:
91 | if args.mixmatch_type == "simple":
92 | from .mixmatch import SimPLE
93 |
94 | return SimPLE(augmenter=augmenter,
95 | strong_augmenter=strong_augmenter,
96 | num_classes=num_classes,
97 | temperature=args.t,
98 | num_augmentations=args.k,
99 | num_strong_augmentations=args.k_strong,
100 | is_strong_augment_x=False,
101 | train_label_guessing=False)
102 |
103 | elif args.mixmatch_type == "enhanced":
104 | from .mixmatch import MixMatchEnhanced
105 |
106 | return MixMatchEnhanced(augmenter=augmenter,
107 | strong_augmenter=strong_augmenter,
108 | num_classes=num_classes,
109 | temperature=args.t,
110 | num_augmentations=args.k,
111 | num_strong_augmentations=args.k_strong,
112 | alpha=args.alpha,
113 | is_strong_augment_x=False,
114 | train_label_guessing=False)
115 |
116 | elif args.mixmatch_type == "mixmatch":
117 | from .mixmatch import MixMatch
118 |
119 | return MixMatch(augmenter=augmenter,
120 | num_classes=num_classes,
121 | temperature=args.t,
122 | num_augmentations=args.k,
123 | alpha=args.alpha,
124 | train_label_guessing=False)
125 |
126 | else:
127 | raise NotImplementedError(f"{args.mixmatch_type} is not a supported mixmatch type")
128 |
129 |
130 | __all__ = [
131 | # modules
132 | "augmentation",
133 | "rampup",
134 | "utils",
135 | "mixmatch",
136 | "models",
137 | "optimization",
138 | "types",
139 |
140 | # classes
141 | "EMA",
142 |
143 | # functions
144 | "interleave",
145 | "get_augmenter",
146 | "get_ramp_up",
147 | "get_accuracy",
148 | "build_model",
149 | "build_ema_model",
150 | "build_optimizer",
151 | "build_lr_scheduler",
152 | "load_pretrain",
153 | "get_trainable_params",
154 | "get_mixmatch_function",
155 | ]
156 |
--------------------------------------------------------------------------------
/utils/dataset/ssl_datamodule.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from pathlib import Path
3 |
4 | from .datamodule import DataModule
5 | from .utils import get_batch
6 |
7 | # for type hint
8 | from typing import Optional, Callable, Tuple, List, Union
9 | from torch.utils.data import Dataset, Subset, DataLoader
10 | from torchvision.datasets import ImageFolder
11 |
12 | from .datasets import LabeledDataset
13 | from .utils import BatchGeneratorType
14 |
15 | DatasetType = Union[LabeledDataset, Dataset]
16 | DataLoaderType = Union[DataLoader, List[DataLoader]]
17 |
18 |
19 | class SSLDataModule(DataModule, ABC):
20 | num_classes: int = 1
21 |
22 | total_train_size: int = 1
23 | total_test_size: int = 1
24 |
25 | DATASET = type(Dataset)
26 |
27 | def __init__(self,
28 | train_batch_size: int,
29 | unlabeled_batch_size: int,
30 | test_batch_size: int,
31 | num_workers: int,
32 | train_min_size: int = 0,
33 | unlabeled_train_min_size: int = 0,
34 | test_min_size: int = 0,
35 | train_transform: Optional[Callable] = None,
36 | val_transform: Optional[Callable] = None,
37 | test_transform: Optional[Callable] = None,
38 | dims: Optional[Tuple[int, ...]] = None):
39 | super(SSLDataModule, self).__init__(train_transform=train_transform,
40 | val_transform=val_transform,
41 | test_transform=test_transform,
42 | dims=dims)
43 | self.train_batch_size = train_batch_size
44 | self.unlabeled_batch_size = unlabeled_batch_size
45 | self.test_batch_size = test_batch_size
46 | self.num_workers = num_workers
47 |
48 | self.train_min_size = max(train_min_size, 0)
49 | self.unlabeled_train_min_size = max(unlabeled_train_min_size, 0)
50 | self.test_min_size = max(test_min_size, 0)
51 |
52 | self._labeled_train_set: Optional[DatasetType] = None
53 | self._unlabeled_train_set: Optional[DatasetType] = None
54 | self._validation_set: Optional[DatasetType] = None
55 | self._test_set: Optional[DatasetType] = None
56 |
57 | # dataset stats
58 | self.dataset_mean: Optional[List[float]] = None
59 | self.dataset_std: Optional[List[float]] = None
60 |
61 | @property
62 | def labeled_train_set(self) -> Optional[DatasetType]:
63 | return self._labeled_train_set
64 |
65 | @labeled_train_set.setter
66 | def labeled_train_set(self, dataset: Optional[DatasetType]) -> None:
67 | self._labeled_train_set = dataset
68 |
69 | @property
70 | def unlabeled_train_set(self) -> Optional[DatasetType]:
71 | return self._unlabeled_train_set
72 |
73 | @unlabeled_train_set.setter
74 | def unlabeled_train_set(self, dataset: Optional[DatasetType]) -> None:
75 | self._unlabeled_train_set = dataset
76 |
77 | @property
78 | def validation_set(self) -> Optional[DatasetType]:
79 | return self._validation_set
80 |
81 | @validation_set.setter
82 | def validation_set(self, dataset: Optional[DatasetType]) -> None:
83 | self._validation_set = dataset
84 |
85 | @property
86 | def test_set(self) -> Optional[DatasetType]:
87 | return self._test_set
88 |
89 | @test_set.setter
90 | def test_set(self, dataset: Optional[DatasetType]) -> None:
91 | self._test_set = dataset
92 |
93 | def train_dataloader(self, **kwargs) -> DataLoaderType:
94 | return [DataLoader(self.labeled_train_set,
95 | shuffle=True,
96 | batch_size=self.train_batch_size,
97 | num_workers=self.num_workers,
98 | drop_last=True,
99 | **kwargs),
100 | DataLoader(self.unlabeled_train_set,
101 | shuffle=True,
102 | batch_size=self.unlabeled_batch_size,
103 | num_workers=self.num_workers,
104 | drop_last=True,
105 | **kwargs)]
106 |
107 | def val_dataloader(self, **kwargs) -> Optional[DataLoaderType]:
108 | return DataLoader(self.validation_set,
109 | batch_size=self.test_batch_size,
110 | shuffle=False,
111 | num_workers=self.num_workers,
112 | drop_last=False,
113 | **kwargs) if self.validation_set is not None else None
114 |
115 | def test_dataloader(self, **kwargs) -> Optional[DataLoaderType]:
116 | return DataLoader(self.test_set,
117 | batch_size=self.test_batch_size,
118 | shuffle=False,
119 | num_workers=self.num_workers,
120 | drop_last=False,
121 | **kwargs) if self.test_set is not None else None
122 |
123 | def get_train_batch(self, train_loaders: List[DataLoader], **kwargs) -> BatchGeneratorType:
124 | return get_batch(train_loaders, **kwargs)
125 |
126 | def save_split_info(self, output_path: Union[str, Path]) -> None:
127 | for subset, filename in [(self.labeled_train_set, "labeled_train_set.txt"),
128 | (self.unlabeled_train_set, "unlabeled_train_set.txt"),
129 | (self.validation_set, "validation_set.txt")]:
130 | if hasattr(subset, "dataset") and isinstance(subset.dataset, Subset):
131 | # save split info if subset is manually split
132 | full_set = subset.dataset.dataset
133 | indices = subset.dataset.indices
134 |
135 | if isinstance(full_set, ImageFolder):
136 | # save file paths
137 | split_info = [str(Path(full_set.imgs[i][0]).relative_to(full_set.root)) + "\n" for i in indices]
138 | else:
139 | # save index values
140 | split_info = [f"{i}\n" for i in indices]
141 |
142 | with open(Path(output_path) / filename, "w") as f:
143 | f.writelines(split_info)
144 |
--------------------------------------------------------------------------------
/models/augmentation/randaugment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from torch.nn import functional as F
5 | from kornia.geometry import transform as T
6 | from kornia import enhance as E
7 | from kornia.augmentation import functional as KF
8 |
9 | # for type hint
10 | from typing import Any, Optional, List, Tuple, Callable
11 | from torch import Tensor
12 | from torch.nn import Module
13 |
14 |
15 | # Affine
16 | def translate_x(x: Tensor, v: float) -> Tensor:
17 | B, C, H, W = x.shape
18 | return T.translate(x, torch.tensor([[v * W, 0]], device=x.device, dtype=x.dtype))
19 |
20 |
21 | def translate_y(x: Tensor, v: float) -> Tensor:
22 | B, C, H, W = x.shape
23 | return T.translate(x, torch.tensor([[0, v * H]], device=x.device, dtype=x.dtype))
24 |
25 |
26 | def shear_x(x: Tensor, v: float) -> Tensor:
27 | return T.shear(x, torch.tensor([[v, 0.0]], device=x.device, dtype=x.dtype))
28 |
29 |
30 | def shear_y(x: Tensor, v: float) -> Tensor:
31 | return T.shear(x, torch.tensor([[0.0, v]], device=x.device, dtype=x.dtype))
32 |
33 |
34 | def rotate(x: Tensor, v: float) -> Tensor:
35 | return T.rotate(x, torch.tensor([v], device=x.device, dtype=x.dtype))
36 |
37 |
38 | def auto_contrast(x: Tensor, _: Any) -> Tensor:
39 | B, C, H, W = x.shape
40 |
41 | x_min = x.view(B, C, -1).min(-1)[0].view(B, C, 1, 1)
42 | x_max = x.view(B, C, -1).max(-1)[0].view(B, C, 1, 1)
43 |
44 | x_out = (x - x_min) / torch.clamp(x_max - x_min, min=1e-9, max=1)
45 |
46 | return x_out.expand_as(x)
47 |
48 |
49 | def invert(x: Tensor, _: Any) -> Tensor:
50 | return 1.0 - x
51 |
52 |
53 | def equalize(x: Tensor, _: Any) -> Tensor:
54 | return KF.apply_equalize(x, params=dict(batch_prob=torch.tensor([1.0] * len(x), dtype=x.dtype, device=x.device)))
55 |
56 |
57 | def flip(x: Tensor, _: Any) -> Tensor:
58 | return T.hflip(x)
59 |
60 |
61 | def solarize(x: Tensor, v: float) -> Tensor:
62 | x[x < v] = 1 - x[x < v]
63 | return x
64 |
65 |
66 | def brightness(x: Tensor, v: float) -> Tensor:
67 | return E.adjust_brightness(x, v)
68 |
69 |
70 | def color(x: Tensor, v: float) -> Tensor:
71 | return E.adjust_saturation(x, v)
72 |
73 |
74 | def contrast(x: Tensor, v: float) -> Tensor:
75 | return E.adjust_contrast(x, v)
76 |
77 |
78 | def sharpness(x: Tensor, v: float) -> Tensor:
79 | return KF.apply_sharpness(x, params=dict(sharpness_factor=v))
80 |
81 |
82 | def identity(x: Tensor, _: Any) -> Tensor:
83 | return x
84 |
85 |
86 | def posterize(x: Tensor, v: float) -> Tensor:
87 | v = int(v)
88 | return E.posterize(x, v)
89 |
90 |
91 | def cutout(x: Tensor, v: float) -> Tensor:
92 | B, C, H, W = x.shape
93 |
94 | x_v = int(v * W)
95 | y_v = int(v * H)
96 |
97 | x_idx = np.random.uniform(low=0, high=W - x_v, size=(B, 1, 1, 1)) + np.arange(x_v).reshape((1, 1, 1, -1))
98 | y_idx = np.random.uniform(low=0, high=H - y_v, size=(B, 1, 1, 1)) + np.arange(y_v).reshape((1, 1, -1, 1))
99 |
100 | x[np.arange(B).reshape((B, 1, 1, 1)), np.arange(C).reshape((1, C, 1, 1)), y_idx, x_idx] = 0.5
101 | return x
102 |
103 |
104 | def cutout_pad(x: Tensor, v: float) -> Tensor:
105 | B, C, H, W = x.shape
106 |
107 | x = F.pad(x, [int(v * W / 2), int(v * W / 2), int(v * H / 2), int(v * H / 2)])
108 |
109 | x = cutout(x, v / (1 + v))
110 |
111 | x = T.center_crop(x, (H, W))
112 |
113 | return x
114 |
115 |
116 | class RandAugment(Module):
117 | def __init__(self, n: int, m: int, augmentation_pool: Optional[List[Tuple[Callable, float, float]]] = None):
118 | """
119 |
120 | Args:
121 | n: number of transformations
122 | m: magnitude
123 | augmentation_pool: transformation pool
124 | """
125 | super().__init__()
126 |
127 | self.n = n
128 | self.m = m
129 | if augmentation_pool is not None:
130 | self.augmentation_pool = augmentation_pool
131 | else:
132 | self.augmentation_pool = [
133 | (auto_contrast, np.nan, np.nan),
134 | (brightness, 0.05, 0.95),
135 | (color, 0.05, 0.95),
136 | (contrast, 0.05, 0.95),
137 | (cutout, 0, 0.3),
138 | (equalize, np.nan, np.nan),
139 | (identity, np.nan, np.nan),
140 | (posterize, 4, 8),
141 | (rotate, -30, 30),
142 | (sharpness, 0.05, 0.95),
143 | (shear_x, -0.3, 0.3),
144 | (shear_y, -0.3, 0.3),
145 | (solarize, 0.0, 1.0),
146 | (translate_x, -0.3, 0.3),
147 | (translate_y, -0.3, 0.3),
148 | ]
149 |
150 | def forward(self, x):
151 | assert len(x.shape) == 4
152 |
153 | ops = random.choices(self.augmentation_pool, k=self.n)
154 |
155 | for op, min_v, max_v in ops:
156 | v = random.randint(1, self.m + 1) / 10 * (max_v - min_v) + min_v
157 | x = op(x, v)
158 |
159 | return x
160 |
161 |
162 | class RandAugmentNS(RandAugment):
163 | def __init__(self, n: int, m: int):
164 | super(RandAugmentNS, self).__init__(
165 | n=n,
166 | m=m,
167 | augmentation_pool=[
168 | (auto_contrast, np.nan, np.nan),
169 | (brightness, 0.05, 0.95),
170 | (color, 0.05, 0.95),
171 | (contrast, 0.05, 0.95),
172 | (equalize, np.nan, np.nan),
173 | (identity, np.nan, np.nan),
174 | (posterize, 4, 8),
175 | (rotate, -30, 30),
176 | (sharpness, 0.05, 0.95),
177 | (shear_x, -0.3, 0.3),
178 | (shear_y, -0.3, 0.3),
179 | (solarize, 0.0, 1.0),
180 | (translate_x, -0.3, 0.3),
181 | (translate_y, -0.3, 0.3),
182 | ])
183 |
184 | def forward(self, x: Tensor) -> Tensor:
185 | assert len(x.shape) == 4
186 |
187 | ops = random.choices(self.augmentation_pool, k=self.n)
188 |
189 | for op, min_v, max_v in ops:
190 | v = random.randint(1, self.m + 1) / 10 * (max_v - min_v) + min_v
191 | if random.random() < 0.5:
192 | x = op(x, v)
193 |
194 | x = cutout(x, 0.5)
195 | return x
196 |
--------------------------------------------------------------------------------
/models/mixmatch/mixmatch_base.py:
--------------------------------------------------------------------------------
1 | """
2 | code inspired by https://github.com/gan3sh500/mixmatch-pytorch and
3 | https://github.com/google-research/mixmatch
4 | """
5 | import torch
6 | import numpy as np
7 | from torch.nn import functional as F
8 |
9 | from .utils import label_guessing, sharpen
10 |
11 | # for type hint
12 | from typing import Optional, Dict, Union, List, Sequence
13 | from torch import Tensor
14 | from torch.nn import Module
15 |
16 |
17 | class MixMatchBase:
18 | def __init__(self,
19 | augmenter: Module,
20 | strong_augmenter: Optional[Module],
21 | num_classes: int,
22 | temperature: float,
23 | num_augmentations: int,
24 | num_strong_augmentations: int,
25 | alpha: float,
26 | is_strong_augment_x: bool,
27 | train_label_guessing: bool):
28 | # callables
29 | self.augmenter = augmenter
30 | self.strong_augmenter = strong_augmenter
31 |
32 | # parameters
33 | self.num_classes = num_classes
34 | self.temperature = temperature
35 | self.alpha = alpha
36 |
37 | self.num_augmentations = num_augmentations
38 | self.num_strong_augmentations = num_strong_augmentations
39 |
40 | # flags
41 | self.train_label_guessing = train_label_guessing
42 | self.is_strong_augment_x = is_strong_augment_x
43 |
44 | @property
45 | def total_num_augmentations(self) -> int:
46 | return self.num_augmentations + self.num_strong_augmentations
47 |
48 | @torch.no_grad()
49 | def __call__(self,
50 | x_augmented: Tensor,
51 | x_strong_augmented: Optional[Tensor],
52 | x_targets_one_hot: Tensor,
53 | u_augmented: List[Tensor],
54 | u_strong_augmented: List[Tensor],
55 | u_true_targets_one_hot: Tensor,
56 | model: Module,
57 | *args,
58 | **kwargs) -> Dict[str, Tensor]:
59 | if self.is_strong_augment_x:
60 | x_inputs = x_strong_augmented
61 | else:
62 | x_inputs = x_augmented
63 | u_inputs = u_augmented + u_strong_augmented
64 |
65 | # label guessing with weakly augmented data
66 | pseudo_label_dict = self.guess_label(u_inputs=u_augmented, model=model)
67 |
68 | return self.postprocess(x_augmented=x_inputs,
69 | x_targets_one_hot=x_targets_one_hot,
70 | u_augmented=u_inputs,
71 | q_guess=pseudo_label_dict["q_guess"],
72 | u_true_targets_one_hot=u_true_targets_one_hot)
73 |
74 | @torch.no_grad()
75 | def preprocess(self,
76 | x_inputs: Tensor,
77 | x_strong_inputs: Tensor,
78 | x_targets: Tensor,
79 | u_inputs: Tensor,
80 | u_strong_inputs: Tensor,
81 | u_true_targets: Tensor) -> Dict[str, Union[Optional[Tensor], List[Tensor]]]:
82 | # convert targets to one-hot
83 | x_targets_one_hot = F.one_hot(x_targets, num_classes=self.num_classes).type_as(x_inputs)
84 | u_true_targets_one_hot = F.one_hot(u_true_targets, num_classes=self.num_classes).type_as(x_inputs)
85 |
86 | # apply augmentations
87 | x_augmented = self.augmenter(x_inputs)
88 | u_augmented = [self.augmenter(u_inputs) for _ in range(self.num_augmentations)]
89 |
90 | if self.strong_augmenter is not None:
91 | x_strong_augmented = self.strong_augmenter(x_strong_inputs)
92 | u_strong_augmented = [self.strong_augmenter(u_strong_inputs) for _ in range(self.num_strong_augmentations)]
93 | else:
94 | x_strong_augmented = None
95 | u_strong_augmented = []
96 |
97 | return dict(x_augmented=x_augmented,
98 | x_strong_augmented=x_strong_augmented,
99 | x_targets_one_hot=x_targets_one_hot,
100 | u_augmented=u_augmented,
101 | u_strong_augmented=u_strong_augmented,
102 | u_true_targets_one_hot=u_true_targets_one_hot)
103 |
104 | def guess_label(self, u_inputs: Sequence[Tensor], model: Module) -> Dict[str, Tensor]:
105 | # label guessing
106 | q_guess = label_guessing(batches=u_inputs, model=model, is_train_mode=self.train_label_guessing)
107 | q_guess = sharpen(q_guess, self.temperature)
108 |
109 | return dict(q_guess=q_guess)
110 |
111 | def postprocess(self,
112 | x_augmented: Tensor,
113 | x_targets_one_hot: Tensor,
114 | u_augmented: List[Tensor],
115 | q_guess: Tensor,
116 | u_true_targets_one_hot: Tensor) -> Dict[str, Tensor]:
117 | # concat the unlabeled data and targets
118 | u_augmented = torch.cat(u_augmented, dim=0)
119 | q_guess = torch.cat([q_guess for _ in range(self.total_num_augmentations)], dim=0)
120 | q_true = torch.cat([u_true_targets_one_hot for _ in range(self.total_num_augmentations)], dim=0)
121 | assert len(u_augmented) == len(q_guess) == len(q_true)
122 |
123 | return self.mixup(x_augmented=x_augmented,
124 | x_targets_one_hot=x_targets_one_hot,
125 | u_augmented=u_augmented,
126 | q_guess=q_guess,
127 | q_true=q_true)
128 |
129 | def mixup(self,
130 | x_augmented: Tensor,
131 | x_targets_one_hot: Tensor,
132 | u_augmented: Tensor,
133 | q_guess: Tensor,
134 | q_true: Tensor) -> Dict[str, Tensor]:
135 | # random shuffle according to the paper
136 | indices = list(range(len(x_augmented) + len(u_augmented)))
137 | np.random.shuffle(indices)
138 |
139 | # MixUp
140 | wx = torch.cat([x_augmented, u_augmented], dim=0)
141 | wy = torch.cat([x_targets_one_hot, q_guess], dim=0)
142 | wq = torch.cat([x_targets_one_hot, q_true], dim=0)
143 | assert len(wx) == len(wy) == len(wq)
144 | assert len(wx) == len(x_augmented) + len(u_augmented)
145 | assert len(wy) == len(x_targets_one_hot) + len(q_guess)
146 | wx_shuffled = wx[indices]
147 | wy_shuffled = wy[indices]
148 | wq_shuffled = wq[indices]
149 | assert len(wx) == len(wx_shuffled)
150 | assert len(wy) == len(wy_shuffled)
151 | assert len(wq) == len(wq_shuffled)
152 |
153 | # the official version use the same lambda ~ sampled from Beta(alpha, alpha) for both
154 | # labeled and unlabeled inputs
155 | lam = np.random.beta(self.alpha, self.alpha)
156 | lam = max(lam, 1 - lam)
157 |
158 | wx_mixed = lam * wx + (1 - lam) * wx_shuffled
159 | wy_mixed = lam * wy + (1 - lam) * wy_shuffled
160 | wq_mixed = lam * wq + (1 - lam) * wq_shuffled
161 |
162 | x_mixed, p_mixed = wx_mixed[:len(x_augmented)], wy_mixed[:len(x_augmented)]
163 | u_mixed, q_mixed = wx_mixed[len(x_augmented):], wy_mixed[len(x_augmented):]
164 | q_true_mixed = wq_mixed[len(x_augmented):]
165 |
166 | return dict(x_mixed=x_mixed,
167 | p_mixed=p_mixed,
168 | u_mixed=u_mixed,
169 | q_mixed=q_mixed,
170 | q_true_mixed=q_true_mixed)
171 |
--------------------------------------------------------------------------------
/loss/visualization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from kornia.utils import confusion_matrix
4 | import plotly.figure_factory as ff
5 | import wandb
6 |
7 | from .utils import to_tensor
8 | from .pair_loss.utils import get_pair_indices
9 |
10 | # for type hint
11 | from typing import Tuple, Union, Dict
12 | from torch import Tensor
13 | from plotly.graph_objects import Figure
14 |
15 | from .types import SimilarityType, LogDictType, PlotDictType, LossInfoType
16 |
17 |
18 | def _detorch(t: Tensor) -> np.ndarray:
19 | return t.detach().cpu().clone().numpy()
20 |
21 |
22 | @torch.no_grad()
23 | def get_pair_info(targets: Tensor,
24 | true_targets: Tensor,
25 | similarity_metric: SimilarityType,
26 | confidence_threshold: float,
27 | similarity_threshold: float,
28 | return_plot_info: bool = False) -> LossInfoType:
29 | indices = get_pair_indices(targets, ordered_pair=True)
30 | i_indices, j_indices = indices[:, 0], indices[:, 1]
31 | targets_i, targets_j = targets[i_indices], targets[j_indices]
32 |
33 | targets_max_prob = targets.max(dim=1).values
34 | true_labels = true_targets.argmax(dim=1)
35 |
36 | similarities: Tensor = similarity_metric(targets_i=targets_i, targets_j=targets_j, dim=1)
37 |
38 | conf_mask: Tensor = targets_max_prob[i_indices] > confidence_threshold
39 | sim_mask: Tensor = (similarities > similarity_threshold)
40 | final_mask: Tensor = (conf_mask & sim_mask)
41 | true_pair_mask: Tensor = (true_labels[i_indices] == true_labels[j_indices])
42 |
43 | log_info, plot_info = get_pair_loss_info(
44 | conf_mask=conf_mask,
45 | sim_mask=sim_mask,
46 | final_mask=final_mask,
47 | true_pair_mask=true_pair_mask,
48 | indices=indices,
49 | return_plot_info=return_plot_info)
50 |
51 | # log extra metrics
52 | extra_log_info, extra_plot_info = get_pair_extra_info(
53 | targets_max_prob=targets_max_prob,
54 | i_indices=i_indices,
55 | j_indices=j_indices,
56 | similarities=similarities,
57 | final_mask=final_mask)
58 |
59 | log_info.update(extra_log_info)
60 | plot_info.update(extra_plot_info)
61 |
62 | return {"log": log_info, "plot": plot_info}
63 |
64 |
65 | @torch.no_grad()
66 | def get_pair_loss_info(conf_mask: Tensor,
67 | sim_mask: Tensor,
68 | final_mask: Tensor,
69 | true_pair_mask: Tensor,
70 | indices: Tensor,
71 | return_plot_info: bool = False) -> Tuple[LogDictType, PlotDictType]:
72 | # prepare log info
73 | total_high_conf = conf_mask.sum()
74 | total_high_sim = sim_mask.sum()
75 | total_thresholded = final_mask.sum()
76 | total_size = len(indices)
77 |
78 | total_true_positive = (true_pair_mask & final_mask).sum().float()
79 | if total_thresholded == 0:
80 | true_pair_given_thresholded = to_tensor(0., tensor_like=total_true_positive)
81 | else:
82 | true_pair_given_thresholded = total_true_positive / total_thresholded
83 |
84 | matrix = get_confusion_matrix(y_true=true_pair_mask, y_pred=final_mask)
85 | normalized_matrix = matrix / matrix.sum()
86 |
87 | log_info = {
88 | "high_conf_ratio": (total_high_conf.float() / total_size),
89 | "high_sim_ratio": (total_high_sim.float() / total_size),
90 | "thresholded_ratio": (total_thresholded.float() / total_size),
91 | "true_pair_given_thresholded": true_pair_given_thresholded,
92 |
93 | # see https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
94 | # In binary classification, the count of true negatives is C[0,0], false negatives is C[1,0],
95 | # true positives is C[1,1] and false positives is C[0,1].
96 | "true_negative_pair_ratio": normalized_matrix[0, 0],
97 | "false_positives_pair_ratio": normalized_matrix[0, 1],
98 | "false_negatives_pair_ratio": normalized_matrix[1, 0],
99 | "true_positives_pair_ratio": normalized_matrix[1, 1],
100 | }
101 |
102 | plot_info = {}
103 | if return_plot_info:
104 | plot_info["pair_confusion_matrix"] = visualize(_detorch(matrix), _detorch(normalized_matrix))
105 |
106 | return {f"pair_loss/{k}": v for k, v in log_info.items()}, \
107 | {f"pair_loss/{k}": v for k, v in plot_info.items()}
108 |
109 |
110 | def get_confusion_matrix(y_true: Tensor, y_pred: Tensor) -> Tensor:
111 | """
112 | Args:
113 | y_true: (batch size,) a vector where each element is the class label
114 | y_pred: (batch size,) predicted class labels
115 |
116 | Returns:
117 |
118 | """
119 | matrix = confusion_matrix(y_pred.view(1, -1), y_true.view(1, -1), num_classes=2, normalized=False)
120 | matrix.squeeze_()
121 |
122 | return matrix
123 |
124 |
125 | def visualize(matrix: np.ndarray, normalized_matrix: np.ndarray, **kwargs) -> Figure:
126 | # see https://matplotlib.org/faq/usage_faq.html#coding-styles
127 | # see https://plotly.com/python/annotated-heatmap/
128 | # see https://plotly.com/python/builtin-colorscales/
129 | annotation_text = [
130 | [f"True Negative: {normalized_matrix[0, 0]:.3%}", f"False Positive: {normalized_matrix[0, 1]:.3%}"],
131 | [f"False Negative: {normalized_matrix[1, 0]:.3%}", f"True Positive: {normalized_matrix[1, 1]:.3%}"]
132 | ]
133 | x_axis_text = ["Predicted False", "Predicted True"]
134 | y_axis_text = ["False", "True"]
135 | fig = ff.create_annotated_heatmap(matrix, x=x_axis_text, y=y_axis_text, annotation_text=annotation_text,
136 | colorscale='Plotly3')
137 | fig.update_layout({
138 | 'xaxis': {'title': {'text': "Predicted"}},
139 | 'yaxis': {'title': {'text': "Ground Truth"}},
140 | })
141 |
142 | return fig
143 |
144 |
145 | @torch.no_grad()
146 | def get_pair_extra_info(targets_max_prob: Tensor,
147 | i_indices: Tensor,
148 | j_indices: Tensor,
149 | similarities: Tensor,
150 | final_mask: Tensor) -> Tuple[LogDictType, PlotDictType]:
151 | def mean_std_max_min(t: Union[Tensor, np.ndarray], prefix: str = "") -> Dict[str, Union[Tensor, np.ndarray]]:
152 | return {
153 | f"{prefix}/mean": t.mean() if t.numel() > 0 else to_tensor(0, tensor_like=t),
154 | f"{prefix}/std": t.std() if t.numel() > 0 else to_tensor(0, tensor_like=t),
155 | f"{prefix}/max": t.max() if t.numel() > 0 else to_tensor(0, tensor_like=t),
156 | f"{prefix}/min": t.min() if t.numel() > 0 else to_tensor(0, tensor_like=t),
157 | }
158 |
159 | targets_i_max_prob = targets_max_prob[i_indices]
160 | targets_j_max_prob = targets_max_prob[j_indices]
161 |
162 | selected_sim = similarities[final_mask]
163 | selected_i_conf = targets_i_max_prob[final_mask]
164 | selected_j_conf = targets_j_max_prob[final_mask]
165 |
166 | selected_i_conf_stat = mean_std_max_min(selected_i_conf, prefix="selected_i_conf")
167 | selected_j_conf_stat = mean_std_max_min(selected_j_conf, prefix="selected_j_conf")
168 | selected_sim_stat = mean_std_max_min(selected_sim, prefix="selected_sim")
169 |
170 | selected_i_conf_hist = wandb.Histogram(_detorch(selected_i_conf))
171 | selected_j_conf_hist = wandb.Histogram(_detorch(selected_j_conf))
172 | selected_sim_hist = wandb.Histogram(_detorch(selected_sim))
173 |
174 | log_info = {
175 | **selected_i_conf_stat,
176 | **selected_j_conf_stat,
177 | **selected_sim_stat,
178 | }
179 | plot_info = {
180 | "selected_i_conf_hist": selected_i_conf_hist,
181 | "selected_j_conf_hist": selected_j_conf_hist,
182 | "selected_sim_hist": selected_sim_hist,
183 | }
184 |
185 | return {f"pair_loss/{k}": v for k, v in log_info.items()}, \
186 | {f"pair_loss/{k}": v for k, v in plot_info.items()}
187 |
--------------------------------------------------------------------------------
/models/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Ref:
3 | https://github.com/hysts/pytorch_wrn/blob/master/wrn.py
4 | https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
5 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
6 | https://github.com/szagoruyko/functional-zoo/blob/master/wide-resnet-50-2-export.ipynb
7 | https://github.com/szagoruyko/wide-residual-networks/tree/master/pytorch
8 |
9 | to get the name of the layer in state_dict, see the following example
10 | >>> wrn = WideResNet(
11 | >>> in_channels=1,
12 | >>> out_channels=5,
13 | >>> base_channels=4,
14 | >>> widening_factor=10,
15 | >>> drop_rate=0,
16 | >>> depth=10
17 | >>> )
18 | >>>
19 | >>> # print the state_dict keys
20 | >>> d = wrn.state_dict()
21 | >>> dl = list(d.keys())
22 | >>> for idx, n in enumerate(dl):
23 | >>> print("{} -> {}".format(idx, n))
24 | """
25 | import torch
26 | from torch import nn
27 | import numpy as np
28 |
29 | # for type hint
30 | from typing import Union, Tuple, Sequence, Optional
31 | from torch import Tensor
32 |
33 |
34 | class BasicBlock(nn.Module):
35 | def __init__(self,
36 | in_channels: int,
37 | out_channels: int,
38 | stride: int,
39 | drop_rate: float = 0.0,
40 | activate_before_residual: bool = False,
41 | batch_norm_momentum: float = 0.001):
42 | super().__init__()
43 |
44 | self.in_channels = in_channels
45 | self.out_channels = out_channels
46 | self.drop_rate = drop_rate
47 | self.equal_in_out = (in_channels == out_channels)
48 | self.activate_before_residual = activate_before_residual
49 |
50 | # see https://github.com/pytorch/examples/issues/289 regarding different convention for batch norm momentum
51 | self.bn1 = nn.BatchNorm2d(self.in_channels, momentum=batch_norm_momentum)
52 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
53 | self.conv1 = nn.Conv2d(
54 | self.in_channels,
55 | self.out_channels,
56 | kernel_size=3,
57 | stride=stride,
58 | padding=1,
59 | bias=False)
60 |
61 | self.bn2 = nn.BatchNorm2d(self.out_channels, momentum=batch_norm_momentum)
62 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
63 | self.conv2 = nn.Conv2d(
64 | self.out_channels,
65 | self.out_channels,
66 | kernel_size=3,
67 | stride=1,
68 | padding=1,
69 | bias=False)
70 |
71 | if not self.equal_in_out:
72 | self.conv_shortcut = nn.Conv2d(
73 | self.in_channels,
74 | self.out_channels,
75 | kernel_size=1,
76 | stride=stride,
77 | padding=0,
78 | bias=False)
79 | else:
80 | self.conv_shortcut = None
81 |
82 | if drop_rate > 0:
83 | self.dropout = nn.Dropout(p=self.drop_rate)
84 |
85 | def forward(self, inputs):
86 | if not self.equal_in_out and self.activate_before_residual:
87 | inputs = self.bn1(inputs)
88 | inputs = self.relu1(inputs)
89 | outputs = self.conv1(inputs)
90 | else:
91 | outputs = self.bn1(inputs)
92 | outputs = self.relu1(outputs)
93 | outputs = self.conv1(outputs)
94 |
95 | outputs = self.bn2(outputs)
96 | outputs = self.relu2(outputs)
97 | outputs = self.conv2(outputs)
98 |
99 | if self.drop_rate > 0:
100 | outputs = self.dropout(outputs)
101 |
102 | if not self.equal_in_out:
103 | inputs = self.conv_shortcut(inputs)
104 |
105 | return torch.add(inputs, outputs)
106 |
107 |
108 | class NetworkBlock(nn.Module):
109 | def __init__(self,
110 | in_channels: int,
111 | out_channels: int,
112 | num_blocks: int,
113 | stride: int,
114 | drop_rate: float = 0.0,
115 | activate_before_residual: bool = False,
116 | basic_block: type(BasicBlock) = BasicBlock,
117 | **block_kwargs):
118 | super().__init__()
119 | self.layer = self._make_layer(basic_block, in_channels, out_channels, num_blocks, stride, drop_rate,
120 | activate_before_residual, **block_kwargs)
121 |
122 | @staticmethod
123 | def _make_layer(block: type(BasicBlock),
124 | in_channels: int,
125 | out_channels: int,
126 | num_blocks: int,
127 | stride: int,
128 | drop_rate: float,
129 | activate_before_residual: bool,
130 | **block_kwargs) -> nn.Module:
131 | layers = []
132 | for i in range(int(num_blocks)):
133 | if i == 0:
134 | layers.append(
135 | block(in_channels,
136 | out_channels,
137 | stride=stride,
138 | drop_rate=drop_rate,
139 | activate_before_residual=activate_before_residual,
140 | **block_kwargs))
141 | else:
142 | layers.append(
143 | block(out_channels,
144 | out_channels,
145 | stride=1,
146 | drop_rate=drop_rate,
147 | activate_before_residual=activate_before_residual,
148 | **block_kwargs))
149 | return nn.Sequential(*layers)
150 |
151 | def forward(self, x):
152 | return self.layer(x)
153 |
154 |
155 | class WideResNet(nn.Module):
156 | def __init__(self,
157 | in_channels: int,
158 | out_channels: int,
159 | depth: int = 28,
160 | widening_factor: int = 2,
161 | base_channels: int = 16,
162 | drop_rate: float = 0.0,
163 | batch_norm_momentum: float = 0.001,
164 | n_channels: Optional[Sequence[int]] = None):
165 | super().__init__()
166 |
167 | if n_channels is None:
168 | # interpolate channels
169 | n_channels = [
170 | int(base_channels),
171 | int(base_channels * widening_factor),
172 | int(base_channels * 2 * widening_factor),
173 | int(base_channels * 4 * widening_factor),
174 | ]
175 | assert len(n_channels) == 4
176 | self.n_channels = n_channels
177 |
178 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
179 | self.num_blocks = (depth - 4) // 6
180 |
181 | if self.num_blocks == 0:
182 | self.fc_in_features = self.n_channels[0]
183 | else:
184 | self.fc_in_features = self.n_channels[3]
185 |
186 | self.conv = nn.Conv2d(
187 | in_channels,
188 | self.n_channels[0],
189 | kernel_size=3,
190 | stride=1,
191 | padding=1,
192 | bias=False)
193 |
194 | self.block1 = NetworkBlock(
195 | in_channels=self.n_channels[0],
196 | out_channels=self.n_channels[1],
197 | num_blocks=self.num_blocks,
198 | stride=1,
199 | drop_rate=drop_rate,
200 | activate_before_residual=True,
201 | basic_block=BasicBlock,
202 | batch_norm_momentum=batch_norm_momentum)
203 |
204 | self.block2 = NetworkBlock(
205 | in_channels=self.n_channels[1],
206 | out_channels=self.n_channels[2],
207 | num_blocks=self.num_blocks,
208 | stride=2,
209 | drop_rate=drop_rate,
210 | basic_block=BasicBlock,
211 | batch_norm_momentum=batch_norm_momentum)
212 |
213 | self.block3 = NetworkBlock(
214 | in_channels=self.n_channels[2],
215 | out_channels=self.n_channels[3],
216 | num_blocks=self.num_blocks,
217 | stride=2,
218 | drop_rate=drop_rate,
219 | basic_block=BasicBlock,
220 | batch_norm_momentum=batch_norm_momentum)
221 |
222 | self.bn = nn.BatchNorm2d(num_features=self.fc_in_features, momentum=batch_norm_momentum)
223 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
224 | # self.pool = nn.AvgPool2d(kernel_size=8)
225 | self.pool = nn.AdaptiveAvgPool2d(output_size=1)
226 | self.fc = nn.Linear(self.fc_in_features, out_channels)
227 |
228 | # initialization
229 | for m in self.modules():
230 | if isinstance(m, nn.Conv2d):
231 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
232 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
233 | m.weight.data.normal_(0, np.sqrt(2. / n))
234 | elif isinstance(m, nn.BatchNorm2d):
235 | m.weight.data.fill_(1)
236 | m.bias.data.zero_()
237 | elif isinstance(m, nn.Linear):
238 | nn.init.xavier_normal_(m.weight.data)
239 | m.bias.data.zero_()
240 |
241 | def forward(self, inputs: Tensor, return_feature: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
242 | outputs = self.conv(inputs)
243 | outputs = self.block1(outputs)
244 | outputs = self.block2(outputs)
245 | outputs = self.block3(outputs)
246 | outputs = self.bn(outputs)
247 | outputs = self.relu(outputs)
248 | outputs = self.pool(outputs)
249 | # outputs = outputs.view(-1, self.n_channels[3])
250 | features = outputs.view(outputs.size(0), -1)
251 |
252 | outputs = self.fc(features)
253 |
254 | if return_feature:
255 | return features, outputs
256 |
257 | return outputs
258 |
--------------------------------------------------------------------------------
/utils/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset, DataLoader, Subset
3 | from tqdm import tqdm
4 |
5 | from itertools import repeat, combinations
6 | from pathlib import Path
7 |
8 | # for type hint
9 | from typing import Optional, Generator, Tuple, Union, List, Sequence, Any
10 | from torch import Tensor
11 | from PIL.Image import Image
12 | from torchvision.datasets import VisionDataset
13 |
14 | BatchType = Union[Tuple[Tuple[Tensor, Tensor], ...], Tuple[Tensor, Tensor]]
15 | LoaderType = Union[DataLoader, Generator[BatchType, None, None]]
16 | BatchGeneratorType = Generator[Tuple[int, BatchType], None, None]
17 |
18 |
19 | def repeater(data_loader: DataLoader) -> Generator[Tuple[Tensor, Tensor], None, None]:
20 | for loader in repeat(data_loader):
21 | for data in loader:
22 | yield data
23 |
24 |
25 | def get_batch(loaders: Sequence[LoaderType], max_iter: int, is_repeat: bool = False) \
26 | -> BatchGeneratorType:
27 | if is_repeat:
28 | loaders = [repeater(loader) for loader in loaders]
29 |
30 | if len(loaders) == 1:
31 | combined_loaders = loaders[0]
32 | else:
33 | combined_loaders = zip(*loaders)
34 |
35 | for idx, batch in enumerate(combined_loaders):
36 | if idx >= max_iter:
37 | return
38 | yield idx, batch
39 |
40 |
41 | def get_targets(dataset: Union[Dataset, VisionDataset]) -> List[Any]:
42 | if hasattr(dataset, 'targets'):
43 | return dataset.targets
44 |
45 | # TODO: handle very large dataset
46 | return [target for (_, target) in dataset]
47 |
48 |
49 | def get_class_indices(dataset: Dataset) -> List[np.ndarray]:
50 | targets = np.asarray(get_targets(dataset))
51 | num_classes = max(targets) + 1
52 |
53 | return [np.where(targets == i)[0] for i in range(num_classes)]
54 |
55 |
56 | def to_indices_or_sections(subset_sizes: Sequence[int]) -> List[int]:
57 | outputs = [sum(subset_sizes[:i]) + size for i, size in enumerate(subset_sizes)]
58 |
59 | return outputs
60 |
61 |
62 | def safe_round(inputs: np.ndarray, target_sums: np.ndarray, min_value: int = 0) -> np.ndarray:
63 | """
64 | Round array while maintaining the sum. The difference is adjusted based on value
65 | distribution in input array
66 |
67 | Args:
68 | inputs: input numpy array
69 | target_sums: target sum values
70 | min_value: each element in the output will be at least min_value
71 |
72 | Returns: numpy array where each element is at least min_value
73 |
74 | """
75 | assert len(inputs) == len(target_sums)
76 | rounded = np.around(inputs).astype(int)
77 | rounded = np.maximum(rounded, min_value)
78 |
79 | outputs = np.zeros_like(inputs, dtype=int)
80 |
81 | for i in range(len(inputs)):
82 | # TODO: use more efficient implementation
83 | rounded_row = rounded[i]
84 | row_target_sum = target_sums[i]
85 |
86 | # subtract by min_value so that rounding adjustment do not affect the min
87 | row_outputs = rounded_row - min_value
88 | adjusted_target_sum = row_target_sum - (len(rounded_row) * min_value)
89 |
90 | round_error = adjusted_target_sum - row_outputs.sum()
91 |
92 | if round_error != 0:
93 | # repeat index by its corresponding value; this will make sure the sampling follows value distribution
94 | extended_idx = np.repeat(np.arange(row_outputs.size), row_outputs)
95 | selected_idx = np.random.choice(extended_idx, abs(round_error))
96 |
97 | unique_idx, idx_counts = np.unique(selected_idx, return_counts=True)
98 | row_outputs[unique_idx] += np.copysign(idx_counts, round_error).astype(int)
99 |
100 | assert row_outputs.sum() == adjusted_target_sum
101 |
102 | # add back the subtracted min_value
103 | row_outputs += min_value
104 |
105 | assert row_outputs.sum() == row_target_sum
106 | outputs[i] = row_outputs
107 |
108 | assert np.all(outputs.sum(axis=1) == target_sums)
109 | assert np.all(outputs >= min_value)
110 | return outputs
111 |
112 |
113 | def per_class_random_split_by_ratio(dataset: Dataset,
114 | ratios: Sequence[float],
115 | num_classes: int,
116 | uneven_split: bool = False,
117 | min_value: int = 1) -> List[Dataset]:
118 | """Split the dataset base on ratios.
119 |
120 | Args:
121 | dataset: dataset to split
122 | ratios: fraction of data in each subset
123 | num_classes: number of classes in dataset
124 | uneven_split: if True, will return len(ratios) + 1 subsets where the size for the last subset is interpolated
125 | min_value: min value in each class of each subset
126 |
127 | Returns: if is_uneven_split = False, returns len(ratios) subsets where the ith subset has size ratio[i] *
128 | len(dataset); if is_uneven_split = True, returns len(ratios) + 1 subsets where the size for the last
129 | subset is interpolated.
130 |
131 | """
132 | ratios = list(ratios)
133 | if uneven_split:
134 | ratios.append(1. - sum(ratios))
135 |
136 | assert sum(ratios) == 1.
137 | ratios = np.asarray(ratios)
138 | assert np.all(ratios >= 0.)
139 |
140 | # shape (num_classes, # samples in this class); each row is the indices of samples in that class
141 | class_indices = get_class_indices(dataset)
142 | assert len(class_indices) == num_classes
143 |
144 | # shape (num_classes,); each element is the number of samples in that class
145 | class_lengths = np.asarray([len(class_idx) for class_idx in class_indices])
146 | assert class_lengths.sum() == len(dataset)
147 |
148 | # shape (num_subset, num_classes); each row is a list of class sizes for the corresponding subset
149 | subset_class_lengths = ratios.reshape(-1, 1) * class_lengths
150 |
151 | subset_class_lengths = safe_round(subset_class_lengths.T, target_sums=class_lengths, min_value=min_value)
152 | subset_class_lengths = subset_class_lengths.transpose()
153 |
154 | return per_class_random_split_helper(dataset=dataset,
155 | class_indices=class_indices,
156 | subset_class_lengths=subset_class_lengths)
157 |
158 |
159 | def per_class_random_split(dataset: Dataset, lengths: Sequence[int], num_classes: int, uneven_split: bool = False) \
160 | -> List[Dataset]:
161 | """Split the dataset evenly across all classes
162 |
163 | Args:
164 | dataset: dataset to split
165 | lengths: length for each subset
166 | num_classes: number of classes in dataset
167 | uneven_split: if True, will return len(lengths) + 1 subsets where the size for the last subset may not be the
168 | same for all classes
169 |
170 | Returns: len(lengths) subsets where the ith subset has size lengths[i]; if is_uneven_split = True, will return
171 | len(lengths) + 1 subsets where the size for the last subset may not be the same for all classes.
172 |
173 | """
174 | # see https://github.com/pytorch/vision/issues/168#issuecomment-319659360
175 | # and https://github.com/pytorch/vision/issues/168#issuecomment-398734285 for detail
176 | total_length = sum(lengths)
177 |
178 | if uneven_split:
179 | assert 0 < total_length <= len(dataset), f"Expecting 0 < length <= {len(dataset)} but get {total_length}"
180 | else:
181 | if len(lengths) <= 1:
182 | return [dataset]
183 |
184 | assert total_length == len(dataset), "Sum of input lengths does not equal the length of the input dataset"
185 |
186 | subset_num_per_class, remainders = np.divmod(lengths, num_classes)
187 | assert np.all(remainders == 0), f"Subset sizes is not divisible by the number of classes ({num_classes})"
188 |
189 | # shape (num_classes, # samples in this class); each row is the indices of samples in that class
190 | class_indices = get_class_indices(dataset)
191 | assert len(class_indices) == num_classes
192 |
193 | # shape (num_classes,); each element is the number of samples in that class
194 | class_lengths = np.asarray([len(class_idx) for class_idx in class_indices])
195 |
196 | # shape (num_subset, num_classes); each row is the class lengths of this subset
197 | subset_class_lengths = np.tile(subset_num_per_class.reshape(-1, 1), num_classes)
198 | if uneven_split:
199 | # interpolate last subset's class lengths
200 | last_subset_class_lengths = class_lengths - subset_class_lengths.sum(axis=0)
201 | subset_class_lengths = np.vstack((subset_class_lengths, last_subset_class_lengths))
202 |
203 | return per_class_random_split_helper(dataset=dataset,
204 | class_indices=class_indices,
205 | subset_class_lengths=subset_class_lengths)
206 |
207 |
208 | def per_class_random_split_helper(dataset: Dataset,
209 | class_indices: List[np.ndarray],
210 | subset_class_lengths: np.ndarray) -> List[Dataset]:
211 | """
212 |
213 | Args:
214 | dataset:
215 | class_indices: shape (num_classes, # samples in this class);
216 | each row is the indices of samples in that class
217 | subset_class_lengths: shape (num_subset, num_classes);
218 | each row is a list of class sizes for that subset
219 |
220 | Returns:
221 |
222 | """
223 | class_lengths = np.asarray([len(class_idx) for class_idx in class_indices])
224 | assert np.all(subset_class_lengths.sum(axis=0) == class_lengths)
225 |
226 | num_subsets, num_classes = subset_class_lengths.shape
227 |
228 | subset_indices = [list() for _ in range(num_subsets)]
229 |
230 | for i in range(num_classes):
231 | # ith column contains the class length for each subset
232 | subset_num_per_class = subset_class_lengths[:, i]
233 |
234 | indices = class_indices[i]
235 | np.random.shuffle(indices)
236 |
237 | indices_or_sections = to_indices_or_sections(subset_num_per_class)[:-1]
238 | class_subset_indices = np.split(indices, indices_or_sections)
239 | [subset_indices[i].extend(idx) for i, idx in enumerate(class_subset_indices)]
240 |
241 | # check if index subsets matches the desired lengths
242 | subset_lengths = subset_class_lengths.sum(axis=1)
243 | assert all(len(subset_idx) == subset_lengths[i] for i, subset_idx in enumerate(subset_indices))
244 |
245 | # check if index subsets in subset_indices are unique
246 | subset_idx_sets = [set(subset_idx) for subset_idx in subset_indices]
247 | assert all(len(idx_set) == len(subset_indices[i]) for i, idx_set in enumerate(subset_idx_sets))
248 |
249 | # check if index subsets are mutually exclusive
250 | combos = combinations(subset_idx_sets, 2)
251 | assert all(combo[0].isdisjoint(combo[1]) for combo in combos)
252 |
253 | return [Subset(dataset, subset_idx) for subset_idx in subset_indices]
254 |
255 |
256 | def get_data_shape(data_point: Union[np.ndarray, Tensor, Image]):
257 | data_shape = np.asarray(data_point).shape
258 | if isinstance(data_point, Tensor) and len(data_shape) >= 3:
259 | # torch tensor has channel (C, H, W, ...), swap channel to (H, W, ..., C)
260 | data_shape = np.roll(data_shape, -1)
261 |
262 | return data_shape
263 |
264 |
265 | def get_directory_size(dirname: Union[Path, str]) -> int:
266 | return sum(f.stat().st_size for f in Path(dirname).rglob("*") if f.is_file())
267 |
--------------------------------------------------------------------------------
/models/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | References: torchvision.models.resnet.py
3 | """
4 | import torch
5 | from torch import nn
6 | from torchvision.models import ResNet as BaseResNet
7 | from torchvision.models.utils import load_state_dict_from_url
8 |
9 | # for type hint
10 | from typing import Union, Tuple
11 | from torch import Tensor
12 |
13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
14 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
15 | 'wide_resnet50_2', 'wide_resnet101_2']
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
45 | base_width=64, dilation=1, norm_layer=None):
46 | super(BasicBlock, self).__init__()
47 | if norm_layer is None:
48 | norm_layer = nn.BatchNorm2d
49 | if groups != 1 or base_width != 64:
50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
51 | if dilation > 1:
52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
54 | self.conv1 = conv3x3(inplanes, planes, stride)
55 | self.bn1 = norm_layer(planes)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.conv2 = conv3x3(planes, planes)
58 | self.bn2 = norm_layer(planes)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | identity = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 |
72 | if self.downsample is not None:
73 | identity = self.downsample(x)
74 |
75 | out += identity
76 | out = self.relu(out)
77 |
78 | return out
79 |
80 |
81 | class Bottleneck(nn.Module):
82 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
83 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
84 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
85 | # This variant is also known as ResNet V1.5 and improves accuracy according to
86 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
87 |
88 | expansion = 4
89 |
90 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
91 | base_width=64, dilation=1, norm_layer=None):
92 | super(Bottleneck, self).__init__()
93 | if norm_layer is None:
94 | norm_layer = nn.BatchNorm2d
95 | width = int(planes * (base_width / 64.)) * groups
96 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
97 | self.conv1 = conv1x1(inplanes, width)
98 | self.bn1 = norm_layer(width)
99 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
100 | self.bn2 = norm_layer(width)
101 | self.conv3 = conv1x1(width, planes * self.expansion)
102 | self.bn3 = norm_layer(planes * self.expansion)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.downsample = downsample
105 | self.stride = stride
106 |
107 | def forward(self, x):
108 | identity = x
109 |
110 | out = self.conv1(x)
111 | out = self.bn1(out)
112 | out = self.relu(out)
113 |
114 | out = self.conv2(out)
115 | out = self.bn2(out)
116 | out = self.relu(out)
117 |
118 | out = self.conv3(out)
119 | out = self.bn3(out)
120 |
121 | if self.downsample is not None:
122 | identity = self.downsample(x)
123 |
124 | out += identity
125 | out = self.relu(out)
126 |
127 | return out
128 |
129 |
130 | class ResNet(BaseResNet):
131 | def forward(self, x, return_feature: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]:
132 | x = self.conv1(x)
133 | x = self.bn1(x)
134 | x = self.relu(x)
135 | x = self.maxpool(x)
136 |
137 | x = self.layer1(x)
138 | x = self.layer2(x)
139 | x = self.layer3(x)
140 | x = self.layer4(x)
141 |
142 | x = self.avgpool(x)
143 | features = torch.flatten(x, 1)
144 | logits = self.fc(features)
145 |
146 | if return_feature:
147 | return features, logits
148 | else:
149 | return logits
150 |
151 |
152 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
153 | model = ResNet(block, layers, **kwargs)
154 | if pretrained:
155 | state_dict = load_state_dict_from_url(model_urls[arch],
156 | progress=progress)
157 | model.load_state_dict(state_dict)
158 | return model
159 |
160 |
161 | def resnet18(pretrained=False, progress=True, **kwargs):
162 | r"""ResNet-18 model from
163 | `"Deep Residual Learning for Image Recognition" `_
164 |
165 | Args:
166 | pretrained (bool): If True, returns a model pre-trained on ImageNet
167 | progress (bool): If True, displays a progress bar of the download to stderr
168 | """
169 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
170 | **kwargs)
171 |
172 |
173 | def resnet34(pretrained=False, progress=True, **kwargs):
174 | r"""ResNet-34 model from
175 | `"Deep Residual Learning for Image Recognition" `_
176 |
177 | Args:
178 | pretrained (bool): If True, returns a model pre-trained on ImageNet
179 | progress (bool): If True, displays a progress bar of the download to stderr
180 | """
181 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
182 | **kwargs)
183 |
184 |
185 | def resnet50(pretrained=False, progress=True, **kwargs):
186 | r"""ResNet-50 model from
187 | `"Deep Residual Learning for Image Recognition" `_
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | progress (bool): If True, displays a progress bar of the download to stderr
192 | """
193 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
194 | **kwargs)
195 |
196 |
197 | def resnet101(pretrained=False, progress=True, **kwargs):
198 | r"""ResNet-101 model from
199 | `"Deep Residual Learning for Image Recognition" `_
200 |
201 | Args:
202 | pretrained (bool): If True, returns a model pre-trained on ImageNet
203 | progress (bool): If True, displays a progress bar of the download to stderr
204 | """
205 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
206 | **kwargs)
207 |
208 |
209 | def resnet152(pretrained=False, progress=True, **kwargs):
210 | r"""ResNet-152 model from
211 | `"Deep Residual Learning for Image Recognition" `_
212 |
213 | Args:
214 | pretrained (bool): If True, returns a model pre-trained on ImageNet
215 | progress (bool): If True, displays a progress bar of the download to stderr
216 | """
217 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
218 | **kwargs)
219 |
220 |
221 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
222 | r"""ResNeXt-50 32x4d model from
223 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | progress (bool): If True, displays a progress bar of the download to stderr
228 | """
229 | kwargs['groups'] = 32
230 | kwargs['width_per_group'] = 4
231 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
232 | pretrained, progress, **kwargs)
233 |
234 |
235 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
236 | r"""ResNeXt-101 32x8d model from
237 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
238 |
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on ImageNet
241 | progress (bool): If True, displays a progress bar of the download to stderr
242 | """
243 | kwargs['groups'] = 32
244 | kwargs['width_per_group'] = 8
245 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
246 | pretrained, progress, **kwargs)
247 |
248 |
249 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
250 | r"""Wide ResNet-50-2 model from
251 | `"Wide Residual Networks" `_
252 |
253 | The model is the same as ResNet except for the bottleneck number of channels
254 | which is twice larger in every block. The number of channels in outer 1x1
255 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
256 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
257 |
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | kwargs['width_per_group'] = 64 * 2
263 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
264 | pretrained, progress, **kwargs)
265 |
266 |
267 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
268 | r"""Wide ResNet-101-2 model from
269 | `"Wide Residual Networks" `_
270 |
271 | The model is the same as ResNet except for the bottleneck number of channels
272 | which is twice larger in every block. The number of channels in outer 1x1
273 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
274 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
275 |
276 | Args:
277 | pretrained (bool): If True, returns a model pre-trained on ImageNet
278 | progress (bool): If True, displays a progress bar of the download to stderr
279 | """
280 | kwargs['width_per_group'] = 64 * 2
281 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
282 | pretrained, progress, **kwargs)
283 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/checkpoint_saver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from pathlib import Path
5 | import re
6 | import warnings
7 |
8 | from utils import find_checkpoint_path, find_all_files
9 | from utils.metrics import MetricMode, MetricMonitor
10 |
11 | # for type hint
12 | from typing import Dict, Any, Optional, Union
13 | from simple_estimator import SimPLEEstimator
14 | from utils import Logger
15 |
16 |
17 | class CheckpointSaver:
18 | def __init__(self,
19 | estimator: SimPLEEstimator,
20 | logger: Logger,
21 | checkpoint_metric: str,
22 | best_checkpoint_str: str,
23 | best_checkpoint_pattern: str,
24 | latest_checkpoint_str: str,
25 | latest_checkpoint_pattern: str,
26 | delayed_best_model_saving: bool = True):
27 | """
28 |
29 | Args:
30 | estimator: Estimator, used to get experiment related data
31 | logger: Logger, used for logging
32 | checkpoint_metric: save model when the best value of this key has changed.
33 | best_checkpoint_str: path str format to save best checkpoint file
34 | best_checkpoint_pattern: regex pattern used to find the best checkpoint file
35 | latest_checkpoint_str: path str format to save best checkpoint file
36 | latest_checkpoint_pattern: regex pattern used to find the latest checkpoint file
37 | delayed_best_model_saving: if True, save best model after calling save_latest_checkpoint()
38 | """
39 | self.absolute_best_path = "best_checkpoint.pth"
40 |
41 | # metrics to keep track of
42 | self.monitor = MetricMonitor()
43 | self.monitor.track(key="mean_acc",
44 | best_value=-np.inf,
45 | mode=MetricMode.MAX,
46 | prefix="test")
47 | self.monitor.track(key="mean_acc",
48 | best_value=-np.inf,
49 | mode=MetricMode.MAX,
50 | prefix="validation")
51 |
52 | self.checkpoint_metric = checkpoint_metric
53 |
54 | # checkpoint path patterns
55 | self.best_checkpoint_str = best_checkpoint_str
56 | self.best_checkpoint_pattern = re.compile(best_checkpoint_pattern)
57 |
58 | self.latest_checkpoint_str = latest_checkpoint_str
59 | self.latest_checkpoint_pattern = re.compile(latest_checkpoint_pattern)
60 |
61 | # save estimator and logger
62 | # this will recover best metrics and register log hooks
63 | self.estimator = estimator
64 | self.logger = logger
65 |
66 | # assign flags
67 | self.delayed_save_best_model = delayed_best_model_saving
68 | self.is_best_model = False
69 |
70 | @property
71 | def estimator(self) -> SimPLEEstimator:
72 | return self._estimator
73 |
74 | @estimator.setter
75 | def estimator(self, estimator: SimPLEEstimator) -> None:
76 | self._estimator = estimator
77 |
78 | # recover best value
79 | checkpoint_path = self.estimator.exp_args.checkpoint_path
80 | if checkpoint_path is not None:
81 | print(f"Recovering best metrics from {checkpoint_path}...")
82 | self.recover_metrics(torch.load(checkpoint_path, map_location=self.device))
83 |
84 | @property
85 | def logger(self) -> Logger:
86 | return self._logger
87 |
88 | @logger.setter
89 | def logger(self, logger: Logger) -> None:
90 | self._logger = logger
91 |
92 | # register log hooks
93 | print("Registering log hooks...")
94 | self.logger.register_log_hook(self.update_best_metric, logger=self.logger)
95 |
96 | @property
97 | def checkpoint_metric(self) -> str:
98 | return self._checkpoint_metric
99 |
100 | @checkpoint_metric.setter
101 | def checkpoint_metric(self, checkpoint_metric: str) -> None:
102 | assert checkpoint_metric in self.monitor, f"{checkpoint_metric} is not in metric monitor"
103 |
104 | self._checkpoint_metric = checkpoint_metric
105 |
106 | @property
107 | def log_dir(self) -> str:
108 | return self.estimator.exp_args.log_dir
109 |
110 | @property
111 | def best_full_checkpoint_str(self) -> str:
112 | return str(Path(self.log_dir) / self.best_checkpoint_str)
113 |
114 | @property
115 | def latest_full_checkpoint_str(self) -> str:
116 | return str(Path(self.log_dir) / self.latest_checkpoint_str)
117 |
118 | @property
119 | def device(self) -> torch.device:
120 | return self.estimator.device
121 |
122 | @property
123 | def global_step(self) -> int:
124 | return self.estimator.global_step
125 |
126 | @property
127 | def num_latest_checkpoints_kept(self) -> Optional[int]:
128 | return self.estimator.exp_args.num_latest_checkpoints_kept
129 |
130 | @property
131 | def is_save_latest_checkpoint(self) -> bool:
132 | return self.num_latest_checkpoints_kept is None or self.num_latest_checkpoints_kept > 0
133 |
134 | @property
135 | def is_remove_old_checkpoint(self) -> bool:
136 | return self.num_latest_checkpoints_kept is not None and self.num_latest_checkpoints_kept > 0
137 |
138 | def save_checkpoint(self,
139 | checkpoint: Dict[str, Any],
140 | checkpoint_path: Union[str, Path],
141 | is_logger_save: bool = False) -> Path:
142 | checkpoint_path = str(checkpoint_path)
143 | torch.save(checkpoint, checkpoint_path)
144 |
145 | print(f"Checkpoint saved to \"{checkpoint_path}\"", flush=True)
146 |
147 | if is_logger_save:
148 | self.logger.save(checkpoint_path)
149 |
150 | return Path(checkpoint_path)
151 |
152 | def save_best_checkpoint(self,
153 | checkpoint: Optional[Dict[str, any]] = None,
154 | is_logger_save: bool = False,
155 | **kwargs) -> Path:
156 | if checkpoint is None:
157 | checkpoint = self.get_checkpoint()
158 |
159 | checkpoint_path = self.save_checkpoint(checkpoint_path=self.best_full_checkpoint_str.format(**kwargs),
160 | checkpoint=checkpoint,
161 | is_logger_save=is_logger_save)
162 | # reset flag
163 | self.is_best_model = False
164 |
165 | return checkpoint_path
166 |
167 | def save_latest_checkpoint(self,
168 | checkpoint: Optional[Dict[str, any]] = None,
169 | is_logger_save: bool = False,
170 | **kwargs) -> Optional[Path]:
171 | checkpoint_path: Optional[Path] = None
172 |
173 | if self.is_save_latest_checkpoint:
174 | if checkpoint is None:
175 | checkpoint = self.get_checkpoint()
176 |
177 | # save new checkpoint
178 | checkpoint_path = self.save_checkpoint(checkpoint_path=self.latest_full_checkpoint_str.format(**kwargs),
179 | checkpoint=checkpoint,
180 | is_logger_save=is_logger_save)
181 |
182 | # cleanup old checkpoints
183 | self.cleanup_checkpoints()
184 |
185 | if self.delayed_save_best_model and self.is_best_model:
186 | self.save_best_checkpoint(**kwargs)
187 |
188 | return checkpoint_path
189 |
190 | def get_checkpoint(self) -> Dict[str, Any]:
191 | checkpoint = self.estimator.get_checkpoint()
192 |
193 | # add best metrics
194 | checkpoint.update({"monitor_state": self.monitor.state_dict()})
195 |
196 | return checkpoint
197 |
198 | def update_best_checkpoint(self) -> None:
199 | """
200 | Update the logged metrics for the best checkpoint
201 |
202 | Returns:
203 |
204 | """
205 | best_checkpoint_path = self.find_best_checkpoint_path()
206 |
207 | if best_checkpoint_path is None:
208 | warnings.warn("Cannot find best checkpoint")
209 | return
210 |
211 | best_checkpoint_path = str(best_checkpoint_path)
212 | best_checkpoint = torch.load(best_checkpoint_path, map_location=self.device)
213 |
214 | # update best metrics
215 | best_checkpoint.update({"monitor_state": self.monitor.state_dict()})
216 |
217 | self.save_checkpoint(checkpoint_path=str(Path(self.log_dir) / self.absolute_best_path),
218 | checkpoint=best_checkpoint)
219 |
220 | def find_best_checkpoint_path(self, checkpoint_dir: Optional[str] = None, ignore_absolute_best: bool = True) \
221 | -> Optional[Path]:
222 | if checkpoint_dir is None:
223 | checkpoint_dir = self.log_dir
224 |
225 | abs_best_path = Path(checkpoint_dir) / self.absolute_best_path
226 |
227 | if not ignore_absolute_best and abs_best_path.is_file():
228 | # if not ignoring absolute best path and the path is a file, return the absolute best file path
229 | return abs_best_path
230 |
231 | checkpoint_path = find_checkpoint_path(checkpoint_dir, step_filter=self.best_checkpoint_pattern)
232 |
233 | if checkpoint_path is None:
234 | checkpoint_path = self.find_latest_checkpoint_path(checkpoint_dir=checkpoint_dir)
235 |
236 | return checkpoint_path
237 |
238 | def find_latest_checkpoint_path(self, checkpoint_dir: Optional[str] = None) -> Optional[Path]:
239 | if checkpoint_dir is None:
240 | checkpoint_dir = self.log_dir
241 |
242 | return find_checkpoint_path(checkpoint_dir, step_filter=self.latest_checkpoint_pattern)
243 |
244 | def update_best_metric(self, log_info: Dict[str, Any], logger: Logger) -> None:
245 | updated_dict = self.monitor.update_metrics(log_info)
246 |
247 | for updated_key, new_best_value in updated_dict.items():
248 | metric_dict = self.monitor[updated_key]
249 |
250 | translated_key = metric_dict["key"]
251 |
252 | # if new_best_value is better than current best value
253 | logger.log({translated_key: new_best_value}, step=self.global_step)
254 |
255 | if self.checkpoint_metric == updated_key:
256 | self.is_best_model = True
257 |
258 | if self.is_best_model and not self.delayed_save_best_model:
259 | # if not delayed_save_best_model save, then save checkpoint
260 | self.save_best_checkpoint(global_step=self.global_step)
261 |
262 | def recover_checkpoint(self, checkpoint: Dict[str, Any], recover_optimizer: bool = True,
263 | recover_train_progress: bool = True) -> None:
264 | self.recover_metrics(checkpoint=checkpoint)
265 |
266 | self.estimator.load_checkpoint(checkpoint=checkpoint,
267 | recover_optimizer=recover_optimizer,
268 | recover_train_progress=recover_train_progress)
269 |
270 | def recover_metrics(self, checkpoint: Dict[str, Any]) -> None:
271 | if "monitor_state" in checkpoint:
272 | monitor_state = checkpoint["monitor_state"]
273 | else:
274 | # for backward compatibility
275 | monitor_state = {
276 | "validation/mean_acc": checkpoint.get("best_val_acc", -np.inf),
277 | "test/mean_acc": checkpoint.get("best_test_acc", -np.inf),
278 | }
279 |
280 | self.monitor.load_state_dict(monitor_state)
281 |
282 | def cleanup_checkpoints(self) -> None:
283 | if not self.is_remove_old_checkpoint:
284 | # do nothing if the model do not save latest checkpoints or if all checkpoints are kept
285 | return
286 |
287 | checkpoint_paths = find_all_files(checkpoint_dir=self.log_dir,
288 | search_pattern=self.latest_checkpoint_pattern)
289 |
290 | # sort by recency (largest step first)
291 | checkpoint_paths.sort(key=lambda x: int(re.search(self.latest_checkpoint_pattern, x.name).group(1)),
292 | reverse=True)
293 |
294 | # remove old checkpoints
295 | for checkpoint_path in checkpoint_paths[self.num_latest_checkpoints_kept:]:
296 | print(f"Removing old checkpoint \"{checkpoint_path}\"", flush=True)
297 | checkpoint_path.unlink()
298 |
--------------------------------------------------------------------------------