├── utils ├── transforms │ ├── __init__.py │ └── transforms.py ├── dataset │ ├── types.py │ ├── cifar100_datamodule.py │ ├── miniimagenet_datamodule.py │ ├── datasets.py │ ├── datamodule.py │ ├── domainnet_real_datamodule.py │ ├── svhn_datamodule.py │ ├── miniimagenet.py │ ├── cifar10_datamodule.py │ ├── __init__.py │ ├── domainnet_real.py │ ├── ssl_datamodule.py │ └── utils.py ├── types.py ├── loggers │ ├── __init__.py │ ├── print_logger.py │ ├── log_aggregator.py │ ├── wandb_logger.py │ └── logger.py ├── timing.py ├── file_io.py ├── utils.py ├── __init__.py └── metrics.py ├── media └── pairloss_anim.png ├── requirements.txt ├── loss ├── pair_loss │ ├── __init__.py │ ├── utils.py │ └── pair_loss.py ├── utils │ ├── __init__.py │ └── utils.py ├── types.py ├── __init__.py ├── loss.py └── visualization.py ├── models ├── mixmatch │ ├── types.py │ ├── __init__.py │ ├── mixmatch.py │ ├── utils.py │ ├── simple_mixmatch.py │ └── mixmatch_base.py ├── types.py ├── models │ ├── utils.py │ ├── __init__.py │ ├── ema.py │ ├── wide_resnet.py │ └── resnet.py ├── optimization │ ├── types.py │ ├── __init__.py │ └── lr_scheduler.py ├── rampup.py ├── augmentation │ ├── augmenter.py │ ├── __init__.py │ └── randaugment.py ├── utils.py └── __init__.py ├── runs ├── miniimagenet_args.txt ├── cifar10_args.txt └── cifar100_args.txt ├── .gitignore ├── environment.yaml ├── .gitattributes ├── example_logger_config.yaml ├── main_ddp.py ├── ablation_estimator.py ├── main.py ├── README.md ├── LICENSE └── checkpoint_saver.py /utils/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | -------------------------------------------------------------------------------- /media/pairloss_anim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijian-hu/SimPLE/HEAD/media/pairloss_anim.png -------------------------------------------------------------------------------- /utils/dataset/types.py: -------------------------------------------------------------------------------- 1 | from .utils import BatchType, LoaderType, BatchGeneratorType 2 | 3 | __all__ = [ 4 | # types 5 | "BatchType", 6 | "LoaderType", 7 | "BatchGeneratorType", 8 | ] 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | future 2 | h5py 3 | matplotlib 4 | numpy 5 | pillow 6 | plotly>=4.0.0 7 | pyyaml 8 | pandas 9 | scikit-learn 10 | scipy 11 | torch>=1.6.0,<=1.9.0 12 | torchvision>=0.7.0,<=0.10.0 13 | kornia==0.5.0 14 | wandb 15 | tqdm 16 | -------------------------------------------------------------------------------- /loss/pair_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .pair_loss import PairLoss 2 | 3 | # modules 4 | from . import utils 5 | 6 | 7 | __all__ = [ 8 | # modules 9 | "utils", 10 | 11 | # classes 12 | "PairLoss", 13 | 14 | # functions 15 | ] 16 | -------------------------------------------------------------------------------- /utils/types.py: -------------------------------------------------------------------------------- 1 | from .dataset.types import BatchType, LoaderType, BatchGeneratorType 2 | from .metrics import MetricDictType 3 | 4 | __all__ = [ 5 | # types 6 | "BatchType", 7 | "LoaderType", 8 | "BatchGeneratorType", 9 | "MetricDictType", 10 | ] 11 | -------------------------------------------------------------------------------- /models/mixmatch/types.py: -------------------------------------------------------------------------------- 1 | from .mixmatch import MixMatch 2 | from .simple_mixmatch import SimPLE 3 | from .mixmatch_base import MixMatchBase as MixMatchEnhanced 4 | 5 | from typing import Union 6 | 7 | MixMatchFunctionType = Union[MixMatch, MixMatchEnhanced, SimPLE] 8 | 9 | __all__ = [ 10 | "MixMatchFunctionType", 11 | ] 12 | -------------------------------------------------------------------------------- /utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from .wandb_logger import WandbLogger 3 | from .print_logger import PrintLogger 4 | 5 | from .log_aggregator import LogAggregator 6 | 7 | __all__ = [ 8 | # classes 9 | "Logger", 10 | "WandbLogger", 11 | "PrintLogger", 12 | "LogAggregator", 13 | ] 14 | -------------------------------------------------------------------------------- /models/types.py: -------------------------------------------------------------------------------- 1 | from .optimization.types import LRSchedulerType, ParametersType, ParametersGroupType, OptimizerParametersType 2 | from .mixmatch.types import MixMatchFunctionType 3 | 4 | __all__ = [ 5 | "LRSchedulerType", 6 | "ParametersType", 7 | "ParametersGroupType", 8 | "OptimizerParametersType", 9 | "MixMatchFunctionType", 10 | ] 11 | -------------------------------------------------------------------------------- /runs/miniimagenet_args.txt: -------------------------------------------------------------------------------- 1 | -e 2 | simple 3 | --num-epochs 4 | 2048 5 | --dataset 6 | miniimagenet 7 | --labeled-train-size 8 | 4000 9 | --validation-size 10 | 7200 11 | --batch-size 12 | 16 13 | --K-strong 14 | 7 15 | --lambda-u 16 | 300 17 | --lambda-pair 18 | 300 19 | --conf-threshold 20 | 0.95 21 | --sim-threshold 22 | 0.9 23 | --ema-type 24 | full 25 | --max-grad-norm 26 | 5 -------------------------------------------------------------------------------- /runs/cifar10_args.txt: -------------------------------------------------------------------------------- 1 | -e 2 | simple 3 | --num-epochs 4 | 1024 5 | --dataset 6 | cifar10 7 | --K-strong 8 | 7 9 | --labeled-train-size 10 | 4000 11 | --lr 12 | 0.03 13 | --weight-decay 14 | 5e-4 15 | --lambda-u 16 | 75 17 | --lambda-pair 18 | 75 19 | --conf-threshold 20 | 0.95 21 | --sim-threshold 22 | 0.9 23 | --max-grad-norm 24 | 5 25 | --optimizer-type 26 | sgd 27 | --lr-scheduler-type 28 | cosine_decay -------------------------------------------------------------------------------- /runs/cifar100_args.txt: -------------------------------------------------------------------------------- 1 | -e 2 | simple 3 | --dataset 4 | cifar100 5 | --K-strong 6 | 4 7 | --labeled-train-size 8 | 10000 9 | --lr 10 | 0.03 11 | --weight-decay 12 | 0.001 13 | --lambda-u 14 | 150 15 | --lambda-pair 16 | 150 17 | --conf-threshold 18 | 0.95 19 | --sim-threshold 20 | 0.9 21 | --max-grad-norm 22 | 5 23 | --model-type 24 | wrn28-8 25 | --optimizer-type 26 | sgd 27 | --lr-scheduler-type 28 | cosine_decay -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # CMake 2 | cmake-build-*/ 3 | 4 | # JetBrains IDE 5 | .idea/ 6 | 7 | # VS Code 8 | .vscode/ 9 | 10 | # Mac 11 | *.DS_Store 12 | 13 | # Node.js 14 | node_modules/ 15 | package-lock.json 16 | 17 | # Python 18 | __pycache__/ 19 | *.pyc 20 | 21 | # Jupyter notebook 22 | .ipynb_checkpoints 23 | 24 | # pyenv files 25 | .python-version 26 | 27 | # data files 28 | data/ 29 | logs/ 30 | wandb/ 31 | lightning_logs/ 32 | -------------------------------------------------------------------------------- /models/mixmatch/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixmatch import MixMatch 2 | from .mixmatch_base import MixMatchBase as MixMatchEnhanced 3 | from .simple_mixmatch import SimPLE 4 | 5 | # modules 6 | from . import utils 7 | from . import types 8 | 9 | __all__ = [ 10 | # modules 11 | "utils", 12 | "types", 13 | 14 | # classes 15 | "MixMatch", 16 | "MixMatchEnhanced", 17 | "SimPLE", 18 | 19 | # functions 20 | ] 21 | -------------------------------------------------------------------------------- /loss/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import (reduce_tensor, to_tensor, bha_coeff_log_prob, bha_coeff, bha_coeff_distance, 2 | l2_distance, pairwise_apply) 3 | 4 | __all__ = [ 5 | # modules 6 | # classes 7 | # functions 8 | "reduce_tensor", 9 | "to_tensor", 10 | "bha_coeff_log_prob", 11 | "bha_coeff", 12 | "bha_coeff_distance", 13 | "l2_distance", 14 | "pairwise_apply", 15 | ] 16 | -------------------------------------------------------------------------------- /models/models/utils.py: -------------------------------------------------------------------------------- 1 | # for type hint 2 | from typing import Union 3 | 4 | from torch.nn import Module, DataParallel 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | ModelType = Union[Module, DataParallel, DistributedDataParallel] 8 | 9 | 10 | def unwrap_model(model: ModelType) -> Module: 11 | if hasattr(model, "module"): 12 | return model.module 13 | else: 14 | return model 15 | 16 | 17 | __all__ = [ 18 | "unwrap_model", 19 | ] 20 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: simple 2 | channels: 3 | - defaults 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - python=3.8 8 | - pip 9 | - pillow 10 | - tqdm 11 | - pyyaml 12 | - matplotlib 13 | - future 14 | - h5py 15 | - numpy 16 | - matplotlib 17 | - plotly 18 | - pandas 19 | - scipy 20 | - scikit-learn 21 | - cudatoolkit=10.2 22 | - pytorch::torchvision 23 | - pytorch::pytorch=1.6.0 24 | - pip: 25 | - kornia==0.5.0 26 | - wandb 27 | -------------------------------------------------------------------------------- /models/optimization/types.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Iterable, Dict 2 | from torch.optim.lr_scheduler import LambdaLR, StepLR 3 | from torch.nn import Parameter 4 | 5 | LRSchedulerType = Union[LambdaLR, StepLR] 6 | ParametersType = Iterable[Parameter] 7 | ParametersGroupType = Iterable[Dict[str, Union[Parameter, float, int]]] 8 | OptimizerParametersType = Union[ParametersType, ParametersGroupType] 9 | 10 | __all__ = [ 11 | "LRSchedulerType", 12 | "ParametersType", 13 | "ParametersGroupType", 14 | "OptimizerParametersType", 15 | ] 16 | -------------------------------------------------------------------------------- /utils/timing.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | from functools import wraps 3 | from datetime import timedelta 4 | 5 | 6 | def timing(func): 7 | # see https://stackoverflow.com/a/27737385/5838091 8 | @wraps(func) 9 | def wrap(*args, **kwargs): 10 | start_time = timer() 11 | 12 | result = func(*args, **kwargs) 13 | 14 | time_elapsed = timer() - start_time 15 | print(f"Total time for {func.__name__}: {str(timedelta(seconds=time_elapsed))}", flush=True) 16 | 17 | return result 18 | 19 | return wrap 20 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # This is an example git attribute file 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | * text=auto 5 | 6 | # Explicitly declare text files you want to always be normalized and converted 7 | # to native line endings on checkout. 8 | *.c text 9 | *.h text 10 | *.py text 11 | 12 | # Declare files that will always have CRLF line endings on checkout. 13 | *.sln text eol=crlf 14 | 15 | # Declare files that will always have LF line endings on checkout. 16 | *.sh text eol=lf 17 | 18 | # Denote all files that are truly binary and should not be modified. 19 | *.png binary 20 | *.jpg binary 21 | -------------------------------------------------------------------------------- /example_logger_config.yaml: -------------------------------------------------------------------------------- 1 | # A short display name for this run, which is how you'll identify this run in the UI 2 | name: null 3 | 4 | # An entity is a username or team name where you're sending runs 5 | entity: null 6 | 7 | # The name of the project where you're sending the new run 8 | project: null 9 | 10 | # A list of strings, which will populate the list of tags on this run in the UI 11 | tags: [ ] 12 | 13 | # A longer description of the run, like a -m commit message in git 14 | notes: null 15 | 16 | # if set to True, the run auto resumes; can also be a unique string for manual resuming 17 | resume: False 18 | 19 | # Can be "online", "offline" or "disabled". Defaults to online 20 | mode: online 21 | -------------------------------------------------------------------------------- /loss/types.py: -------------------------------------------------------------------------------- 1 | # for type hint 2 | from typing import Union, Dict 3 | from torch import Tensor 4 | from plotly.graph_objects import Figure 5 | from wandb import Histogram 6 | 7 | from .utils import (bha_coeff, bha_coeff_distance, l2_distance) 8 | from .loss import softmax_cross_entropy_loss, bha_coeff_loss, l2_dist_loss 9 | 10 | LogDictType = Dict[str, Tensor] 11 | PlotDictType = Dict[str, Union[Figure, Histogram]] 12 | LossInfoType = Union[Dict[str, Union[LogDictType, PlotDictType]], LogDictType] 13 | 14 | SimilarityType = Union[bha_coeff] 15 | DistanceType = Union[bha_coeff_distance, l2_distance] 16 | DistanceLossType = Union[softmax_cross_entropy_loss, l2_dist_loss, bha_coeff_loss] 17 | 18 | __all__ = [ 19 | "LogDictType", 20 | "PlotDictType", 21 | "LossInfoType", 22 | "SimilarityType", 23 | "DistanceType", 24 | "DistanceLossType", 25 | ] 26 | -------------------------------------------------------------------------------- /loss/pair_loss/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # for type hint 4 | from torch import Tensor 5 | 6 | 7 | def get_pair_indices(inputs: Tensor, ordered_pair: bool = False) -> Tensor: 8 | """ 9 | Get pair indices between each element in input tensor 10 | 11 | Args: 12 | inputs: input tensor 13 | ordered_pair: if True, will return ordered pairs. (e.g. both inputs[i,j] and inputs[j,i] are included) 14 | 15 | Returns: a tensor of shape (K, 2) where K = choose(len(inputs),2) if ordered_pair is False. 16 | Else K = 2 * choose(len(inputs),2). Each row corresponds to two indices in inputs. 17 | 18 | """ 19 | indices = torch.combinations(torch.tensor(range(len(inputs))), r=2) 20 | 21 | if ordered_pair: 22 | # make pairs ordered (e.g. both (0,1) and (1,0) are included) 23 | indices = torch.cat((indices, indices[:, [1, 0]]), dim=0) 24 | 25 | return indices 26 | -------------------------------------------------------------------------------- /models/mixmatch/mixmatch.py: -------------------------------------------------------------------------------- 1 | from .mixmatch_base import MixMatchBase 2 | 3 | # for type hint 4 | from torch.nn import Module 5 | 6 | 7 | class MixMatch(MixMatchBase): 8 | def __init__(self, 9 | augmenter: Module, 10 | num_classes: int, 11 | temperature: float, 12 | num_augmentations: int, 13 | alpha: float, 14 | train_label_guessing: bool): 15 | super(MixMatch, self).__init__(augmenter=augmenter, 16 | strong_augmenter=None, 17 | num_classes=num_classes, 18 | temperature=temperature, 19 | num_augmentations=num_augmentations, 20 | num_strong_augmentations=0, 21 | alpha=alpha, 22 | is_strong_augment_x=False, 23 | train_label_guessing=train_label_guessing) 24 | -------------------------------------------------------------------------------- /utils/dataset/cifar100_datamodule.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR100 2 | 3 | from .cifar10_datamodule import CIFAR10DataModule 4 | 5 | # for type hint 6 | from typing import Optional 7 | 8 | 9 | class CIFAR100DataModule(CIFAR10DataModule): 10 | num_classes: int = 100 11 | 12 | total_train_size: int = 50_000 13 | total_test_size: int = 10_000 14 | 15 | DATASET = CIFAR100 16 | 17 | def __init__(self, 18 | data_dir: str, 19 | labeled_train_size: int, 20 | validation_size: int, 21 | unlabeled_train_size: Optional[int] = None, 22 | **kwargs): 23 | super(CIFAR100DataModule, self).__init__( 24 | data_dir=data_dir, 25 | labeled_train_size=labeled_train_size, 26 | validation_size=validation_size, 27 | unlabeled_train_size=unlabeled_train_size, 28 | **kwargs) 29 | 30 | # dataset stats 31 | # CIFAR-100 mean, std values in CHW 32 | self.dataset_mean = [0.44091784, 0.50707516, 0.48654887] 33 | self.dataset_std = [0.27615047, 0.26733429, 0.25643846] 34 | -------------------------------------------------------------------------------- /main_ddp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed 3 | 4 | import warnings 5 | 6 | from main import main as main_single_thread 7 | from utils import get_args, timing, set_random_seed 8 | 9 | # for type hint 10 | from typing import Optional 11 | from argparse import Namespace 12 | from utils.dataset import SSLDataModule 13 | 14 | IS_DISTRIBUTED_AVAILABLE = distributed.is_available() 15 | 16 | 17 | @timing 18 | def main(args: Namespace, datamodule: Optional[SSLDataModule] = None): 19 | if IS_DISTRIBUTED_AVAILABLE and torch.cuda.is_available() and torch.cuda.device_count() > 1: 20 | distributed.init_process_group(backend='nccl') 21 | 22 | torch.cuda.set_device(args.local_rank) 23 | device = torch.device("cuda", args.local_rank) 24 | else: 25 | warnings.warn("Cannot initializePyTorch distributed training, fallback to single GPU training") 26 | device = None 27 | 28 | main_single_thread(args=args, datamodule=datamodule, device=device) 29 | 30 | 31 | if __name__ == '__main__': 32 | parsed_args = get_args() 33 | 34 | # fix random seed 35 | set_random_seed(parsed_args.seed, is_cudnn_deterministic=parsed_args.debug_mode) 36 | 37 | main(parsed_args) 38 | -------------------------------------------------------------------------------- /models/mixmatch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import functional as F 4 | 5 | from functools import reduce 6 | 7 | from ..utils import set_model_mode 8 | 9 | # for type hint 10 | from typing import Sequence, Tuple 11 | from torch import Tensor 12 | from torch.nn import Module 13 | 14 | 15 | def label_guessing(batches: Sequence[Tensor], model: Module, is_train_mode: bool = True) -> Tensor: 16 | with set_model_mode(model, is_train_mode): 17 | with torch.no_grad(): 18 | probs = [F.softmax(model(batch), dim=1) for batch in batches] 19 | mean_prob = reduce(lambda x, y: x + y, probs) / len(batches) 20 | 21 | return mean_prob 22 | 23 | 24 | def sharpen(x: Tensor, temperature: float) -> Tensor: 25 | sharpened_x = x ** (1 / temperature) 26 | return sharpened_x / sharpened_x.sum(dim=1, keepdim=True) 27 | 28 | 29 | def mixup(x1: Tensor, x2: Tensor, y1: Tensor, y2: Tensor, alpha: float) -> Tuple[Tensor, Tensor]: 30 | # lambda is a reserved word in python, substituting by beta 31 | lam = np.random.beta(alpha, alpha) 32 | lam = max(lam, 1 - lam) 33 | x = lam * x1 + (1 - lam) * x2 34 | y = lam * y1 + (1 - lam) * y2 35 | return x, y 36 | -------------------------------------------------------------------------------- /models/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, AdamW 2 | 3 | from .lr_scheduler import build_lr_scheduler 4 | 5 | from . import lr_scheduler 6 | from . import types 7 | 8 | # for type hint 9 | from torch.optim.optimizer import Optimizer 10 | 11 | from .types import OptimizerParametersType 12 | 13 | 14 | def build_optimizer(optimizer_type: str, 15 | params: OptimizerParametersType, 16 | learning_rate: float, 17 | weight_decay: float, 18 | momentum: float) -> Optimizer: 19 | if optimizer_type == "sgd": 20 | optimizer = SGD(params, 21 | lr=learning_rate, 22 | weight_decay=weight_decay, 23 | momentum=momentum, 24 | nesterov=True) 25 | 26 | elif optimizer_type == "adamw": 27 | optimizer = AdamW(params, lr=learning_rate, weight_decay=weight_decay) 28 | 29 | else: 30 | raise NotImplementedError(f"\"{optimizer_type}\" is not a supported optimizer type") 31 | 32 | return optimizer 33 | 34 | 35 | __all__ = [ 36 | # classes 37 | # modules 38 | "lr_scheduler", 39 | "types", 40 | 41 | # functions 42 | "build_optimizer", 43 | "build_lr_scheduler", 44 | ] 45 | -------------------------------------------------------------------------------- /utils/dataset/miniimagenet_datamodule.py: -------------------------------------------------------------------------------- 1 | from .miniimagenet import MiniImageNet 2 | from .cifar10_datamodule import CIFAR10DataModule 3 | 4 | # for type hint 5 | from typing import Optional, Tuple 6 | 7 | 8 | class MiniImageNetDataModule(CIFAR10DataModule): 9 | num_classes: int = 100 10 | 11 | total_train_size: int = 50_000 12 | total_test_size: int = 10_000 13 | 14 | DATASET = MiniImageNet 15 | 16 | def __init__(self, 17 | data_dir: str, 18 | labeled_train_size: int, 19 | validation_size: int, 20 | unlabeled_train_size: Optional[int] = None, 21 | dims: Optional[Tuple[int, ...]] = None, 22 | **kwargs): 23 | if dims is None: 24 | dims = (3, 84, 84) 25 | 26 | super(MiniImageNetDataModule, self).__init__( 27 | data_dir=data_dir, 28 | labeled_train_size=labeled_train_size, 29 | validation_size=validation_size, 30 | unlabeled_train_size=unlabeled_train_size, 31 | dims=dims, 32 | **kwargs) 33 | 34 | # dataset stats 35 | # Mini-ImageNet mean, std values in CHW 36 | self.dataset_mean = [0.40233998, 0.47269102, 0.44823737] 37 | self.dataset_std = [0.2884859, 0.28327602, 0.27511246] 38 | -------------------------------------------------------------------------------- /utils/file_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import yaml 3 | 4 | from pathlib import Path 5 | import re 6 | 7 | # for type hint 8 | from typing import Union, Pattern, Optional, Dict, Any, List 9 | 10 | 11 | def find_checkpoint_path(checkpoint_dir: Union[str, Path], step_filter: Union[Pattern, str]) -> Optional[Path]: 12 | checkpoint_dir_path = Path(checkpoint_dir) 13 | output_file = None 14 | max_step_num = -np.inf 15 | 16 | for file_item in checkpoint_dir_path.iterdir(): 17 | if not file_item.is_file(): 18 | continue 19 | 20 | search_result = re.search(step_filter, file_item.name) 21 | if search_result is None: 22 | continue 23 | 24 | step_num = int(search_result.group(1)) 25 | if step_num > max_step_num: 26 | max_step_num = step_num 27 | output_file = file_item 28 | 29 | return output_file 30 | 31 | 32 | def read_yaml(path: Union[str, Path]) -> Dict[str, Any]: 33 | with open(path, "r") as f: 34 | return yaml.safe_load(f) 35 | 36 | 37 | def find_all_files(checkpoint_dir: Union[str, Path], search_pattern: Union[Pattern, str]) -> List[Path]: 38 | checkpoint_dir_path = Path(checkpoint_dir) 39 | 40 | return [file_item for file_item in checkpoint_dir_path.iterdir() 41 | if file_item.is_file() and re.search(pattern=search_pattern, string=file_item.name) is not None] 42 | -------------------------------------------------------------------------------- /utils/dataset/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, Subset 2 | 3 | # for type hint 4 | from torchvision.datasets import VisionDataset 5 | from typing import Union, Optional, Callable 6 | 7 | 8 | class LabeledDataset(VisionDataset): 9 | def __init__(self, 10 | dataset: Union[Dataset, Subset], 11 | root: str, 12 | min_size: int = 0, 13 | transforms: Optional[Callable] = None, 14 | transform: Optional[Callable] = None, 15 | target_transform: Optional[Callable] = None): 16 | super().__init__(root, transforms, transform, target_transform) 17 | 18 | self.dataset = dataset 19 | 20 | self.min_size = min_size 21 | 22 | @property 23 | def min_size(self) -> int: 24 | return self._min_size if len(self.dataset) > 0 else 0 25 | 26 | @min_size.setter 27 | def min_size(self, min_size: int) -> None: 28 | if min_size < 0: 29 | raise ValueError(f"only non-negative min_size is allowed") 30 | 31 | self._min_size = min_size 32 | 33 | def __getitem__(self, index: int): 34 | img, target = self.dataset[index % len(self.dataset)] 35 | 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | if self.target_transform is not None: 40 | target = self.target_transform(target) 41 | 42 | return img, target 43 | 44 | def __len__(self): 45 | return max(len(self.dataset), self.min_size) 46 | -------------------------------------------------------------------------------- /utils/dataset/datamodule.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | # for type hint 4 | from typing import Optional, Tuple, Callable, Union, List, Generator, Any 5 | from torch.utils.data import DataLoader 6 | 7 | 8 | class DataModule(ABC): 9 | def __init__(self, 10 | train_transform: Optional[Callable] = None, 11 | val_transform: Optional[Callable] = None, 12 | test_transform: Optional[Callable] = None, 13 | dims: Optional[Tuple[int, ...]] = None): 14 | self.train_transform = train_transform 15 | self.val_transform = val_transform 16 | self.test_transform = test_transform 17 | self.dims = dims 18 | 19 | @abstractmethod 20 | def prepare_data(self, *args, **kwargs): 21 | # download, split, etc... 22 | # only called on 1 GPU/TPU in distributed 23 | pass 24 | 25 | @abstractmethod 26 | def setup(self, stage: Optional[str] = None): 27 | # make assignments here (val/train/test split) 28 | # called on every process in DDP 29 | pass 30 | 31 | @abstractmethod 32 | def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 33 | pass 34 | 35 | @abstractmethod 36 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 37 | pass 38 | 39 | @abstractmethod 40 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 41 | pass 42 | 43 | def get_train_batch(self, *args, **kwargs) -> Generator[Any, Any, Any]: 44 | pass 45 | -------------------------------------------------------------------------------- /utils/dataset/domainnet_real_datamodule.py: -------------------------------------------------------------------------------- 1 | from .cifar10_datamodule import CIFAR10DataModule 2 | from .domainnet_real import DomainNetReal 3 | from .utils import per_class_random_split 4 | 5 | # for type hint 6 | from typing import Optional, Tuple, List 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class DomainNetRealDataModule(CIFAR10DataModule): 11 | num_classes: int = 345 12 | 13 | total_train_size: int = 120_906 14 | total_test_size: int = 52_041 15 | 16 | DATASET = DomainNetReal 17 | 18 | def __init__(self, 19 | data_dir: str, 20 | labeled_train_size: int, 21 | validation_size: int, 22 | unlabeled_train_size: Optional[int] = None, 23 | dims: Optional[Tuple[int, ...]] = None, 24 | **kwargs): 25 | if dims is None: 26 | dims = (3, 224, 224) 27 | 28 | super(DomainNetRealDataModule, self).__init__( 29 | data_dir=data_dir, 30 | labeled_train_size=labeled_train_size, 31 | validation_size=validation_size, 32 | unlabeled_train_size=unlabeled_train_size, 33 | dims=dims, 34 | **kwargs) 35 | 36 | # dataset stats 37 | # DomainNet-Real mean, std values in CHW 38 | self.dataset_mean = [0.54873651, 0.60511086, 0.5840634] 39 | self.dataset_std = [0.33955591, 0.32637834, 0.31887854] 40 | 41 | def split_dataset(self, dataset: Dataset, **kwargs) -> List[Dataset]: 42 | split_kwargs = dict(lengths=[self.validation_size, self.labeled_train_size], 43 | num_classes=self.num_classes, 44 | uneven_split=True) 45 | 46 | # update split arguments 47 | split_kwargs.update(kwargs) 48 | 49 | return per_class_random_split(dataset, **split_kwargs) 50 | -------------------------------------------------------------------------------- /ablation_estimator.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from simple_estimator import SimPLEEstimator 4 | 5 | # for type hint 6 | from typing import Tuple, Dict 7 | from torch import Tensor 8 | from loss.types import LossInfoType 9 | 10 | 11 | class AblationEstimator(SimPLEEstimator): 12 | def training_step(self, batch: Tuple[Tuple[Tensor, Tensor], ...], batch_idx: int) -> Tuple[Tensor, LossInfoType]: 13 | outputs = self.preprocess_batch(batch, batch_idx) 14 | 15 | model_outputs = self.compute_train_logits(x_inputs=outputs["x_inputs"]) 16 | 17 | outputs.update(model_outputs) 18 | 19 | # calculate loss 20 | return self.compute_train_loss( 21 | x_logits=outputs["x_logits"], 22 | x_targets=outputs["x_targets"] 23 | ) 24 | 25 | def preprocess_batch(self, batch: Tuple[Tuple[Tensor, Tensor], ...], batch_idx: int) -> Dict[str, Tensor]: 26 | # unpack batch 27 | (x_inputs, x_targets), (_, _) = batch 28 | 29 | # load data to device 30 | x_inputs = x_inputs.to(self.device) 31 | x_targets = x_targets.to(self.device) 32 | 33 | # apply augmentations 34 | x_inputs = self.augmenter(x_inputs) 35 | 36 | return dict( 37 | x_inputs=x_inputs, 38 | x_targets=x_targets, 39 | ) 40 | 41 | def compute_train_logits(self, x_inputs: Tensor, *args: Tensor) -> Dict[str, Tensor]: 42 | return dict(x_logits=self.model(x_inputs)) 43 | 44 | def compute_train_loss(self, 45 | x_logits: Tensor, 46 | x_targets: Tensor, 47 | *args: Tensor, 48 | **kwargs) -> Tuple[Tensor, LossInfoType]: 49 | loss = F.cross_entropy(x_logits, x_targets, reduction="mean") 50 | 51 | log_info = { 52 | "loss": loss.detach().clone(), 53 | "loss_x": loss.detach().clone(), 54 | } 55 | 56 | return loss, {"log": log_info, "plot": {}} 57 | -------------------------------------------------------------------------------- /utils/loggers/print_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from .logger import Logger 4 | 5 | # for type hint 6 | from typing import Any, Dict, Optional, Union 7 | from argparse import Namespace 8 | 9 | from torch.nn import Module 10 | 11 | 12 | class PrintLogger(Logger): 13 | def __init__(self, 14 | log_dir: str, 15 | config: Union[Namespace, Dict[str, Any]], 16 | is_display_plots: bool = False): 17 | super().__init__( 18 | log_dir=log_dir, 19 | config=config, 20 | log_info_key_map={ 21 | "top1_acc": "mean_acc", 22 | "top5_acc": "mean_top5_acc", 23 | }) 24 | 25 | self.is_display_plots = is_display_plots 26 | 27 | def log(self, 28 | log_info: Dict[str, Any], 29 | step: Optional[int] = None, 30 | file=sys.stdout, 31 | flush: bool = False, 32 | sep: str = "\n\t", 33 | prefix: Optional[str] = None, 34 | log_info_override: Optional[Dict[str, Any]] = None, 35 | **kwargs): 36 | # process log_info 37 | log_info = self.process_log_info(log_info, prefix=prefix, log_info_override=log_info_override) 38 | 39 | if len(log_info) == 0: 40 | return 41 | 42 | log_info, plot_info = self.separate_plot(log_info) 43 | 44 | if self.is_display_plots: 45 | # display plots 46 | for key, fig in plot_info.items(): 47 | if hasattr(fig, "show"): 48 | fig.show() 49 | 50 | output_str = sep.join(f"{str(key)}: {str(val)}" for key, val in log_info.items()) 51 | if step is not None: 52 | print(f"Step {step}:\n\t{output_str}", file=file, flush=flush) 53 | else: 54 | print(output_str, file=file, flush=flush) 55 | 56 | # invoke all log hook functions 57 | self.call_log_hooks(log_info) 58 | 59 | def watch(self, model: Module, **kwargs): 60 | # TODO: implement 61 | pass 62 | 63 | def save(self, output_path: str): 64 | # TODO: implement 65 | pass 66 | -------------------------------------------------------------------------------- /utils/dataset/svhn_datamodule.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import SVHN 2 | 3 | from .cifar10_datamodule import CIFAR10DataModule 4 | from .utils import per_class_random_split 5 | 6 | # for type hint 7 | from typing import Optional, List 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class SVHNDataModule(CIFAR10DataModule): 12 | num_classes: int = 10 13 | 14 | total_train_size: int = 73_257 15 | total_test_size: int = 26_032 16 | 17 | DATASET = SVHN 18 | 19 | def __init__(self, 20 | data_dir: str, 21 | labeled_train_size: int, 22 | validation_size: int, 23 | unlabeled_train_size: Optional[int] = None, 24 | **kwargs): 25 | super(SVHNDataModule, self).__init__( 26 | data_dir=data_dir, 27 | labeled_train_size=labeled_train_size, 28 | validation_size=validation_size, 29 | unlabeled_train_size=unlabeled_train_size, 30 | **kwargs) 31 | 32 | # dataset stats 33 | # SVHN mean, std values in CHW 34 | self.dataset_mean = [0.4376821, 0.4437697, 0.47280442] 35 | self.dataset_std = [0.19803012, 0.20101562, 0.19703614] 36 | 37 | def prepare_data(self, *args, **kwargs): 38 | self.DATASET(root=self.data_dir, split="train", download=True) 39 | self.DATASET(root=self.data_dir, split="test", download=True) 40 | 41 | def setup(self, stage: Optional[str] = None): 42 | full_train_set = self.DATASET(root=self.data_dir, split="train") 43 | full_test_set = self.DATASET(root=self.data_dir, split="test") 44 | 45 | self.setup_helper(full_train_set=full_train_set, full_test_set=full_test_set, stage=stage) 46 | 47 | def split_dataset(self, dataset: Dataset, **kwargs) -> List[Dataset]: 48 | split_kwargs = dict(lengths=[self.validation_size, self.labeled_train_size], 49 | num_classes=self.num_classes, 50 | uneven_split=True) 51 | 52 | # update split arguments 53 | split_kwargs.update(kwargs) 54 | 55 | return per_class_random_split(dataset, **split_kwargs) 56 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import logging 5 | 6 | # for type hint 7 | from typing import Union, Dict, Any, Set, Optional 8 | from torch import Tensor 9 | 10 | 11 | def str_to_bool(input_str: Union[str]) -> Union[str, bool]: 12 | """ 13 | If input_str is "True" or "False" (case insensitive and allows spaces on the side). Else, return the input_str 14 | 15 | Args: 16 | input_str: 17 | 18 | Returns: a boolean value if input_str is "True" or "False"; else, return the input string 19 | 20 | """ 21 | comp_str = input_str.lower().strip() 22 | 23 | if comp_str == "true": 24 | return True 25 | elif comp_str == "false": 26 | return False 27 | else: 28 | return input_str 29 | 30 | 31 | def get_device(device_id: str) -> torch.device: 32 | # update device 33 | if device_id != "cpu": 34 | if torch.cuda.is_available(): 35 | device = torch.device("cuda") 36 | else: 37 | logging.warning(f"device \"{device_id}\" is not available") 38 | device = torch.device("cpu") 39 | else: 40 | device = torch.device("cpu") 41 | 42 | return device 43 | 44 | 45 | def dict_add_prefix(input_dict: Dict[str, Any], prefix: str, separator: str = "/") -> Dict[str, Any]: 46 | return {f"{prefix}{separator}{key}": val for key, val in input_dict.items()} 47 | 48 | 49 | def filter_dict(input_dict: Dict[str, Any], excluded_keys: Optional[Set[str]] = None, 50 | included_keys: Optional[Set[str]] = None) -> Dict[str, Any]: 51 | if excluded_keys is None: 52 | excluded_keys = set() 53 | 54 | if included_keys is not None: 55 | input_dict = {k: v for k, v in input_dict.items() if k in included_keys} 56 | 57 | return {k: v for k, v in input_dict.items() if k not in excluded_keys} 58 | 59 | 60 | def detorch(inputs: Union[Tensor, np.ndarray, float, int, bool]) -> Union[np.ndarray, float, int, bool]: 61 | if isinstance(inputs, Tensor): 62 | outputs = inputs.detach().cpu().clone().numpy() 63 | else: 64 | outputs = inputs 65 | 66 | if isinstance(outputs, np.ndarray) and outputs.size == 1: 67 | outputs = outputs.item() 68 | 69 | return outputs 70 | -------------------------------------------------------------------------------- /models/rampup.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | # for type hint 4 | from typing import Dict, Any, Set, Tuple, Optional 5 | 6 | 7 | class RampUp(ABC): 8 | def __init__(self, length: int, current: int = 0): 9 | self.current = current 10 | self.length = length 11 | 12 | @abstractmethod 13 | def __call__(self, current: Optional[int] = None, is_step: bool = True) -> float: 14 | pass 15 | 16 | def state_dict(self) -> Dict[str, Any]: 17 | return dict(current=self.current, length=self.length) 18 | 19 | def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True): 20 | if strict: 21 | is_equal, incompatible_keys = self._verify_state_dict(state_dict) 22 | assert is_equal, f"loaded state dict contains incompatible keys: {incompatible_keys}" 23 | 24 | # for attr_name, attr_value in state_dict.items(): 25 | # if attr_name in self.__dict__: 26 | # self.__dict__[attr_name] = attr_value 27 | 28 | self.current = state_dict["current"] 29 | self.length = state_dict["length"] 30 | 31 | def _verify_state_dict(self, state_dict: Dict[str, Any]) -> Tuple[bool, Set[str]]: 32 | self_keys = set(self.__dict__.keys()) 33 | state_dict_keys = set(state_dict.keys()) 34 | incompatible_keys = self_keys.union(state_dict_keys) - self_keys.intersection(state_dict_keys) 35 | is_equal = (len(incompatible_keys) == 0) 36 | 37 | return is_equal, incompatible_keys 38 | 39 | def _update_step(self, is_step: bool): 40 | if is_step: 41 | self.current += 1 42 | 43 | 44 | class LinearRampUp(RampUp): 45 | def __call__(self, current: Optional[int] = None, is_step: bool = True) -> float: 46 | if current is not None: 47 | self.current = current 48 | 49 | if self.current >= self.length: 50 | ramp_up = 1.0 51 | else: 52 | ramp_up = self.current / self.length 53 | 54 | self._update_step(is_step) 55 | 56 | return ramp_up 57 | 58 | 59 | def get_ramp_up(ramp_up_type: str, length: int, current: int = 0) -> RampUp: 60 | return { 61 | "linear": lambda: LinearRampUp(length, current), 62 | }[ramp_up_type]() 63 | -------------------------------------------------------------------------------- /utils/loggers/log_aggregator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..utils import detorch 4 | 5 | # for type hint 6 | from typing import List, Any, Dict, Optional, Union, FrozenSet 7 | 8 | 9 | class LogAggregator: 10 | supported_reductions = { 11 | "mean": lambda x: np.mean(x).item(), 12 | "sum": lambda x: np.sum(x).item(), 13 | } 14 | 15 | def __init__(self): 16 | self.log_dict: Dict[str, Union[List[Any], Any]] = dict() 17 | self.plot_dict: Dict[str, Any] = dict() 18 | 19 | def __getitem__(self, k): 20 | return self.log_dict[k] 21 | 22 | @property 23 | def log_keys(self) -> FrozenSet[str]: 24 | return frozenset(self.log_dict.keys()) 25 | 26 | @property 27 | def plot_keys(self) -> FrozenSet[str]: 28 | return frozenset(self.plot_dict.keys()) 29 | 30 | def clear(self) -> None: 31 | self.log_dict.clear() 32 | self.plot_dict.clear() 33 | 34 | def add_log(self, log_info: Dict[str, Any]) -> None: 35 | conflict_keys = self.plot_keys.intersection(log_info.keys()) 36 | assert len(conflict_keys) == 0, f"conflicting keys for plot and log data: {conflict_keys}" 37 | 38 | for k, v in log_info.items(): 39 | if k not in self.log_dict: 40 | self.log_dict[k] = list() 41 | 42 | self.log_dict[k].append(detorch(v)) 43 | 44 | def add_plot(self, plot_info: Dict[str, Any]) -> None: 45 | conflict_keys = self.log_keys.intersection(plot_info.keys()) 46 | assert len(conflict_keys) == 0, f"conflicting keys for plot and log data: {conflict_keys}" 47 | 48 | for k, v in plot_info.items(): 49 | self.plot_dict[k] = v 50 | 51 | def aggregate(self, reduction: str = "mean", key_mapping: Optional[Dict[Any, Any]] = None) -> Dict[str, Any]: 52 | assert reduction in self.supported_reductions, f"unsupported reduction method: {reduction}" 53 | 54 | if key_mapping is None: 55 | key_mapping = {} 56 | 57 | reduce_func = self.supported_reductions[reduction] 58 | 59 | output_log_dict = {key_mapping.get(k, k): reduce_func(v) for k, v in self.log_dict.items()} 60 | output_plot_dict = {key_mapping.get(k, k): v for k, v in self.plot_dict.items()} 61 | 62 | output_dict = output_log_dict 63 | output_plot_dict.update(output_plot_dict) 64 | 65 | return output_dict 66 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.backends import cudnn 5 | 6 | from .cli import get_arg_parser, get_args, update_args, args_to_logger_config 7 | from .dataset import get_dataset, repeater 8 | from .file_io import find_checkpoint_path, read_yaml, find_all_files 9 | 10 | from .timing import timing 11 | from .loggers import Logger, LogAggregator 12 | from .utils import str_to_bool, get_device, dict_add_prefix, filter_dict, detorch 13 | 14 | from . import dataset 15 | from . import cli 16 | from . import loggers 17 | from . import metrics 18 | from . import types 19 | 20 | # for type hint 21 | from typing import Optional 22 | from argparse import Namespace 23 | 24 | 25 | def set_random_seed(seed: Optional[int], is_cudnn_deterministic: bool) -> None: 26 | if seed is not None: 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | 31 | if is_cudnn_deterministic: 32 | cudnn.deterministic = True 33 | cudnn.benchmark = False 34 | 35 | 36 | def get_logger(args: Namespace) -> Logger: 37 | logger_type = args.logger 38 | config_dict = args_to_logger_config(args) 39 | 40 | if logger_type == "wandb": 41 | from .loggers import WandbLogger 42 | return WandbLogger( 43 | log_dir=args.log_dir, 44 | config=config_dict, 45 | **args.logger_config_dict) 46 | elif logger_type == "nop": 47 | return Logger( 48 | log_dir=args.log_dir, 49 | config=config_dict) 50 | else: 51 | from .loggers import PrintLogger 52 | return PrintLogger( 53 | log_dir=args.log_dir, 54 | config=config_dict, 55 | is_display_plots=args.is_display_plots) 56 | 57 | 58 | __all__ = [ 59 | # modules 60 | "dataset", 61 | "cli", 62 | "loggers", 63 | "metrics", 64 | "types", 65 | 66 | # classes 67 | "LogAggregator", 68 | "Logger", 69 | 70 | # functions 71 | "get_arg_parser", 72 | "get_args", 73 | "update_args", 74 | "args_to_logger_config", 75 | "get_dataset", 76 | "repeater", 77 | "find_checkpoint_path", 78 | "read_yaml", 79 | "find_all_files", 80 | "timing", 81 | "get_logger", 82 | "str_to_bool", 83 | "get_device", 84 | "dict_add_prefix", 85 | "filter_dict", 86 | "detorch", 87 | "set_random_seed", 88 | ] 89 | -------------------------------------------------------------------------------- /models/augmentation/augmenter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import filters as F 3 | import numpy as np 4 | 5 | from random import random 6 | 7 | # for type hint 8 | from typing import Optional, Tuple, Sequence 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | 13 | class RandomAugmentation(Module): 14 | def __init__(self, augmentation: Module, p: float = 0.5, same_on_batch: bool = False): 15 | super().__init__() 16 | 17 | self.prob = p 18 | self.augmentation = augmentation 19 | self.same_on_batch = same_on_batch 20 | 21 | def forward(self, images: Tensor) -> Tensor: 22 | is_batch = len(images) < 4 23 | 24 | if not is_batch or self.same_on_batch: 25 | if random() <= self.prob: 26 | out = self.augmentation(images) 27 | else: 28 | out = images 29 | else: 30 | out = self.augmentation(images) 31 | batch_size = len(images) 32 | 33 | # get the indices of data which shouldn't apply augmentation 34 | indices = torch.where(torch.rand(batch_size) > self.prob) 35 | out[indices] = images[indices] 36 | 37 | return out 38 | 39 | 40 | class RandomGaussianBlur(Module): 41 | def __init__(self, kernel_size: Tuple[int, int], min_sigma=0.1, max_sigma=2.0, p=0.5) -> None: 42 | super().__init__() 43 | self.kernel_size = tuple(s + 1 if s % 2 == 0 else s for s in kernel_size) # kernel size must be odd 44 | self.min_sigma = min_sigma 45 | self.max_sigma = max_sigma 46 | self.p = p 47 | 48 | def forward(self, img): 49 | if self.p > random(): 50 | sigma = (self.max_sigma - self.min_sigma) * random() + self.min_sigma 51 | return F.gaussian_blur2d(img, kernel_size=self.kernel_size, sigma=(sigma, sigma)) 52 | else: 53 | return img 54 | 55 | 56 | class RandomChoice(Module): 57 | def __init__(self, augmentations: Sequence[Module], size: int = 2, p: Optional[Sequence[float]] = None): 58 | super().__init__() 59 | 60 | assert size <= len(augmentations), f"size = {size} should be <= # aug. = {len(augmentations)}" 61 | 62 | self.augmentations = augmentations 63 | self.size = size 64 | self.p = p 65 | 66 | def forward(self, inputs: Tensor) -> Tensor: 67 | indices = np.random.choice(range(len(self.augmentations)), size=self.size, replace=False, p=self.p) 68 | 69 | outputs = inputs 70 | for i in indices: 71 | outputs = self.augmentations[i](outputs) 72 | 73 | return outputs 74 | -------------------------------------------------------------------------------- /utils/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import functional as F 2 | from PIL import Image 3 | 4 | # for type hint 5 | from typing import Union, Tuple, Callable, List 6 | from collections.abc import Iterable 7 | from torch import Tensor 8 | from PIL.Image import Image as PILImage 9 | 10 | _pil_interpolation_to_str = { 11 | Image.NEAREST: 'PIL.Image.NEAREST', 12 | Image.BILINEAR: 'PIL.Image.BILINEAR', 13 | Image.BICUBIC: 'PIL.Image.BICUBIC', 14 | Image.LANCZOS: 'PIL.Image.LANCZOS', 15 | Image.HAMMING: 'PIL.Image.HAMMING', 16 | Image.BOX: 'PIL.Image.BOX', 17 | } 18 | 19 | 20 | class CenterResizedCrop(object): 21 | """Crops the given PIL Image at the center. Then resize to desired shape. 22 | 23 | First, a largest possible center crop is performed at the center. Then, 24 | the cropped image is resized to the desired output size 25 | 26 | Args: 27 | size (sequence or int): Desired output size of the crop. If size is an 28 | int instead of sequence like (h, w), a square crop (size, size) is 29 | made. 30 | interpolation: Default: PIL.Image.BILINEAR 31 | """ 32 | 33 | def __init__(self, size: Union[int, Tuple[int, int]], interpolation: int = Image.BILINEAR): 34 | assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) 35 | self.size = size 36 | self.interpolation = interpolation 37 | 38 | def __call__(self, img: PILImage): 39 | """ 40 | Args: 41 | img (PIL Image): Image to be cropped. 42 | 43 | Returns: 44 | PIL Image: Cropped image. 45 | """ 46 | image_width, image_height = img.size 47 | output_height, output_width = self.size 48 | 49 | image_ratio = image_width / image_height 50 | output_ratio = output_width / output_height 51 | 52 | if image_ratio >= output_ratio: 53 | crop_height = int(image_height) 54 | crop_width = int(image_height * output_ratio) 55 | else: 56 | crop_height = int(image_width / output_ratio) 57 | crop_width = int(image_width) 58 | 59 | cropped_img = F.center_crop(img, (crop_height, crop_width)) 60 | 61 | return F.resize(cropped_img, size=self.size, interpolation=self.interpolation) 62 | 63 | def __repr__(self): 64 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 65 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 66 | format_string += ', interpolation={0})'.format(interpolate_str) 67 | return format_string 68 | 69 | 70 | __all__ = [ 71 | "CenterResizedCrop", 72 | ] 73 | -------------------------------------------------------------------------------- /utils/dataset/miniimagenet.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | from PIL import Image 3 | from torchvision.datasets import VisionDataset 4 | from torchvision.datasets.utils import check_integrity, download_file_from_google_drive 5 | 6 | from pathlib import Path 7 | 8 | # for type hint 9 | from typing import Optional, Callable 10 | 11 | 12 | class MiniImageNet(VisionDataset): 13 | base_folder = 'mini-imagenet' 14 | gdrive_id = '1EKmnUcpipszzBHBRcXxmejuO4pceD4ht' 15 | file_md5 = '3bda5120eb7353dd88e06de46e680146' 16 | filename = 'mini-imagenet.hdf5' 17 | 18 | def __init__(self, 19 | root: str, 20 | train: bool = True, 21 | transform: Optional[Callable] = None, 22 | target_transform: Optional[Callable] = None, 23 | download: bool = False): 24 | super().__init__(root, transform=transform, target_transform=target_transform) 25 | self.train = train 26 | self.root = root 27 | 28 | if download: 29 | self.download() 30 | 31 | if not self._check_integrity(): 32 | raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it') 33 | 34 | img_key = 'train_img' if self.train else 'test_img' 35 | target_key = 'train_target' if self.train else 'test_target' 36 | with h5py.File(self.data_root / self.filename, "r", swmr=True) as h5_f: 37 | self.data = h5_f[img_key][...] 38 | self.target = h5_f[target_key][...] 39 | 40 | @property 41 | def data_root(self) -> Path: 42 | return Path(self.root) / self.base_folder 43 | 44 | @property 45 | def download_root(self) -> Path: 46 | return self.data_root 47 | 48 | def __len__(self): 49 | return len(self.target) 50 | 51 | def __getitem__(self, idx): 52 | img, target = Image.fromarray(self.data[idx]), self.target[idx] 53 | 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | 57 | if self.target_transform is not None: 58 | target = self.target_transform(target) 59 | 60 | return img, target 61 | 62 | def download(self): 63 | if self._check_integrity(): 64 | print('Files already downloaded and verified') 65 | return 66 | download_file_from_google_drive(file_id=self.gdrive_id, 67 | root=str(self.download_root), 68 | filename=self.filename, 69 | md5=self.file_md5) 70 | 71 | def _check_integrity(self): 72 | return check_integrity(fpath=str(self.download_root / self.filename), md5=self.file_md5) 73 | -------------------------------------------------------------------------------- /models/models/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from .wide_resnet import WideResNet 4 | from .resnet import * 5 | from .ema import EMA, FullEMA 6 | from .utils import unwrap_model 7 | 8 | from . import utils 9 | 10 | # for type hint 11 | from typing import Union 12 | from torch.nn import Module 13 | 14 | 15 | def get_model_size(model: Module) -> int: 16 | return sum(p.numel() for p in model.parameters()) 17 | 18 | 19 | def build_model(model_type: str, in_channels: int, out_channels: int, **kwargs) -> Union[Module, EMA]: 20 | if model_type == "wrn28-8": 21 | model = WideResNet(in_channels=in_channels, 22 | out_channels=out_channels, 23 | depth=28, 24 | widening_factor=8, 25 | base_channels=16, 26 | **kwargs) 27 | 28 | elif model_type == "resnet18": 29 | model = resnet18(pretrained=False, progress=True, num_classes=out_channels, **kwargs) 30 | 31 | elif model_type == "resnet50": 32 | model = resnet50(pretrained=False, progress=True, num_classes=out_channels, **kwargs) 33 | 34 | elif model_type == "wrn28-2": 35 | # default model is WRN 28-2 36 | model = WideResNet(in_channels=in_channels, 37 | out_channels=out_channels, 38 | depth=28, 39 | widening_factor=2, 40 | base_channels=16, 41 | **kwargs) 42 | 43 | else: 44 | raise NotImplementedError(f"\"{model_type}\" is not a supported model type") 45 | 46 | print(f'{model_type} Total params: {(get_model_size(model) / 1e6):.2f}M') 47 | return model 48 | 49 | 50 | def build_ema_model(model: Module, ema_type: str, ema_decay: float) -> Union[Module, EMA]: 51 | if ema_decay == 0: 52 | warnings.warn("EMA decay is 0, turn off EMA") 53 | return model 54 | 55 | elif ema_type == "full": 56 | return FullEMA(model, decay=ema_decay) 57 | 58 | elif ema_type == "default": 59 | return EMA(model, decay=ema_decay) 60 | 61 | else: 62 | raise NotImplementedError(f"\"{ema_type}\" is not a supported EMA type ") 63 | 64 | 65 | __all__ = [ 66 | # modules, 67 | "utils", 68 | 69 | # classes 70 | "WideResNet", 71 | "ResNet", 72 | "EMA", 73 | "FullEMA", 74 | 75 | # functions 76 | "get_model_size", 77 | "build_model", 78 | "build_ema_model", 79 | 80 | # model functions 81 | "resnet18", 82 | "resnet34", 83 | "resnet50", 84 | "resnet101", 85 | "resnet152", 86 | "resnext50_32x4d", 87 | "resnext101_32x8d", 88 | "wide_resnet50_2", 89 | "wide_resnet101_2", 90 | ] 91 | -------------------------------------------------------------------------------- /models/mixmatch/simple_mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .mixmatch_base import MixMatchBase 4 | 5 | # for type hint 6 | from typing import List, Dict 7 | from torch import Tensor 8 | from torch.nn import Module 9 | 10 | 11 | class SimPLE(MixMatchBase): 12 | def __init__(self, 13 | augmenter: Module, 14 | strong_augmenter: Module, 15 | num_classes: int, 16 | temperature: float, 17 | num_augmentations: int, 18 | num_strong_augmentations: int, 19 | is_strong_augment_x: bool, 20 | train_label_guessing: bool): 21 | super(SimPLE, self).__init__(augmenter=augmenter, 22 | strong_augmenter=strong_augmenter, 23 | num_classes=num_classes, 24 | temperature=temperature, 25 | num_augmentations=num_augmentations, 26 | num_strong_augmentations=num_strong_augmentations, 27 | alpha=0., 28 | is_strong_augment_x=is_strong_augment_x, 29 | train_label_guessing=train_label_guessing) 30 | 31 | @property 32 | def total_num_augmentations(self) -> int: 33 | return self.num_strong_augmentations 34 | 35 | @torch.no_grad() 36 | def __call__(self, 37 | x_augmented: Tensor, 38 | x_strong_augmented: Tensor, 39 | x_targets_one_hot: Tensor, 40 | u_augmented: List[Tensor], 41 | u_strong_augmented: List[Tensor], 42 | u_true_targets_one_hot: Tensor, 43 | model: Module, 44 | *args, 45 | **kwargs) -> Dict[str, Tensor]: 46 | if self.is_strong_augment_x: 47 | x_inputs = x_strong_augmented 48 | else: 49 | x_inputs = x_augmented 50 | u_inputs = u_strong_augmented 51 | 52 | # label guessing with weakly augmented data 53 | pseudo_label_dict = self.guess_label(u_inputs=u_augmented, model=model) 54 | 55 | return self.postprocess(x_augmented=x_inputs, 56 | x_targets_one_hot=x_targets_one_hot, 57 | u_augmented=u_inputs, 58 | q_guess=pseudo_label_dict["q_guess"], 59 | u_true_targets_one_hot=u_true_targets_one_hot) 60 | 61 | def mixup(self, 62 | x_augmented: Tensor, 63 | x_targets_one_hot: Tensor, 64 | u_augmented: Tensor, 65 | q_guess: Tensor, 66 | q_true: Tensor) -> Dict[str, Tensor]: 67 | # SimPLE do not use mixup 68 | return dict(x_mixed=x_augmented, 69 | p_mixed=x_targets_one_hot, 70 | u_mixed=u_augmented, 71 | q_mixed=q_guess, 72 | q_true_mixed=q_true) 73 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | # for type hint 4 | from typing import Dict, Union, Optional, KeysView, Set, Any 5 | 6 | 7 | class MetricMode(Enum): 8 | MAX = 0 9 | MIN = 1 10 | 11 | 12 | MetricDictType = Dict[str, Union[str, MetricMode, int, float]] 13 | 14 | 15 | class MetricMonitor: 16 | def __init__(self, metrics: Optional[Dict[str, MetricDictType]] = None): 17 | self.metrics: Dict[str, MetricDictType] = dict() 18 | 19 | if metrics is not None: 20 | self.update(metrics) 21 | 22 | def __getitem__(self, key: str) -> MetricDictType: 23 | return self.metrics[key] 24 | 25 | def __setitem__(self, key: str, value: MetricDictType): 26 | self.track(key=key, 27 | log_key=value["key"], 28 | mode=value["mode"], 29 | best_value=value["best_value"]) 30 | 31 | def __contains__(self, key: str) -> bool: 32 | return key in self.metrics 33 | 34 | def track(self, 35 | key: str, 36 | best_value: Union[float, int], 37 | mode: MetricMode, 38 | log_key: Optional[str] = None, 39 | prefix: Optional[str] = None): 40 | if log_key is None: 41 | log_key = f"best_{key}" 42 | 43 | if prefix is not None: 44 | key = f"{prefix}/{key}" 45 | log_key = f"{prefix}/{log_key}" 46 | 47 | if key in self.metrics: 48 | # if key exist, update best_value 49 | curr_best_value = self.metrics[key]["best_value"] 50 | 51 | if mode == MetricMode.MIN: 52 | best_value = min(best_value, curr_best_value) 53 | else: 54 | best_value = max(best_value, curr_best_value) 55 | 56 | self.metrics[key] = { 57 | "key": log_key, 58 | "best_value": best_value, 59 | "mode": mode, 60 | } 61 | 62 | def keys(self) -> KeysView[str]: 63 | return self.metrics.keys() 64 | 65 | def update(self, metrics: Dict[str, MetricDictType]): 66 | for key, metric in metrics.items(): 67 | self[key] = metric 68 | 69 | def mutual_keys(self, keys: Union[Set[str], KeysView[str]]) -> Set[str]: 70 | return set(keys).intersection(set(self.metrics.keys())) 71 | 72 | def update_metrics(self, log_info: Dict[str, Any]) -> Dict[str, Union[int, float]]: 73 | updated_dict = {} 74 | 75 | for mutual_key in self.mutual_keys(log_info.keys()): 76 | metric_dict = self[mutual_key] 77 | new_value = log_info[mutual_key] 78 | 79 | mode: MetricMode = metric_dict["mode"] 80 | best_value = metric_dict["best_value"] 81 | 82 | if (mode == MetricMode.MAX and new_value > best_value) or \ 83 | (mode == MetricMode.MIN and new_value < best_value): 84 | metric_dict["best_value"] = new_value 85 | 86 | # save updated key and value 87 | updated_dict[mutual_key] = new_value 88 | 89 | return updated_dict 90 | 91 | def state_dict(self) -> Dict[str, Any]: 92 | return {k: v["best_value"] for k, v in self.metrics.items()} 93 | 94 | def load_state_dict(self, state_dict: Dict[str, Any]): 95 | self.update_metrics(state_dict) 96 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .pair_loss import PairLoss 2 | from .utils import bha_coeff, bha_coeff_distance, l2_distance 3 | from .loss import (SupervisedLoss, UnsupervisedLoss, softmax_cross_entropy_loss, bha_coeff_loss, 4 | l2_dist_loss) 5 | 6 | # modules 7 | from . import utils 8 | from . import types 9 | from . import pair_loss 10 | from . import visualization 11 | 12 | # for type hint 13 | from typing import Tuple 14 | from argparse import Namespace 15 | 16 | from .types import SimilarityType, DistanceType, DistanceLossType 17 | 18 | 19 | def get_similarity_metric(similarity_type: str) -> Tuple[SimilarityType, str]: 20 | """ 21 | 22 | Args: 23 | similarity_type: the type of the similarity function 24 | 25 | Returns: similarity function, string indicating the type of the similarity (from [logit, prob, feature]) 26 | 27 | """ 28 | if similarity_type == "bhc": 29 | return bha_coeff, "prob" 30 | 31 | else: 32 | raise NotImplementedError(f"\"{similarity_type}\" is not a supported similarity type") 33 | 34 | 35 | def get_distance_loss_metric(distance_loss_type: str) -> Tuple[DistanceLossType, str]: 36 | """ 37 | 38 | 39 | Args: 40 | distance_loss_type: the type of the distance loss function 41 | 42 | Returns: distance loss function, string indicating the type of the loss (from [logit, prob]) 43 | 44 | """ 45 | if distance_loss_type == "bhc": 46 | return bha_coeff_loss, "logit" 47 | 48 | elif distance_loss_type == "l2": 49 | return l2_dist_loss, "prob" 50 | 51 | elif distance_loss_type == "entropy": 52 | return softmax_cross_entropy_loss, "logit" 53 | 54 | else: 55 | raise NotImplementedError(f"\"{distance_loss_type}\" is not a supported distance loss type") 56 | 57 | 58 | def build_supervised_loss(args: Namespace) -> SupervisedLoss: 59 | return SupervisedLoss(reduction="mean") 60 | 61 | 62 | def build_unsupervised_loss(args: Namespace) -> UnsupervisedLoss: 63 | return UnsupervisedLoss( 64 | loss_type=args.u_loss_type, 65 | loss_thresholded=args.u_loss_thresholded, 66 | confidence_threshold=args.confidence_threshold, 67 | reduction="mean") 68 | 69 | 70 | def build_pair_loss(args: Namespace, reduction: str = "mean") -> PairLoss: 71 | similarity_metric, similarity_type = get_similarity_metric(args.similarity_type) 72 | distance_loss_metric, distance_loss_type = get_distance_loss_metric(args.distance_loss_type) 73 | 74 | return PairLoss( 75 | similarity_metric=similarity_metric, 76 | distance_loss_metric=distance_loss_metric, 77 | confidence_threshold=args.confidence_threshold, 78 | similarity_threshold=args.similarity_threshold, 79 | similarity_type=similarity_type, 80 | distance_loss_type=distance_loss_type, 81 | reduction=reduction) 82 | 83 | 84 | __all__ = [ 85 | # modules 86 | "utils", 87 | "types", 88 | "pair_loss", 89 | "visualization", 90 | 91 | # classes 92 | "SupervisedLoss", 93 | "UnsupervisedLoss", 94 | 95 | # functions 96 | "get_similarity_metric", 97 | "get_distance_loss_metric", 98 | 99 | # loss functions 100 | "build_supervised_loss", 101 | "build_unsupervised_loss", 102 | "build_pair_loss", 103 | ] 104 | -------------------------------------------------------------------------------- /utils/loggers/wandb_logger.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | 3 | import sys 4 | 5 | from .logger import Logger 6 | 7 | # for type hint 8 | from typing import Any, Dict, Optional, Union, Sequence 9 | from argparse import Namespace 10 | 11 | from torch.nn import Module 12 | 13 | 14 | class WandbLogger(Logger): 15 | def __init__(self, 16 | log_dir: str, 17 | config: Union[Namespace, Dict[str, Any]], 18 | name: str, 19 | tags: Sequence[str], 20 | notes: str, 21 | entity: str, 22 | project: str, 23 | mode: str = "offline", 24 | resume: Union[bool, str] = False): 25 | super().__init__( 26 | log_dir=log_dir, 27 | config=config, 28 | log_info_key_map={ 29 | "top1_acc": "mean_acc", 30 | "top5_acc": "mean_top5_acc", 31 | }) 32 | 33 | self.name = name 34 | self.config = config 35 | self.tags = tags 36 | self.notes = notes 37 | self.entity = entity 38 | self.project = project 39 | self.mode = mode 40 | self.resume = resume 41 | 42 | self.is_init = False 43 | 44 | def _init_wandb(self): 45 | if not self.is_init: 46 | wandb.init( 47 | name=self.name, 48 | config=self.config, 49 | project=self.project, 50 | entity=self.entity, 51 | dir=self.log_dir, 52 | tags=self.tags, 53 | notes=self.notes, 54 | mode=self.mode, 55 | resume=self.resume) 56 | # update is_init flag 57 | self.is_init = True 58 | 59 | # update config if resumed 60 | if bool(self.resume): 61 | self.log(self.config, is_config=True) 62 | 63 | def log(self, 64 | log_info: Dict[str, Any], 65 | step: Optional[int] = None, 66 | is_summary: bool = False, 67 | is_config: bool = False, 68 | is_commit: Optional[bool] = None, 69 | prefix: Optional[str] = None, 70 | log_info_override: Optional[Dict[str, Any]] = None, 71 | **kwargs): 72 | if not self.is_init: 73 | self._init_wandb() 74 | 75 | # process log_info 76 | log_info = self.process_log_info(log_info, prefix=prefix, log_info_override=log_info_override) 77 | 78 | if len(log_info) == 0: 79 | return 80 | 81 | if is_summary: 82 | # for log_info_key, log_info_value in log_info.items(): 83 | # wandb.run.summary[log_info_key] = log_info_value 84 | wandb.run.summary.update(log_info) 85 | elif is_config: 86 | wandb.run.config.update(log_info, allow_val_change=True) 87 | else: 88 | log_info, plot_info = self.separate_plot(log_info) 89 | wandb.log(plot_info, commit=False, step=step) 90 | wandb.log(log_info, commit=is_commit, step=step) 91 | 92 | # invoke all log hook functions 93 | self.call_log_hooks(log_info) 94 | 95 | def watch(self, model: Module, **kwargs): 96 | if not self.is_init: 97 | self._init_wandb() 98 | 99 | wandb.watch(model, **kwargs) 100 | 101 | def save(self, output_path: str): 102 | if not self.is_init: 103 | self._init_wandb() 104 | 105 | if sys.platform != "win32": 106 | # TODO: remove the if condition once found a solution 107 | # Windows requires elevated access to 108 | wandb.save(output_path) 109 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from .utils import reduce_tensor, bha_coeff_log_prob, l2_distance 5 | 6 | # for type hint 7 | from torch import Tensor 8 | 9 | 10 | def softmax_cross_entropy_loss(logits: Tensor, targets: Tensor, dim: int = 1, reduction: str = 'mean') -> Tensor: 11 | """ 12 | :param logits: (labeled_batch_size, num_classes) model output of the labeled data 13 | :param targets: (labeled_batch_size, num_classes) labels distribution for the data 14 | :param dim: the dimension or dimensions to reduce 15 | :param reduction: choose from 'mean', 'sum', and 'none' 16 | :return: 17 | """ 18 | loss = -torch.sum(F.log_softmax(logits, dim=dim) * targets, dim=dim) 19 | 20 | return reduce_tensor(loss, reduction) 21 | 22 | 23 | def mse_loss(prob: Tensor, targets: Tensor, reduction: str = 'mean', **kwargs) -> Tensor: 24 | return F.mse_loss(prob, targets, reduction=reduction) 25 | 26 | 27 | def bha_coeff_loss(logits: Tensor, targets: Tensor, dim: int = 1, reduction: str = "none") -> Tensor: 28 | """ 29 | Bhattacharyya coefficient of p and q; the more similar the larger the coefficient 30 | :param logits: (batch_size, num_classes) model predictions of the data 31 | :param targets: (batch_size, num_classes) label prob distribution 32 | :param dim: the dimension or dimensions to reduce 33 | :param reduction: reduction method, choose from "sum", "mean", "none 34 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance 35 | """ 36 | log_probs = F.log_softmax(logits, dim=dim) 37 | log_targets = torch.log(targets) 38 | 39 | # since BC(P,Q) is maximized when P and Q are the same, we minimize 1 - B(P,Q) 40 | return 1. - bha_coeff_log_prob(log_probs, log_targets, dim=dim, reduction=reduction) 41 | 42 | 43 | def l2_dist_loss(probs: Tensor, targets: Tensor, dim: int = 1, reduction: str = "none") -> Tensor: 44 | loss = l2_distance(probs, targets, dim=dim) 45 | 46 | return reduce_tensor(loss, reduction) 47 | 48 | 49 | class SupervisedLoss: 50 | def __init__(self, reduction: str = 'mean'): 51 | self.loss_use_prob = False 52 | self.loss_fn = softmax_cross_entropy_loss 53 | 54 | self.reduction = reduction 55 | 56 | def __call__(self, logits: Tensor, probs: Tensor, targets: Tensor) -> Tensor: 57 | loss_input = probs if self.loss_use_prob else logits 58 | loss = self.loss_fn(loss_input, targets, dim=1, reduction=self.reduction) 59 | 60 | return loss 61 | 62 | 63 | class UnsupervisedLoss: 64 | def __init__(self, 65 | loss_type: str, 66 | loss_thresholded: bool = False, 67 | confidence_threshold: float = 0., 68 | reduction: str = "mean"): 69 | if loss_type in ["entropy", "cross entropy"]: 70 | self.loss_use_prob = False 71 | self.loss_fn = softmax_cross_entropy_loss 72 | else: 73 | self.loss_use_prob = True 74 | self.loss_fn = mse_loss 75 | 76 | self.loss_thresholded = loss_thresholded 77 | self.confidence_threshold = confidence_threshold 78 | self.reduction = reduction 79 | 80 | def __call__(self, logits: Tensor, probs: Tensor, targets: Tensor) -> Tensor: 81 | loss_input = probs if self.loss_use_prob else logits 82 | loss = self.loss_fn(loss_input, targets, dim=1, reduction="none") 83 | 84 | if self.loss_thresholded: 85 | targets_mask = (targets.max(dim=1).values > self.confidence_threshold) 86 | 87 | if len(loss.shape) > 1: 88 | # mse_loss returns a matrix, need to reshape mask 89 | targets_mask = targets_mask.view(-1, 1) 90 | 91 | loss *= targets_mask.float() 92 | 93 | return reduce_tensor(loss, reduction=self.reduction) 94 | -------------------------------------------------------------------------------- /loss/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # for type hint 4 | from typing import Union, Optional, Sequence, Callable 5 | from torch import Tensor 6 | 7 | ScalarType = Union[int, float, bool] 8 | 9 | 10 | def reduce_tensor(inputs: Tensor, reduction: str) -> Tensor: 11 | if reduction == 'mean': 12 | return torch.mean(inputs) 13 | 14 | elif reduction == 'sum': 15 | return torch.sum(inputs) 16 | 17 | return inputs 18 | 19 | 20 | def to_tensor(data: Union[ScalarType, Sequence[ScalarType]], 21 | dtype: Optional[torch.dtype] = None, 22 | device: Optional[Union[torch.device, str]] = None, 23 | tensor_like: Optional[Tensor] = None) -> Tensor: 24 | if tensor_like is not None: 25 | dtype = tensor_like.dtype if dtype is None else dtype 26 | device = tensor_like.device if device is None else device 27 | 28 | return torch.tensor(data, dtype=dtype, device=device) 29 | 30 | 31 | def bha_coeff_log_prob(log_p: Tensor, log_q: Tensor, dim: int = 1, reduction: str = "none") -> Tensor: 32 | """ 33 | Bhattacharyya coefficient of log(p) and log(q); the more similar the larger the coefficient 34 | :param log_p: (batch_size, num_classes) first log prob distribution 35 | :param log_q: (batch_size, num_classes) second log prob distribution 36 | :param dim: the dimension or dimensions to reduce 37 | :param reduction: reduction method, choose from "sum", "mean", "none" 38 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance 39 | """ 40 | # numerical unstable version 41 | # coefficient = torch.sum(torch.sqrt(p * q), dim=dim) 42 | # numerical stable version 43 | coefficient = torch.sum(torch.exp((log_p + log_q) / 2), dim=dim) 44 | 45 | return reduce_tensor(coefficient, reduction) 46 | 47 | 48 | def bha_coeff(p: Tensor, q: Tensor, dim: int = 1, reduction: str = "none") -> Tensor: 49 | """ 50 | Bhattacharyya coefficient of p and q; the more similar the larger the coefficient 51 | :param p: (batch_size, num_classes) first prob distribution 52 | :param q: (batch_size, num_classes) second prob distribution 53 | :param dim: the dimension or dimensions to reduce 54 | :param reduction: reduction method, choose from "sum", "mean", "none" 55 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance 56 | """ 57 | log_p = torch.log(p) 58 | log_q = torch.log(q) 59 | 60 | return bha_coeff_log_prob(log_p, log_q, dim=dim, reduction=reduction) 61 | 62 | 63 | def bha_coeff_distance(p: Tensor, q: Tensor, dim: int = 1, reduction: str = "none") -> Tensor: 64 | """ 65 | Bhattacharyya coefficient of p and q; the more similar the larger the coefficient 66 | :param p: (batch_size, num_classes) model predictions of the data 67 | :param q: (batch_size, num_classes) label prob distribution 68 | :param dim: the dimension or dimensions to reduce 69 | :param reduction: reduction method, choose from "sum", "mean", "none" 70 | :return: Bhattacharyya coefficient of p and q, see https://en.wikipedia.org/wiki/Bhattacharyya_distance 71 | """ 72 | return 1. - bha_coeff(p, q, dim=dim, reduction=reduction) 73 | 74 | 75 | def l2_distance(x: Tensor, y: Tensor, dim: int, **kwargs) -> Tensor: 76 | return torch.norm(x - y, p=2, dim=dim) 77 | 78 | 79 | def pairwise_apply(p: Tensor, q: Tensor, func: Callable, *args, **kwargs) -> Tensor: 80 | """ 81 | 82 | Args: 83 | p: (batch_size, num_classes) first prob distribution 84 | q: (batch_size, num_classes) second prob distribution 85 | func: function to be applied on p and q 86 | 87 | Returns: a matrix of pair-wise result between each element of p and q 88 | 89 | """ 90 | p = p.unsqueeze(-1) 91 | q = q.T.unsqueeze(0) 92 | return func(p, q, *args, **kwargs) 93 | -------------------------------------------------------------------------------- /utils/dataset/cifar10_datamodule.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR10 2 | 3 | from .ssl_datamodule import SSLDataModule 4 | from .datasets import LabeledDataset 5 | from .utils import per_class_random_split 6 | 7 | # for type hint 8 | from typing import Optional, Tuple, List 9 | from torch.utils.data import Dataset 10 | from torchvision.datasets import VisionDataset 11 | 12 | 13 | class CIFAR10DataModule(SSLDataModule): 14 | num_classes: int = 10 15 | 16 | total_train_size: int = 50_000 17 | total_test_size: int = 10_000 18 | 19 | DATASET = CIFAR10 20 | 21 | def __init__(self, 22 | data_dir: str, 23 | labeled_train_size: int, 24 | validation_size: int, 25 | unlabeled_train_size: Optional[int] = None, 26 | dims: Optional[Tuple[int, ...]] = None, 27 | **kwargs): 28 | if dims is None: 29 | dims = (3, 32, 32) 30 | 31 | super(CIFAR10DataModule, self).__init__(dims=dims, **kwargs) 32 | 33 | self.data_dir = data_dir 34 | 35 | self.labeled_train_size = labeled_train_size 36 | self.validation_size = validation_size 37 | if unlabeled_train_size is None: 38 | self.unlabeled_train_size = self.total_train_size - self.validation_size - self.labeled_train_size 39 | else: 40 | self.unlabeled_train_size = unlabeled_train_size 41 | 42 | # dataset stats 43 | # CIFAR-10 mean, std values in CHW 44 | self.dataset_mean = [0.44653091, 0.49139968, 0.48215841] 45 | self.dataset_std = [0.26158784, 0.24703223, 0.24348513] 46 | 47 | def prepare_data(self, *args, **kwargs): 48 | self.DATASET(root=self.data_dir, train=True, download=True) 49 | self.DATASET(root=self.data_dir, train=False, download=True) 50 | 51 | def setup(self, stage: Optional[str] = None): 52 | full_train_set = self.DATASET(root=self.data_dir, train=True) 53 | full_test_set = self.DATASET(root=self.data_dir, train=False) 54 | 55 | self.setup_helper(full_train_set=full_train_set, full_test_set=full_test_set, stage=stage) 56 | 57 | def setup_helper(self, full_train_set: VisionDataset, full_test_set: VisionDataset, stage: Optional[str] = None): 58 | self.test_set = LabeledDataset( 59 | full_test_set, 60 | root=full_test_set.root, 61 | min_size=self.test_min_size, 62 | transform=self.test_transform) 63 | 64 | # get subsets 65 | validation_subset, labeled_train_subset, unlabeled_train_subset = self.split_dataset(full_train_set) 66 | 67 | # convert to dataset 68 | self.validation_set = LabeledDataset( 69 | validation_subset, 70 | root=full_train_set.root, 71 | min_size=self.test_min_size, 72 | transform=self.val_transform) 73 | self.labeled_train_set = LabeledDataset( 74 | labeled_train_subset, 75 | root=full_train_set.root, 76 | min_size=self.train_min_size, 77 | transform=self.train_transform) 78 | self.unlabeled_train_set = LabeledDataset( 79 | unlabeled_train_subset, 80 | root=full_train_set.root, 81 | min_size=self.unlabeled_train_min_size, 82 | transform=self.train_transform) 83 | 84 | assert len(self.validation_set.dataset) == len(validation_subset) == self.validation_size 85 | assert len(self.labeled_train_set.dataset) == len(labeled_train_subset) == self.labeled_train_size 86 | assert len(self.unlabeled_train_set.dataset) == len(unlabeled_train_subset) == self.unlabeled_train_size 87 | assert len(self.test_set.dataset) == self.total_test_size 88 | 89 | def split_dataset(self, dataset: Dataset, **kwargs) -> List[Dataset]: 90 | split_kwargs = dict(lengths=[self.validation_size, self.labeled_train_size, self.unlabeled_train_size], 91 | num_classes=self.num_classes, 92 | uneven_split=False) 93 | 94 | # update split arguments 95 | split_kwargs.update(kwargs) 96 | 97 | return per_class_random_split(dataset, **split_kwargs) 98 | -------------------------------------------------------------------------------- /models/optimization/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.optim.lr_scheduler import LambdaLR, StepLR 3 | 4 | # for type hint 5 | from torch.optim.optimizer import Optimizer 6 | 7 | from .types import LRSchedulerType 8 | 9 | 10 | class CosineDecay: 11 | def __init__(self, max_iter: int, factor: float, min_value: float = 0., max_value: float = 1.): 12 | self.max_iter = max_iter 13 | self.factor = factor 14 | 15 | self.min_value = min_value 16 | self.max_value = max_value 17 | 18 | def __call__(self, curr_step: int) -> float: 19 | max_iter = max(self.max_iter, 1) 20 | curr_step = np.clip(curr_step, 0, max_iter).item() 21 | 22 | return self.compute_cosine_decay(curr_step=curr_step, 23 | max_iter=max_iter, 24 | factor=self.factor, 25 | min_value=self.min_value, 26 | max_value=self.max_value) 27 | 28 | @staticmethod 29 | def compute_cosine_decay(curr_step: int, 30 | max_iter: int, 31 | factor: float, 32 | min_value: float, 33 | max_value: float) -> float: 34 | output = np.cos(factor * np.pi * float(curr_step) / float(max(max_iter, 1))) 35 | 36 | return np.clip(output, min_value, max_value).item() 37 | 38 | 39 | class CosineWarmupDecay(CosineDecay): 40 | def __init__(self, 41 | max_iter: int, 42 | factor: float, 43 | num_warmup_steps: int, 44 | min_value: float = 0., 45 | max_value: float = 1.): 46 | super(CosineWarmupDecay, self).__init__(max_iter=max_iter, 47 | factor=factor, 48 | min_value=min_value, 49 | max_value=max_value) 50 | 51 | self.num_warmup_steps = max(num_warmup_steps, 0) 52 | 53 | def __call__(self, curr_step: int) -> float: 54 | if curr_step < self.num_warmup_steps: 55 | return float(curr_step) / float(max(self.num_warmup_steps, 1)) 56 | else: 57 | max_iter = max(self.max_iter - self.num_warmup_steps, 1) 58 | curr_step = np.clip(curr_step - self.num_warmup_steps, 0, max_iter).item() 59 | 60 | return self.compute_cosine_decay(curr_step=curr_step, 61 | max_iter=max_iter, 62 | factor=self.factor, 63 | min_value=self.min_value, 64 | max_value=self.max_value) 65 | 66 | 67 | def build_lr_scheduler(scheduler_type: str, 68 | optimizer: Optimizer, 69 | max_iter: int, 70 | cosine_factor: float, 71 | step_size: int, 72 | gamma: float, 73 | num_warmup_steps: int, 74 | **kwargs) -> LRSchedulerType: 75 | if scheduler_type == "cosine_decay": 76 | return LambdaLR(optimizer=optimizer, 77 | lr_lambda=CosineDecay(max_iter=max_iter, factor=cosine_factor)) 78 | 79 | elif scheduler_type == "cosine_warmup_decay": 80 | return LambdaLR(optimizer=optimizer, 81 | lr_lambda=CosineWarmupDecay(max_iter=max_iter, 82 | factor=cosine_factor, 83 | num_warmup_steps=num_warmup_steps)) 84 | 85 | elif scheduler_type == "step_decay": 86 | return StepLR(optimizer=optimizer, 87 | step_size=step_size, 88 | gamma=gamma) 89 | 90 | else: 91 | # dummy scheduler 92 | return LambdaLR(optimizer=optimizer, lr_lambda=lambda curr_iter: 1.0) 93 | 94 | 95 | __all__ = [ 96 | # functions 97 | "build_lr_scheduler", 98 | 99 | # classes 100 | "CosineDecay", 101 | ] 102 | -------------------------------------------------------------------------------- /utils/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch import distributed 3 | from PIL import Image 4 | 5 | from .utils import repeater, get_batch 6 | 7 | from .miniimagenet import MiniImageNet 8 | from .domainnet_real import DomainNetReal 9 | from .datasets import LabeledDataset 10 | 11 | from .datamodule import DataModule 12 | from .ssl_datamodule import SSLDataModule 13 | 14 | from .cifar10_datamodule import CIFAR10DataModule 15 | from .cifar100_datamodule import CIFAR100DataModule 16 | from .svhn_datamodule import SVHNDataModule 17 | from .miniimagenet_datamodule import MiniImageNetDataModule 18 | from .domainnet_real_datamodule import DomainNetRealDataModule 19 | 20 | from ..transforms import CenterResizedCrop 21 | 22 | # modules 23 | from . import utils 24 | 25 | # for type hint 26 | from typing import Dict, Callable 27 | from argparse import Namespace 28 | 29 | 30 | def _get_world_size() -> int: 31 | if distributed.is_available() and distributed.is_initialized(): 32 | return distributed.get_world_size() 33 | else: 34 | return 1 35 | 36 | 37 | def _get_dataset_transforms(args: Namespace) -> Dict[str, Callable]: 38 | if args.data_dims is None and args.dataset == "domainnet-real": 39 | dims = (3, 224, 224) 40 | else: 41 | dims = args.data_dims 42 | 43 | if dims is not None: 44 | dims = tuple(dims) 45 | 46 | return dict( 47 | dims=dims, 48 | train_transform=transforms.Compose([ 49 | transforms.RandomResizedCrop( 50 | size=dims[-2:], 51 | scale=(0.08, 1.0), 52 | ratio=(3. / 4, 4. / 3.), 53 | interpolation=Image.BICUBIC), 54 | transforms.ToTensor()]), 55 | val_transform=transforms.Compose([ 56 | CenterResizedCrop(dims[-2:], interpolation=Image.BICUBIC), 57 | transforms.ToTensor()]), 58 | test_transform=transforms.Compose([ 59 | CenterResizedCrop(dims[-2:], interpolation=Image.BICUBIC), 60 | transforms.ToTensor()]), 61 | ) 62 | 63 | else: 64 | return dict( 65 | train_transform=transforms.Compose([transforms.ToTensor()]), 66 | val_transform=transforms.Compose([transforms.ToTensor()]), 67 | test_transform=transforms.Compose([transforms.ToTensor()]), 68 | ) 69 | 70 | 71 | def get_dataset(args: Namespace) -> SSLDataModule: 72 | world_size = _get_world_size() 73 | 74 | kwargs = dict( 75 | data_dir=args.data_dir, 76 | labeled_train_size=args.labeled_train_size, 77 | validation_size=args.validation_size, 78 | train_batch_size=args.train_batch_size, 79 | unlabeled_batch_size=args.unlabeled_train_batch_size, 80 | test_batch_size=args.test_batch_size, 81 | num_workers=args.num_workers, 82 | train_min_size=world_size * args.train_batch_size, 83 | unlabeled_train_min_size=world_size * args.unlabeled_train_batch_size, 84 | test_min_size=world_size * args.test_batch_size, 85 | **_get_dataset_transforms(args), 86 | ) 87 | 88 | if args.dataset == "cifar10": 89 | return CIFAR10DataModule(**kwargs) 90 | 91 | elif args.dataset == "cifar100": 92 | return CIFAR100DataModule(**kwargs) 93 | 94 | elif args.dataset == "svhn": 95 | return SVHNDataModule(**kwargs) 96 | 97 | elif args.dataset == "miniimagenet": 98 | return MiniImageNetDataModule(**kwargs) 99 | 100 | elif args.dataset == "domainnet-real": 101 | return DomainNetRealDataModule(**kwargs) 102 | 103 | else: 104 | raise NotImplementedError(f"\"{args.dataset}\" is not a supported dataset") 105 | 106 | 107 | __all__ = [ 108 | # modules 109 | "utils", 110 | "types", 111 | 112 | # classes 113 | "MiniImageNet", 114 | "DomainNetReal", 115 | "LabeledDataset", 116 | 117 | "DataModule", 118 | "SSLDataModule", 119 | "CIFAR10DataModule", 120 | "CIFAR100DataModule", 121 | "DomainNetRealDataModule", 122 | "MiniImageNetDataModule", 123 | "SVHNDataModule", 124 | 125 | # functions 126 | "repeater", 127 | "get_batch", 128 | "get_dataset", 129 | ] 130 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils import get_args, timing, set_random_seed, get_device, get_dataset 4 | from models import get_augmenter 5 | from simple_estimator import SimPLEEstimator 6 | from ablation_estimator import AblationEstimator 7 | from trainer import Trainer 8 | 9 | # for type hint 10 | from typing import Optional, Type 11 | from argparse import Namespace 12 | from torch.nn import Module 13 | 14 | from utils.dataset import SSLDataModule 15 | 16 | 17 | def get_estimator_type(estimator_type: str) -> Type[SimPLEEstimator]: 18 | if estimator_type == "ablation": 19 | return AblationEstimator 20 | else: 21 | return SimPLEEstimator 22 | 23 | 24 | def get_estimator(args: Namespace, 25 | augmenter: Module, 26 | strong_augmenter: Module, 27 | val_augmenter: Module, 28 | num_classes: int, 29 | in_channels: int, 30 | device: Optional[torch.device], 31 | args_override: Optional[Namespace] = None, 32 | estimator_type: Optional[type(SimPLEEstimator)] = None): 33 | if estimator_type is None: 34 | estimator_type = get_estimator_type(args.estimator) 35 | 36 | if args.checkpoint_path is not None and \ 37 | (bool(args.logger_config_dict.get("resume", False)) or not args.use_pretrain): 38 | estimator = estimator_type.from_checkpoint( 39 | augmenter=augmenter, 40 | strong_augmenter=strong_augmenter, 41 | val_augmenter=val_augmenter, 42 | checkpoint_path=args.checkpoint_path, 43 | num_classes=num_classes, 44 | in_channels=in_channels, 45 | device=device, 46 | args_override=args_override, 47 | recover_train_progress=True, 48 | recover_random_state=True) 49 | else: 50 | estimator = estimator_type( 51 | args, 52 | augmenter=augmenter, 53 | strong_augmenter=strong_augmenter, 54 | val_augmenter=val_augmenter, 55 | num_classes=num_classes, 56 | in_channels=in_channels, 57 | device=device) 58 | 59 | return estimator 60 | 61 | 62 | @timing 63 | def main(args: Namespace, datamodule: Optional[SSLDataModule] = None, device: Optional[torch.device] = None): 64 | if device is None: 65 | device = get_device(args.device) 66 | 67 | if datamodule is None: 68 | datamodule = get_dataset(args) 69 | 70 | # dataset stats 71 | dataset_mean = datamodule.dataset_mean 72 | dataset_std = datamodule.dataset_std 73 | image_size = datamodule.dims[1:] 74 | 75 | # build augmenters 76 | augmenter = get_augmenter(args.augmenter_type, image_size=image_size, 77 | dataset_mean=dataset_mean, dataset_std=dataset_std) 78 | strong_augmenter = get_augmenter(args.strong_augmenter_type, image_size=image_size, 79 | dataset_mean=dataset_mean, dataset_std=dataset_std) 80 | val_augmenter = get_augmenter("validation", image_size=image_size, dataset_mean=dataset_mean, 81 | dataset_std=dataset_std) 82 | 83 | estimator = get_estimator(args, 84 | augmenter=augmenter, 85 | strong_augmenter=strong_augmenter, 86 | val_augmenter=val_augmenter, 87 | num_classes=datamodule.num_classes, 88 | in_channels=datamodule.dims[0], 89 | device=device, 90 | args_override=args) 91 | trainer = Trainer(estimator, datamodule=datamodule) 92 | 93 | # training 94 | trainer.fit() 95 | 96 | # load the best state 97 | best_checkpoint_path = trainer.saver.find_best_checkpoint_path(ignore_absolute_best=False) 98 | 99 | if best_checkpoint_path is not None: 100 | best_checkpoint = torch.load(str(best_checkpoint_path), map_location=device) 101 | trainer.load_checkpoint(best_checkpoint) 102 | 103 | # evaluation 104 | trainer.test() 105 | 106 | 107 | if __name__ == '__main__': 108 | parsed_args = get_args() 109 | 110 | # fix random seed 111 | set_random_seed(parsed_args.seed, is_cudnn_deterministic=parsed_args.debug_mode) 112 | 113 | main(parsed_args) 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimPLE 2 | 3 | The code for the 4 | paper: "[SimPLE: Similar Pseudo Label Exploitation for Semi-Supervised Classification](https://arxiv.org/abs/2103.16725)" 5 | by [Zijian Hu*](https://www.zijianhu.com/), 6 | [Zhengyu Yang*](https://zhengyuyang.com/), 7 | [Xuefeng Hu](https://xuefenghu.me/), and [Ram Nevatia](https://sites.usc.edu/iris-cvlab/professor-ram-nevatia/). 8 | 9 | ## Abstract 10 | 11 | Pair Loss 12 | 13 | A common classification task situation is where one has a large amount of data available for training, but only a small 14 | portion is annotated with class labels. The goal of semi-supervised training, in this context, is to improve 15 | classification accuracy by leverage information not only from labeled data but also from a large amount of unlabeled 16 | data. Recent works have developed significant improvements by exploring the consistency constrain between differently 17 | augmented labeled and unlabeled data. Following this path, we propose a novel unsupervised objective that focuses on the 18 | less studied relationship between the high confidence unlabeled data that are similar to each other. The new proposed 19 | Pair Loss minimizes the statistical distance between high confidence pseudo labels with similarity above a certain 20 | threshold. Combining the Pair Loss with the techniques developed by the MixMatch family, our proposed SimPLE algorithm 21 | shows significant performance gains over previous algorithms on CIFAR-100 and Mini-ImageNet, and is on par with the 22 | state-of-the-art methods on CIFAR-10 and SVHN. Furthermore, SimPLE also outperforms the state-of-the-art methods in the 23 | transfer learning setting, where models are initialized by the weights pre-trained on ImageNet or DomainNet-Real. 24 | 25 | ## News 26 | 27 | [11/8/2021]: bugfix that could result in incorrect data loading in distributed training
28 | [9/20/2021]: add data_dims option for data resizing
29 | [8/31/2021]: update the code base for easier extension
30 | [6/22/2021]: add Mini-ImageNet example
31 | [6/2/2021]: add animations and fix broken link in README
32 | [5/30/2021]: initial release 33 | 34 | ## Requirements 35 | 36 | *see [requirements.txt](requirements.txt) for detail* 37 | 38 | - Python 3.6 or newer 39 | - [PyTorch](https://pytorch.org/) 1.6.0 or newer 40 | - [torchvision](https://pytorch.org/docs/stable/torchvision/index.html) 0.7.0 or newer 41 | - [kornia](https://kornia.readthedocs.io/en/latest/augmentation.html) 0.5.0 or newer 42 | - numpy 43 | - scikit-learn 44 | - [plotly](https://plotly.com/python/) 4.0.0 or newer 45 | - wandb 0.9.0 or newer (**optional**, required for logging to [Weights & Biases](https://www.wandb.com/) 46 | see [utils.loggers.WandbLogger](utils/loggers.py) for detail) 47 | 48 | ### Recommended versions 49 | 50 | |Python|PyTorch|torchvision|kornia| 51 | | --- | --- | --- | --- | 52 | |3.8.5|1.6.0|0.7.0|0.5.0| 53 | 54 | ## Setup 55 | 56 | ### Install dependencies 57 | 58 | using pip: 59 | 60 | ```shell 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | or using conda: 65 | 66 | ```shell 67 | conda env create -f environment.yaml 68 | ``` 69 | 70 | ## Running 71 | 72 | ### Example 73 | 74 | To replicate Mini-ImageNet results 75 | 76 | ```shell 77 | CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES="0" \ 78 | python main.py \ 79 | @runs/miniimagenet_args.txt 80 | ``` 81 | 82 | To replicate CIFAR-10 results 83 | 84 | ```shell 85 | CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES="0" \ 86 | python main.py \ 87 | @runs/cifar10_args.txt 88 | ``` 89 | 90 | To replicate CIFAR-100 result (with distributed training) 91 | 92 | ```shell 93 | CUDA_DEVICE_ORDER="PCI_BUS_ID" CUDA_VISIBLE_DEVICES="0,1" \ 94 | python -m torch.distributed.launch \ 95 | --nproc_per_node=2 main_ddp.py \ 96 | @runs/cifar100_args.txt \ 97 | --num-epochs 2048 \ 98 | --num-step-per-epoch 512 99 | ``` 100 | 101 | ## Citation 102 | 103 | ```bibtex 104 | @InProceedings{Hu-2020-SimPLE, 105 | author = {{Hu*}, Zijian and {Yang*}, Zhengyu and Hu, Xuefeng and Nevaita, Ram}, 106 | title = {{SimPLE}: {S}imilar {P}seudo {L}abel {E}xploitation for {S}emi-{S}upervised {C}lassification}, 107 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 108 | month = {June}, 109 | year = {2021}, 110 | url = {https://arxiv.org/abs/2103.16725}, 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /models/models/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from copy import deepcopy 5 | from collections import OrderedDict 6 | import warnings 7 | from contextlib import contextmanager 8 | 9 | from .utils import unwrap_model 10 | 11 | # for type hint 12 | from torch import Tensor 13 | from typing import Dict 14 | 15 | 16 | class EMA(nn.Module): 17 | def __init__(self, model: nn.Module, decay: float): 18 | # adapted from https://fyubang.com/2019/06/01/ema/ 19 | super().__init__() 20 | self.decay = decay 21 | 22 | self.model = model 23 | self.shadow = deepcopy(self.model) 24 | 25 | for param in self.shadow.parameters(): 26 | param.detach_() 27 | 28 | @torch.no_grad() 29 | def update(self): 30 | if not self.training: 31 | warnings.warn("EMA update should only be called during training") 32 | return 33 | 34 | model_params = OrderedDict(self.model.named_parameters()) 35 | shadow_params = OrderedDict(self.shadow.named_parameters()) 36 | 37 | # check if both model contains the same set of keys 38 | assert model_params.keys() == shadow_params.keys() 39 | 40 | for name, param in model_params.items(): 41 | # see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 42 | # shadow_variable -= (1 - decay) * (shadow_variable - variable) 43 | shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param)) 44 | 45 | model_buffers = OrderedDict(self.model.named_buffers()) 46 | shadow_buffers = OrderedDict(self.shadow.named_buffers()) 47 | 48 | # check if both model contains the same set of keys 49 | assert model_buffers.keys() == shadow_buffers.keys() 50 | 51 | for name, buffer in model_buffers.items(): 52 | # buffers are copied 53 | shadow_buffers[name].copy_(buffer) 54 | 55 | def forward(self, inputs: Tensor, return_feature: bool = False) -> Tensor: 56 | if self.training: 57 | return self.model(inputs, return_feature) 58 | else: 59 | return self.shadow(inputs, return_feature) 60 | 61 | @contextmanager 62 | def data_parallel_switch(self): 63 | model = self.model 64 | shadow = self.shadow 65 | 66 | self.model = unwrap_model(self.model) 67 | self.shadow = unwrap_model(self.shadow) 68 | 69 | try: 70 | yield 71 | finally: 72 | self.model = model 73 | self.shadow = shadow 74 | 75 | def state_dict(self, destination=None, prefix='', keep_vars=False): 76 | with self.data_parallel_switch(): 77 | return super().state_dict(destination, prefix, keep_vars) 78 | 79 | def load_state_dict(self, state_dict: Dict[str, Tensor], strict=True): 80 | with self.data_parallel_switch(): 81 | super().load_state_dict(state_dict, strict) 82 | 83 | 84 | class FullEMA(EMA): 85 | excluded_param_suffix = ["num_batches_tracked"] 86 | 87 | def __init__(self, model: nn.Module, decay: float): 88 | super().__init__(model, decay) 89 | 90 | self._excluded_param_names = set() 91 | 92 | @torch.no_grad() 93 | def update(self): 94 | if not self.training: 95 | warnings.warn("EMA update should only be called during training") 96 | return 97 | 98 | model_state_dict = self.model.state_dict() 99 | shadow_state_dict = self.shadow.state_dict() 100 | 101 | # check if both model contains the same set of keys 102 | assert model_state_dict.keys() == shadow_state_dict.keys() 103 | 104 | for name, param in model_state_dict.items(): 105 | if name not in self.excluded_param_names: 106 | shadow_state_dict[name].sub_((1. - self.decay) * (shadow_state_dict[name] - param)) 107 | else: 108 | shadow_state_dict[name].copy_(param) 109 | 110 | @staticmethod 111 | def is_ema_exclude(param_name: str) -> bool: 112 | return any([param_name.endswith(suffix) for suffix in FullEMA.excluded_param_suffix]) 113 | 114 | @property 115 | def excluded_param_names(self): 116 | if len(self._excluded_param_names) == 0: 117 | for name, param in self.model.state_dict().items(): 118 | if self.is_ema_exclude(name): 119 | self._excluded_param_names.add(name) 120 | 121 | return self._excluded_param_names 122 | -------------------------------------------------------------------------------- /loss/pair_loss/pair_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from .utils import get_pair_indices 5 | 6 | # for type hint 7 | from typing import Optional 8 | from torch import Tensor 9 | 10 | from ..types import SimilarityType, DistanceLossType 11 | 12 | 13 | class PairLoss: 14 | def __init__(self, 15 | similarity_metric: SimilarityType, 16 | distance_loss_metric: DistanceLossType, 17 | confidence_threshold: float, 18 | similarity_threshold: float, 19 | similarity_type: str, 20 | distance_loss_type: str, 21 | reduction: str = "mean"): 22 | self.confidence_threshold = confidence_threshold 23 | self.similarity_threshold = similarity_threshold 24 | 25 | self.similarity_type = similarity_type 26 | self.distance_loss_type = distance_loss_type 27 | 28 | self.reduction = reduction 29 | 30 | self.similarity_metric = similarity_metric 31 | self.distance_loss_metric = distance_loss_metric 32 | 33 | def __call__(self, 34 | logits: Tensor, 35 | probs: Tensor, 36 | targets: Tensor, 37 | *args, 38 | indices: Optional[Tensor] = None, 39 | **kwargs) -> Tensor: 40 | """ 41 | 42 | Args: 43 | logits: (batch_size, num_classes) predictions of batch data 44 | probs: (batch_size, num_classes) softmax probs logits 45 | targets: (batch_size, num_classes) one-hot labels 46 | 47 | Returns: Pair loss value as a Tensor. 48 | 49 | """ 50 | if indices is None: 51 | indices = get_pair_indices(targets, ordered_pair=True) 52 | total_size = len(indices) // 2 53 | 54 | i_indices, j_indices = indices[:, 0], indices[:, 1] 55 | targets_max_prob = targets.max(dim=1).values 56 | 57 | return self.compute_loss(logits_j=logits[j_indices], 58 | probs_j=probs[j_indices], 59 | targets_i=targets[i_indices], 60 | targets_j=targets[j_indices], 61 | targets_i_max_prob=targets_max_prob[i_indices], 62 | total_size=total_size) 63 | 64 | def compute_loss(self, 65 | logits_j: Tensor, 66 | probs_j: Tensor, 67 | targets_i: Tensor, 68 | targets_j: Tensor, 69 | targets_i_max_prob: Tensor, 70 | total_size: int): 71 | # conf_mask should not track gradient 72 | conf_mask = (targets_i_max_prob > self.confidence_threshold).detach().float() 73 | 74 | similarities: Tensor = self.get_similarity(targets_i=targets_i, 75 | targets_j=targets_j, 76 | dim=1) 77 | # sim_mask should not track gradient 78 | sim_mask = F.threshold(similarities, self.similarity_threshold, 0).detach() 79 | 80 | distance = self.get_distance_loss(logits=logits_j, 81 | probs=probs_j, 82 | targets=targets_i, 83 | dim=1, 84 | reduction='none') 85 | 86 | loss = conf_mask * sim_mask * distance 87 | 88 | if self.reduction == "mean": 89 | loss = torch.sum(loss) / total_size 90 | elif self.reduction == "sum": 91 | loss = torch.sum(loss) 92 | 93 | return loss 94 | 95 | def get_similarity(self, 96 | targets_i: Tensor, 97 | targets_j: Tensor, 98 | *args, 99 | **kwargs) -> Tensor: 100 | x, y = targets_i, targets_j 101 | 102 | return self.similarity_metric(x, y, *args, **kwargs) 103 | 104 | def get_distance_loss(self, 105 | logits: Tensor, 106 | probs: Tensor, 107 | targets: Tensor, 108 | *args, 109 | **kwargs) -> Tensor: 110 | if self.distance_loss_type == "prob": 111 | x, y = probs, targets 112 | else: 113 | x, y = logits, targets 114 | 115 | return self.distance_loss_metric(x, y, *args, **kwargs) 116 | -------------------------------------------------------------------------------- /models/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from kornia import augmentation as K 4 | from kornia import filters as F 5 | from torchvision import transforms 6 | 7 | from .augmenter import RandomAugmentation 8 | from .randaugment import RandAugmentNS 9 | 10 | # for type hint 11 | from typing import List, Tuple, Union, Callable 12 | from torch import Tensor 13 | from torch.nn import Module 14 | from PIL.Image import Image as PILImage 15 | 16 | DatasetStatType = List[float] 17 | ImageSizeType = Tuple[int, int] 18 | PaddingInputType = Union[float, Tuple[float, float], Tuple[float, float, float, float]] 19 | ImageType = Union[Tensor, PILImage] 20 | 21 | 22 | def get_augmenter(augmenter_type: str, 23 | image_size: ImageSizeType, 24 | dataset_mean: DatasetStatType, 25 | dataset_std: DatasetStatType, 26 | padding: PaddingInputType = 1. / 8., 27 | pad_if_needed: bool = False, 28 | subset_size: int = 2) -> Union[Module, Callable]: 29 | """ 30 | 31 | Args: 32 | augmenter_type: augmenter type 33 | image_size: (height, width) image size 34 | dataset_mean: dataset mean value in CHW 35 | dataset_std: dataset standard deviation in CHW 36 | padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided, 37 | it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is 38 | used to pad left/right, top/bottom borders, respectively. 39 | pad_if_needed: bool flag for RandomCrop "pad_if_needed" option 40 | subset_size: number of augmentations used in subset 41 | 42 | Returns: nn.Module for Kornia augmentation or Callable for torchvision transform 43 | 44 | """ 45 | if not isinstance(padding, tuple): 46 | assert isinstance(padding, float) 47 | padding = (padding, padding, padding, padding) 48 | 49 | assert len(padding) == 2 or len(padding) == 4 50 | if len(padding) == 2: 51 | # padding of length 2 is used to pad left/right, top/bottom borders, respectively 52 | # padding of length 4 is used to pad left, top, right, bottom borders respectively 53 | padding = (padding[0], padding[1], padding[0], padding[1]) 54 | 55 | # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders 56 | padding = ( 57 | int(image_size[1] * padding[0]), 58 | int(image_size[0] * padding[1]), 59 | int(image_size[1] * padding[2]), 60 | int(image_size[0] * padding[3]) 61 | ) 62 | 63 | augmenter_type = augmenter_type.strip().lower() 64 | 65 | if augmenter_type == "simple": 66 | return nn.Sequential( 67 | K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, 68 | padding_mode='reflect'), 69 | K.RandomHorizontalFlip(p=0.5), 70 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), 71 | std=torch.tensor(dataset_std, dtype=torch.float32)), 72 | ) 73 | 74 | elif augmenter_type == "fixed": 75 | return nn.Sequential( 76 | K.RandomHorizontalFlip(p=0.5), 77 | # K.RandomVerticalFlip(p=0.2), 78 | K.RandomResizedCrop(size=image_size, scale=(0.8, 1.0), ratio=(1., 1.)), 79 | RandomAugmentation( 80 | p=0.5, 81 | augmentation=F.GaussianBlur2d( 82 | kernel_size=(3, 3), 83 | sigma=(1.5, 1.5), 84 | border_type='constant' 85 | ) 86 | ), 87 | K.ColorJitter(contrast=(0.75, 1.5)), 88 | # additive Gaussian noise 89 | K.RandomErasing(p=0.1), 90 | # Multiply 91 | K.RandomAffine( 92 | degrees=(-25., 25.), 93 | translate=(0.2, 0.2), 94 | scale=(0.8, 1.2), 95 | shear=(-8., 8.) 96 | ), 97 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), 98 | std=torch.tensor(dataset_std, dtype=torch.float32)), 99 | ) 100 | 101 | elif augmenter_type in ["validation", "test"]: 102 | return nn.Sequential( 103 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), 104 | std=torch.tensor(dataset_std, dtype=torch.float32)), 105 | ) 106 | 107 | elif augmenter_type == "randaugment": 108 | return nn.Sequential( 109 | K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, 110 | padding_mode='reflect'), 111 | K.RandomHorizontalFlip(p=0.5), 112 | RandAugmentNS(n=subset_size, m=10), 113 | K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), 114 | std=torch.tensor(dataset_std, dtype=torch.float32)), 115 | ) 116 | 117 | else: 118 | raise NotImplementedError(f"\"{augmenter_type}\" is not a supported augmenter type") 119 | 120 | 121 | __all__ = [ 122 | # modules 123 | # classes 124 | # functions 125 | "get_augmenter", 126 | ] 127 | -------------------------------------------------------------------------------- /utils/loggers/logger.py: -------------------------------------------------------------------------------- 1 | from .log_aggregator import LogAggregator 2 | from ..utils import dict_add_prefix 3 | 4 | # for type hint 5 | from typing import Any, Dict, Optional, Union, Tuple 6 | from argparse import Namespace 7 | 8 | from torch.nn import Module 9 | from plotly.graph_objects import Figure 10 | from wandb import Histogram 11 | 12 | 13 | class Logger: 14 | def __init__(self, 15 | log_dir: str, 16 | config: Union[Namespace, Dict[str, Any]], 17 | *args, 18 | log_info_key_map: Optional[Dict[str, str]] = None, 19 | **kwargs): 20 | self.log_dir = log_dir 21 | self.config = config 22 | 23 | # hook functions 24 | self.log_hooks = [] 25 | 26 | self.metric_smooth_record: Dict[str, Dict[str, Union[str, Optional[float]]]] = { 27 | "train/mean_acc": { 28 | "key": "train/smoothed_acc", 29 | "value": None, 30 | }, 31 | "unlabeled/mean_acc": { 32 | "key": "unlabeled/smoothed_acc", 33 | "value": None, 34 | }, 35 | "validation/mean_acc": { 36 | "key": "validation/smoothed_acc", 37 | "value": None, 38 | }, 39 | "test/mean_acc": { 40 | "key": "test/smoothed_acc", 41 | "value": None, 42 | }, 43 | } 44 | 45 | self.smoothing_weight = 0.9 46 | 47 | # init key map 48 | if log_info_key_map is None: 49 | log_info_key_map = dict() 50 | 51 | self.log_info_key_map = log_info_key_map 52 | 53 | # init log accumulator 54 | self.log_aggregator = LogAggregator() 55 | 56 | def log(self, log_info: Dict[str, Any], *args, **kwargs): 57 | pass 58 | 59 | def watch(self, model: Module, *args, **kwargs): 60 | pass 61 | 62 | def save(self, output_path: str): 63 | pass 64 | 65 | def process_log_info(self, 66 | log_info: Dict[str, Any], 67 | *args, 68 | prefix: Optional[str] = None, 69 | log_info_override: Optional[Dict[str, Any]] = None, 70 | **kwargs) -> Dict[str, Any]: 71 | if log_info_override is None: 72 | log_info_override = {} 73 | 74 | # create shallow copy 75 | log_info = dict(log_info) 76 | 77 | # apply override 78 | log_info.update(log_info_override) 79 | 80 | # update keys based on key map 81 | for key in list(log_info.keys()): 82 | if key in self.log_info_key_map: 83 | new_key = self.log_info_key_map[key] 84 | log_info[new_key] = log_info.pop(key) 85 | 86 | if bool(prefix): 87 | # prepend prefix to info_dict keys 88 | log_info = dict_add_prefix(log_info, prefix) 89 | 90 | # apply smoothing 91 | smoothed_metrics = self.smooth_metrics(log_info) 92 | log_info.update(smoothed_metrics) 93 | 94 | return log_info 95 | 96 | def register_log_hook(self, func, *args, **kwargs): 97 | self.log_hooks.append([func, args, kwargs]) 98 | 99 | def call_log_hooks(self, log_info: Dict[str, Any]): 100 | for (func, args, kwargs) in self.log_hooks: 101 | func(log_info=log_info, *args, **kwargs) 102 | 103 | @staticmethod 104 | def separate_plot(input_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: 105 | log_info: Dict[str, Any] = dict() 106 | plot_info: Dict[str, Any] = dict() 107 | 108 | for k, v in input_dict.items(): 109 | if isinstance(v, Figure) or isinstance(v, Histogram): 110 | plot_info[k] = v 111 | else: 112 | log_info[k] = v 113 | 114 | return log_info, plot_info 115 | 116 | def smooth_metrics(self, log_info: Dict[str, Any]) -> Dict[str, Any]: 117 | # see https://stackoverflow.com/a/49357445/5838091 118 | smoothed_metrics = dict() 119 | 120 | for key, value in log_info.items(): 121 | if key not in self.metric_smooth_record: 122 | continue 123 | 124 | smoothed_metric_dict = self.metric_smooth_record[key] 125 | if smoothed_metric_dict["value"] is None: 126 | smoothed_metric_dict["value"] = value 127 | else: 128 | smoothed_metric_dict["value"] = smoothed_metric_dict["value"] * self.smoothing_weight + \ 129 | (1 - self.smoothing_weight) * value 130 | 131 | smoothed_metrics[smoothed_metric_dict["key"]] = smoothed_metric_dict["value"] 132 | 133 | return smoothed_metrics 134 | 135 | def accumulate_log(self, log_info: Optional[Dict[str, Any]] = None, plot_info: Optional[Dict[str, Any]] = None): 136 | if log_info is not None: 137 | self.log_aggregator.add_log(log_info) 138 | 139 | if plot_info is not None: 140 | self.log_aggregator.add_plot(plot_info) 141 | 142 | def aggregate_log(self, reduction: str = "mean") -> Dict[str, Any]: 143 | return self.log_aggregator.aggregate(reduction=reduction) 144 | 145 | def reset_aggregator(self): 146 | self.log_aggregator.clear() 147 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import onnx 3 | import numpy as np 4 | from .models.utils import * 5 | 6 | # for type hint 7 | from typing import Optional, Sequence, List, Generator, Any, Union, Set, Tuple, Dict 8 | from torch import Tensor 9 | from torch.nn import Module, Parameter 10 | 11 | 12 | @torch.no_grad() 13 | def get_accuracy(output: Tensor, target: Tensor, top_k: Sequence[int] = (1,)) -> List[Tensor]: 14 | # see https://discuss.pytorch.org/t/imagenet-example-accuracy-calculation/7840 15 | max_k = max(top_k) 16 | batch_size = target.size(0) 17 | 18 | _, pred = output.topk(max_k, dim=1, largest=True, sorted=True) 19 | correct = pred.eq(target.view(-1, 1).expand_as(pred)) 20 | 21 | res = [] 22 | for k in top_k: 23 | correct_k = (correct[:, :k].sum(dim=1, keepdim=False) > 0).float() 24 | res.append(correct_k.sum() / batch_size) 25 | return res 26 | 27 | 28 | def interleave_offsets(batch_size: int, num_unlabeled: int) -> List[int]: 29 | # TODO: scrutiny 30 | groups = [batch_size // (num_unlabeled + 1)] * (num_unlabeled + 1) 31 | for x in range(batch_size - sum(groups)): 32 | groups[-x - 1] += 1 33 | offsets = [0] + np.cumsum(groups).tolist() 34 | assert offsets[-1] == batch_size 35 | return offsets 36 | 37 | 38 | def interleave(xy: Sequence[Tensor], batch_size: int) -> List[Tensor]: 39 | # TODO: scrutiny 40 | num_unlabeled = len(xy) - 1 41 | offsets = interleave_offsets(batch_size, num_unlabeled) 42 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(num_unlabeled + 1)] for v in xy] 43 | for i in range(1, num_unlabeled + 1): 44 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 45 | return [torch.cat(v, dim=0) for v in xy] 46 | 47 | 48 | def get_gradient_norm(model: Module, grad_enabled: bool = False) -> Tensor: 49 | with torch.set_grad_enabled(grad_enabled): 50 | return sum([(p.grad.detach() ** 2).sum() for p in model.parameters() if p.grad is not None]) 51 | 52 | 53 | def get_weight_norm(model: Module, grad_enabled: bool = False) -> Tensor: 54 | with torch.set_grad_enabled(grad_enabled): 55 | return sum([(p.detach() ** 2).sum() for p in model.parameters() if p.data is not None]) 56 | 57 | 58 | def split_classifier_params(model: Module, classifier_prefix: Union[str, Set[str]]) \ 59 | -> Tuple[List[Parameter], List[Parameter]]: 60 | if not isinstance(classifier_prefix, Set): 61 | classifier_prefix = {classifier_prefix} 62 | 63 | # build tuple for multiple prefix matching 64 | classifier_prefix = tuple(sorted(f"{prefix}." for prefix in classifier_prefix)) 65 | 66 | embedder_weights = [] 67 | classifier_weights = [] 68 | 69 | for k, v in model.named_parameters(): 70 | if k.startswith(classifier_prefix): 71 | classifier_weights.append(v) 72 | else: 73 | embedder_weights.append(v) 74 | 75 | return embedder_weights, classifier_weights 76 | 77 | 78 | def set_model_mode(model: Module, mode: Optional[bool]) -> Generator[Any, Any, None]: 79 | """ 80 | A context manager to temporarily set the training mode of ‘model’ to ‘mode’, resetting it when we exit the 81 | with-block. A no-op if mode is None 82 | 83 | Args: 84 | model: the model 85 | mode: a bool or None 86 | 87 | Returns: 88 | 89 | """ 90 | if hasattr(onnx, "select_model_mode_for_export"): 91 | # In PyTorch 1.6+, set_training is changed to select_model_mode_for_export 92 | return onnx.select_model_mode_for_export(model, mode) 93 | else: 94 | return onnx.set_training(model, mode) 95 | 96 | 97 | def consume_prefix_in_state_dict_if_present(state_dict: Dict[str, Any], prefix: str): 98 | r"""copied from https://github.com/pytorch/pytorch/blob/255494c2aa1fcee7e605a6905be72e5b8ccf4646/torch/nn/modules/utils.py#L37-L67 99 | 100 | Strip the prefix in state_dict, if any. 101 | ..note:: 102 | Given a `state_dict` from a DP/DDP model, a local model can load it by applying 103 | `consume_prefix_in_state_dict_if_present(state_dict, "module.")` before calling 104 | :meth:`torch.nn.Module.load_state_dict`. 105 | Args: 106 | state_dict (OrderedDict): a state-dict to be loaded to the model. 107 | prefix (str): prefix. 108 | """ 109 | keys = sorted(state_dict.keys()) 110 | for key in keys: 111 | if key.startswith(prefix): 112 | newkey = key[len(prefix):] 113 | state_dict[newkey] = state_dict.pop(key) 114 | 115 | # also strip the prefix in metadata if any. 116 | if "_metadata" in state_dict: 117 | metadata = state_dict["_metadata"] 118 | for key in list(metadata.keys()): 119 | # for the metadata dict, the key can be: 120 | # '': for the DDP module, which we want to remove. 121 | # 'module': for the actual model. 122 | # 'module.xx.xx': for the rest. 123 | 124 | if len(key) == 0: 125 | continue 126 | newkey = key[len(prefix):] 127 | metadata[newkey] = metadata.pop(key) 128 | 129 | 130 | __all__ = [ 131 | "get_accuracy", 132 | "interleave", 133 | "get_gradient_norm", 134 | "get_weight_norm", 135 | "split_classifier_params", 136 | "set_model_mode", 137 | "unwrap_model", 138 | "consume_prefix_in_state_dict_if_present", 139 | ] 140 | -------------------------------------------------------------------------------- /utils/dataset/domainnet_real.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | from torchvision.datasets.utils import check_integrity, download_url, extract_archive 3 | 4 | from pathlib import Path 5 | from typing import NamedTuple 6 | import re 7 | import os 8 | import shutil 9 | 10 | from .utils import get_directory_size 11 | 12 | # for type hint 13 | from typing import Optional 14 | 15 | FileMeta = NamedTuple("FileMeta", [("filename", str), ("url", str), ("md5", Optional[str])]) 16 | 17 | 18 | class DomainNetReal(ImageFolder): 19 | base_folder = 'domainnet-real' 20 | 21 | data_file_meta = FileMeta( 22 | filename="domainnet-real.zip", 23 | url="http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 24 | md5="dcc47055e8935767784b7162e7c7cca6") 25 | 26 | train_label_file_meta = FileMeta( 27 | filename="domainnet-real_train.txt", 28 | url="http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/real_train.txt", 29 | md5="8ebf02c2075fadd564705f0dc7cd6291") 30 | 31 | test_label_file_meta = FileMeta( 32 | filename="domainnet-real_test.txt", 33 | url="http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/real_test.txt", 34 | md5="6098816791c3ebed543c71ffa11b9054") 35 | 36 | label_file_pattern = re.compile(r"(.*) (\d+)") 37 | 38 | # extracted file size in Bytes 39 | DATA_FOLDER_SIZE = 6_234_186_058 40 | TRAIN_FOLDER_SIZE = 4_301_431_405 41 | TEST_FOLDER_SIZE = 1_860_180_803 42 | 43 | def __init__(self, 44 | root: str, 45 | train: bool = True, 46 | download: bool = False, 47 | **kwargs): 48 | self.root = root 49 | self.train = train 50 | 51 | if download: 52 | self.download() 53 | 54 | if not self._check_integrity(): 55 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 56 | 57 | # parse dataset 58 | self.parse_archives() 59 | 60 | super(DomainNetReal, self).__init__(root=str(self.split_folder), **kwargs) 61 | self.root = root 62 | 63 | @property 64 | def data_root(self) -> Path: 65 | return Path(self.root) / self.base_folder 66 | 67 | @property 68 | def download_root(self) -> Path: 69 | return Path(self.root) 70 | 71 | @property 72 | def data_folder(self) -> Path: 73 | return self.data_root / "real" 74 | 75 | @property 76 | def split_folder(self) -> Path: 77 | return self.data_root / ("train" if self.train else "test") 78 | 79 | @property 80 | def SPLIT_FOLDER_SIZE(self) -> int: 81 | return self.TRAIN_FOLDER_SIZE if self.train else self.TEST_FOLDER_SIZE 82 | 83 | @property 84 | def label_file_meta(self) -> FileMeta: 85 | return self.train_label_file_meta if self.train else self.test_label_file_meta 86 | 87 | def download(self) -> None: 88 | if self._check_integrity(): 89 | print('Files already downloaded and verified') 90 | return 91 | 92 | # remove old files 93 | shutil.rmtree(self.split_folder, ignore_errors=True) 94 | shutil.rmtree(self.data_folder, ignore_errors=True) 95 | 96 | for file_meta in (self.data_file_meta, self.label_file_meta): 97 | download_url(url=file_meta.url, 98 | root=str(self.download_root), 99 | filename=file_meta.filename, 100 | md5=file_meta.md5) 101 | 102 | def _check_integrity(self) -> bool: 103 | for file_meta in (self.data_file_meta, self.label_file_meta): 104 | if not check_integrity(fpath=str(self.download_root / file_meta.filename), md5=file_meta.md5): 105 | return False 106 | 107 | return True 108 | 109 | def parse_archives(self) -> None: 110 | if not self.split_folder.is_dir() or get_directory_size(self.split_folder) < self.SPLIT_FOLDER_SIZE: 111 | # if split_folder do not exist or not large enough 112 | self.parse_data_archive() 113 | 114 | # remove old files 115 | shutil.rmtree(self.split_folder, ignore_errors=True) 116 | 117 | with open(self.download_root / self.label_file_meta.filename, "r") as f: 118 | file_content = [line.strip() for line in f.readlines()] 119 | 120 | for line in file_content: 121 | search_result = self.label_file_pattern.search(line) 122 | assert search_result is not None, f"{self.label_file_meta.filename} contains invalid line \"{line}\"" 123 | 124 | image_path = Path(search_result.group(1)) 125 | image_relative_path = Path(*image_path.parts[1:]) 126 | 127 | source_path = self.data_root / image_path 128 | target_path = self.split_folder / image_relative_path 129 | 130 | if not target_path.is_file(): 131 | target_path.parent.mkdir(parents=True, exist_ok=True) 132 | os.link(src=source_path.absolute(), dst=target_path.absolute()) 133 | 134 | def parse_data_archive(self) -> None: 135 | if not self.data_folder.is_dir() or get_directory_size(self.data_folder) < self.DATA_FOLDER_SIZE: 136 | # if data_folder do not exist or not large enough 137 | # remove old files 138 | shutil.rmtree(self.data_folder, ignore_errors=True) 139 | 140 | print(f"extracting {self.data_file_meta.filename}...") 141 | extract_archive(from_path=str(self.download_root / self.data_file_meta.filename), 142 | to_path=str(self.data_root)) 143 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from functools import partial 4 | 5 | from .augmentation import get_augmenter 6 | from .rampup import RampUp, LinearRampUp, get_ramp_up 7 | from .utils import (get_accuracy, interleave, unwrap_model, split_classifier_params) 8 | from .models import EMA, build_model, build_ema_model 9 | from .optimization import build_optimizer, build_lr_scheduler 10 | 11 | from . import augmentation 12 | from . import rampup 13 | from . import utils 14 | from . import mixmatch 15 | from . import optimization 16 | from . import types 17 | from . import models 18 | 19 | # for type hint 20 | from typing import Optional, Union, Set 21 | from argparse import Namespace 22 | from torch.nn import Module 23 | 24 | from .types import OptimizerParametersType, MixMatchFunctionType 25 | 26 | 27 | def load_pretrain(model: Module, 28 | checkpoint_path: str, 29 | allowed_prefix: str, 30 | ignored_prefix: str, 31 | device: torch.device, 32 | checkpoint_key: Optional[str] = None): 33 | full_allowed_prefix = f"{allowed_prefix}." if bool(allowed_prefix) else allowed_prefix 34 | full_ignored_prefix = f"{ignored_prefix}." if bool(ignored_prefix) else ignored_prefix 35 | 36 | checkpoint = torch.load(checkpoint_path, map_location=device) 37 | if checkpoint_key is not None: 38 | pretrain_state_dict = checkpoint[checkpoint_key] 39 | else: 40 | pretrain_state_dict = checkpoint 41 | 42 | shadow = None 43 | if isinstance(model, EMA): 44 | shadow = unwrap_model(model.shadow) 45 | model = unwrap_model(model.model) 46 | 47 | state_dict = model.state_dict() 48 | 49 | for name, param in pretrain_state_dict.items(): 50 | if name.startswith(full_allowed_prefix) and not name.startswith(full_ignored_prefix): 51 | name = name[len(full_allowed_prefix):] 52 | 53 | assert name in state_dict.keys() 54 | state_dict[name] = param 55 | 56 | # load pretrain model 57 | model.load_state_dict(state_dict) 58 | if shadow is not None: 59 | shadow.load_state_dict(state_dict) 60 | 61 | 62 | def get_trainable_params(model: Module, 63 | learning_rate: float, 64 | feature_learning_rate: Optional[float], 65 | classifier_prefix: Union[str, Set[str]] = 'fc', 66 | requires_grad_only: bool = True) -> OptimizerParametersType: 67 | if feature_learning_rate is not None: 68 | embedder_weights, classifier_weights = split_classifier_params(model, classifier_prefix) 69 | 70 | if requires_grad_only: 71 | # keep only the parameters that requires grad 72 | embedder_weights = [param for param in embedder_weights if param.requires_grad] 73 | classifier_weights = [param for param in classifier_weights if param.requires_grad] 74 | 75 | params = [dict(params=embedder_weights, lr=feature_learning_rate), 76 | dict(params=classifier_weights, lr=learning_rate)] 77 | else: 78 | params = model.parameters() 79 | 80 | if requires_grad_only: 81 | # keep only the parameters that requires grad 82 | params = [param for param in params if param.requires_grad] 83 | 84 | return params 85 | 86 | 87 | def get_mixmatch_function(args: Namespace, 88 | num_classes: int, 89 | augmenter: Module, 90 | strong_augmenter: Module) -> MixMatchFunctionType: 91 | if args.mixmatch_type == "simple": 92 | from .mixmatch import SimPLE 93 | 94 | return SimPLE(augmenter=augmenter, 95 | strong_augmenter=strong_augmenter, 96 | num_classes=num_classes, 97 | temperature=args.t, 98 | num_augmentations=args.k, 99 | num_strong_augmentations=args.k_strong, 100 | is_strong_augment_x=False, 101 | train_label_guessing=False) 102 | 103 | elif args.mixmatch_type == "enhanced": 104 | from .mixmatch import MixMatchEnhanced 105 | 106 | return MixMatchEnhanced(augmenter=augmenter, 107 | strong_augmenter=strong_augmenter, 108 | num_classes=num_classes, 109 | temperature=args.t, 110 | num_augmentations=args.k, 111 | num_strong_augmentations=args.k_strong, 112 | alpha=args.alpha, 113 | is_strong_augment_x=False, 114 | train_label_guessing=False) 115 | 116 | elif args.mixmatch_type == "mixmatch": 117 | from .mixmatch import MixMatch 118 | 119 | return MixMatch(augmenter=augmenter, 120 | num_classes=num_classes, 121 | temperature=args.t, 122 | num_augmentations=args.k, 123 | alpha=args.alpha, 124 | train_label_guessing=False) 125 | 126 | else: 127 | raise NotImplementedError(f"{args.mixmatch_type} is not a supported mixmatch type") 128 | 129 | 130 | __all__ = [ 131 | # modules 132 | "augmentation", 133 | "rampup", 134 | "utils", 135 | "mixmatch", 136 | "models", 137 | "optimization", 138 | "types", 139 | 140 | # classes 141 | "EMA", 142 | 143 | # functions 144 | "interleave", 145 | "get_augmenter", 146 | "get_ramp_up", 147 | "get_accuracy", 148 | "build_model", 149 | "build_ema_model", 150 | "build_optimizer", 151 | "build_lr_scheduler", 152 | "load_pretrain", 153 | "get_trainable_params", 154 | "get_mixmatch_function", 155 | ] 156 | -------------------------------------------------------------------------------- /utils/dataset/ssl_datamodule.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from pathlib import Path 3 | 4 | from .datamodule import DataModule 5 | from .utils import get_batch 6 | 7 | # for type hint 8 | from typing import Optional, Callable, Tuple, List, Union 9 | from torch.utils.data import Dataset, Subset, DataLoader 10 | from torchvision.datasets import ImageFolder 11 | 12 | from .datasets import LabeledDataset 13 | from .utils import BatchGeneratorType 14 | 15 | DatasetType = Union[LabeledDataset, Dataset] 16 | DataLoaderType = Union[DataLoader, List[DataLoader]] 17 | 18 | 19 | class SSLDataModule(DataModule, ABC): 20 | num_classes: int = 1 21 | 22 | total_train_size: int = 1 23 | total_test_size: int = 1 24 | 25 | DATASET = type(Dataset) 26 | 27 | def __init__(self, 28 | train_batch_size: int, 29 | unlabeled_batch_size: int, 30 | test_batch_size: int, 31 | num_workers: int, 32 | train_min_size: int = 0, 33 | unlabeled_train_min_size: int = 0, 34 | test_min_size: int = 0, 35 | train_transform: Optional[Callable] = None, 36 | val_transform: Optional[Callable] = None, 37 | test_transform: Optional[Callable] = None, 38 | dims: Optional[Tuple[int, ...]] = None): 39 | super(SSLDataModule, self).__init__(train_transform=train_transform, 40 | val_transform=val_transform, 41 | test_transform=test_transform, 42 | dims=dims) 43 | self.train_batch_size = train_batch_size 44 | self.unlabeled_batch_size = unlabeled_batch_size 45 | self.test_batch_size = test_batch_size 46 | self.num_workers = num_workers 47 | 48 | self.train_min_size = max(train_min_size, 0) 49 | self.unlabeled_train_min_size = max(unlabeled_train_min_size, 0) 50 | self.test_min_size = max(test_min_size, 0) 51 | 52 | self._labeled_train_set: Optional[DatasetType] = None 53 | self._unlabeled_train_set: Optional[DatasetType] = None 54 | self._validation_set: Optional[DatasetType] = None 55 | self._test_set: Optional[DatasetType] = None 56 | 57 | # dataset stats 58 | self.dataset_mean: Optional[List[float]] = None 59 | self.dataset_std: Optional[List[float]] = None 60 | 61 | @property 62 | def labeled_train_set(self) -> Optional[DatasetType]: 63 | return self._labeled_train_set 64 | 65 | @labeled_train_set.setter 66 | def labeled_train_set(self, dataset: Optional[DatasetType]) -> None: 67 | self._labeled_train_set = dataset 68 | 69 | @property 70 | def unlabeled_train_set(self) -> Optional[DatasetType]: 71 | return self._unlabeled_train_set 72 | 73 | @unlabeled_train_set.setter 74 | def unlabeled_train_set(self, dataset: Optional[DatasetType]) -> None: 75 | self._unlabeled_train_set = dataset 76 | 77 | @property 78 | def validation_set(self) -> Optional[DatasetType]: 79 | return self._validation_set 80 | 81 | @validation_set.setter 82 | def validation_set(self, dataset: Optional[DatasetType]) -> None: 83 | self._validation_set = dataset 84 | 85 | @property 86 | def test_set(self) -> Optional[DatasetType]: 87 | return self._test_set 88 | 89 | @test_set.setter 90 | def test_set(self, dataset: Optional[DatasetType]) -> None: 91 | self._test_set = dataset 92 | 93 | def train_dataloader(self, **kwargs) -> DataLoaderType: 94 | return [DataLoader(self.labeled_train_set, 95 | shuffle=True, 96 | batch_size=self.train_batch_size, 97 | num_workers=self.num_workers, 98 | drop_last=True, 99 | **kwargs), 100 | DataLoader(self.unlabeled_train_set, 101 | shuffle=True, 102 | batch_size=self.unlabeled_batch_size, 103 | num_workers=self.num_workers, 104 | drop_last=True, 105 | **kwargs)] 106 | 107 | def val_dataloader(self, **kwargs) -> Optional[DataLoaderType]: 108 | return DataLoader(self.validation_set, 109 | batch_size=self.test_batch_size, 110 | shuffle=False, 111 | num_workers=self.num_workers, 112 | drop_last=False, 113 | **kwargs) if self.validation_set is not None else None 114 | 115 | def test_dataloader(self, **kwargs) -> Optional[DataLoaderType]: 116 | return DataLoader(self.test_set, 117 | batch_size=self.test_batch_size, 118 | shuffle=False, 119 | num_workers=self.num_workers, 120 | drop_last=False, 121 | **kwargs) if self.test_set is not None else None 122 | 123 | def get_train_batch(self, train_loaders: List[DataLoader], **kwargs) -> BatchGeneratorType: 124 | return get_batch(train_loaders, **kwargs) 125 | 126 | def save_split_info(self, output_path: Union[str, Path]) -> None: 127 | for subset, filename in [(self.labeled_train_set, "labeled_train_set.txt"), 128 | (self.unlabeled_train_set, "unlabeled_train_set.txt"), 129 | (self.validation_set, "validation_set.txt")]: 130 | if hasattr(subset, "dataset") and isinstance(subset.dataset, Subset): 131 | # save split info if subset is manually split 132 | full_set = subset.dataset.dataset 133 | indices = subset.dataset.indices 134 | 135 | if isinstance(full_set, ImageFolder): 136 | # save file paths 137 | split_info = [str(Path(full_set.imgs[i][0]).relative_to(full_set.root)) + "\n" for i in indices] 138 | else: 139 | # save index values 140 | split_info = [f"{i}\n" for i in indices] 141 | 142 | with open(Path(output_path) / filename, "w") as f: 143 | f.writelines(split_info) 144 | -------------------------------------------------------------------------------- /models/augmentation/randaugment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from torch.nn import functional as F 5 | from kornia.geometry import transform as T 6 | from kornia import enhance as E 7 | from kornia.augmentation import functional as KF 8 | 9 | # for type hint 10 | from typing import Any, Optional, List, Tuple, Callable 11 | from torch import Tensor 12 | from torch.nn import Module 13 | 14 | 15 | # Affine 16 | def translate_x(x: Tensor, v: float) -> Tensor: 17 | B, C, H, W = x.shape 18 | return T.translate(x, torch.tensor([[v * W, 0]], device=x.device, dtype=x.dtype)) 19 | 20 | 21 | def translate_y(x: Tensor, v: float) -> Tensor: 22 | B, C, H, W = x.shape 23 | return T.translate(x, torch.tensor([[0, v * H]], device=x.device, dtype=x.dtype)) 24 | 25 | 26 | def shear_x(x: Tensor, v: float) -> Tensor: 27 | return T.shear(x, torch.tensor([[v, 0.0]], device=x.device, dtype=x.dtype)) 28 | 29 | 30 | def shear_y(x: Tensor, v: float) -> Tensor: 31 | return T.shear(x, torch.tensor([[0.0, v]], device=x.device, dtype=x.dtype)) 32 | 33 | 34 | def rotate(x: Tensor, v: float) -> Tensor: 35 | return T.rotate(x, torch.tensor([v], device=x.device, dtype=x.dtype)) 36 | 37 | 38 | def auto_contrast(x: Tensor, _: Any) -> Tensor: 39 | B, C, H, W = x.shape 40 | 41 | x_min = x.view(B, C, -1).min(-1)[0].view(B, C, 1, 1) 42 | x_max = x.view(B, C, -1).max(-1)[0].view(B, C, 1, 1) 43 | 44 | x_out = (x - x_min) / torch.clamp(x_max - x_min, min=1e-9, max=1) 45 | 46 | return x_out.expand_as(x) 47 | 48 | 49 | def invert(x: Tensor, _: Any) -> Tensor: 50 | return 1.0 - x 51 | 52 | 53 | def equalize(x: Tensor, _: Any) -> Tensor: 54 | return KF.apply_equalize(x, params=dict(batch_prob=torch.tensor([1.0] * len(x), dtype=x.dtype, device=x.device))) 55 | 56 | 57 | def flip(x: Tensor, _: Any) -> Tensor: 58 | return T.hflip(x) 59 | 60 | 61 | def solarize(x: Tensor, v: float) -> Tensor: 62 | x[x < v] = 1 - x[x < v] 63 | return x 64 | 65 | 66 | def brightness(x: Tensor, v: float) -> Tensor: 67 | return E.adjust_brightness(x, v) 68 | 69 | 70 | def color(x: Tensor, v: float) -> Tensor: 71 | return E.adjust_saturation(x, v) 72 | 73 | 74 | def contrast(x: Tensor, v: float) -> Tensor: 75 | return E.adjust_contrast(x, v) 76 | 77 | 78 | def sharpness(x: Tensor, v: float) -> Tensor: 79 | return KF.apply_sharpness(x, params=dict(sharpness_factor=v)) 80 | 81 | 82 | def identity(x: Tensor, _: Any) -> Tensor: 83 | return x 84 | 85 | 86 | def posterize(x: Tensor, v: float) -> Tensor: 87 | v = int(v) 88 | return E.posterize(x, v) 89 | 90 | 91 | def cutout(x: Tensor, v: float) -> Tensor: 92 | B, C, H, W = x.shape 93 | 94 | x_v = int(v * W) 95 | y_v = int(v * H) 96 | 97 | x_idx = np.random.uniform(low=0, high=W - x_v, size=(B, 1, 1, 1)) + np.arange(x_v).reshape((1, 1, 1, -1)) 98 | y_idx = np.random.uniform(low=0, high=H - y_v, size=(B, 1, 1, 1)) + np.arange(y_v).reshape((1, 1, -1, 1)) 99 | 100 | x[np.arange(B).reshape((B, 1, 1, 1)), np.arange(C).reshape((1, C, 1, 1)), y_idx, x_idx] = 0.5 101 | return x 102 | 103 | 104 | def cutout_pad(x: Tensor, v: float) -> Tensor: 105 | B, C, H, W = x.shape 106 | 107 | x = F.pad(x, [int(v * W / 2), int(v * W / 2), int(v * H / 2), int(v * H / 2)]) 108 | 109 | x = cutout(x, v / (1 + v)) 110 | 111 | x = T.center_crop(x, (H, W)) 112 | 113 | return x 114 | 115 | 116 | class RandAugment(Module): 117 | def __init__(self, n: int, m: int, augmentation_pool: Optional[List[Tuple[Callable, float, float]]] = None): 118 | """ 119 | 120 | Args: 121 | n: number of transformations 122 | m: magnitude 123 | augmentation_pool: transformation pool 124 | """ 125 | super().__init__() 126 | 127 | self.n = n 128 | self.m = m 129 | if augmentation_pool is not None: 130 | self.augmentation_pool = augmentation_pool 131 | else: 132 | self.augmentation_pool = [ 133 | (auto_contrast, np.nan, np.nan), 134 | (brightness, 0.05, 0.95), 135 | (color, 0.05, 0.95), 136 | (contrast, 0.05, 0.95), 137 | (cutout, 0, 0.3), 138 | (equalize, np.nan, np.nan), 139 | (identity, np.nan, np.nan), 140 | (posterize, 4, 8), 141 | (rotate, -30, 30), 142 | (sharpness, 0.05, 0.95), 143 | (shear_x, -0.3, 0.3), 144 | (shear_y, -0.3, 0.3), 145 | (solarize, 0.0, 1.0), 146 | (translate_x, -0.3, 0.3), 147 | (translate_y, -0.3, 0.3), 148 | ] 149 | 150 | def forward(self, x): 151 | assert len(x.shape) == 4 152 | 153 | ops = random.choices(self.augmentation_pool, k=self.n) 154 | 155 | for op, min_v, max_v in ops: 156 | v = random.randint(1, self.m + 1) / 10 * (max_v - min_v) + min_v 157 | x = op(x, v) 158 | 159 | return x 160 | 161 | 162 | class RandAugmentNS(RandAugment): 163 | def __init__(self, n: int, m: int): 164 | super(RandAugmentNS, self).__init__( 165 | n=n, 166 | m=m, 167 | augmentation_pool=[ 168 | (auto_contrast, np.nan, np.nan), 169 | (brightness, 0.05, 0.95), 170 | (color, 0.05, 0.95), 171 | (contrast, 0.05, 0.95), 172 | (equalize, np.nan, np.nan), 173 | (identity, np.nan, np.nan), 174 | (posterize, 4, 8), 175 | (rotate, -30, 30), 176 | (sharpness, 0.05, 0.95), 177 | (shear_x, -0.3, 0.3), 178 | (shear_y, -0.3, 0.3), 179 | (solarize, 0.0, 1.0), 180 | (translate_x, -0.3, 0.3), 181 | (translate_y, -0.3, 0.3), 182 | ]) 183 | 184 | def forward(self, x: Tensor) -> Tensor: 185 | assert len(x.shape) == 4 186 | 187 | ops = random.choices(self.augmentation_pool, k=self.n) 188 | 189 | for op, min_v, max_v in ops: 190 | v = random.randint(1, self.m + 1) / 10 * (max_v - min_v) + min_v 191 | if random.random() < 0.5: 192 | x = op(x, v) 193 | 194 | x = cutout(x, 0.5) 195 | return x 196 | -------------------------------------------------------------------------------- /models/mixmatch/mixmatch_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | code inspired by https://github.com/gan3sh500/mixmatch-pytorch and 3 | https://github.com/google-research/mixmatch 4 | """ 5 | import torch 6 | import numpy as np 7 | from torch.nn import functional as F 8 | 9 | from .utils import label_guessing, sharpen 10 | 11 | # for type hint 12 | from typing import Optional, Dict, Union, List, Sequence 13 | from torch import Tensor 14 | from torch.nn import Module 15 | 16 | 17 | class MixMatchBase: 18 | def __init__(self, 19 | augmenter: Module, 20 | strong_augmenter: Optional[Module], 21 | num_classes: int, 22 | temperature: float, 23 | num_augmentations: int, 24 | num_strong_augmentations: int, 25 | alpha: float, 26 | is_strong_augment_x: bool, 27 | train_label_guessing: bool): 28 | # callables 29 | self.augmenter = augmenter 30 | self.strong_augmenter = strong_augmenter 31 | 32 | # parameters 33 | self.num_classes = num_classes 34 | self.temperature = temperature 35 | self.alpha = alpha 36 | 37 | self.num_augmentations = num_augmentations 38 | self.num_strong_augmentations = num_strong_augmentations 39 | 40 | # flags 41 | self.train_label_guessing = train_label_guessing 42 | self.is_strong_augment_x = is_strong_augment_x 43 | 44 | @property 45 | def total_num_augmentations(self) -> int: 46 | return self.num_augmentations + self.num_strong_augmentations 47 | 48 | @torch.no_grad() 49 | def __call__(self, 50 | x_augmented: Tensor, 51 | x_strong_augmented: Optional[Tensor], 52 | x_targets_one_hot: Tensor, 53 | u_augmented: List[Tensor], 54 | u_strong_augmented: List[Tensor], 55 | u_true_targets_one_hot: Tensor, 56 | model: Module, 57 | *args, 58 | **kwargs) -> Dict[str, Tensor]: 59 | if self.is_strong_augment_x: 60 | x_inputs = x_strong_augmented 61 | else: 62 | x_inputs = x_augmented 63 | u_inputs = u_augmented + u_strong_augmented 64 | 65 | # label guessing with weakly augmented data 66 | pseudo_label_dict = self.guess_label(u_inputs=u_augmented, model=model) 67 | 68 | return self.postprocess(x_augmented=x_inputs, 69 | x_targets_one_hot=x_targets_one_hot, 70 | u_augmented=u_inputs, 71 | q_guess=pseudo_label_dict["q_guess"], 72 | u_true_targets_one_hot=u_true_targets_one_hot) 73 | 74 | @torch.no_grad() 75 | def preprocess(self, 76 | x_inputs: Tensor, 77 | x_strong_inputs: Tensor, 78 | x_targets: Tensor, 79 | u_inputs: Tensor, 80 | u_strong_inputs: Tensor, 81 | u_true_targets: Tensor) -> Dict[str, Union[Optional[Tensor], List[Tensor]]]: 82 | # convert targets to one-hot 83 | x_targets_one_hot = F.one_hot(x_targets, num_classes=self.num_classes).type_as(x_inputs) 84 | u_true_targets_one_hot = F.one_hot(u_true_targets, num_classes=self.num_classes).type_as(x_inputs) 85 | 86 | # apply augmentations 87 | x_augmented = self.augmenter(x_inputs) 88 | u_augmented = [self.augmenter(u_inputs) for _ in range(self.num_augmentations)] 89 | 90 | if self.strong_augmenter is not None: 91 | x_strong_augmented = self.strong_augmenter(x_strong_inputs) 92 | u_strong_augmented = [self.strong_augmenter(u_strong_inputs) for _ in range(self.num_strong_augmentations)] 93 | else: 94 | x_strong_augmented = None 95 | u_strong_augmented = [] 96 | 97 | return dict(x_augmented=x_augmented, 98 | x_strong_augmented=x_strong_augmented, 99 | x_targets_one_hot=x_targets_one_hot, 100 | u_augmented=u_augmented, 101 | u_strong_augmented=u_strong_augmented, 102 | u_true_targets_one_hot=u_true_targets_one_hot) 103 | 104 | def guess_label(self, u_inputs: Sequence[Tensor], model: Module) -> Dict[str, Tensor]: 105 | # label guessing 106 | q_guess = label_guessing(batches=u_inputs, model=model, is_train_mode=self.train_label_guessing) 107 | q_guess = sharpen(q_guess, self.temperature) 108 | 109 | return dict(q_guess=q_guess) 110 | 111 | def postprocess(self, 112 | x_augmented: Tensor, 113 | x_targets_one_hot: Tensor, 114 | u_augmented: List[Tensor], 115 | q_guess: Tensor, 116 | u_true_targets_one_hot: Tensor) -> Dict[str, Tensor]: 117 | # concat the unlabeled data and targets 118 | u_augmented = torch.cat(u_augmented, dim=0) 119 | q_guess = torch.cat([q_guess for _ in range(self.total_num_augmentations)], dim=0) 120 | q_true = torch.cat([u_true_targets_one_hot for _ in range(self.total_num_augmentations)], dim=0) 121 | assert len(u_augmented) == len(q_guess) == len(q_true) 122 | 123 | return self.mixup(x_augmented=x_augmented, 124 | x_targets_one_hot=x_targets_one_hot, 125 | u_augmented=u_augmented, 126 | q_guess=q_guess, 127 | q_true=q_true) 128 | 129 | def mixup(self, 130 | x_augmented: Tensor, 131 | x_targets_one_hot: Tensor, 132 | u_augmented: Tensor, 133 | q_guess: Tensor, 134 | q_true: Tensor) -> Dict[str, Tensor]: 135 | # random shuffle according to the paper 136 | indices = list(range(len(x_augmented) + len(u_augmented))) 137 | np.random.shuffle(indices) 138 | 139 | # MixUp 140 | wx = torch.cat([x_augmented, u_augmented], dim=0) 141 | wy = torch.cat([x_targets_one_hot, q_guess], dim=0) 142 | wq = torch.cat([x_targets_one_hot, q_true], dim=0) 143 | assert len(wx) == len(wy) == len(wq) 144 | assert len(wx) == len(x_augmented) + len(u_augmented) 145 | assert len(wy) == len(x_targets_one_hot) + len(q_guess) 146 | wx_shuffled = wx[indices] 147 | wy_shuffled = wy[indices] 148 | wq_shuffled = wq[indices] 149 | assert len(wx) == len(wx_shuffled) 150 | assert len(wy) == len(wy_shuffled) 151 | assert len(wq) == len(wq_shuffled) 152 | 153 | # the official version use the same lambda ~ sampled from Beta(alpha, alpha) for both 154 | # labeled and unlabeled inputs 155 | lam = np.random.beta(self.alpha, self.alpha) 156 | lam = max(lam, 1 - lam) 157 | 158 | wx_mixed = lam * wx + (1 - lam) * wx_shuffled 159 | wy_mixed = lam * wy + (1 - lam) * wy_shuffled 160 | wq_mixed = lam * wq + (1 - lam) * wq_shuffled 161 | 162 | x_mixed, p_mixed = wx_mixed[:len(x_augmented)], wy_mixed[:len(x_augmented)] 163 | u_mixed, q_mixed = wx_mixed[len(x_augmented):], wy_mixed[len(x_augmented):] 164 | q_true_mixed = wq_mixed[len(x_augmented):] 165 | 166 | return dict(x_mixed=x_mixed, 167 | p_mixed=p_mixed, 168 | u_mixed=u_mixed, 169 | q_mixed=q_mixed, 170 | q_true_mixed=q_true_mixed) 171 | -------------------------------------------------------------------------------- /loss/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from kornia.utils import confusion_matrix 4 | import plotly.figure_factory as ff 5 | import wandb 6 | 7 | from .utils import to_tensor 8 | from .pair_loss.utils import get_pair_indices 9 | 10 | # for type hint 11 | from typing import Tuple, Union, Dict 12 | from torch import Tensor 13 | from plotly.graph_objects import Figure 14 | 15 | from .types import SimilarityType, LogDictType, PlotDictType, LossInfoType 16 | 17 | 18 | def _detorch(t: Tensor) -> np.ndarray: 19 | return t.detach().cpu().clone().numpy() 20 | 21 | 22 | @torch.no_grad() 23 | def get_pair_info(targets: Tensor, 24 | true_targets: Tensor, 25 | similarity_metric: SimilarityType, 26 | confidence_threshold: float, 27 | similarity_threshold: float, 28 | return_plot_info: bool = False) -> LossInfoType: 29 | indices = get_pair_indices(targets, ordered_pair=True) 30 | i_indices, j_indices = indices[:, 0], indices[:, 1] 31 | targets_i, targets_j = targets[i_indices], targets[j_indices] 32 | 33 | targets_max_prob = targets.max(dim=1).values 34 | true_labels = true_targets.argmax(dim=1) 35 | 36 | similarities: Tensor = similarity_metric(targets_i=targets_i, targets_j=targets_j, dim=1) 37 | 38 | conf_mask: Tensor = targets_max_prob[i_indices] > confidence_threshold 39 | sim_mask: Tensor = (similarities > similarity_threshold) 40 | final_mask: Tensor = (conf_mask & sim_mask) 41 | true_pair_mask: Tensor = (true_labels[i_indices] == true_labels[j_indices]) 42 | 43 | log_info, plot_info = get_pair_loss_info( 44 | conf_mask=conf_mask, 45 | sim_mask=sim_mask, 46 | final_mask=final_mask, 47 | true_pair_mask=true_pair_mask, 48 | indices=indices, 49 | return_plot_info=return_plot_info) 50 | 51 | # log extra metrics 52 | extra_log_info, extra_plot_info = get_pair_extra_info( 53 | targets_max_prob=targets_max_prob, 54 | i_indices=i_indices, 55 | j_indices=j_indices, 56 | similarities=similarities, 57 | final_mask=final_mask) 58 | 59 | log_info.update(extra_log_info) 60 | plot_info.update(extra_plot_info) 61 | 62 | return {"log": log_info, "plot": plot_info} 63 | 64 | 65 | @torch.no_grad() 66 | def get_pair_loss_info(conf_mask: Tensor, 67 | sim_mask: Tensor, 68 | final_mask: Tensor, 69 | true_pair_mask: Tensor, 70 | indices: Tensor, 71 | return_plot_info: bool = False) -> Tuple[LogDictType, PlotDictType]: 72 | # prepare log info 73 | total_high_conf = conf_mask.sum() 74 | total_high_sim = sim_mask.sum() 75 | total_thresholded = final_mask.sum() 76 | total_size = len(indices) 77 | 78 | total_true_positive = (true_pair_mask & final_mask).sum().float() 79 | if total_thresholded == 0: 80 | true_pair_given_thresholded = to_tensor(0., tensor_like=total_true_positive) 81 | else: 82 | true_pair_given_thresholded = total_true_positive / total_thresholded 83 | 84 | matrix = get_confusion_matrix(y_true=true_pair_mask, y_pred=final_mask) 85 | normalized_matrix = matrix / matrix.sum() 86 | 87 | log_info = { 88 | "high_conf_ratio": (total_high_conf.float() / total_size), 89 | "high_sim_ratio": (total_high_sim.float() / total_size), 90 | "thresholded_ratio": (total_thresholded.float() / total_size), 91 | "true_pair_given_thresholded": true_pair_given_thresholded, 92 | 93 | # see https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html 94 | # In binary classification, the count of true negatives is C[0,0], false negatives is C[1,0], 95 | # true positives is C[1,1] and false positives is C[0,1]. 96 | "true_negative_pair_ratio": normalized_matrix[0, 0], 97 | "false_positives_pair_ratio": normalized_matrix[0, 1], 98 | "false_negatives_pair_ratio": normalized_matrix[1, 0], 99 | "true_positives_pair_ratio": normalized_matrix[1, 1], 100 | } 101 | 102 | plot_info = {} 103 | if return_plot_info: 104 | plot_info["pair_confusion_matrix"] = visualize(_detorch(matrix), _detorch(normalized_matrix)) 105 | 106 | return {f"pair_loss/{k}": v for k, v in log_info.items()}, \ 107 | {f"pair_loss/{k}": v for k, v in plot_info.items()} 108 | 109 | 110 | def get_confusion_matrix(y_true: Tensor, y_pred: Tensor) -> Tensor: 111 | """ 112 | Args: 113 | y_true: (batch size,) a vector where each element is the class label 114 | y_pred: (batch size,) predicted class labels 115 | 116 | Returns: 117 | 118 | """ 119 | matrix = confusion_matrix(y_pred.view(1, -1), y_true.view(1, -1), num_classes=2, normalized=False) 120 | matrix.squeeze_() 121 | 122 | return matrix 123 | 124 | 125 | def visualize(matrix: np.ndarray, normalized_matrix: np.ndarray, **kwargs) -> Figure: 126 | # see https://matplotlib.org/faq/usage_faq.html#coding-styles 127 | # see https://plotly.com/python/annotated-heatmap/ 128 | # see https://plotly.com/python/builtin-colorscales/ 129 | annotation_text = [ 130 | [f"True Negative: {normalized_matrix[0, 0]:.3%}", f"False Positive: {normalized_matrix[0, 1]:.3%}"], 131 | [f"False Negative: {normalized_matrix[1, 0]:.3%}", f"True Positive: {normalized_matrix[1, 1]:.3%}"] 132 | ] 133 | x_axis_text = ["Predicted False", "Predicted True"] 134 | y_axis_text = ["False", "True"] 135 | fig = ff.create_annotated_heatmap(matrix, x=x_axis_text, y=y_axis_text, annotation_text=annotation_text, 136 | colorscale='Plotly3') 137 | fig.update_layout({ 138 | 'xaxis': {'title': {'text': "Predicted"}}, 139 | 'yaxis': {'title': {'text': "Ground Truth"}}, 140 | }) 141 | 142 | return fig 143 | 144 | 145 | @torch.no_grad() 146 | def get_pair_extra_info(targets_max_prob: Tensor, 147 | i_indices: Tensor, 148 | j_indices: Tensor, 149 | similarities: Tensor, 150 | final_mask: Tensor) -> Tuple[LogDictType, PlotDictType]: 151 | def mean_std_max_min(t: Union[Tensor, np.ndarray], prefix: str = "") -> Dict[str, Union[Tensor, np.ndarray]]: 152 | return { 153 | f"{prefix}/mean": t.mean() if t.numel() > 0 else to_tensor(0, tensor_like=t), 154 | f"{prefix}/std": t.std() if t.numel() > 0 else to_tensor(0, tensor_like=t), 155 | f"{prefix}/max": t.max() if t.numel() > 0 else to_tensor(0, tensor_like=t), 156 | f"{prefix}/min": t.min() if t.numel() > 0 else to_tensor(0, tensor_like=t), 157 | } 158 | 159 | targets_i_max_prob = targets_max_prob[i_indices] 160 | targets_j_max_prob = targets_max_prob[j_indices] 161 | 162 | selected_sim = similarities[final_mask] 163 | selected_i_conf = targets_i_max_prob[final_mask] 164 | selected_j_conf = targets_j_max_prob[final_mask] 165 | 166 | selected_i_conf_stat = mean_std_max_min(selected_i_conf, prefix="selected_i_conf") 167 | selected_j_conf_stat = mean_std_max_min(selected_j_conf, prefix="selected_j_conf") 168 | selected_sim_stat = mean_std_max_min(selected_sim, prefix="selected_sim") 169 | 170 | selected_i_conf_hist = wandb.Histogram(_detorch(selected_i_conf)) 171 | selected_j_conf_hist = wandb.Histogram(_detorch(selected_j_conf)) 172 | selected_sim_hist = wandb.Histogram(_detorch(selected_sim)) 173 | 174 | log_info = { 175 | **selected_i_conf_stat, 176 | **selected_j_conf_stat, 177 | **selected_sim_stat, 178 | } 179 | plot_info = { 180 | "selected_i_conf_hist": selected_i_conf_hist, 181 | "selected_j_conf_hist": selected_j_conf_hist, 182 | "selected_sim_hist": selected_sim_hist, 183 | } 184 | 185 | return {f"pair_loss/{k}": v for k, v in log_info.items()}, \ 186 | {f"pair_loss/{k}": v for k, v in plot_info.items()} 187 | -------------------------------------------------------------------------------- /models/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ref: 3 | https://github.com/hysts/pytorch_wrn/blob/master/wrn.py 4 | https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py 5 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | https://github.com/szagoruyko/functional-zoo/blob/master/wide-resnet-50-2-export.ipynb 7 | https://github.com/szagoruyko/wide-residual-networks/tree/master/pytorch 8 | 9 | to get the name of the layer in state_dict, see the following example 10 | >>> wrn = WideResNet( 11 | >>> in_channels=1, 12 | >>> out_channels=5, 13 | >>> base_channels=4, 14 | >>> widening_factor=10, 15 | >>> drop_rate=0, 16 | >>> depth=10 17 | >>> ) 18 | >>> 19 | >>> # print the state_dict keys 20 | >>> d = wrn.state_dict() 21 | >>> dl = list(d.keys()) 22 | >>> for idx, n in enumerate(dl): 23 | >>> print("{} -> {}".format(idx, n)) 24 | """ 25 | import torch 26 | from torch import nn 27 | import numpy as np 28 | 29 | # for type hint 30 | from typing import Union, Tuple, Sequence, Optional 31 | from torch import Tensor 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | def __init__(self, 36 | in_channels: int, 37 | out_channels: int, 38 | stride: int, 39 | drop_rate: float = 0.0, 40 | activate_before_residual: bool = False, 41 | batch_norm_momentum: float = 0.001): 42 | super().__init__() 43 | 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.drop_rate = drop_rate 47 | self.equal_in_out = (in_channels == out_channels) 48 | self.activate_before_residual = activate_before_residual 49 | 50 | # see https://github.com/pytorch/examples/issues/289 regarding different convention for batch norm momentum 51 | self.bn1 = nn.BatchNorm2d(self.in_channels, momentum=batch_norm_momentum) 52 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 53 | self.conv1 = nn.Conv2d( 54 | self.in_channels, 55 | self.out_channels, 56 | kernel_size=3, 57 | stride=stride, 58 | padding=1, 59 | bias=False) 60 | 61 | self.bn2 = nn.BatchNorm2d(self.out_channels, momentum=batch_norm_momentum) 62 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 63 | self.conv2 = nn.Conv2d( 64 | self.out_channels, 65 | self.out_channels, 66 | kernel_size=3, 67 | stride=1, 68 | padding=1, 69 | bias=False) 70 | 71 | if not self.equal_in_out: 72 | self.conv_shortcut = nn.Conv2d( 73 | self.in_channels, 74 | self.out_channels, 75 | kernel_size=1, 76 | stride=stride, 77 | padding=0, 78 | bias=False) 79 | else: 80 | self.conv_shortcut = None 81 | 82 | if drop_rate > 0: 83 | self.dropout = nn.Dropout(p=self.drop_rate) 84 | 85 | def forward(self, inputs): 86 | if not self.equal_in_out and self.activate_before_residual: 87 | inputs = self.bn1(inputs) 88 | inputs = self.relu1(inputs) 89 | outputs = self.conv1(inputs) 90 | else: 91 | outputs = self.bn1(inputs) 92 | outputs = self.relu1(outputs) 93 | outputs = self.conv1(outputs) 94 | 95 | outputs = self.bn2(outputs) 96 | outputs = self.relu2(outputs) 97 | outputs = self.conv2(outputs) 98 | 99 | if self.drop_rate > 0: 100 | outputs = self.dropout(outputs) 101 | 102 | if not self.equal_in_out: 103 | inputs = self.conv_shortcut(inputs) 104 | 105 | return torch.add(inputs, outputs) 106 | 107 | 108 | class NetworkBlock(nn.Module): 109 | def __init__(self, 110 | in_channels: int, 111 | out_channels: int, 112 | num_blocks: int, 113 | stride: int, 114 | drop_rate: float = 0.0, 115 | activate_before_residual: bool = False, 116 | basic_block: type(BasicBlock) = BasicBlock, 117 | **block_kwargs): 118 | super().__init__() 119 | self.layer = self._make_layer(basic_block, in_channels, out_channels, num_blocks, stride, drop_rate, 120 | activate_before_residual, **block_kwargs) 121 | 122 | @staticmethod 123 | def _make_layer(block: type(BasicBlock), 124 | in_channels: int, 125 | out_channels: int, 126 | num_blocks: int, 127 | stride: int, 128 | drop_rate: float, 129 | activate_before_residual: bool, 130 | **block_kwargs) -> nn.Module: 131 | layers = [] 132 | for i in range(int(num_blocks)): 133 | if i == 0: 134 | layers.append( 135 | block(in_channels, 136 | out_channels, 137 | stride=stride, 138 | drop_rate=drop_rate, 139 | activate_before_residual=activate_before_residual, 140 | **block_kwargs)) 141 | else: 142 | layers.append( 143 | block(out_channels, 144 | out_channels, 145 | stride=1, 146 | drop_rate=drop_rate, 147 | activate_before_residual=activate_before_residual, 148 | **block_kwargs)) 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | return self.layer(x) 153 | 154 | 155 | class WideResNet(nn.Module): 156 | def __init__(self, 157 | in_channels: int, 158 | out_channels: int, 159 | depth: int = 28, 160 | widening_factor: int = 2, 161 | base_channels: int = 16, 162 | drop_rate: float = 0.0, 163 | batch_norm_momentum: float = 0.001, 164 | n_channels: Optional[Sequence[int]] = None): 165 | super().__init__() 166 | 167 | if n_channels is None: 168 | # interpolate channels 169 | n_channels = [ 170 | int(base_channels), 171 | int(base_channels * widening_factor), 172 | int(base_channels * 2 * widening_factor), 173 | int(base_channels * 4 * widening_factor), 174 | ] 175 | assert len(n_channels) == 4 176 | self.n_channels = n_channels 177 | 178 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 179 | self.num_blocks = (depth - 4) // 6 180 | 181 | if self.num_blocks == 0: 182 | self.fc_in_features = self.n_channels[0] 183 | else: 184 | self.fc_in_features = self.n_channels[3] 185 | 186 | self.conv = nn.Conv2d( 187 | in_channels, 188 | self.n_channels[0], 189 | kernel_size=3, 190 | stride=1, 191 | padding=1, 192 | bias=False) 193 | 194 | self.block1 = NetworkBlock( 195 | in_channels=self.n_channels[0], 196 | out_channels=self.n_channels[1], 197 | num_blocks=self.num_blocks, 198 | stride=1, 199 | drop_rate=drop_rate, 200 | activate_before_residual=True, 201 | basic_block=BasicBlock, 202 | batch_norm_momentum=batch_norm_momentum) 203 | 204 | self.block2 = NetworkBlock( 205 | in_channels=self.n_channels[1], 206 | out_channels=self.n_channels[2], 207 | num_blocks=self.num_blocks, 208 | stride=2, 209 | drop_rate=drop_rate, 210 | basic_block=BasicBlock, 211 | batch_norm_momentum=batch_norm_momentum) 212 | 213 | self.block3 = NetworkBlock( 214 | in_channels=self.n_channels[2], 215 | out_channels=self.n_channels[3], 216 | num_blocks=self.num_blocks, 217 | stride=2, 218 | drop_rate=drop_rate, 219 | basic_block=BasicBlock, 220 | batch_norm_momentum=batch_norm_momentum) 221 | 222 | self.bn = nn.BatchNorm2d(num_features=self.fc_in_features, momentum=batch_norm_momentum) 223 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 224 | # self.pool = nn.AvgPool2d(kernel_size=8) 225 | self.pool = nn.AdaptiveAvgPool2d(output_size=1) 226 | self.fc = nn.Linear(self.fc_in_features, out_channels) 227 | 228 | # initialization 229 | for m in self.modules(): 230 | if isinstance(m, nn.Conv2d): 231 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 232 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 233 | m.weight.data.normal_(0, np.sqrt(2. / n)) 234 | elif isinstance(m, nn.BatchNorm2d): 235 | m.weight.data.fill_(1) 236 | m.bias.data.zero_() 237 | elif isinstance(m, nn.Linear): 238 | nn.init.xavier_normal_(m.weight.data) 239 | m.bias.data.zero_() 240 | 241 | def forward(self, inputs: Tensor, return_feature: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]: 242 | outputs = self.conv(inputs) 243 | outputs = self.block1(outputs) 244 | outputs = self.block2(outputs) 245 | outputs = self.block3(outputs) 246 | outputs = self.bn(outputs) 247 | outputs = self.relu(outputs) 248 | outputs = self.pool(outputs) 249 | # outputs = outputs.view(-1, self.n_channels[3]) 250 | features = outputs.view(outputs.size(0), -1) 251 | 252 | outputs = self.fc(features) 253 | 254 | if return_feature: 255 | return features, outputs 256 | 257 | return outputs 258 | -------------------------------------------------------------------------------- /utils/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader, Subset 3 | from tqdm import tqdm 4 | 5 | from itertools import repeat, combinations 6 | from pathlib import Path 7 | 8 | # for type hint 9 | from typing import Optional, Generator, Tuple, Union, List, Sequence, Any 10 | from torch import Tensor 11 | from PIL.Image import Image 12 | from torchvision.datasets import VisionDataset 13 | 14 | BatchType = Union[Tuple[Tuple[Tensor, Tensor], ...], Tuple[Tensor, Tensor]] 15 | LoaderType = Union[DataLoader, Generator[BatchType, None, None]] 16 | BatchGeneratorType = Generator[Tuple[int, BatchType], None, None] 17 | 18 | 19 | def repeater(data_loader: DataLoader) -> Generator[Tuple[Tensor, Tensor], None, None]: 20 | for loader in repeat(data_loader): 21 | for data in loader: 22 | yield data 23 | 24 | 25 | def get_batch(loaders: Sequence[LoaderType], max_iter: int, is_repeat: bool = False) \ 26 | -> BatchGeneratorType: 27 | if is_repeat: 28 | loaders = [repeater(loader) for loader in loaders] 29 | 30 | if len(loaders) == 1: 31 | combined_loaders = loaders[0] 32 | else: 33 | combined_loaders = zip(*loaders) 34 | 35 | for idx, batch in enumerate(combined_loaders): 36 | if idx >= max_iter: 37 | return 38 | yield idx, batch 39 | 40 | 41 | def get_targets(dataset: Union[Dataset, VisionDataset]) -> List[Any]: 42 | if hasattr(dataset, 'targets'): 43 | return dataset.targets 44 | 45 | # TODO: handle very large dataset 46 | return [target for (_, target) in dataset] 47 | 48 | 49 | def get_class_indices(dataset: Dataset) -> List[np.ndarray]: 50 | targets = np.asarray(get_targets(dataset)) 51 | num_classes = max(targets) + 1 52 | 53 | return [np.where(targets == i)[0] for i in range(num_classes)] 54 | 55 | 56 | def to_indices_or_sections(subset_sizes: Sequence[int]) -> List[int]: 57 | outputs = [sum(subset_sizes[:i]) + size for i, size in enumerate(subset_sizes)] 58 | 59 | return outputs 60 | 61 | 62 | def safe_round(inputs: np.ndarray, target_sums: np.ndarray, min_value: int = 0) -> np.ndarray: 63 | """ 64 | Round array while maintaining the sum. The difference is adjusted based on value 65 | distribution in input array 66 | 67 | Args: 68 | inputs: input numpy array 69 | target_sums: target sum values 70 | min_value: each element in the output will be at least min_value 71 | 72 | Returns: numpy array where each element is at least min_value 73 | 74 | """ 75 | assert len(inputs) == len(target_sums) 76 | rounded = np.around(inputs).astype(int) 77 | rounded = np.maximum(rounded, min_value) 78 | 79 | outputs = np.zeros_like(inputs, dtype=int) 80 | 81 | for i in range(len(inputs)): 82 | # TODO: use more efficient implementation 83 | rounded_row = rounded[i] 84 | row_target_sum = target_sums[i] 85 | 86 | # subtract by min_value so that rounding adjustment do not affect the min 87 | row_outputs = rounded_row - min_value 88 | adjusted_target_sum = row_target_sum - (len(rounded_row) * min_value) 89 | 90 | round_error = adjusted_target_sum - row_outputs.sum() 91 | 92 | if round_error != 0: 93 | # repeat index by its corresponding value; this will make sure the sampling follows value distribution 94 | extended_idx = np.repeat(np.arange(row_outputs.size), row_outputs) 95 | selected_idx = np.random.choice(extended_idx, abs(round_error)) 96 | 97 | unique_idx, idx_counts = np.unique(selected_idx, return_counts=True) 98 | row_outputs[unique_idx] += np.copysign(idx_counts, round_error).astype(int) 99 | 100 | assert row_outputs.sum() == adjusted_target_sum 101 | 102 | # add back the subtracted min_value 103 | row_outputs += min_value 104 | 105 | assert row_outputs.sum() == row_target_sum 106 | outputs[i] = row_outputs 107 | 108 | assert np.all(outputs.sum(axis=1) == target_sums) 109 | assert np.all(outputs >= min_value) 110 | return outputs 111 | 112 | 113 | def per_class_random_split_by_ratio(dataset: Dataset, 114 | ratios: Sequence[float], 115 | num_classes: int, 116 | uneven_split: bool = False, 117 | min_value: int = 1) -> List[Dataset]: 118 | """Split the dataset base on ratios. 119 | 120 | Args: 121 | dataset: dataset to split 122 | ratios: fraction of data in each subset 123 | num_classes: number of classes in dataset 124 | uneven_split: if True, will return len(ratios) + 1 subsets where the size for the last subset is interpolated 125 | min_value: min value in each class of each subset 126 | 127 | Returns: if is_uneven_split = False, returns len(ratios) subsets where the ith subset has size ratio[i] * 128 | len(dataset); if is_uneven_split = True, returns len(ratios) + 1 subsets where the size for the last 129 | subset is interpolated. 130 | 131 | """ 132 | ratios = list(ratios) 133 | if uneven_split: 134 | ratios.append(1. - sum(ratios)) 135 | 136 | assert sum(ratios) == 1. 137 | ratios = np.asarray(ratios) 138 | assert np.all(ratios >= 0.) 139 | 140 | # shape (num_classes, # samples in this class); each row is the indices of samples in that class 141 | class_indices = get_class_indices(dataset) 142 | assert len(class_indices) == num_classes 143 | 144 | # shape (num_classes,); each element is the number of samples in that class 145 | class_lengths = np.asarray([len(class_idx) for class_idx in class_indices]) 146 | assert class_lengths.sum() == len(dataset) 147 | 148 | # shape (num_subset, num_classes); each row is a list of class sizes for the corresponding subset 149 | subset_class_lengths = ratios.reshape(-1, 1) * class_lengths 150 | 151 | subset_class_lengths = safe_round(subset_class_lengths.T, target_sums=class_lengths, min_value=min_value) 152 | subset_class_lengths = subset_class_lengths.transpose() 153 | 154 | return per_class_random_split_helper(dataset=dataset, 155 | class_indices=class_indices, 156 | subset_class_lengths=subset_class_lengths) 157 | 158 | 159 | def per_class_random_split(dataset: Dataset, lengths: Sequence[int], num_classes: int, uneven_split: bool = False) \ 160 | -> List[Dataset]: 161 | """Split the dataset evenly across all classes 162 | 163 | Args: 164 | dataset: dataset to split 165 | lengths: length for each subset 166 | num_classes: number of classes in dataset 167 | uneven_split: if True, will return len(lengths) + 1 subsets where the size for the last subset may not be the 168 | same for all classes 169 | 170 | Returns: len(lengths) subsets where the ith subset has size lengths[i]; if is_uneven_split = True, will return 171 | len(lengths) + 1 subsets where the size for the last subset may not be the same for all classes. 172 | 173 | """ 174 | # see https://github.com/pytorch/vision/issues/168#issuecomment-319659360 175 | # and https://github.com/pytorch/vision/issues/168#issuecomment-398734285 for detail 176 | total_length = sum(lengths) 177 | 178 | if uneven_split: 179 | assert 0 < total_length <= len(dataset), f"Expecting 0 < length <= {len(dataset)} but get {total_length}" 180 | else: 181 | if len(lengths) <= 1: 182 | return [dataset] 183 | 184 | assert total_length == len(dataset), "Sum of input lengths does not equal the length of the input dataset" 185 | 186 | subset_num_per_class, remainders = np.divmod(lengths, num_classes) 187 | assert np.all(remainders == 0), f"Subset sizes is not divisible by the number of classes ({num_classes})" 188 | 189 | # shape (num_classes, # samples in this class); each row is the indices of samples in that class 190 | class_indices = get_class_indices(dataset) 191 | assert len(class_indices) == num_classes 192 | 193 | # shape (num_classes,); each element is the number of samples in that class 194 | class_lengths = np.asarray([len(class_idx) for class_idx in class_indices]) 195 | 196 | # shape (num_subset, num_classes); each row is the class lengths of this subset 197 | subset_class_lengths = np.tile(subset_num_per_class.reshape(-1, 1), num_classes) 198 | if uneven_split: 199 | # interpolate last subset's class lengths 200 | last_subset_class_lengths = class_lengths - subset_class_lengths.sum(axis=0) 201 | subset_class_lengths = np.vstack((subset_class_lengths, last_subset_class_lengths)) 202 | 203 | return per_class_random_split_helper(dataset=dataset, 204 | class_indices=class_indices, 205 | subset_class_lengths=subset_class_lengths) 206 | 207 | 208 | def per_class_random_split_helper(dataset: Dataset, 209 | class_indices: List[np.ndarray], 210 | subset_class_lengths: np.ndarray) -> List[Dataset]: 211 | """ 212 | 213 | Args: 214 | dataset: 215 | class_indices: shape (num_classes, # samples in this class); 216 | each row is the indices of samples in that class 217 | subset_class_lengths: shape (num_subset, num_classes); 218 | each row is a list of class sizes for that subset 219 | 220 | Returns: 221 | 222 | """ 223 | class_lengths = np.asarray([len(class_idx) for class_idx in class_indices]) 224 | assert np.all(subset_class_lengths.sum(axis=0) == class_lengths) 225 | 226 | num_subsets, num_classes = subset_class_lengths.shape 227 | 228 | subset_indices = [list() for _ in range(num_subsets)] 229 | 230 | for i in range(num_classes): 231 | # ith column contains the class length for each subset 232 | subset_num_per_class = subset_class_lengths[:, i] 233 | 234 | indices = class_indices[i] 235 | np.random.shuffle(indices) 236 | 237 | indices_or_sections = to_indices_or_sections(subset_num_per_class)[:-1] 238 | class_subset_indices = np.split(indices, indices_or_sections) 239 | [subset_indices[i].extend(idx) for i, idx in enumerate(class_subset_indices)] 240 | 241 | # check if index subsets matches the desired lengths 242 | subset_lengths = subset_class_lengths.sum(axis=1) 243 | assert all(len(subset_idx) == subset_lengths[i] for i, subset_idx in enumerate(subset_indices)) 244 | 245 | # check if index subsets in subset_indices are unique 246 | subset_idx_sets = [set(subset_idx) for subset_idx in subset_indices] 247 | assert all(len(idx_set) == len(subset_indices[i]) for i, idx_set in enumerate(subset_idx_sets)) 248 | 249 | # check if index subsets are mutually exclusive 250 | combos = combinations(subset_idx_sets, 2) 251 | assert all(combo[0].isdisjoint(combo[1]) for combo in combos) 252 | 253 | return [Subset(dataset, subset_idx) for subset_idx in subset_indices] 254 | 255 | 256 | def get_data_shape(data_point: Union[np.ndarray, Tensor, Image]): 257 | data_shape = np.asarray(data_point).shape 258 | if isinstance(data_point, Tensor) and len(data_shape) >= 3: 259 | # torch tensor has channel (C, H, W, ...), swap channel to (H, W, ..., C) 260 | data_shape = np.roll(data_shape, -1) 261 | 262 | return data_shape 263 | 264 | 265 | def get_directory_size(dirname: Union[Path, str]) -> int: 266 | return sum(f.stat().st_size for f in Path(dirname).rglob("*") if f.is_file()) 267 | -------------------------------------------------------------------------------- /models/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: torchvision.models.resnet.py 3 | """ 4 | import torch 5 | from torch import nn 6 | from torchvision.models import ResNet as BaseResNet 7 | from torchvision.models.utils import load_state_dict_from_url 8 | 9 | # for type hint 10 | from typing import Union, Tuple 11 | from torch import Tensor 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 15 | 'wide_resnet50_2', 'wide_resnet101_2'] 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 83 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 84 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 85 | # This variant is also known as ResNet V1.5 and improves accuracy according to 86 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 87 | 88 | expansion = 4 89 | 90 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 91 | base_width=64, dilation=1, norm_layer=None): 92 | super(Bottleneck, self).__init__() 93 | if norm_layer is None: 94 | norm_layer = nn.BatchNorm2d 95 | width = int(planes * (base_width / 64.)) * groups 96 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 97 | self.conv1 = conv1x1(inplanes, width) 98 | self.bn1 = norm_layer(width) 99 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 100 | self.bn2 = norm_layer(width) 101 | self.conv3 = conv1x1(width, planes * self.expansion) 102 | self.bn3 = norm_layer(planes * self.expansion) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.downsample = downsample 105 | self.stride = stride 106 | 107 | def forward(self, x): 108 | identity = x 109 | 110 | out = self.conv1(x) 111 | out = self.bn1(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv2(out) 115 | out = self.bn2(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv3(out) 119 | out = self.bn3(out) 120 | 121 | if self.downsample is not None: 122 | identity = self.downsample(x) 123 | 124 | out += identity 125 | out = self.relu(out) 126 | 127 | return out 128 | 129 | 130 | class ResNet(BaseResNet): 131 | def forward(self, x, return_feature: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]: 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | x = self.maxpool(x) 136 | 137 | x = self.layer1(x) 138 | x = self.layer2(x) 139 | x = self.layer3(x) 140 | x = self.layer4(x) 141 | 142 | x = self.avgpool(x) 143 | features = torch.flatten(x, 1) 144 | logits = self.fc(features) 145 | 146 | if return_feature: 147 | return features, logits 148 | else: 149 | return logits 150 | 151 | 152 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 153 | model = ResNet(block, layers, **kwargs) 154 | if pretrained: 155 | state_dict = load_state_dict_from_url(model_urls[arch], 156 | progress=progress) 157 | model.load_state_dict(state_dict) 158 | return model 159 | 160 | 161 | def resnet18(pretrained=False, progress=True, **kwargs): 162 | r"""ResNet-18 model from 163 | `"Deep Residual Learning for Image Recognition" `_ 164 | 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | progress (bool): If True, displays a progress bar of the download to stderr 168 | """ 169 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 170 | **kwargs) 171 | 172 | 173 | def resnet34(pretrained=False, progress=True, **kwargs): 174 | r"""ResNet-34 model from 175 | `"Deep Residual Learning for Image Recognition" `_ 176 | 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | progress (bool): If True, displays a progress bar of the download to stderr 180 | """ 181 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 182 | **kwargs) 183 | 184 | 185 | def resnet50(pretrained=False, progress=True, **kwargs): 186 | r"""ResNet-50 model from 187 | `"Deep Residual Learning for Image Recognition" `_ 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | progress (bool): If True, displays a progress bar of the download to stderr 192 | """ 193 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 194 | **kwargs) 195 | 196 | 197 | def resnet101(pretrained=False, progress=True, **kwargs): 198 | r"""ResNet-101 model from 199 | `"Deep Residual Learning for Image Recognition" `_ 200 | 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | progress (bool): If True, displays a progress bar of the download to stderr 204 | """ 205 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 206 | **kwargs) 207 | 208 | 209 | def resnet152(pretrained=False, progress=True, **kwargs): 210 | r"""ResNet-152 model from 211 | `"Deep Residual Learning for Image Recognition" `_ 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | progress (bool): If True, displays a progress bar of the download to stderr 216 | """ 217 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 218 | **kwargs) 219 | 220 | 221 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 222 | r"""ResNeXt-50 32x4d model from 223 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | progress (bool): If True, displays a progress bar of the download to stderr 228 | """ 229 | kwargs['groups'] = 32 230 | kwargs['width_per_group'] = 4 231 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 232 | pretrained, progress, **kwargs) 233 | 234 | 235 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 236 | r"""ResNeXt-101 32x8d model from 237 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | kwargs['groups'] = 32 244 | kwargs['width_per_group'] = 8 245 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 246 | pretrained, progress, **kwargs) 247 | 248 | 249 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 250 | r"""Wide ResNet-50-2 model from 251 | `"Wide Residual Networks" `_ 252 | 253 | The model is the same as ResNet except for the bottleneck number of channels 254 | which is twice larger in every block. The number of channels in outer 1x1 255 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 256 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | kwargs['width_per_group'] = 64 * 2 263 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 264 | pretrained, progress, **kwargs) 265 | 266 | 267 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 268 | r"""Wide ResNet-101-2 model from 269 | `"Wide Residual Networks" `_ 270 | 271 | The model is the same as ResNet except for the bottleneck number of channels 272 | which is twice larger in every block. The number of channels in outer 1x1 273 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 274 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 275 | 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | progress (bool): If True, displays a progress bar of the download to stderr 279 | """ 280 | kwargs['width_per_group'] = 64 * 2 281 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 282 | pretrained, progress, **kwargs) 283 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /checkpoint_saver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from pathlib import Path 5 | import re 6 | import warnings 7 | 8 | from utils import find_checkpoint_path, find_all_files 9 | from utils.metrics import MetricMode, MetricMonitor 10 | 11 | # for type hint 12 | from typing import Dict, Any, Optional, Union 13 | from simple_estimator import SimPLEEstimator 14 | from utils import Logger 15 | 16 | 17 | class CheckpointSaver: 18 | def __init__(self, 19 | estimator: SimPLEEstimator, 20 | logger: Logger, 21 | checkpoint_metric: str, 22 | best_checkpoint_str: str, 23 | best_checkpoint_pattern: str, 24 | latest_checkpoint_str: str, 25 | latest_checkpoint_pattern: str, 26 | delayed_best_model_saving: bool = True): 27 | """ 28 | 29 | Args: 30 | estimator: Estimator, used to get experiment related data 31 | logger: Logger, used for logging 32 | checkpoint_metric: save model when the best value of this key has changed. 33 | best_checkpoint_str: path str format to save best checkpoint file 34 | best_checkpoint_pattern: regex pattern used to find the best checkpoint file 35 | latest_checkpoint_str: path str format to save best checkpoint file 36 | latest_checkpoint_pattern: regex pattern used to find the latest checkpoint file 37 | delayed_best_model_saving: if True, save best model after calling save_latest_checkpoint() 38 | """ 39 | self.absolute_best_path = "best_checkpoint.pth" 40 | 41 | # metrics to keep track of 42 | self.monitor = MetricMonitor() 43 | self.monitor.track(key="mean_acc", 44 | best_value=-np.inf, 45 | mode=MetricMode.MAX, 46 | prefix="test") 47 | self.monitor.track(key="mean_acc", 48 | best_value=-np.inf, 49 | mode=MetricMode.MAX, 50 | prefix="validation") 51 | 52 | self.checkpoint_metric = checkpoint_metric 53 | 54 | # checkpoint path patterns 55 | self.best_checkpoint_str = best_checkpoint_str 56 | self.best_checkpoint_pattern = re.compile(best_checkpoint_pattern) 57 | 58 | self.latest_checkpoint_str = latest_checkpoint_str 59 | self.latest_checkpoint_pattern = re.compile(latest_checkpoint_pattern) 60 | 61 | # save estimator and logger 62 | # this will recover best metrics and register log hooks 63 | self.estimator = estimator 64 | self.logger = logger 65 | 66 | # assign flags 67 | self.delayed_save_best_model = delayed_best_model_saving 68 | self.is_best_model = False 69 | 70 | @property 71 | def estimator(self) -> SimPLEEstimator: 72 | return self._estimator 73 | 74 | @estimator.setter 75 | def estimator(self, estimator: SimPLEEstimator) -> None: 76 | self._estimator = estimator 77 | 78 | # recover best value 79 | checkpoint_path = self.estimator.exp_args.checkpoint_path 80 | if checkpoint_path is not None: 81 | print(f"Recovering best metrics from {checkpoint_path}...") 82 | self.recover_metrics(torch.load(checkpoint_path, map_location=self.device)) 83 | 84 | @property 85 | def logger(self) -> Logger: 86 | return self._logger 87 | 88 | @logger.setter 89 | def logger(self, logger: Logger) -> None: 90 | self._logger = logger 91 | 92 | # register log hooks 93 | print("Registering log hooks...") 94 | self.logger.register_log_hook(self.update_best_metric, logger=self.logger) 95 | 96 | @property 97 | def checkpoint_metric(self) -> str: 98 | return self._checkpoint_metric 99 | 100 | @checkpoint_metric.setter 101 | def checkpoint_metric(self, checkpoint_metric: str) -> None: 102 | assert checkpoint_metric in self.monitor, f"{checkpoint_metric} is not in metric monitor" 103 | 104 | self._checkpoint_metric = checkpoint_metric 105 | 106 | @property 107 | def log_dir(self) -> str: 108 | return self.estimator.exp_args.log_dir 109 | 110 | @property 111 | def best_full_checkpoint_str(self) -> str: 112 | return str(Path(self.log_dir) / self.best_checkpoint_str) 113 | 114 | @property 115 | def latest_full_checkpoint_str(self) -> str: 116 | return str(Path(self.log_dir) / self.latest_checkpoint_str) 117 | 118 | @property 119 | def device(self) -> torch.device: 120 | return self.estimator.device 121 | 122 | @property 123 | def global_step(self) -> int: 124 | return self.estimator.global_step 125 | 126 | @property 127 | def num_latest_checkpoints_kept(self) -> Optional[int]: 128 | return self.estimator.exp_args.num_latest_checkpoints_kept 129 | 130 | @property 131 | def is_save_latest_checkpoint(self) -> bool: 132 | return self.num_latest_checkpoints_kept is None or self.num_latest_checkpoints_kept > 0 133 | 134 | @property 135 | def is_remove_old_checkpoint(self) -> bool: 136 | return self.num_latest_checkpoints_kept is not None and self.num_latest_checkpoints_kept > 0 137 | 138 | def save_checkpoint(self, 139 | checkpoint: Dict[str, Any], 140 | checkpoint_path: Union[str, Path], 141 | is_logger_save: bool = False) -> Path: 142 | checkpoint_path = str(checkpoint_path) 143 | torch.save(checkpoint, checkpoint_path) 144 | 145 | print(f"Checkpoint saved to \"{checkpoint_path}\"", flush=True) 146 | 147 | if is_logger_save: 148 | self.logger.save(checkpoint_path) 149 | 150 | return Path(checkpoint_path) 151 | 152 | def save_best_checkpoint(self, 153 | checkpoint: Optional[Dict[str, any]] = None, 154 | is_logger_save: bool = False, 155 | **kwargs) -> Path: 156 | if checkpoint is None: 157 | checkpoint = self.get_checkpoint() 158 | 159 | checkpoint_path = self.save_checkpoint(checkpoint_path=self.best_full_checkpoint_str.format(**kwargs), 160 | checkpoint=checkpoint, 161 | is_logger_save=is_logger_save) 162 | # reset flag 163 | self.is_best_model = False 164 | 165 | return checkpoint_path 166 | 167 | def save_latest_checkpoint(self, 168 | checkpoint: Optional[Dict[str, any]] = None, 169 | is_logger_save: bool = False, 170 | **kwargs) -> Optional[Path]: 171 | checkpoint_path: Optional[Path] = None 172 | 173 | if self.is_save_latest_checkpoint: 174 | if checkpoint is None: 175 | checkpoint = self.get_checkpoint() 176 | 177 | # save new checkpoint 178 | checkpoint_path = self.save_checkpoint(checkpoint_path=self.latest_full_checkpoint_str.format(**kwargs), 179 | checkpoint=checkpoint, 180 | is_logger_save=is_logger_save) 181 | 182 | # cleanup old checkpoints 183 | self.cleanup_checkpoints() 184 | 185 | if self.delayed_save_best_model and self.is_best_model: 186 | self.save_best_checkpoint(**kwargs) 187 | 188 | return checkpoint_path 189 | 190 | def get_checkpoint(self) -> Dict[str, Any]: 191 | checkpoint = self.estimator.get_checkpoint() 192 | 193 | # add best metrics 194 | checkpoint.update({"monitor_state": self.monitor.state_dict()}) 195 | 196 | return checkpoint 197 | 198 | def update_best_checkpoint(self) -> None: 199 | """ 200 | Update the logged metrics for the best checkpoint 201 | 202 | Returns: 203 | 204 | """ 205 | best_checkpoint_path = self.find_best_checkpoint_path() 206 | 207 | if best_checkpoint_path is None: 208 | warnings.warn("Cannot find best checkpoint") 209 | return 210 | 211 | best_checkpoint_path = str(best_checkpoint_path) 212 | best_checkpoint = torch.load(best_checkpoint_path, map_location=self.device) 213 | 214 | # update best metrics 215 | best_checkpoint.update({"monitor_state": self.monitor.state_dict()}) 216 | 217 | self.save_checkpoint(checkpoint_path=str(Path(self.log_dir) / self.absolute_best_path), 218 | checkpoint=best_checkpoint) 219 | 220 | def find_best_checkpoint_path(self, checkpoint_dir: Optional[str] = None, ignore_absolute_best: bool = True) \ 221 | -> Optional[Path]: 222 | if checkpoint_dir is None: 223 | checkpoint_dir = self.log_dir 224 | 225 | abs_best_path = Path(checkpoint_dir) / self.absolute_best_path 226 | 227 | if not ignore_absolute_best and abs_best_path.is_file(): 228 | # if not ignoring absolute best path and the path is a file, return the absolute best file path 229 | return abs_best_path 230 | 231 | checkpoint_path = find_checkpoint_path(checkpoint_dir, step_filter=self.best_checkpoint_pattern) 232 | 233 | if checkpoint_path is None: 234 | checkpoint_path = self.find_latest_checkpoint_path(checkpoint_dir=checkpoint_dir) 235 | 236 | return checkpoint_path 237 | 238 | def find_latest_checkpoint_path(self, checkpoint_dir: Optional[str] = None) -> Optional[Path]: 239 | if checkpoint_dir is None: 240 | checkpoint_dir = self.log_dir 241 | 242 | return find_checkpoint_path(checkpoint_dir, step_filter=self.latest_checkpoint_pattern) 243 | 244 | def update_best_metric(self, log_info: Dict[str, Any], logger: Logger) -> None: 245 | updated_dict = self.monitor.update_metrics(log_info) 246 | 247 | for updated_key, new_best_value in updated_dict.items(): 248 | metric_dict = self.monitor[updated_key] 249 | 250 | translated_key = metric_dict["key"] 251 | 252 | # if new_best_value is better than current best value 253 | logger.log({translated_key: new_best_value}, step=self.global_step) 254 | 255 | if self.checkpoint_metric == updated_key: 256 | self.is_best_model = True 257 | 258 | if self.is_best_model and not self.delayed_save_best_model: 259 | # if not delayed_save_best_model save, then save checkpoint 260 | self.save_best_checkpoint(global_step=self.global_step) 261 | 262 | def recover_checkpoint(self, checkpoint: Dict[str, Any], recover_optimizer: bool = True, 263 | recover_train_progress: bool = True) -> None: 264 | self.recover_metrics(checkpoint=checkpoint) 265 | 266 | self.estimator.load_checkpoint(checkpoint=checkpoint, 267 | recover_optimizer=recover_optimizer, 268 | recover_train_progress=recover_train_progress) 269 | 270 | def recover_metrics(self, checkpoint: Dict[str, Any]) -> None: 271 | if "monitor_state" in checkpoint: 272 | monitor_state = checkpoint["monitor_state"] 273 | else: 274 | # for backward compatibility 275 | monitor_state = { 276 | "validation/mean_acc": checkpoint.get("best_val_acc", -np.inf), 277 | "test/mean_acc": checkpoint.get("best_test_acc", -np.inf), 278 | } 279 | 280 | self.monitor.load_state_dict(monitor_state) 281 | 282 | def cleanup_checkpoints(self) -> None: 283 | if not self.is_remove_old_checkpoint: 284 | # do nothing if the model do not save latest checkpoints or if all checkpoints are kept 285 | return 286 | 287 | checkpoint_paths = find_all_files(checkpoint_dir=self.log_dir, 288 | search_pattern=self.latest_checkpoint_pattern) 289 | 290 | # sort by recency (largest step first) 291 | checkpoint_paths.sort(key=lambda x: int(re.search(self.latest_checkpoint_pattern, x.name).group(1)), 292 | reverse=True) 293 | 294 | # remove old checkpoints 295 | for checkpoint_path in checkpoint_paths[self.num_latest_checkpoints_kept:]: 296 | print(f"Removing old checkpoint \"{checkpoint_path}\"", flush=True) 297 | checkpoint_path.unlink() 298 | --------------------------------------------------------------------------------