├── experiments ├── small_CIFAR10_4000.sh ├── small_CIFAR10_1000.sh ├── small_CIFAR10_2000.sh ├── download.sh └── experiment_AppendixB.sh ├── SSL_loss ├── __init__.py └── mixmatch.py ├── src ├── __init__.py ├── __pycache__ │ ├── base.cpython-36.pyc │ ├── base.cpython-37.pyc │ ├── eval.cpython-36.pyc │ ├── eval.cpython-37.pyc │ ├── train.cpython-36.pyc │ ├── train.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── eval.py ├── base.py └── train.py ├── model ├── __init__.py ├── __pycache__ │ ├── ema.cpython-36.pyc │ ├── ema.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── lr_scheduler.cpython-36.pyc │ ├── lr_scheduler.cpython-37.pyc │ ├── wideresnet.cpython-36.pyc │ └── wideresnet.cpython-37.pyc ├── ema.py ├── lr_scheduler.py └── wideresnet.py ├── __pycache__ ├── loss.cpython-36.pyc ├── utils.cpython-36.pyc └── config.cpython-36.pyc ├── data_loader ├── __pycache__ │ ├── cifar.cpython-36.pyc │ ├── cifar.cpython-37.pyc │ ├── loader.cpython-36.pyc │ ├── loader.cpython-37.pyc │ ├── svhn.cpython-36.pyc │ ├── svhn.cpython-37.pyc │ ├── SSL_Dataset.cpython-36.pyc │ ├── SSL_Dataset.cpython-37.pyc │ ├── randaugment.cpython-36.pyc │ ├── randaugment.cpython-37.pyc │ ├── transform.cpython-36.pyc │ └── transform.cpython-37.pyc ├── SSL_Dataset.py ├── svhn.py ├── cifar.py ├── transform.py └── loader.py ├── .gitignore ├── config └── mixmatch │ ├── SVHN │ ├── eval_SVHN_250.json │ ├── eval_SVHN_500.json │ ├── eval_SVHN_1000.json │ ├── eval_SVHN_2000.json │ ├── eval_SVHN_4000.json │ ├── train_SVHN_500.json │ ├── train_SVHN_2000.json │ ├── train_SVHN_250.json │ ├── train_SVHN_4000.json │ └── train_SVHN_1000.json │ └── CIFAR10 │ ├── eval_CIFAR10_250.json │ ├── eval_CIFAR10_500.json │ ├── eval_CIFAR10_1000.json │ ├── eval_CIFAR10_2000.json │ ├── eval_CIFAR10_4000.json │ ├── train_CIFAR10_1000.json │ ├── train_CIFAR10_250.json │ ├── train_CIFAR10_500.json │ ├── train_CIFAR10_2000.json │ └── train_CIFAR10_4000.json ├── main.py ├── config.py ├── utils.py └── README.md /experiments/small_CIFAR10_4000.sh: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SSL_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixmatch import MixMatchLoss -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import Trainer 2 | from .eval import Evaluator 3 | from .base import BaseModel -------------------------------------------------------------------------------- /experiments/small_CIFAR10_1000.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/train_CIFAR10_1000.json -------------------------------------------------------------------------------- /experiments/small_CIFAR10_2000.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py --cfg_path config/train_CIFAR10_2000.json -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .wideresnet import * 2 | from .ema import WeightEMA 3 | from .lr_scheduler import WarmupCosineLrScheduler -------------------------------------------------------------------------------- /__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/ema.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/ema.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ema.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/ema.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/src/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/cifar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/cifar.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/cifar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/cifar.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/loader.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/loader.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/svhn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/svhn.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/svhn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/svhn.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/wideresnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/wideresnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/wideresnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/model/__pycache__/wideresnet.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/SSL_Dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/SSL_Dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/SSL_Dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/SSL_Dataset.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/randaugment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/randaugment.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/randaugment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/randaugment.cpython-37.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/transform.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/transform.cpython-36.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/transform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jeffkang-94/Mixmatch-pytorch-SSL/HEAD/data_loader/__pycache__/transform.cpython-37.pyc -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | results/* 2 | __pycache__/* 3 | data_loader/__pycache__/* 4 | SSL_loss/__pycache__/* 5 | model/__pycache__/* 6 | src/__pycache__/* 7 | data/* 8 | nohup.out 9 | experiments/* 10 | config/* -------------------------------------------------------------------------------- /config/mixmatch/SVHN/eval_SVHN_250.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_250_SVHN_6", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":250, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "ckpt": "ema_best.pth" 16 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/eval_CIFAR10_250.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_250", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":250, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/eval_CIFAR10_500.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_500", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":500, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/eval_SVHN_500.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_500_SVHN", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":500, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "ema_best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/eval_CIFAR10_1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_1000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":1000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/eval_CIFAR10_2000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_2000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":2000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/eval_CIFAR10_4000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_4000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":4000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/eval_SVHN_1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_1000_SVHN_ADAM", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":1000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "best.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/eval_SVHN_2000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_SVHN_2000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":2000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "latest.pth" 17 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/eval_SVHN_4000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "eval", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_4000_SVHN", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":4000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": true, 16 | "ckpt": "ema_best.pth" 17 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from config import * 3 | from utils import * 4 | from src import * 5 | import os 6 | os.environ['CUDA_VISIBLE_DEVICES'] = "1" 7 | def main(): 8 | args = parse_args() 9 | configs = get_configs(args) 10 | if configs.mode == 'train': 11 | MixmatchTrainer = Trainer(configs) 12 | MixmatchTrainer.train() 13 | elif configs.mode == 'eval': 14 | MixmatchEvaluator = Evaluator(configs) 15 | MixmatchEvaluator.evaluate() 16 | else: 17 | raise ValueError ("Invalid mode, ['train', 'eval'] modes are supported") 18 | 19 | if __name__ == '__main__': 20 | main() 21 | -------------------------------------------------------------------------------- /config/mixmatch/SVHN/train_SVHN_500.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_500_SVHN", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":500, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.002, 20 | "optim":"ADAM", 21 | "lambda_u": 250, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "weight_decay": 0.02, 26 | "ema_alpha":0.999, 27 | "seed":"None" 28 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/train_CIFAR10_1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_1000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":1000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": true, 18 | 19 | "lr":0.002, 20 | "optim":"ADAM", 21 | "lambda_u": 75, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "ema_alpha":0.999, 26 | "weight_deacy": 0.0004, 27 | "seed":3114 28 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/train_CIFAR10_250.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_250", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":250, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": true, 18 | 19 | "lr":0.002, 20 | "optim":"ADAM", 21 | "lambda_u": 75, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "ema_alpha":0.999, 26 | "weight_deacy": 0.0004, 27 | "seed":2114 28 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/train_CIFAR10_500.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_500", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":500, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.002, 20 | "optim":"ADAM", 21 | "lambda_u": 75, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "ema_alpha":0.999, 26 | "weight_deacy": 0.0004, 27 | "seed":3114 28 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/train_SVHN_2000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_2000_SVHN", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":2000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.0002, 20 | "optim":"ADAM", 21 | "lambda_u": 250, 22 | "alpha":0.5, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "weight_deacy": 0.0004, 26 | "ema_alpha":0.999, 27 | "seed":"None" 28 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/train_SVHN_250.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_250_SVHN_7", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":250, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "ema_best.pth", 17 | "verbose": false, 18 | 19 | "lr":0.0002, 20 | "optim":"ADAM", 21 | "lambda_u": 250, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "weight_decay": 0.02, 26 | "ema_alpha":0.999, 27 | "seed":478 28 | } -------------------------------------------------------------------------------- /config/mixmatch/SVHN/train_SVHN_4000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_4000_SVHN", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":4000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.0002, 20 | "optim":"ADAM", 21 | "lambda_u": 250, 22 | "alpha":0.5, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "weight_deacy": 0.0004, 26 | "ema_alpha":0.999, 27 | "seed":"None" 28 | } -------------------------------------------------------------------------------- /data_loader/SSL_Dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | 5 | class SSL_Dataset(torch.utils.data.Dataset): 6 | def __init__(self, transform=None, target_transform=None): 7 | self.transform =transform 8 | self.target_transform = target_transform 9 | 10 | def __getitem__(self, index): 11 | raise NotImplementedError 12 | 13 | def __len__(self): 14 | raise NotImplementedError 15 | 16 | def _transpose(self, x, source='NHWC', target='NCHW'): 17 | return x.transpose([source.index(d) for d in target]) 18 | 19 | def _get_PIL(self, x): 20 | return Image.fromarray(x) 21 | 22 | -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/train_CIFAR10_2000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_2000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":2000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.002, 20 | "optim":"ADAM", 21 | "lambda_u": 75, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "ema_alpha":0.999, 26 | "weight_deacy": 0.0004, 27 | "seed":3114 28 | } -------------------------------------------------------------------------------- /config/mixmatch/CIFAR10/train_CIFAR10_4000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method": "Mixmatch", 4 | "name": "Mixmatch_4000", 5 | "dataset": "CIFAR10", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":false, 10 | "num_classes":10, 11 | "num_label":4000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.002, 20 | "optim":"ADAM", 21 | "lambda_u": 75, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "ema_alpha":0.999, 26 | "weight_deacy": 0.0004, 27 | "seed":3114 28 | } 29 | -------------------------------------------------------------------------------- /config/mixmatch/SVHN/train_SVHN_1000.json: -------------------------------------------------------------------------------- 1 | { 2 | "mode": "train", 3 | "method":"Mixmatch", 4 | "name": "Mixmatch_1000_SVHN_ADAM_Large", 5 | "dataset": "SVHN", 6 | "datapath":"./data", 7 | "depth":28, 8 | "width":2, 9 | "large":true, 10 | "num_classes":10, 11 | "num_label":1000, 12 | "batch_size":64, 13 | "epochs":1024, 14 | "save_epoch":10, 15 | "resume": false, 16 | "ckpt": "latest.pth", 17 | "verbose": false, 18 | 19 | "lr":0.0002, 20 | "optim":"ADAM", 21 | "lambda_u": 250, 22 | "alpha":0.75, 23 | "T" : 0.5, 24 | "K" : 2, 25 | "weight_decay": 0.02, 26 | "ema_alpha":0.999, 27 | "seed":"None" 28 | } -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import utils 4 | 5 | def get_configs(args): 6 | with open(args.cfg_path, "r") as f: 7 | configs = json.load(f) 8 | arg_dict = vars(args) 9 | for key in arg_dict : 10 | if key in configs: 11 | if arg_dict[key] is not None: 12 | configs[key] = arg_dict[key] 13 | configs = utils.ConfigMapper(configs) 14 | return configs 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser(description='Semi-supervised Learning') 18 | parser.add_argument('--cfg_path', type=str, default='./config/train.json') 19 | return parser.parse_args() 20 | 21 | return parser.parse_args() -------------------------------------------------------------------------------- /model/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | class WeightEMA(object): 5 | def __init__(self, model, ema_model, wd, alpha=0.999): 6 | self.model = model 7 | self.ema_model = ema_model 8 | self.alpha = alpha 9 | self.params = list(model.state_dict().values()) 10 | self.ema_params = list(ema_model.state_dict().values()) 11 | self.wd = wd 12 | 13 | for param, ema_param in zip(self.params, self.ema_params): 14 | param.data.copy_(ema_param.data) 15 | 16 | def step(self): 17 | one_minus_alpha = 1.0 - self.alpha 18 | for param, ema_param in zip(self.params, self.ema_params): 19 | if ema_param.dtype==torch.float32: 20 | ema_param.mul_(self.alpha) 21 | ema_param.add_(param * one_minus_alpha) 22 | # customized weight decay 23 | param.mul_(1 - self.wd) 24 | 25 | -------------------------------------------------------------------------------- /experiments/download.sh: -------------------------------------------------------------------------------- 1 | 2 | # CIFAR10 3 | for VAR in 1sJaSoNvqiaczxB9e-x5QqJ0Hy0gM_jO3 1y-PkCmIpyXsZpMxCawrYhXxhppf919Jh 1gwytj5uUnpiARkQFesU9XmiAaPGOPVrt 14lERKZdQbNV4mxN4QUXrepPjjBoGruYF 1mYQjfuT-4Wmm8yH-4YnjEEnBk5OeQn3D 4 | do 5 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=$VAR" > /dev/null 6 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=$VAR" -o "$VAR.pth" 7 | done 8 | 9 | mkdir results 10 | label_list=(250 500 1000 2000 4000) 11 | file_list=(1sJaSoNvqiaczxB9e-x5QqJ0Hy0gM_jO3 1y-PkCmIpyXsZpMxCawrYhXxhppf919Jh 1gwytj5uUnpiARkQFesU9XmiAaPGOPVrt 14lERKZdQbNV4mxN4QUXrepPjjBoGruYF 1mYQjfuT-4Wmm8yH-4YnjEEnBk5OeQn3D) 12 | for index in {0..4} 13 | do 14 | filepath="results/Mixmatch_${label_list[index]}/CIFAR10_28-2_${label_list[index]}" 15 | mkdir -p $filepath 16 | mv ${file_list[index]}.pth best.pth 17 | mv best.pth $filepath 18 | done 19 | rm cookie -------------------------------------------------------------------------------- /experiments/experiment_AppendixB.sh: -------------------------------------------------------------------------------- 1 | # This shell file will generate the experimental results of Appendix B: B.1, B.2 2 | 3 | # CIFAR10 4 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/CIFAR10/train_CIFAR10_250.json 5 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/CIFAR10/train_CIFAR10_500.json 6 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/CIFAR10/train_CIFAR10_1000.json 7 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/CIFAR10/train_CIFAR10_2000.json 8 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/CIFAR10/train_CIFAR10_4000.json 9 | 10 | #SVHN 11 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/SVHN/train_SVHN_250.json 12 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/SVHN/train_SVHN_500.json 13 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/SVHN/train_SVHN_1000.json 14 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/SVHN/train_SVHN_2000.json 15 | CUDA_VISIBLE_DEVICES=1 python main.py --cfg_path config/mixmatch/SVHN/train_SVHN_4000.json -------------------------------------------------------------------------------- /data_loader/svhn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data_loader.SSL_Dataset import SSL_Dataset 3 | from PIL import Image 4 | 5 | class SVHN_labeled(SSL_Dataset): 6 | def __init__(self, data, targets, indexes=None, transform=None, target_transform=None): 7 | super(SVHN_labeled, self).__init__(transform=transform, target_transform=target_transform) 8 | self.data = data 9 | self.targets = targets 10 | if indexes is not None: 11 | self.data = self.data[indexes] 12 | self.targets = np.array(self.targets)[indexes] 13 | 14 | def __getitem__(self, index): 15 | """ 16 | Args: 17 | index (int): Index 18 | 19 | Returns: 20 | tuple: (image, target) where target is index of the target class. 21 | """ 22 | img, target = self.data[index], self.targets[index] 23 | img = Image.fromarray(np.transpose(img, (1,2,0))) 24 | if self.transform is not None: 25 | try: 26 | img = self.transform(img) 27 | except: 28 | img = self.transform[0](img) 29 | 30 | if self.target_transform is not None: 31 | target = self.target_transform(target) 32 | 33 | return img, target 34 | 35 | def __len__(self): 36 | leng = len(self.data) 37 | return leng 38 | 39 | class SVHN_unlabeled(SVHN_labeled): 40 | def __init__(self, data, targets, indexes=None, 41 | transform=None, target_transform=None): 42 | super(SVHN_unlabeled, self).__init__(data, targets, indexes, transform=transform, target_transform=target_transform) 43 | self.targets = np.array([-1 for _ in range(len(self.targets))]) 44 | 45 | -------------------------------------------------------------------------------- /data_loader/cifar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from data_loader.SSL_Dataset import SSL_Dataset 3 | from PIL import Image 4 | 5 | class CIFAR_labeled(SSL_Dataset): 6 | def __init__(self, data, targets, indexes=None, transform=None, target_transform=None): 7 | super(CIFAR_labeled, self).__init__(transform=transform, target_transform=target_transform) 8 | self.data = data 9 | self.targets = targets 10 | if indexes is not None: 11 | self.data = self.data[indexes] 12 | self.targets = np.array(self.targets)[indexes] 13 | #self.data = self._transpose(self._normalize(self.data)) 14 | 15 | def __getitem__(self, index): 16 | """ 17 | Args: 18 | index (int): Index 19 | 20 | Returns: 21 | tuple: (image, target) where target is index of the target class. 22 | """ 23 | img, target = self.data[index], self.targets[index] 24 | img = Image.fromarray(img) 25 | if self.transform is not None: 26 | try: 27 | img = self.transform(img) 28 | except: 29 | img = self.transform[0](img) 30 | 31 | if self.target_transform is not None: 32 | target = self.target_transform(target) 33 | 34 | return img, target 35 | 36 | def __len__(self): 37 | leng = len(self.data) 38 | return leng 39 | 40 | class CIFAR_unlabeled(CIFAR_labeled): 41 | def __init__(self, data, targets, indexes=None, 42 | transform=None, target_transform=None): 43 | super(CIFAR_unlabeled, self).__init__(data, targets, indexes, transform=transform, target_transform=target_transform) 44 | self.targets = np.array([-1 for _ in range(len(self.targets))]) 45 | 46 | -------------------------------------------------------------------------------- /data_loader/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def pad(x, border=4): 6 | return np.pad(x, [(0, 0), (border, border), (border, border)], mode='reflect') 7 | 8 | class RandomPadandCrop(object): 9 | """Crop randomly the image. 10 | 11 | Args: 12 | output_size (tuple or int): Desired output size. If int, square crop 13 | is made. 14 | """ 15 | 16 | def __init__(self, output_size, border=4): 17 | assert isinstance(output_size, (int, tuple)) 18 | if isinstance(output_size, int): 19 | self.output_size = (output_size, output_size) 20 | else: 21 | assert len(output_size) == 2 22 | self.output_size = output_size 23 | self.border = border 24 | 25 | def __call__(self, x): 26 | x = pad(x, self.border) 27 | 28 | h, w = x.shape[1:] 29 | new_h, new_w = self.output_size 30 | 31 | top = np.random.randint(0, h - new_h) 32 | left = np.random.randint(0, w - new_w) 33 | 34 | x = x[:, top: top + new_h, left: left + new_w] 35 | 36 | return x 37 | 38 | class RandomFlip(object): 39 | """Flip randomly the image. 40 | """ 41 | def __call__(self, x): 42 | if np.random.rand() < 0.5: 43 | x = x[:, :, ::-1] 44 | 45 | return x.copy() 46 | 47 | class GaussianNoise(object): 48 | """Add gaussian noise to the image. 49 | """ 50 | def __call__(self, x): 51 | c, h, w = x.shape 52 | x += np.random.randn(c, h, w) * 0.15 53 | return x 54 | 55 | class ToTensor(object): 56 | """Transform the image to tensor. 57 | """ 58 | def __call__(self, x): 59 | x = torch.from_numpy(x) 60 | return x 61 | 62 | class Compose(object): 63 | def __init__(self, ops): 64 | self.ops = ops 65 | 66 | def __call__(self, im): 67 | for op in self.ops: 68 | im = op(im) 69 | return im -------------------------------------------------------------------------------- /model/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | import math 4 | 5 | import torch 6 | from torch.optim.lr_scheduler import _LRScheduler, LambdaLR 7 | import numpy as np 8 | 9 | 10 | class WarmupCosineLrScheduler(_LRScheduler): 11 | def __init__( 12 | self, 13 | optimizer, 14 | max_iter, 15 | warmup_iter, 16 | warmup_ratio=5e-4, 17 | warmup='exp', 18 | last_epoch=-1, 19 | ): 20 | self.max_iter = max_iter 21 | self.warmup_iter = warmup_iter 22 | self.warmup_ratio = warmup_ratio 23 | self.warmup = warmup 24 | super(WarmupCosineLrScheduler, self).__init__(optimizer, last_epoch) 25 | 26 | def get_lr(self): 27 | ratio = self.get_lr_ratio() 28 | lrs = [ratio * lr for lr in self.base_lrs] 29 | return lrs 30 | 31 | def get_lr_ratio(self): 32 | if self.last_epoch < self.warmup_iter: 33 | ratio = self.get_warmup_ratio() 34 | else: 35 | real_iter = self.last_epoch - self.warmup_iter 36 | real_max_iter = self.max_iter - self.warmup_iter 37 | ratio = np.cos((7 * np.pi * real_iter) / (16 * real_max_iter)) 38 | return ratio 39 | 40 | def get_warmup_ratio(self): 41 | assert self.warmup in ('linear', 'exp') 42 | alpha = self.last_epoch / self.warmup_iter 43 | if self.warmup == 'linear': 44 | ratio = self.warmup_ratio + (1 - self.warmup_ratio) * alpha 45 | elif self.warmup == 'exp': 46 | ratio = self.warmup_ratio ** (1. - alpha) 47 | return ratio 48 | 49 | 50 | if __name__ == "__main__": 51 | ''' 52 | For testing 53 | ''' 54 | model = torch.nn.Conv2d(3, 16, 3, 1, 1) 55 | optim = torch.optim.SGD(model.parameters(), lr=1e-3) 56 | 57 | max_iter = 20000 58 | lr_scheduler = WarmupCosineLrScheduler(optim, max_iter, 0) 59 | 60 | lrs = [] 61 | for _ in range(max_iter): 62 | lr = lr_scheduler.get_lr()[0] 63 | lrs.append(lr) 64 | lr_scheduler.step() 65 | import matplotlib 66 | import matplotlib.pyplot as plt 67 | import numpy as np 68 | lrs = np.array(lrs) 69 | n_lrs = len(lrs) 70 | plt.plot(np.arange(n_lrs), lrs) 71 | plt.title('CosineLrScheduler') 72 | plt.grid() 73 | plt.show() 74 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | import os 4 | from model import * 5 | import torch.utils.data as data 6 | from data_loader.loader import get_test_data 7 | from tqdm import tqdm 8 | from src.base import BaseModel 9 | from utils import * 10 | import torchvision.transforms as transforms 11 | # Evaluate the performance with ema-model 12 | class Evaluator(BaseModel): 13 | def __init__(self, configs): 14 | BaseModel.__init__(self, configs) 15 | self.ema_model = WRN(num_classes=configs.num_classes, 16 | depth=configs.depth, 17 | width=configs.width, 18 | large=configs.large).to(self.device) 19 | for param in self.ema_model.parameters(): 20 | param.detach_() 21 | 22 | _, transform_val = get_transform(configs.method, configs.dataset) 23 | test_set = get_test_data(self.configs.datapath, self.configs.dataset, transform_val) 24 | self.test_loader = data.DataLoader(test_set, batch_size=configs.batch_size, shuffle=False, num_workers=0, drop_last=False) 25 | self.eval_criterion = nn.CrossEntropyLoss().to(self.device) 26 | 27 | ckpt_path = os.path.join(self.out_dir, self.configs.ckpt) 28 | self._load_checkpoint(ckpt_path) 29 | 30 | def evaluate(self): 31 | 32 | loss_ema_meter = AverageMeter() 33 | top1_ema_meter = AverageMeter() 34 | top5_ema_meter = AverageMeter() 35 | 36 | with torch.no_grad(): 37 | tq = tqdm(self.test_loader, total=self.test_loader.__len__(), leave=False) 38 | for x,y in tq: 39 | x, y = x.to(self.device), y.to(self.device) 40 | 41 | logits_ema, _ = self.ema_model(x) 42 | loss_ema = self.eval_criterion(logits_ema, y) 43 | prob_ema = torch.softmax(logits_ema, dim=1) 44 | top1_ema, top5_ema = accuracy(prob_ema, y, (1,5)) 45 | 46 | loss_ema_meter.update(loss_ema.item()) 47 | top1_ema_meter.update(top1_ema.item()) 48 | top5_ema_meter.update(top5_ema.item()) 49 | tq.set_description("[{}] Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format("EMA", top1_ema_meter.avg, top5_ema_meter.avg, loss_ema_meter.avg)) 50 | self.logger.info(" [{}] Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format("EMA ", top1_ema_meter.avg, top5_ema_meter.avg, loss_ema_meter.avg)) -------------------------------------------------------------------------------- /src/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from abc import ABC, abstractmethod 4 | from utils import create_logger 5 | 6 | class BaseModel(ABC): 7 | def __init__(self, configs): 8 | self.configs = configs 9 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | logger, writer, out_dir = create_logger(self.configs) 11 | self.logger = logger 12 | self.writer = writer 13 | self.out_dir = out_dir 14 | 15 | def validation(self): 16 | """ 17 | Validating the model while training 18 | """ 19 | pass 20 | 21 | def train(self): 22 | """ 23 | Training the model 24 | """ 25 | pass 26 | 27 | def evaluate(self): 28 | """ 29 | Evaluating the model 30 | """ 31 | pass 32 | 33 | def _load_checkpoint(self, ckpt_path): 34 | checkpoint = torch.load(ckpt_path) 35 | self.logger.info(f" Loading the checkpoint from {ckpt_path}") 36 | for key in checkpoint: 37 | try: 38 | getattr(self, key).load_state_dict(checkpoint[key]) 39 | except: 40 | setattr(self, key, checkpoint[key]) 41 | self.logger.info(f" Loading Done.. {ckpt_path}") 42 | self.logger.info(" top1_ema_val : {:.2f}".format(self.top1_ema_val)) 43 | self.logger.info(" Training resumes from epoch {0}".format(checkpoint['epoch'])) 44 | return checkpoint['epoch'] 45 | 46 | def _save_checkpoint(self, epoch, is_best_val=False, ema_is_best_val=False): 47 | # latest checkpoint 48 | model = getattr(self, 'model') 49 | ema_model = getattr(self, 'ema_model') 50 | optimizer = getattr(self, 'optimizer') 51 | top1_val = getattr(self, 'top1_val') 52 | top1_ema_val = getattr(self, 'top1_ema_val') 53 | 54 | 55 | checkpoint = dict() 56 | checkpoint['model'] = model.state_dict() 57 | checkpoint['ema_model'] = ema_model.state_dict() 58 | checkpoint['optimizer'] = optimizer.state_dict() 59 | checkpoint['top1_val'] = top1_val 60 | checkpoint['top1_ema_val'] = top1_ema_val 61 | checkpoint['epoch'] = epoch + 1 62 | if is_best_val: 63 | torch.save(checkpoint, os.path.join(self.out_dir, 'best.pth')) 64 | self.logger.info(" Best saving ... ") 65 | else: 66 | torch.save(checkpoint, os.path.join(self.out_dir, 'latest.pth')) 67 | 68 | if ema_is_best_val: 69 | torch.save(checkpoint, os.path.join(self.out_dir, 'ema_best.pth')) 70 | self.logger.info(" EMA Best saving ... ") 71 | else: 72 | torch.save(checkpoint, os.path.join(self.out_dir, 'ema_latest.pth')) -------------------------------------------------------------------------------- /SSL_loss/mixmatch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from utils import * 6 | 7 | class MixMatchLoss(nn.Module): 8 | def __init__(self, configs, device): 9 | super().__init__() 10 | self.K = configs.K 11 | self.T = configs.T 12 | self.bt = configs.batch_size 13 | self.epochs = configs.epochs 14 | self.lambda_u = configs.lambda_u 15 | self.beta_dist = torch.distributions.beta.Beta(configs.alpha, configs.alpha) 16 | self.num_classes = configs.num_classes 17 | self.device = device 18 | 19 | def sharpen(self, y): 20 | y = y.pow(1/self.T) 21 | return y / y.sum(dim=1, keepdim=True) 22 | 23 | def cal_loss(self, logit_x, y, logit_u_x, y_hat, epoch, lambda_u): 24 | """ 25 | :param logit_x : f(x) 26 | :param y : true target of x 27 | :param logit_u_x : f(u_x) 28 | :param y_hat : guessed label of u_x 29 | :param epoch : current epoch 30 | :param lambda_u : linearly increase the weight from 0 to lambda_u 31 | :return : CE loss of x, mse loss of (f(u_x), y_hat), weight of u_x 32 | """ 33 | probs_u = torch.softmax(logit_u_x, dim=1) # score of u_x 34 | loss_x = -torch.mean(torch.sum(F.log_softmax(logit_x, dim=1) * y, dim=1)) # Cross entropy 35 | loss_u_x = F.mse_loss(probs_u, y_hat) # MSE loss 36 | linear_weight = float(np.clip(epoch / self.epochs, 0.0, 1.0)) # linearly ramp up the contribution of unlabeled set 37 | 38 | return loss_x, loss_u_x, lambda_u * linear_weight 39 | 40 | def mixup(self, all_inputs, all_targets): 41 | lam = self.beta_dist.sample().item() 42 | lam = max(lam, 1-lam) 43 | idx = torch.randperm(all_inputs.size(0)) 44 | 45 | input_a, input_b = all_inputs, all_inputs[idx] 46 | target_a, target_b = all_targets, all_targets[idx] 47 | mixed_input = lam * input_a + (1 - lam) * input_b 48 | mixed_target = lam * target_a + (1 - lam) * target_b 49 | mixed_input = list(torch.split(mixed_input, self.bt)) 50 | mixed_input = mixmatch_interleave(mixed_input, self.bt) 51 | return mixed_input, mixed_target 52 | 53 | 54 | 55 | def forward(self, input): 56 | x = input['x'] 57 | y = input['y'] 58 | u_x = [x for x in input['u_x']] 59 | current = input['current'] 60 | model = input['model'] 61 | 62 | # make onehot label 63 | y = F.one_hot(y, self.num_classes) 64 | x,y = x.to(self.device), y.to(self.device) 65 | u_x = [i.to(self.device) for i in u_x] 66 | 67 | with torch.no_grad(): 68 | y_hat = sum([model(k)[0].softmax(1) for k in u_x]) / self.K 69 | y_hat = self.sharpen(y_hat) 70 | y_hat.detach_() 71 | 72 | # mixup 73 | all_inputs = torch.cat([x]+u_x, dim=0) 74 | all_targets = torch.cat([y] +[y_hat]*self.K, dim=0) 75 | mixed_input, mixed_target = self.mixup(all_inputs, all_targets) 76 | 77 | logit = [model(mixed_input[i])[0] for i in range(len(mixed_input))] 78 | logits = mixmatch_interleave(logit, self.bt) 79 | logits_x = logits[0] 80 | logits_u = torch.cat(logits[1:], dim=0) 81 | loss_x, loss_u, w = self.cal_loss(logits_x, mixed_target[:self.bt], logits_u, mixed_target[self.bt:], current, self.lambda_u) 82 | return loss_x, loss_u, w 83 | 84 | 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /model/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activate_before_residual=False): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001) 11 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001) 15 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | self.activate_before_residual = activate_before_residual 23 | def forward(self, x): 24 | if not self.equalInOut and self.activate_before_residual == True: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activate_before_residual=False): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate, activate_before_residual): 39 | layers = [] 40 | for i in range(int(nb_layers)): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, activate_before_residual)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WRN(nn.Module): 47 | def __init__(self, num_classes, depth=28, width=2, dropRate=0.0, large=False): 48 | super(WRN, self).__init__() 49 | if large: 50 | nChannels = [16, 135, 135*width, 270*width] 51 | else: 52 | nChannels = [16, 16*width, 32*width, 64*width] 53 | assert((depth - 4) % 6 == 0) 54 | n = (depth - 4) / 6 55 | block = BasicBlock 56 | # 1st conv before any network block 57 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 58 | padding=1, bias=False) 59 | # 1st block 60 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activate_before_residual=True) 61 | # 2nd block 62 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 63 | # 3rd block 64 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 65 | # global average pooling and classifier 66 | self.bn1 = nn.BatchNorm2d(nChannels[3], momentum=0.001) 67 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 68 | self.fc = nn.Linear(nChannels[3], num_classes) 69 | self.nChannels = nChannels[3] 70 | 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 74 | m.weight.data.normal_(0, math.sqrt(2. / n)) 75 | elif isinstance(m, nn.BatchNorm2d): 76 | m.weight.data.fill_(1) 77 | m.bias.data.zero_() 78 | elif isinstance(m, nn.Linear): 79 | nn.init.xavier_normal_(m.weight.data) 80 | m.bias.data.zero_() 81 | 82 | def forward(self, x): 83 | out = self.conv1(x) 84 | out = self.block1(out) 85 | out = self.block2(out) 86 | out = self.block3(out) 87 | out = self.relu(self.bn1(out)) 88 | out = F.avg_pool2d(out, 8) 89 | out = out.view(-1, self.nChannels) 90 | return self.fc(out), out -------------------------------------------------------------------------------- /data_loader/loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision 3 | from data_loader.cifar import CIFAR_labeled, CIFAR_unlabeled 4 | from data_loader.svhn import SVHN_labeled, SVHN_unlabeled 5 | 6 | class Augmentation: 7 | def __init__(self, K, transform): 8 | self.transform = transform 9 | self.K = K 10 | 11 | def __call__(self, x): 12 | # Applying stochastic augmentation K times 13 | out = [self.transform(x) for _ in range(self.K)] 14 | return out 15 | 16 | def train_val_split(labels, n_labeled, num_class, num_val): 17 | # depends on the seed, which means the performance 18 | # might be changed if you use different seed value. 19 | n_labeled_per_class = int(n_labeled/num_class) 20 | labels = np.array(labels) 21 | train_labeled_idxs = [] 22 | train_unlabeled_idxs = [] 23 | val_idxs = [] 24 | 25 | for i in range(num_class): 26 | idxs = np.where(labels == i)[0] 27 | np.random.shuffle(idxs) 28 | train_labeled_idxs.extend(idxs[:n_labeled_per_class]) 29 | if num_val == 0: 30 | train_unlabeled_idxs.extend(idxs[n_labeled_per_class:]) 31 | else: 32 | train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-num_val]) 33 | val_idxs.extend(idxs[-num_val:]) 34 | np.random.shuffle(train_labeled_idxs) 35 | np.random.shuffle(train_unlabeled_idxs) 36 | if not num_val == 0: 37 | np.random.shuffle(val_idxs) 38 | 39 | return train_labeled_idxs, train_unlabeled_idxs, val_idxs 40 | 41 | def get_trainval_data(root, method, dataset, K, n_labeled, num_class, 42 | transform_train=None, transform_val=None, 43 | download=True): 44 | if dataset=='CIFAR10': 45 | """ 46 | TOTAL : 32x32 RGB 60,000 images 47 | 6,000 images, 10 classes. 48 | Training : 50,000 images, 5000 images per each class 49 | Test : 10,000 images, 1000 images per each class 50 | """ 51 | print("Dataset : CIFAR10") 52 | base_dataset = torchvision.datasets.CIFAR10(root, train=True, download=download) 53 | num_val = 5000 // num_class 54 | 55 | train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, n_labeled, num_class, num_val) 56 | train_labeled_dataset = CIFAR_labeled(base_dataset.data ,base_dataset.targets, train_labeled_idxs, transform=transform_train) 57 | train_unlabeled_dataset = CIFAR_unlabeled(base_dataset.data , base_dataset.targets, train_unlabeled_idxs, transform=Augmentation(K,transform_train)) 58 | val_dataset = CIFAR_labeled(base_dataset.data, base_dataset.targets, val_idxs, transform=transform_val) 59 | 60 | elif dataset=='CIFAR100': 61 | """ 62 | TOTAL : 32x32 RGB 60,000 images 63 | 600 images, 100 classes. 64 | Training : 50,000 images, 500 images per each class 65 | Test : 10,000 images, 100 images per each class 66 | """ 67 | print("Dataset : CIFAR100") 68 | base_dataset = torchvision.datasets.CIFAR100(root, train=True, download=download) 69 | num_val = 5000 // num_class 70 | 71 | train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.targets, n_labeled, num_class, num_val) 72 | train_labeled_dataset = CIFAR_labeled(base_dataset.data ,base_dataset.targets, train_labeled_idxs, transform=transform_train) 73 | train_unlabeled_dataset = CIFAR_unlabeled(base_dataset.data , base_dataset.targets, train_unlabeled_idxs, transform=Augmentation(K,transform_train)) 74 | val_dataset = CIFAR_labeled(base_dataset.data, base_dataset.targets, val_idxs, transform=transform_val) 75 | 76 | 77 | elif dataset =='SVHN': 78 | """ 79 | TOTAL : 32x32 RGB 99,289 + (531,131 extra) images, 10 classes. 80 | Training : 73,257 images (531,131 extra) 81 | Test : 26,032 images 82 | """ 83 | print("Dataset : SVHN") 84 | base_dataset = torchvision.datasets.SVHN(root, split='train', download=download) 85 | 86 | #num_val = 5000 // num_class 87 | num_val = 0 88 | train_labeled_idxs, train_unlabeled_idxs, val_idxs = train_val_split(base_dataset.labels, n_labeled, num_class, num_val) 89 | train_labeled_dataset = SVHN_labeled(base_dataset.data ,base_dataset.labels, train_labeled_idxs, transform=transform_train) 90 | train_unlabeled_dataset = SVHN_unlabeled(base_dataset.data , base_dataset.labels, train_unlabeled_idxs, transform=Augmentation(K,transform_train)) 91 | val_dataset = get_test_data(root, dataset, transform_val) 92 | elif dataset =='STL10': 93 | """ 94 | TOTAL : 96x96 RGB 113,000 images, 10 classes. 95 | Training : 500 images per each class, 10 classes(pre-defined 10 folds). + 100,000 images without label 96 | Test : 800 images per each class, 10 classes 97 | """ 98 | print("Dataset : STL10") 99 | base_dataset = torchvision.datasets.STL10(root, split='train', download=download) 100 | num_val = 5000 // num_class 101 | raise NotImplementedError 102 | 103 | print (f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} #Val: {len(val_idxs)}") 104 | return train_labeled_dataset, train_unlabeled_dataset, val_dataset 105 | 106 | def get_test_data(root, dataset, transform_val): 107 | if dataset=='CIFAR10': 108 | base_dataset = torchvision.datasets.CIFAR10(root, train=False, download=True) 109 | test_dataset = CIFAR_labeled(base_dataset.data, base_dataset.targets, None, transform=transform_val) 110 | elif dataset=='CIFAR100': 111 | base_dataset = torchvision.datasets.CIFAR100(root, train=False, download=True) 112 | test_dataset = CIFAR_labeled(base_dataset.data, base_dataset.targets, None, transform=transform_val) 113 | elif dataset=='SVHN': 114 | base_dataset = torchvision.datasets.SVHN(root, split='test', download=True) 115 | test_dataset = SVHN_labeled(base_dataset.data, base_dataset.labels, None, transform=transform_val) 116 | elif dataset=='STL10': 117 | base_dataset = torchvision.datasets.STL10(root, split='test', download=True) 118 | raise NotImplementedError 119 | return test_dataset 120 | 121 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os,sys 3 | import torch 4 | from tensorboardX import SummaryWriter 5 | import data_loader.transform as T 6 | from data_loader.randaugment import RandAugmentMC as RandomAugment 7 | import torchvision.transforms as transforms 8 | from torch.optim.lr_scheduler import LambdaLR 9 | import math 10 | class ConfigMapper(object): 11 | def __init__(self, args): 12 | for key in args: 13 | self.__dict__[key] = args[key] 14 | 15 | class AverageMeter(object): 16 | """ 17 | Computes and stores the average and current value 18 | 19 | """ 20 | 21 | def __init__(self): 22 | self.reset() 23 | 24 | def reset(self): 25 | self.val = 0 26 | self.avg = 0 27 | self.sum = 0 28 | self.count = 0 29 | 30 | def update(self, val, n=1): 31 | self.val = val 32 | self.sum += val * n 33 | self.count += n 34 | self.avg = self.sum / self.count 35 | 36 | def get_cosine_schedule_with_warmup(optimizer, 37 | num_warmup_steps, 38 | num_training_steps, 39 | num_cycles=7./16., 40 | last_epoch=-1): 41 | def _lr_lambda(current_step): 42 | if current_step < num_warmup_steps: 43 | return float(current_step) / float(max(1, num_warmup_steps)) 44 | no_progress = float(current_step - num_warmup_steps) / \ 45 | float(max(1, num_training_steps - num_warmup_steps)) 46 | return max(0., math.cos(math.pi * num_cycles * no_progress)) 47 | 48 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 49 | 50 | def get_normalize(_dataset): 51 | if _dataset == 'CIFAR10': 52 | return (0.4914, 0.4822, 0.4465),(0.2471, 0.2435, 0.2616) 53 | elif _dataset =='CIFAR100': 54 | return (0.5071, 0.4867, 0.4408),(0.2675, 0.2565, 0.2761) 55 | elif _dataset =='SVHN': 56 | return (0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970) 57 | elif _dataset =='STL10': 58 | return (0.4409, 0.4279, 0.3868), (0.2683, 0.2611, 0.2687) 59 | else: 60 | raise NotImplementedError 61 | 62 | def get_mixmatch_transform(_dataset): 63 | mean, std = get_normalize(_dataset) 64 | if _dataset=='CIFAR10' or _dataset=='CIFAR100': 65 | train_transform = transforms.Compose([ 66 | transforms.RandomHorizontalFlip(p=0.5), 67 | transforms.RandomCrop(size=32, 68 | padding=int(32*0.125), 69 | padding_mode='reflect'), 70 | transforms.ToTensor(), 71 | transforms.Normalize(mean,std) 72 | ]) 73 | test_transform = transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize(mean,std) 76 | ]) 77 | elif _dataset == 'SVHN': 78 | train_transform = transforms.Compose([ 79 | transforms.RandomCrop(size=32, 80 | padding=4, 81 | padding_mode='reflect'), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean,std) 84 | ]) 85 | test_transform = transforms.Compose([ 86 | transforms.ToTensor(), 87 | transforms.Normalize(mean,std) 88 | ]) 89 | elif _dataset =='STL10': 90 | train_transform = transforms.Compose([ 91 | transforms.RandomCrop(96, padding=int(96*0.125), padding_mode='reflect'), 92 | transforms.RandomHorizontalFlip(), 93 | transforms.ToTensor(), 94 | transforms.Normalize(mean,std) 95 | ]) 96 | test_transform = transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Normalize(mean, std) 99 | ]) 100 | else: 101 | raise NotImplementedError 102 | 103 | return train_transform, test_transform 104 | 105 | def get_transform(method, _dataset): 106 | if method == 'Mixmatch': 107 | return get_mixmatch_transform(_dataset) 108 | else: 109 | raise NotImplementedError 110 | 111 | def mixmatch_interleave_offsets(batch, nu): 112 | groups = [batch // (nu + 1)] * (nu + 1) 113 | for x in range(batch - sum(groups)): 114 | groups[-x - 1] += 1 115 | offsets = [0] 116 | for g in groups: 117 | offsets.append(offsets[-1] + g) 118 | assert offsets[-1] == batch 119 | return offsets 120 | 121 | 122 | def mixmatch_interleave(xy, batch): 123 | nu = len(xy) - 1 124 | offsets = mixmatch_interleave_offsets(batch, nu) 125 | xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy] 126 | for i in range(1, nu + 1): 127 | xy[0][i], xy[i][i] = xy[i][i], xy[0][i] 128 | return [torch.cat(v, dim=0) for v in xy] 129 | 130 | 131 | def accuracy(output, target, topk=(1,)): 132 | """Computes the precision@k for the specified values of k""" 133 | maxk = max(topk) 134 | batch_size = target.size(0) 135 | 136 | _, pred = output.topk(maxk, 1, largest=True, sorted=True) 137 | pred = pred.t() 138 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 139 | 140 | res = [] 141 | for k in topk: 142 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 143 | res.append(correct_k.mul_(100.0 / batch_size)) 144 | return res 145 | 146 | def create_logger(configs): 147 | result_dir = os.path.join("results", configs.name) # "results/MixMatch" 148 | os.makedirs(result_dir, exist_ok=True) 149 | out_dir = os.path.join(result_dir, str(configs.dataset) + '_' + str(configs.depth) + '-' +str(configs.width) + '_' + str(configs.num_label)) 150 | os.makedirs(out_dir, exist_ok=True) 151 | log_dir = os.path.join(result_dir, "log") 152 | os.makedirs(log_dir, exist_ok=True) 153 | writer = SummaryWriter(log_dir=log_dir) 154 | 155 | log_file = '{}.log'.format(configs.name) 156 | final_log_file = os.path.join(out_dir, log_file) 157 | head = '%(asctime)-15s %(message)s' 158 | logging.basicConfig(filename=str(final_log_file), 159 | format=head) 160 | logger = logging.getLogger() 161 | logger.setLevel(logging.INFO) 162 | 163 | console_handler = logging.StreamHandler(sys.stdout) 164 | console_handler.setFormatter(logging.Formatter(head)) 165 | logger.addHandler(console_handler) 166 | 167 | if configs.mode =='train': 168 | logger.info(f" Desc = PyTorch Implementation of MixMatch") 169 | logger.info(f" Task = {configs.dataset}@{configs.num_label}") 170 | logger.info(f" Model = WideResNet {configs.depth}x{configs.width}") 171 | logger.info(f" large model = {configs.large}") 172 | logger.info(f" Batch size = {configs.batch_size}") 173 | logger.info(f" Epoch = {configs.epochs}") 174 | logger.info(f" Optim = {configs.optim}") 175 | logger.info(f" lambda_u = {configs.lambda_u}") 176 | logger.info(f" alpha = {configs.alpha}") 177 | logger.info(f" T = {configs.T}") 178 | logger.info(f" K = {configs.K}") 179 | return logger, writer, out_dir 180 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-MixMatch - A Holistic Approach to Semi-Supervised Learning 2 | 3 | :warning: Unofficial reproduced code for **[MixMatch](https://arxiv.org/pdf/1905.02249.pdf)**. 4 | This repository covers a variety of dataset e.g., CIFAR-10, CIFAR-100, STL-10, MiniImageNet, etc. 5 | 6 | ## :hammer: Setup 7 | 8 | ### Dependency 9 | 10 | ``` 11 | pytorch > 1.0 12 | torchvision > 0.5 13 | tqdm 14 | tensorboardX > 2.0 15 | ``` 16 | 17 | ### Dataset 18 | 19 | You have to specify the datapath, for example, `data` folder in this codebase. 20 | `torchvision` will automatically download the corresponding dataset(e.g., [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html), [SVHN](http://ufldl.stanford.edu/housenumbers/),[STL10](https://cs.stanford.edu/~acoates/stl10/)) under `data` folder if `download=True`. 21 | Or you also can directly download the datasets under your datapath and use a symbolic link instead as below. 22 | 23 | ```bash 24 | mkdir data 25 | ln -s ${datapath} data 26 | ``` 27 | 28 | 29 | ## :rainbow: Training 30 | 31 | We maintain the code with several configuration files. 32 | To train MixMatch model, just follow the below command with a configuration file. 33 | 34 | ```bash 35 | python main.py --cfg_path config/${method}/${dataset}/${config_name} 36 | ``` 37 | 38 | If you want to train the model on background, refer to the below command. Plus, we recommend you to use `verbose : false` in the configuration file. 39 | 40 | ```bash 41 | nohup python main.py --cfg_path config/${method}/${dataset}/${config_name} & 42 | ``` 43 | 44 | Training configurations are located under `config` folder. You can tune the each parameter. 45 | Plus, `experiments` folder includes the shell files to reproduce the results introduced in the paper. 46 | MixMatch has 4 primary parameter: `lambda_u, alpha, T` and ` K`. (See 3.5 section of [MixMatch](https://arxiv.org/pdf/1905.02249.pdf)) 47 | The original paper fixes the `T` and `K` as `0.5` and `2`, respectively. 48 | The authors vary the value of `lambda_u` and `alpha` depending on the type of dataset. 49 | CIFAR-10, for instance, `lambda_u=75` and `alpha=0.5` are used. 50 | Specifically, they mentioned that *`lambda_u=100` and `alpha=0.75` are good starting points for tunning*. 51 | For those who want to use a custom dataset, you can refer to that mention. 52 | This is an example configuration for CIFAR-10 dataset. 53 | 54 | ```python 55 | { 56 | "mode": "train", # mode [train/eval] 57 | "method":"Mixmatch", # type of SSL method [Mixmatch] 58 | "name": "Experiment1", # name of trial 59 | "dataset": "CIFAR10", # dataset [CIFAR10, CIFAR100, STL-10, SVHN] 60 | "datapath":"./data", # datapath 61 | "depth":28, # ResNet depth 62 | "width":2, # ResNet width 63 | "large":false, # flag of using large model(i.e., 135 filter size) 64 | "num_classes":10, # Number of class, e.g., CIFAR-10 : 10 65 | "num_label":250, # The number of available label [250, 1000, 4000] 66 | "batch_size":64, # batch size 67 | "epochs":1024, # epoch 68 | "save_epoch":10, # interval of saving checkpoint 69 | "resume": false, # resuming the training 70 | "ckpt": "latest.pth", # checkpoint name 71 | "verbose": false, # If True, print training log on the console 72 | 73 | /* Training Configuration */ 74 | "lr":0.002, 75 | "lambda_u": 75, 76 | "optim":"ADAM", # type of optimizer [Adam, SGD] 77 | "alpha":0.75, 78 | "T" : 0.5, # fixed across all experiments, but you can adjust it 79 | "K" : 2, # fixed across all experiments, but you can adjust it 80 | "ema_alpha":0.999, 81 | "seed":2114 # Different seed yields different result 82 | } 83 | ``` 84 | 85 | - `lambda_u` : A hyper-parameter weighting the contribution of the unlabeled examples to the training loss 86 | - `alpha` : Hyperparameter for the Beta distribution used in MixU 87 | - `T` : Temperature parameter for sharpening used in MixMatch 88 | - `K` : Number of augmentations used when guessing labels in MixMatch 89 | - `seed` : A number to initialize the random sampling. The results might be changed if you use different seed since it leads to different sampling strategy. 90 | 91 | ### Training Example 92 | 93 | Training MixMatch on WideResNet28x2 using a CIFAR10 with 250 labeled data 94 | 95 | > python main.py --cfg_path config/mixmatch/CIFAR10/train_CIFAR10_250.json 96 | 97 | ### Evaluation Example 98 | 99 | Evaluating MixMatch on WideResNet28x2 using a CIFAR10 with 250 labeled data 100 | 101 | > python main.py --cfg_path config/mixmatch/CIFAR10/eval_CIFAR10_250.json 102 | 103 | ## :gift: Pre-trained model 104 | 105 | We provide the pre-trained model of CIFAR10 dataset. You can easily download the checkpoint files using below commands. 106 | This shell file will automatically download the files and organize them to the desired path. The default result directory is `results`. 107 | For those who cannot download the files using shell file, access the [link](https://drive.google.com/drive/folders/1Fjh-9aSvhAVYrxxXkxnrtW5s6yrprjRs?usp=sharing) directly. 108 | In the case of downloading the file directly, plz modify the `"ckpt": $checkpoint_name` in the configuration file. For instance, `"ckpt": Mixmatch_250.pth`. 109 | 110 | ``` 111 | bash experiments/download.sh 112 | python main.py --cfg_path config/mixmatch/CIFAR10/eval_CIFAR10_250.json 113 | python main.py --cfg_path config/mixmatch/CIFAR10/eval_CIFAR10_500.json 114 | python main.py --cfg_path config/mixmatch/CIFAR10/eval_CIFAR10_1000.json 115 | python main.py --cfg_path config/mixmatch/CIFAR10/eval_CIFAR10_2000.json 116 | python main.py --cfg_path config/mixmatch/CIFAR10/eval_CIFAR10_4000.json 117 | ``` 118 | ## :link: Experiments 119 | 120 | ### Table 121 | 122 | **CIFAR-10** | 250 | 500 | 1000 | 2000 | 4000 | 123 | | :-----:| :-----:| :-----:| :-----:| :-----:| :-----:| 124 | #Paper | 88.92±0.87 | 90.35±0.94 | 92.25±0.32 | 92.97±0.15 | 93.76±0.06 | 125 | **Repo #Shallow** | 88.53 | 88.60 | 90.72 | 93.10 | 93.27 | 126 | 127 | **SVHN** | 250 | 500 | 1000 | 2000 | 4000 | 128 | | :-----:| :-----:| :-----:| :-----:| :-----:| :-----:| 129 | #Paper | 96.22±0.87 | 96.36±0.94 | 96.73±0.32 | 96.96±0.15 | 97.11±0.06 | 130 | **Repo #Shallow** | 94.10 | 94.27 | 94.52 | 95.11 | 96.08 | 131 | 132 | ### Training log 133 | 134 | We provide a board to monitor log values. 135 | Follow the below commands to view the progress. 136 | 137 | ```bash 138 | cd results/${name} 139 | tensorboard --logdir=log/ --bind_all 140 | ``` 141 | 142 | ## Reference 143 | 144 | - YU1ut [MixMatch-pytorch](https://github.com/YU1ut/MixMatch-pytorch) 145 | - perrying [realistic-ssl-evaluation-pytorch](https://github.com/perrying/realistic-ssl-evaluation-pytorch) 146 | - google-research [mixmatch](https://github.com/google-research/mixmatch) 147 | 148 | 149 | ``` 150 | @article{berthelot2019mixmatch, 151 | title={MixMatch: A Holistic Approach to Semi-Supervised Learning}, 152 | author={Berthelot, David and Carlini, Nicholas and Goodfellow, Ian and Papernot, Nicolas and Oliver, Avital and Raffel, Colin}, 153 | journal={arXiv preprint arXiv:1905.02249}, 154 | year={2019} 155 | } 156 | ``` -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.data as data 4 | import numpy as np 5 | import random 6 | import time 7 | 8 | from data_loader.loader import get_trainval_data 9 | from src.base import BaseModel 10 | from SSL_loss.mixmatch import MixMatchLoss 11 | from tqdm import tqdm 12 | from utils import * 13 | from model import * 14 | 15 | class Trainer(BaseModel): 16 | def __init__(self, configs): 17 | BaseModel.__init__(self, configs) 18 | self.model = WRN(num_classes=configs.num_classes, 19 | depth=configs.depth, 20 | width=configs.width, 21 | large=configs.large).to(self.device) 22 | 23 | self.ema_model = WRN(num_classes=configs.num_classes, 24 | depth=configs.depth, 25 | width=configs.width, 26 | large=configs.large).to(self.device) 27 | for param in self.ema_model.parameters(): 28 | param.detach_() 29 | 30 | if self.configs.seed == "None": 31 | manualSeed = random.randint(1, 10000) 32 | else: 33 | manualSeed = self.configs.seed 34 | np.random.seed(manualSeed) 35 | torch.manual_seed(manualSeed) 36 | 37 | self.logger.info(" Total params: {:.2f}M".format( 38 | sum(p.numel() for p in self.model.parameters()) / 1e6)) 39 | self.logger.info(" Sampling seed : {0}".format(manualSeed)) 40 | transform_train, transform_val = get_transform(configs.method, configs.dataset) 41 | 42 | train_labeled_set, train_unlabeled_set, val_set = get_trainval_data(configs.datapath, configs.method, configs.dataset, configs.K, \ 43 | configs.num_label, configs.num_classes, transform_train=transform_train, transform_val=transform_val) 44 | 45 | if configs.method=='Mixmatch': 46 | self.train_loader = data.DataLoader(train_labeled_set, batch_size=configs.batch_size, shuffle=True, num_workers=0, drop_last=True) 47 | self.u_train_loader = data.DataLoader(train_unlabeled_set, batch_size=configs.batch_size, shuffle=True, num_workers=0, drop_last=True) 48 | else: 49 | raise NotImplementedError 50 | self.val_loader = data.DataLoader(val_set, batch_size=configs.batch_size, shuffle=False, num_workers=0) 51 | 52 | self.criterion = MixMatchLoss(configs, self.device) 53 | self.eval_criterion = nn.CrossEntropyLoss().to(self.device) 54 | if configs.optim == 'ADAM': 55 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr = configs.lr) 56 | elif configs.optim =='SGD': 57 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr = configs.lr, momentum=0.9, nesterov=True, weight_decay=configs.weight_decay) 58 | self.ema_optimizer = WeightEMA(self.model, self.ema_model, configs.weight_decay*configs.lr, alpha=configs.ema_alpha) 59 | 60 | self.top1_val = 0 61 | self.top1_ema_val = 0 62 | 63 | if self.configs.resume: 64 | ckpt_path = os.path.join(self.out_dir, self.configs.ckpt) 65 | self.start_epoch = self._load_checkpoint(ckpt_path) 66 | else: 67 | self.start_epoch = 0 68 | 69 | def _terminate(self): 70 | # terminate the logger and SummaryWriter 71 | self.writer.close() 72 | 73 | def evaluate(self, epoch): 74 | self.model.eval() 75 | loss_meter = AverageMeter() 76 | top1_meter = AverageMeter() 77 | top5_meter = AverageMeter() 78 | 79 | loss_ema_meter = AverageMeter() 80 | top1_ema_meter = AverageMeter() 81 | top5_ema_meter = AverageMeter() 82 | 83 | is_best_val = False 84 | ema_is_best_val = False 85 | with torch.no_grad(): 86 | if self.configs.verbose: 87 | tq = tqdm(self.val_loader, total=self.val_loader.__len__(), leave=False) 88 | else: 89 | tq = self.val_loader 90 | for x,y in tq: 91 | 92 | x, y = x.to(self.device), y.to(self.device) 93 | logits, _ = self.model(x) 94 | logits_ema, _ = self.ema_model(x) 95 | 96 | loss = self.eval_criterion(logits, y) 97 | prob = torch.softmax(logits, dim=1) 98 | top1, top5 = accuracy(prob, y, (1,5)) 99 | 100 | loss_meter.update(loss.item()) 101 | top1_meter.update(top1.item()) 102 | top5_meter.update(top5.item()) 103 | 104 | loss_ema = self.eval_criterion(logits_ema, y) 105 | prob_ema = torch.softmax(logits_ema, dim=1) 106 | top1_ema, top5_ema = accuracy(prob_ema, y, (1,5)) 107 | 108 | loss_ema_meter.update(loss_ema.item()) 109 | top1_ema_meter.update(top1_ema.item()) 110 | top5_ema_meter.update(top5_ema.item()) 111 | if self.configs.verbose: 112 | tq.set_description("[{}] Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format("VAL", epoch, top1_meter.avg, top5_meter.avg, loss_meter.avg)) 113 | self.logger.info(" [{}] Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(" VAL ", epoch, top1_meter.avg, top5_meter.avg, loss_meter.avg)) 114 | self.logger.info(" [{}] Epoch:{}. Top1: {:.4f}. Top5: {:.4f}. Loss: {:.4f}.".format(" EMA ", epoch, top1_ema_meter.avg, top5_ema_meter.avg, loss_ema_meter.avg)) 115 | self.writer.add_scalars('val_acc/top1', { 116 | 'top1_val': top1_meter.avg, 117 | 'top1_ema_val': top1_ema_meter.avg, 118 | }, epoch) 119 | self.writer.add_scalars('val_acc/top5', { 120 | 'top5_val': top5_meter.avg, 121 | 'top5_ema_val': top5_ema_meter.avg 122 | }, epoch) 123 | self.writer.add_scalars('val_loss', { 124 | 'loss_val': loss_meter.avg, 125 | 'loss_ema_val': loss_ema_meter.avg 126 | }, epoch) 127 | 128 | if self.top1_val < top1_meter.avg: 129 | self.top1_val = top1_meter.avg 130 | is_best_val = True 131 | if self.top1_ema_val < top1_ema_meter.avg: 132 | self.top1_ema_val = top1_ema_meter.avg 133 | ema_is_best_val = True 134 | self._save_checkpoint(epoch, is_best_val, ema_is_best_val) 135 | 136 | 137 | def train(self): 138 | 139 | epoch_start = time.time() # start time 140 | loss_meter = AverageMeter() 141 | loss_x_meter = AverageMeter() 142 | loss_u_meter = AverageMeter() 143 | n_iters = 1024 144 | train_loader_iter = iter(self.train_loader) 145 | u_train_loader_iter = iter(self.u_train_loader) 146 | for epoch in range(self.start_epoch, self.configs.epochs): 147 | self.model.train() 148 | if self.configs.verbose: 149 | tq = tqdm(range(n_iters), total = n_iters, leave=True) 150 | else: 151 | tq = range(n_iters) 152 | for it in tq: 153 | try: 154 | x, y = train_loader_iter.next() 155 | except: 156 | train_loader_iter = iter(self.train_loader) 157 | x, y = train_loader_iter.next() 158 | try: 159 | u_x, _ = u_train_loader_iter.next() 160 | except: 161 | u_train_loader_iter = iter(self.u_train_loader) 162 | u_x, _ = u_train_loader_iter.next() 163 | 164 | # forward inputs 165 | 166 | current = epoch + it / n_iters 167 | input = {'model' : self.model, 168 | 'u_x' : u_x, 169 | 'x' : x, 170 | 'y' : y, 171 | 'current' : current} 172 | 173 | # compute mixmatch loss 174 | loss_x, loss_u, w = self.criterion(input) 175 | loss = loss_x + loss_u * w 176 | 177 | # update 178 | self.optimizer.zero_grad() 179 | loss.backward() 180 | self.optimizer.step() 181 | self.ema_optimizer.step() 182 | 183 | 184 | # logging 185 | loss_meter.update(loss.item()) 186 | loss_x_meter.update(loss_x.item()) 187 | loss_u_meter.update(loss_u.item()) 188 | 189 | t = time.time() - epoch_start 190 | if self.configs.verbose: 191 | tq.set_description(" Epoch [{}/{}], iter: {}. loss: {:.4f}. loss_x: {:.4f}. loss_u: {:.4f}. weight :{:.4f} Time: {:.2f}".format( 192 | epoch, self.configs.epochs, it + 1, 193 | loss_meter.avg, 194 | loss_x_meter.avg, 195 | loss_u_meter.avg, 196 | w, 197 | t)) 198 | self.logger.info(" Epoch [{}/{}], iter: {}. loss: {:.4f}. loss_x: {:.4f}. loss_u: {:.4f}. weight : {:.4f}. Time: {:.2f}".format( 199 | epoch, self.configs.epochs, it + 1, 200 | loss_meter.avg, 201 | loss_x_meter.avg, 202 | loss_u_meter.avg, 203 | w, 204 | t)) 205 | self.writer.add_scalars('train_loss', { 206 | 'loss': loss_meter.avg, 207 | 'loss_x': loss_x_meter.avg, 208 | 'loss_u': loss_u_meter.avg, 209 | }, epoch) 210 | self.evaluate(epoch) 211 | self._terminate() 212 | 213 | 214 | 215 | --------------------------------------------------------------------------------