├── assets └── main.png ├── libs ├── __pycache__ │ ├── ebclr.cpython-39.pyc │ ├── net.cpython-39.pyc │ ├── utils.cpython-39.pyc │ ├── models.cpython-39.pyc │ ├── resnets.cpython-39.pyc │ ├── evaluation.cpython-39.pyc │ └── gaussian_blur.cpython-39.pyc ├── gaussian_blur.py ├── models.py ├── net.py ├── simclr.py ├── resnets.py ├── utils.py ├── simsiam.py ├── moco.py ├── byol.py ├── evaluation.py └── ebclr.py ├── configs_simsiam ├── fmnist_b256.yaml ├── mnist_b256.yaml ├── cifar100_b256.yaml └── cifar10_b256.yaml ├── configs_byol ├── fmnist_b256.yaml ├── mnist_b256.yaml ├── cifar100_b256.yaml └── cifar10_b256.yaml ├── configs_simclr ├── fmnist_b256.yaml ├── fmnist_b64.yaml ├── mnist_b128.yaml ├── mnist_b16.yaml ├── mnist_b256.yaml ├── mnist_b64.yaml ├── fmnist_b128.yaml ├── fmnist_b16.yaml ├── cifar100_b128.yaml ├── cifar100_b256.yaml ├── cifar100_b64.yaml ├── cifar10_b128.yaml ├── cifar10_b16.yaml ├── cifar10_b256.yaml ├── cifar10_b64.yaml └── cifar100_b16.yaml ├── configs_moco ├── fmnist_b256.yaml ├── mnist_b256.yaml ├── cifar100_b256.yaml └── cifar10_b256.yaml ├── configs_ebclr ├── fmnist_b16.yaml ├── fmnist_b64.yaml ├── mnist_b128.yaml ├── mnist_b16.yaml ├── mnist_b64.yaml ├── fmnist_b128.yaml ├── cifar100_b128.yaml ├── cifar100_b16.yaml ├── cifar100_b64.yaml ├── cifar10_b128.yaml ├── cifar10_b16.yaml └── cifar10_b64.yaml ├── Train.ipynb ├── README.md └── Evaluation.ipynb /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/assets/main.png -------------------------------------------------------------------------------- /libs/__pycache__/ebclr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/ebclr.cpython-39.pyc -------------------------------------------------------------------------------- /libs/__pycache__/net.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/net.cpython-39.pyc -------------------------------------------------------------------------------- /libs/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /libs/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /libs/__pycache__/resnets.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/resnets.cpython-39.pyc -------------------------------------------------------------------------------- /libs/__pycache__/evaluation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/evaluation.cpython-39.pyc -------------------------------------------------------------------------------- /libs/__pycache__/gaussian_blur.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/gaussian_blur.cpython-39.pyc -------------------------------------------------------------------------------- /configs_simsiam/fmnist_b256.yaml: -------------------------------------------------------------------------------- 1 | epochs : 200 2 | bs : 256 3 | p_epoch : 5 4 | s_epoch : 5 5 | 6 | data: 7 | data_dir : "./data" 8 | dataset : "fmnist" 9 | data_ratio : 1.0 10 | 11 | optim: 12 | optimizer: "sgd" 13 | init_lr: 0.03 14 | lr_schedule: "const" 15 | weight_decay: 0.0001 16 | 17 | net: 18 | encoder : "resnet18" 19 | use_bn : true 20 | use_sn : false 21 | use_wn : false 22 | act : "relu" 23 | proj_dim : 2048 24 | nc : 1 25 | 26 | t: 27 | crop_scale: 28 | min: 0.08 29 | max: 1.0 -------------------------------------------------------------------------------- /configs_simsiam/mnist_b256.yaml: -------------------------------------------------------------------------------- 1 | epochs : 200 2 | bs : 256 3 | p_epoch : 5 4 | s_epoch : 5 5 | 6 | data: 7 | data_dir : "./data" 8 | dataset : "kmnist" 9 | data_ratio : 1.0 10 | 11 | optim: 12 | optimizer: "sgd" 13 | init_lr: 0.03 14 | lr_schedule: "const" 15 | weight_decay: 0.0001 16 | 17 | net: 18 | encoder : "resnet18" 19 | use_bn : true 20 | use_sn : false 21 | use_wn : false 22 | act : "relu" 23 | proj_dim : 2048 24 | nc : 1 25 | 26 | t: 27 | crop_scale: 28 | min: 0.08 29 | max: 1.0 -------------------------------------------------------------------------------- /configs_byol/fmnist_b256.yaml: -------------------------------------------------------------------------------- 1 | momentum : 0.9 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "fmnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 256 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_byol/mnist_b256.yaml: -------------------------------------------------------------------------------- 1 | momentum : 0.9 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "mnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 256 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/fmnist_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "fmnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/fmnist_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 64 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "fmnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.0075 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/mnist_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 128 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "mnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.015 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/mnist_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 16 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "mnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.001875 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/mnist_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "mnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/mnist_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 64 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "mnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.0075 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/fmnist_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 128 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "fmnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.015 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_simclr/fmnist_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 16 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "fmnist" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.001875 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 1 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 -------------------------------------------------------------------------------- /configs_moco/fmnist_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.07 2 | queue_size : 65536 3 | momentum : 0.999 4 | epochs : 200 5 | bs : 256 6 | p_epoch : 5 7 | s_epoch : 5 8 | 9 | data: 10 | data_dir : "./data" 11 | dataset : "fmnist" 12 | data_ratio : 1.0 13 | 14 | optim: 15 | optimizer: "sgd" 16 | init_lr: 0.03 17 | lr_schedule: "const" 18 | weight_decay: 0.0001 19 | 20 | net: 21 | encoder : "resnet18" 22 | use_bn : true 23 | use_sn : false 24 | use_wn : false 25 | act : "relu" 26 | proj_dim : 128 27 | nc : 1 28 | 29 | t: 30 | crop_scale: 31 | min: 0.08 32 | max: 1.0 -------------------------------------------------------------------------------- /configs_moco/mnist_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.07 2 | queue_size : 65536 3 | momentum : 0.999 4 | epochs : 200 5 | bs : 256 6 | p_epoch : 5 7 | s_epoch : 5 8 | 9 | data: 10 | data_dir : "./data" 11 | dataset : "mnist" 12 | data_ratio : 1.0 13 | 14 | optim: 15 | optimizer: "sgd" 16 | init_lr: 0.03 17 | lr_schedule: "const" 18 | weight_decay: 0.0001 19 | 20 | net: 21 | encoder : "resnet18" 22 | use_bn : true 23 | use_sn : false 24 | use_wn : false 25 | act : "relu" 26 | proj_dim : 128 27 | nc : 1 28 | 29 | t: 30 | crop_scale: 31 | min: 0.08 32 | max: 1.0 -------------------------------------------------------------------------------- /configs_simsiam/cifar100_b256.yaml: -------------------------------------------------------------------------------- 1 | epochs : 200 2 | bs : 256 3 | p_epoch : 5 4 | s_epoch : 5 5 | 6 | data: 7 | data_dir : "./data" 8 | dataset : "cifar100" 9 | data_ratio : 1.0 10 | 11 | optim: 12 | optimizer: "sgd" 13 | init_lr: 0.03 14 | lr_schedule: "const" 15 | weight_decay: 0.0001 16 | 17 | net: 18 | encoder : "resnet18" 19 | use_bn : true 20 | use_sn : false 21 | use_wn : false 22 | act : "relu" 23 | proj_dim : 2048 24 | nc : 3 25 | 26 | t: 27 | crop_scale: 28 | min: 0.08 29 | max: 1.0 30 | flip_p : 0.5 31 | jitter: 32 | b: 0.8 33 | c: 0.8 34 | s: 0.8 35 | h: 0.2 36 | jitter_p : 0.8 37 | gray_p : 0.2 38 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simsiam/cifar10_b256.yaml: -------------------------------------------------------------------------------- 1 | epochs : 200 2 | bs : 256 3 | p_epoch : 5 4 | s_epoch : 5 5 | 6 | data: 7 | data_dir : "./data" 8 | dataset : "cifar10" 9 | data_ratio : 1.0 10 | 11 | optim: 12 | optimizer: "sgd" 13 | init_lr: 0.03 14 | lr_schedule: "const" 15 | weight_decay: 0.0001 16 | 17 | net: 18 | encoder : "resnet18" 19 | use_bn : true 20 | use_sn : false 21 | use_wn : false 22 | act : "relu" 23 | proj_dim : 2048 24 | nc : 3 25 | 26 | t: 27 | crop_scale: 28 | min: 0.08 29 | max: 1.0 30 | flip_p : 0.5 31 | jitter: 32 | b: 0.8 33 | c: 0.8 34 | s: 0.8 35 | h: 0.2 36 | jitter_p : 0.8 37 | gray_p : 0.2 38 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_byol/cifar100_b256.yaml: -------------------------------------------------------------------------------- 1 | momentum : 0.9 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar10" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 256 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_byol/cifar10_b256.yaml: -------------------------------------------------------------------------------- 1 | momentum : 0.9 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar10" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 256 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar100_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 128 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar100" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.015 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar100_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar100" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar100_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 64 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar100" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.0075 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar10_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 128 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar10" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.015 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar10_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 16 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar10" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.001875 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar10_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 256 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar10" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.03 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar10_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 64 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar10" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.0075 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_simclr/cifar100_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.1 2 | epochs : 200 3 | bs : 16 4 | p_epoch : 5 5 | s_epoch : 5 6 | 7 | data: 8 | data_dir : "./data" 9 | dataset : "cifar100" 10 | data_ratio : 1.0 11 | 12 | optim: 13 | optimizer: "sgd" 14 | init_lr: 0.001875 15 | lr_schedule: "const" 16 | weight_decay: 0.0001 17 | 18 | net: 19 | encoder : "resnet18" 20 | use_bn : true 21 | use_sn : false 22 | use_wn : false 23 | act : "relu" 24 | proj_dim : 128 25 | nc : 3 26 | 27 | t: 28 | crop_scale: 29 | min: 0.08 30 | max: 1.0 31 | flip_p : 0.5 32 | jitter: 33 | b: 0.8 34 | c: 0.8 35 | s: 0.8 36 | h: 0.2 37 | jitter_p : 0.8 38 | gray_p : 0.2 39 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_moco/cifar100_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.07 2 | queue_size : 65536 3 | momentum : 0.999 4 | epochs : 200 5 | bs : 256 6 | p_epoch : 5 7 | s_epoch : 5 8 | 9 | data: 10 | data_dir : "./data" 11 | dataset : "cifar100" 12 | data_ratio : 1.0 13 | 14 | optim: 15 | optimizer: "sgd" 16 | init_lr: 0.03 17 | lr_schedule: "const" 18 | weight_decay: 0.0001 19 | 20 | net: 21 | encoder : "resnet18" 22 | use_bn : true 23 | use_sn : false 24 | use_wn : false 25 | act : "relu" 26 | proj_dim : 128 27 | nc : 3 28 | 29 | t: 30 | crop_scale: 31 | min: 0.08 32 | max: 1.0 33 | flip_p : 0.5 34 | jitter: 35 | b: 0.8 36 | c: 0.8 37 | s: 0.8 38 | h: 0.2 39 | jitter_p : 0.8 40 | gray_p : 0.2 41 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_moco/cifar10_b256.yaml: -------------------------------------------------------------------------------- 1 | temperature : 0.07 2 | queue_size : 65536 3 | momentum : 0.999 4 | epochs : 200 5 | bs : 256 6 | p_epoch : 5 7 | s_epoch : 5 8 | 9 | data: 10 | data_dir : "./data" 11 | dataset : "cifar10" 12 | data_ratio : 1.0 13 | 14 | optim: 15 | optimizer: "sgd" 16 | init_lr: 0.03 17 | lr_schedule: "const" 18 | weight_decay: 0.0001 19 | 20 | net: 21 | encoder : "resnet18" 22 | use_bn : true 23 | use_sn : false 24 | use_wn : false 25 | act : "relu" 26 | proj_dim : 128 27 | nc : 3 28 | 29 | t: 30 | crop_scale: 31 | min: 0.08 32 | max: 1.0 33 | flip_p : 0.5 34 | jitter: 35 | b: 0.8 36 | c: 0.8 37 | s: 0.8 38 | h: 0.2 39 | jitter_p : 0.8 40 | gray_p : 0.2 41 | blur_scale : 0.1 -------------------------------------------------------------------------------- /configs_ebclr/fmnist_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 16 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 960 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "fmnist" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.6 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 1 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/fmnist_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 64 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 480 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "fmnist" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.6 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 1 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/mnist_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 128 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 240 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "mnist" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.4 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0002 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 1 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/mnist_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 16 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 960 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "mnist" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.4 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 1 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/mnist_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 64 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 480 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "mnist" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.4 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 1 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/fmnist_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 128 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 240 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "fmnist" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.6 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0002 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 1 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/cifar100_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 128 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 320 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "cifar100" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.2 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0002 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 3 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | flip_p : 0.5 49 | jitter: 50 | b: 0.8 51 | c: 0.8 52 | s: 0.8 53 | h: 0.2 54 | jitter_p : 0.8 55 | gray_p : 0.2 56 | blur_scale : 0.1 57 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/cifar100_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 16 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 1280 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "cifar100" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.2 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 3 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | flip_p : 0.5 49 | jitter: 50 | b: 0.8 51 | c: 0.8 52 | s: 0.8 53 | h: 0.2 54 | jitter_p : 0.8 55 | gray_p : 0.2 56 | blur_scale : 0.1 57 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/cifar100_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 64 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 640 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "cifar100" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.2 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 3 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | flip_p : 0.5 49 | jitter: 50 | b: 0.8 51 | c: 0.8 52 | s: 0.8 53 | h: 0.2 54 | jitter_p : 0.8 55 | gray_p : 0.2 56 | blur_scale : 0.1 57 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/cifar10_b128.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 128 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 320 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "cifar10" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.2 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0002 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 3 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | flip_p : 0.5 49 | jitter: 50 | b: 0.8 51 | c: 0.8 52 | s: 0.8 53 | h: 0.2 54 | jitter_p : 0.8 55 | gray_p : 0.2 56 | blur_scale : 0.1 57 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/cifar10_b16.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 16 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 1280 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "cifar10" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.2 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 3 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | flip_p : 0.5 49 | jitter: 50 | b: 0.8 51 | c: 0.8 52 | s: 0.8 53 | h: 0.2 54 | jitter_p : 0.8 55 | gray_p : 0.2 56 | blur_scale : 0.1 57 | noise_std : 0.03 -------------------------------------------------------------------------------- /configs_ebclr/cifar10_b64.yaml: -------------------------------------------------------------------------------- 1 | temperature: 0.1 2 | normalize: true 3 | bs: 64 4 | lmda1: 0.001 5 | lmda2: 0.1 6 | epochs: 200 7 | p_iter: 640 8 | s_epoch: 5 9 | save_buffer: true 10 | 11 | data: 12 | data_dir : "./data" 13 | dataset : "cifar10" 14 | data_ratio : 1.0 15 | 16 | sgld: 17 | iter: 10 18 | lr: 0.05 19 | min_std: 0.01 20 | max_std: 0.05 21 | threshold: 3.0 22 | tau: 1.0 23 | 24 | buffer: 25 | CD_ratio: 1.0 26 | size: 50000 27 | rho: 0.2 28 | bs: 256 29 | 30 | optim: 31 | optimizer: "adam" 32 | init_lr: 0.0001 33 | lr_schedule: "const" 34 | 35 | net: 36 | encoder: "resnet18" 37 | use_bn: false 38 | use_sn: false 39 | use_wn: false 40 | act: "lrelu" 41 | proj_dim: 128 42 | nc: 3 43 | 44 | t: 45 | crop_scale: 46 | min: 0.08 47 | max: 1.0 48 | flip_p : 0.5 49 | jitter: 50 | b: 0.8 51 | c: 0.8 52 | s: 0.8 53 | h: 0.2 54 | jitter_p : 0.8 55 | gray_p : 0.2 56 | blur_scale : 0.1 57 | noise_std : 0.03 -------------------------------------------------------------------------------- /Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "87eead4f", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "from libs import ebclr\n", 13 | "\n", 14 | "ebclr.run(log_dir='./logs', config_dir='./configs_ebclr/cifar10_b16.yaml', start_epoch=0, device=0)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "9a4cc447", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [] 24 | } 25 | ], 26 | "metadata": { 27 | "kernelspec": { 28 | "display_name": "Python 3 (ipykernel)", 29 | "language": "python", 30 | "name": "python3" 31 | }, 32 | "language_info": { 33 | "codemirror_mode": { 34 | "name": "ipython", 35 | "version": 3 36 | }, 37 | "file_extension": ".py", 38 | "mimetype": "text/x-python", 39 | "name": "python", 40 | "nbconvert_exporter": "python", 41 | "pygments_lexer": "ipython3", 42 | "version": "3.9.7" 43 | } 44 | }, 45 | "nbformat": 4, 46 | "nbformat_minor": 5 47 | } 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EBCLR 2 | Official PyTorch implementation of [Energy-Based Contrastive Learning of Visual Representations](https://arxiv.org/abs/2202.04933), NeurIPS 2022. 3 | 4 | We propose a visual representation learning framework, Energy-Based Contrastive Learning (EBCLR), that combines Energy-Based Models (EBMs) with contrastive learning. EBCLR associates distance on the projection space with the density of positive pairs to learn useful visual representations. The figure below illustrates the general idea of EBCLR. We find EBCLR shows accelerated convergence and robustness to small number of negative pairs per positive pair. 5 | 6 |

7 | 8 |

9 | 10 | ## How to Run This Code 11 | 12 | Example codes for training and linear / KNN evaluation can be found in `Train.ipynb` and `Evaluation.ipynb`, respectively. 13 | 14 | ## References 15 | 16 | If you find the code useful for your research, please consider citing 17 | ```bib 18 | @inproceedings{ 19 | kim2022ebclr, 20 | title={Energy-Based Contrastive Learning of Visual Representations}, 21 | author={Beomsu Kim and Jong Chul Ye}, 22 | booktitle={NeurIPS}, 23 | year={2022} 24 | } 25 | ``` 26 | -------------------------------------------------------------------------------- /libs/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch 6 | 7 | class GaussianBlur(object): 8 | """blur a single image on CPU""" 9 | 10 | def __init__(self, kernel_size): 11 | radius = kernel_size // 2 12 | kernel_size = radius * 2 + 1 13 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), stride=1, padding=0, bias=False, groups=3) 14 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), stride=1, padding=0, bias=False, groups=3) 15 | self.k = kernel_size 16 | self.r = radius 17 | 18 | self.blur = nn.Sequential( 19 | nn.ReflectionPad2d(radius), 20 | self.blur_h, 21 | self.blur_v 22 | ) 23 | 24 | def __call__(self, img): 25 | 26 | # img = img[None] 27 | 28 | sigma = np.random.uniform(0.1, 2.0) 29 | x = np.arange(-self.r, self.r + 1) 30 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma)) 31 | x = x / x.sum() 32 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1) 33 | 34 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1)) 35 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k)) 36 | 37 | with torch.no_grad(): 38 | img = self.blur(img) 39 | # img = img.squeeze() 40 | 41 | return img -------------------------------------------------------------------------------- /libs/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import torch 6 | 7 | def Linear(in_features, out_features, bias=True, use_bn=False, use_sn=False, use_wn=False): 8 | 9 | layer = nn.Linear(in_features, out_features, bias=bias) 10 | 11 | if use_bn: 12 | layer = nn.Sequential(layer, nn.BatchNorm1d(out_features)) 13 | 14 | if use_sn: 15 | layer = nn.utils.spectral_norm(layer) 16 | 17 | if use_wn: 18 | layer = nn.utils.weight_norm(layer) 19 | 20 | return layer 21 | 22 | class LIN(nn.Module): 23 | 24 | def __init__(self, in_dim, out_dim, use_sn=False): 25 | super(LIN, self).__init__() 26 | 27 | self.fc1 = Linear(in_dim, out_dim, use_sn=use_sn) 28 | 29 | def forward(self, x): 30 | x = self.fc1(x) 31 | return x 32 | 33 | class MLP(nn.Module): 34 | 35 | def __init__(self, in_dim, out_dim, n_layers=2, use_bn=False, use_sn=False, use_wn=False, act=nn.LeakyReLU(0.2)): 36 | super(MLP, self).__init__() 37 | self.act = act 38 | 39 | self.fc1 = Linear(in_dim, in_dim, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn) 40 | 41 | # fcs = [] 42 | # for _ in range(n_layers - 1): 43 | # fcs.append(Linear(in_dim, in_dim, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn)) 44 | # fcs.append(self.act) 45 | # self.fcs = nn.Sequential(*fcs) 46 | 47 | self.fc2 = Linear(in_dim, out_dim, use_sn=use_sn, use_wn=use_wn) 48 | 49 | def forward(self, x): 50 | x = self.act(self.fc1(x)) 51 | x = self.fc2(x) 52 | return x -------------------------------------------------------------------------------- /libs/net.py: -------------------------------------------------------------------------------- 1 | from libs.resnets import ResNet18, ResNet34, ResNet50 2 | from libs.models import MLP 3 | 4 | import torch.nn as nn 5 | 6 | def get_act(net_config): 7 | 8 | if net_config['act'] == 'relu': 9 | return nn.ReLU() 10 | elif net_config['act'] == 'lrelu': 11 | return nn.LeakyReLU(0.2) 12 | else: 13 | raise NotImplementedError 14 | 15 | def init_enc(net_config, device): 16 | 17 | if net_config['encoder'] == 'resnet18': 18 | enc = ResNet18(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device) 19 | feature_dim = 512 20 | elif net_config['encoder'] == 'resnet34': 21 | enc = ResNet34(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device) 22 | feature_dim = 512 23 | elif net_config['encoder'] == 'resnet50': 24 | enc = ResNet50(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device) 25 | feature_dim = 2048 26 | else: 27 | raise NotImplementedError 28 | 29 | return enc, feature_dim 30 | 31 | def init_enc_proj(net_config, device): 32 | 33 | if net_config['encoder'] == 'resnet18': 34 | enc = ResNet18(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device) 35 | feature_dim = 512 36 | elif net_config['encoder'] == 'resnet34': 37 | enc = ResNet34(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device) 38 | feature_dim = 512 39 | elif net_config['encoder'] == 'resnet50': 40 | enc = ResNet50(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device) 41 | feature_dim = 2048 42 | else: 43 | raise NotImplementedError 44 | 45 | if 'proj_layers' in net_config: 46 | proj = MLP(feature_dim, net_config['proj_dim'], n_layers=net_config['proj_layers'], act=get_act(net_config)).cuda(device) 47 | else: 48 | proj = MLP(feature_dim, net_config['proj_dim'], n_layers=2, act=get_act(net_config)).cuda(device) 49 | 50 | return enc, proj, feature_dim -------------------------------------------------------------------------------- /Evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3d0be4ea", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from libs.evaluation import transfer_acc, eval_acc, eval_knn_transfer_acc, eval_knn_acc" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "a5973077", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "for epoch in ['curr']:\n", 21 | "\n", 22 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n", 23 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n", 24 | "\n", 25 | " eval_config = {\n", 26 | " 'K' : 20,\n", 27 | " 'bs' : 50,\n", 28 | " 'standardize' : True\n", 29 | " }\n", 30 | "\n", 31 | " data_config = {\n", 32 | " 'data_dir' : './data',\n", 33 | " 'dataset' : 'cifar10',\n", 34 | " 'data_ratio' : 1.0\n", 35 | " }\n", 36 | " \n", 37 | " device = 0\n", 38 | " \n", 39 | " print('Epoch {}'.format(epoch))\n", 40 | " \n", 41 | " eval_knn_acc(ckpt_dir, config_dir, eval_config, data_config, device)" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "995cb890", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "for epoch in ['curr']:\n", 52 | "\n", 53 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n", 54 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n", 55 | "\n", 56 | " eval_config = {\n", 57 | " 'K' : 20,\n", 58 | " 'bs' : 50,\n", 59 | " 'standardize' : False\n", 60 | " }\n", 61 | " \n", 62 | " source_data_config = {\n", 63 | " 'data_dir' : './data',\n", 64 | " 'dataset' : 'cifar100',\n", 65 | " 'data_ratio' : 1.0\n", 66 | " }\n", 67 | "\n", 68 | " target_data_config = {\n", 69 | " 'data_dir' : './data',\n", 70 | " 'dataset' : 'cifar10',\n", 71 | " 'data_ratio' : 1.0\n", 72 | " }\n", 73 | " \n", 74 | " device = 0\n", 75 | " \n", 76 | " print('Epoch {}'.format(epoch))\n", 77 | " \n", 78 | " eval_knn_transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "b43feb97", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "for epoch in ['curr']:\n", 89 | "\n", 90 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n", 91 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n", 92 | "\n", 93 | " eval_config = {\n", 94 | " 'epochs' : 200,\n", 95 | " 'bs' : 512,\n", 96 | " 'lr' : 1e-3,\n", 97 | " 'p_epoch' : 10,\n", 98 | " 'standardize' : False\n", 99 | " }\n", 100 | "\n", 101 | " data_config = {\n", 102 | " 'data_dir' : './data',\n", 103 | " 'dataset' : 'cifar100',\n", 104 | " 'data_ratio' : 1.0\n", 105 | " }\n", 106 | " \n", 107 | " device = 0\n", 108 | " \n", 109 | " print('Epoch {}'.format(epoch))\n", 110 | "\n", 111 | " eval_acc(ckpt_dir, config_dir, eval_config, data_config, device)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "10124fbf", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "for epoch in ['curr']:\n", 122 | "\n", 123 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n", 124 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n", 125 | "\n", 126 | " eval_config = {\n", 127 | " 'epochs' : 200,\n", 128 | " 'bs' : 512,\n", 129 | " 'lr' : 8e-4,\n", 130 | " 'p_epoch' : 10,\n", 131 | " 'standardize' : False\n", 132 | " }\n", 133 | " \n", 134 | " source_data_config = {\n", 135 | " 'data_dir' : './data',\n", 136 | " 'dataset' : 'cifar100',\n", 137 | " 'data_ratio' : 1.0\n", 138 | " }\n", 139 | "\n", 140 | " target_data_config = {\n", 141 | " 'data_dir' : './data',\n", 142 | " 'dataset' : 'cifar10',\n", 143 | " 'data_ratio' : 1.0\n", 144 | " }\n", 145 | " \n", 146 | " device = 0\n", 147 | " \n", 148 | " print('Epoch {}'.format(epoch))\n", 149 | " \n", 150 | " transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "1cfe6e8c", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "Python 3 (ipykernel)", 165 | "language": "python", 166 | "name": "python3" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.9.7" 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 5 183 | } 184 | -------------------------------------------------------------------------------- /libs/simclr.py: -------------------------------------------------------------------------------- 1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config 2 | from libs.net import init_enc_proj 3 | 4 | import torchvision.transforms.functional as TF 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | import matplotlib.pyplot as plt 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | import torchvision 13 | import itertools 14 | import torch 15 | import time 16 | import os 17 | 18 | def save_ckpt(enc, proj, opt, iteration, runtime, ckpt_path): 19 | 20 | checkpoint = { 21 | 'iteration' : iteration, 22 | 'runtime' : runtime, 23 | 'opt' : opt.state_dict() 24 | } 25 | 26 | if isinstance(enc, nn.DataParallel): 27 | checkpoint['enc_state_dict'] = enc.module.state_dict() 28 | else: 29 | checkpoint['enc_state_dict'] = enc.state_dict() 30 | 31 | if isinstance(proj, nn.DataParallel): 32 | checkpoint['proj_state_dict'] = proj.module.state_dict() 33 | else: 34 | checkpoint['proj_state_dict'] = proj.state_dict() 35 | 36 | torch.save(checkpoint, ckpt_path) 37 | 38 | def load_ckpt(enc, proj, opt, ckpt_path): 39 | ckpt = torch.load(ckpt_path) 40 | 41 | if isinstance(enc, nn.DataParallel): 42 | enc.module.load_state_dict(ckpt['enc_state_dict']) 43 | else: 44 | enc.load_state_dict(ckpt['enc_state_dict']) 45 | 46 | if isinstance(proj, nn.DataParallel): 47 | proj.module.load_state_dict(ckpt['proj_state_dict']) 48 | else: 49 | proj.load_state_dict(ckpt['proj_state_dict']) 50 | 51 | opt.load_state_dict(ckpt['opt']) 52 | it = ckpt['iteration'] 53 | rt = ckpt['runtime'] 54 | 55 | return it, rt 56 | 57 | class SimCLR_Loss(nn.Module): 58 | 59 | def __init__(self, temperature): 60 | super(SimCLR_Loss, self).__init__() 61 | self.temperature = temperature 62 | 63 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 64 | self.similarity_f = nn.CosineSimilarity(dim=2) 65 | 66 | def mask_correlated_samples(self, batch_size): 67 | N = 2 * batch_size 68 | mask = torch.ones((N, N), dtype=bool) 69 | mask = mask.fill_diagonal_(0) 70 | 71 | for i in range(batch_size): 72 | mask[i, batch_size + i] = 0 73 | mask[batch_size + i, i] = 0 74 | return mask 75 | 76 | def forward(self, z_i, z_j): 77 | 78 | batch_size = z_i.shape[0] 79 | 80 | N = 2 * batch_size 81 | 82 | z = torch.cat((z_i, z_j), dim=0) 83 | 84 | sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature 85 | 86 | sim_i_j = torch.diag(sim, batch_size) 87 | sim_j_i = torch.diag(sim, -batch_size) 88 | 89 | mask = self.mask_correlated_samples(batch_size) 90 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 91 | negative_samples = sim[mask].reshape(N, -1) 92 | 93 | labels = torch.from_numpy(np.array([0] * N)).reshape(-1).to(positive_samples.device).long() 94 | 95 | logits = torch.cat((positive_samples, negative_samples), dim=1) 96 | loss = self.criterion(logits, labels) 97 | loss /= N 98 | 99 | return loss 100 | 101 | def SimCLR(train_X, enc, proj, config, log_dir, start_epoch): 102 | 103 | train_n = train_X.shape[0] 104 | size = train_X.shape[2] 105 | nc = train_X.shape[1] 106 | 107 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 108 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 109 | 110 | simclr_loss = SimCLR_Loss(config['temperature']) 111 | 112 | params = itertools.chain(enc.parameters(), proj.parameters()) 113 | net = lambda x : proj(enc((x - mean) / std)) 114 | 115 | t = get_t(size, config['t']) 116 | lr_schedule = get_lr_schedule(config) 117 | opt = get_optimizer(config['optim'], params) 118 | 119 | if start_epoch == 0: 120 | it, rt = 0, 0 121 | else: 122 | it, rt = load_ckpt(enc, proj, opt, log_dir + '/{}.pt'.format(start_epoch)) 123 | 124 | enc.train() 125 | proj.train() 126 | 127 | epoch_n_iter = int(np.ceil(train_n / config['bs'])) 128 | while True: 129 | if it % epoch_n_iter == 0: 130 | train_X = train_X[torch.randperm(train_n)] 131 | 132 | i = it % epoch_n_iter 133 | it += 1 134 | 135 | s = time.time() 136 | 137 | X = train_X[i * config['bs']:(i + 1) * config['bs']] 138 | 139 | X_v1 = t(X).cuda() 140 | X_v2 = t(X).cuda() 141 | 142 | loss = simclr_loss(net(X_v1), net(X_v2)) 143 | 144 | # Update 145 | opt.param_groups[0]['lr'] = lr_schedule(it) 146 | opt.zero_grad() 147 | loss.backward() 148 | opt.step() 149 | 150 | e = time.time() 151 | rt += (e - s) 152 | 153 | if it % config['p_iter'] == 0: 154 | save_ckpt(enc, proj, opt, it, rt, log_dir + '/curr.pt') 155 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt)) 156 | 157 | if it % (epoch_n_iter * config['s_epoch']) == 0: 158 | save_ckpt(enc, proj, opt, it, rt, log_dir + '/{}.pt'.format(it // epoch_n_iter)) 159 | 160 | if it >= config['its']: 161 | break 162 | 163 | def run(log_dir, config_dir, start_epoch, device): 164 | 165 | if not os.path.isdir(log_dir): 166 | os.makedirs(log_dir) 167 | 168 | config = load_config(config_dir) 169 | 170 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data']) 171 | 172 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs'])) 173 | 174 | if 'epochs' in config: 175 | config['its'] = epoch_n_iter * config['epochs'] 176 | 177 | if 'p_epoch' in config: 178 | config['p_iter'] = epoch_n_iter * config['p_epoch'] 179 | 180 | enc, proj, _ = init_enc_proj(config['net'], device) 181 | 182 | print('Running SimCLR from epoch {}'.format(start_epoch)) 183 | 184 | with torch.cuda.device(device): 185 | SimCLR(train_X, enc, proj, config, log_dir, start_epoch) 186 | 187 | print('Finished!') -------------------------------------------------------------------------------- /libs/resnets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, use_bn=False, use_sn=False, use_wn=False, avg_pool=False): 6 | 7 | if avg_pool: 8 | layer = nn.Conv2d(in_channels, out_channels, kernel_size, 1, padding, bias=False if use_bn else bias) 9 | else: 10 | layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False if use_bn else bias) 11 | 12 | if use_bn: 13 | layer = nn.Sequential(layer, nn.BatchNorm2d(out_channels)) 14 | 15 | if use_sn: 16 | layer = nn.utils.spectral_norm(layer) 17 | 18 | if use_wn: 19 | layer = nn.utils.weight_norm(layer) 20 | 21 | if avg_pool and stride > 1: 22 | layer = nn.Sequential(layer, nn.AvgPool2d(stride, stride)) 23 | 24 | return layer 25 | 26 | def Linear(in_features, out_features, bias=True, use_bn=False, use_sn=False, use_wn=False): 27 | 28 | layer = nn.Linear(in_features, out_features, bias=bias) 29 | 30 | if use_bn: 31 | layer = nn.Sequential(layer, nn.BatchNorm1d(out_features)) 32 | 33 | if use_sn: 34 | layer = nn.utils.spectral_norm(layer) 35 | 36 | if use_wn: 37 | layer = nn.utils.weight_norm(layer) 38 | 39 | return layer 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, in_planes, planes, stride=1, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 45 | super(BasicBlock, self).__init__() 46 | self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 47 | self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 48 | self.act = act 49 | 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or in_planes != self.expansion*planes: 52 | self.shortcut = nn.Sequential( 53 | Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 54 | ) 55 | 56 | def forward(self, x): 57 | out = self.act(self.conv1(x)) 58 | out = self.conv2(out) 59 | out += self.shortcut(x) 60 | out = self.act(out) 61 | return out 62 | 63 | 64 | class Bottleneck(nn.Module): 65 | expansion = 4 66 | 67 | def __init__(self, in_planes, planes, stride=1, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 68 | super(Bottleneck, self).__init__() 69 | self.conv1 = Conv2d(in_planes, planes, kernel_size=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 70 | self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 71 | self.conv3 = Conv2d(planes, self.expansion * planes, kernel_size=1, bias=True, use_sn=use_sn, use_bn=use_bn, use_wn=use_wn, avg_pool=avg_pool) 72 | self.act = act 73 | 74 | self.shortcut = nn.Sequential() 75 | if stride != 1 or in_planes != self.expansion*planes: 76 | self.shortcut = nn.Sequential( 77 | Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 78 | ) 79 | 80 | def forward(self, x): 81 | out = self.act(self.conv1(x)) 82 | out = self.act(self.conv2(out)) 83 | out = self.conv3(out) 84 | out += self.shortcut(x) 85 | out = self.act(out) 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | def __init__(self, nc, block, num_blocks, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 91 | super(ResNet, self).__init__() 92 | self.in_planes = 64 93 | 94 | self.conv1 = Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool) 95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act) 96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act) 97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act) 98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act) 99 | self.act = act 100 | 101 | def _make_layer(self, block, planes, num_blocks, stride, use_bn, use_sn, use_wn, avg_pool, act): 102 | strides = [stride] + [1]*(num_blocks-1) 103 | layers = [] 104 | for stride in strides: 105 | layers.append(block(self.in_planes, planes, stride, use_bn, use_sn, use_wn, avg_pool, act)) 106 | self.in_planes = planes * block.expansion 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = self.act(self.conv1(x)) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = self.layer4(out) 115 | out = F.adaptive_avg_pool2d(out, 1) 116 | out = out.view(out.size(0), -1) 117 | return out 118 | 119 | 120 | def ResNet18(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 121 | return ResNet(nc, BasicBlock, [2, 2, 2, 2], use_bn, use_sn, use_wn, avg_pool, act) 122 | 123 | 124 | def ResNet34(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 125 | return ResNet(nc, BasicBlock, [3, 4, 6, 3], use_bn, use_sn, use_wn, avg_pool, act) 126 | 127 | 128 | def ResNet50(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 129 | return ResNet(nc, Bottleneck, [3, 4, 6, 3], use_bn, use_sn, use_wn, avg_pool, act) 130 | 131 | 132 | def ResNet101(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 133 | return ResNetnc, (Bottleneck, [3, 4, 23, 3], use_bn, use_sn, use_wn, avg_pool, act) 134 | 135 | 136 | def ResNet152(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)): 137 | return ResNet(nc, Bottleneck, [3, 8, 36, 3], use_bn, use_sn, use_wn, avg_pool, act) -------------------------------------------------------------------------------- /libs/utils.py: -------------------------------------------------------------------------------- 1 | from libs.gaussian_blur import GaussianBlur 2 | 3 | import torchvision.transforms.functional as TF 4 | import torchvision.transforms as transforms 5 | import torch.nn.functional as F 6 | import matplotlib.pyplot as plt 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | import torchvision 12 | import torch 13 | import yaml 14 | 15 | class MSGLD: 16 | 17 | def __init__(self, config): 18 | self.config = config 19 | 20 | def __get_std__(self, count): 21 | return self.config['min_std'] + (self.config['max_std'] - self.config['min_std']) * torch.maximum(1.0 - count / self.config['threshold'], torch.zeros_like(count).cuda()) 22 | 23 | def __call__(self, log_pdf, init, count): 24 | out = init.detach().clone().requires_grad_(True) 25 | for i in range(self.config['iter']): 26 | lp = log_pdf(out).sum() 27 | lp.backward() 28 | out.data = out + self.config['lr'] * torch.clamp(out.grad, -self.config['tau'], self.config['tau']) + self.__get_std__(count) * torch.randn_like(out) 29 | out.grad.zero_() 30 | return out.detach().clone() 31 | 32 | def logsumexp(log_p): 33 | m, _ = torch.max(log_p, dim=1, keepdim=True) 34 | return torch.log(torch.exp(log_p - m).sum(dim=1, keepdim=True)) + m 35 | 36 | def log_softmax(log_p): 37 | return log_p - logsumexp(log_p) 38 | 39 | def softmax(log_p): 40 | m, _ = torch.max(log_p, dim=1, keepdim=True) 41 | f = torch.exp(log_p - m) 42 | return f / f.sum(dim=1, keepdim=True) 43 | 44 | def get_t(img_size, t_config): 45 | 46 | t = [] 47 | 48 | if 'crop_scale' in t_config: 49 | t.append(transforms.RandomResizedCrop(size=img_size, scale=tuple(t_config['crop_scale'].values()))) 50 | 51 | if 'flip_p' in t_config: 52 | t.append(transforms.RandomHorizontalFlip(p=t_config['flip_p'])) 53 | 54 | if 'jitter' in t_config: 55 | jitter = transforms.ColorJitter(t_config['jitter']['b'], t_config['jitter']['c'], t_config['jitter']['s'], t_config['jitter']['h']) 56 | t.append(transforms.RandomApply([jitter], p=t_config['jitter_p'])) 57 | 58 | if 'gray_p' in t_config: 59 | t.append(transforms.RandomGrayscale(p=t_config['gray_p'])) 60 | 61 | if 'blur_scale' in t_config: 62 | blur = GaussianBlur(kernel_size=int(t_config['blur_scale'] * img_size)) 63 | t.append(transforms.RandomApply([blur], p=1.0)) 64 | 65 | if 'noise_std' in t_config: 66 | t.append(lambda x : x + torch.randn_like(x) * t_config['noise_std']) 67 | 68 | return transforms.Compose(t) 69 | 70 | def get_optimizer(optim_config, params): 71 | if optim_config['optimizer'] == 'sgd': 72 | return optim.SGD(params, lr=optim_config['init_lr'], momentum=0.9, weight_decay=optim_config['weight_decay'] if ('weight_decay' in optim_config) else 0.0) 73 | elif optim_config['optimizer'] == 'adam': 74 | return optim.Adam(params, lr=optim_config['init_lr'], weight_decay=optim_config['weight_decay'] if ('weight_decay' in optim_config) else 0.0) 75 | else: 76 | raise NotImplementedError 77 | 78 | def get_lr_schedule(config): 79 | 80 | if config['optim']['lr_schedule'] == 'const': 81 | return lambda it : config['optim']['init_lr'] 82 | elif config['optim']['lr_schedule'] == 'cosine': 83 | return lambda it : config['optim']['init_lr'] * np.cos(0.5 * np.pi * it / config['its']) 84 | else: 85 | raise NotImplementedError 86 | 87 | def inner(x, y, normalize=False): 88 | x = x.flatten(start_dim=1) 89 | y = y.flatten(start_dim=1) 90 | 91 | if normalize: 92 | x = x / x.norm(dim=1, keepdim=True) 93 | y = y / y.norm(dim=1, keepdim=True) 94 | 95 | d = (x * y).sum(dim=1, keepdim=True) 96 | return d 97 | 98 | def dist(x, y, normalize=False): 99 | x = x.flatten(start_dim=1) 100 | y = y.flatten(start_dim=1) 101 | 102 | if normalize: 103 | x = x / (x.norm(dim=1, keepdim=True) + 1e-10) 104 | y = y / (y.norm(dim=1, keepdim=True) + 1e-10) 105 | 106 | d = (x - y).square().sum(dim=1, keepdim=True) 107 | return d 108 | 109 | def generate_views(x, t, n_samples): 110 | views = [] 111 | shape = list(x.shape[1:]) 112 | for _ in range(n_samples): 113 | views.append(t(x)[:,None]) 114 | views = torch.cat(views, dim=1).reshape([-1] + shape) 115 | return views 116 | 117 | def load_data(data_config, download=False): 118 | 119 | root = data_config['data_dir'] 120 | dataset = data_config['dataset'] 121 | 122 | if dataset == 'mnist': 123 | trainset = torchvision.datasets.MNIST(root=root, train=True, download=download, transform=[transforms.ToTensor()]) 124 | testset = torchvision.datasets.MNIST(root=root, train=False, download=download, transform=[transforms.ToTensor()]) 125 | 126 | train_X, train_y = trainset.data[:,None] / 255, trainset.targets 127 | test_X, test_y = testset.data[:,None] / 255, testset.targets 128 | n_classes = 10 129 | elif dataset == 'fmnist': 130 | trainset = torchvision.datasets.FashionMNIST(root=root, train=True, download=download, transform=[transforms.ToTensor()]) 131 | testset = torchvision.datasets.FashionMNIST(root=root, train=False, download=download, transform=[transforms.ToTensor()]) 132 | 133 | train_X, train_y = trainset.data[:,None] / 255, trainset.targets 134 | test_X, test_y = testset.data[:,None] / 255, testset.targets 135 | n_classes = 10 136 | elif dataset == 'cifar10': 137 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=download, transform=[transforms.ToTensor()]) 138 | testset = torchvision.datasets.CIFAR10(root=root, train=False, download=download, transform=[transforms.ToTensor()]) 139 | 140 | train_X = torch.tensor(trainset.data).permute(0,3,1,2) / 255 141 | train_y = torch.tensor(trainset.targets, dtype=int) 142 | 143 | test_X = torch.tensor(testset.data).permute(0,3,1,2) / 255 144 | test_y = torch.tensor(testset.targets, dtype=int) 145 | n_classes = 10 146 | elif dataset == 'cifar100': 147 | trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=download, transform=[transforms.ToTensor()]) 148 | testset = torchvision.datasets.CIFAR100(root=root, train=False, download=download, transform=[transforms.ToTensor()]) 149 | 150 | train_X = torch.tensor(trainset.data).permute(0,3,1,2) / 255 151 | train_y = torch.tensor(trainset.targets, dtype=int) 152 | 153 | test_X = torch.tensor(testset.data).permute(0,3,1,2) / 255 154 | test_y = torch.tensor(testset.targets, dtype=int) 155 | n_classes = 100 156 | 157 | train_n = train_X.shape[0] 158 | idx = torch.randperm(train_n) 159 | train_X = train_X[idx][:int(train_n * data_config['data_ratio'])] 160 | train_y = train_y[idx][:int(train_n * data_config['data_ratio'])] 161 | 162 | return train_X, train_y, test_X, test_y, n_classes 163 | 164 | def save_config(config, config_path): 165 | with open(config_path, 'w') as f: 166 | yaml.dump(config, f) 167 | 168 | def load_config(config_path): 169 | with open(config_path, 'r') as f: 170 | config = yaml.load(f, Loader=yaml.FullLoader) 171 | return config -------------------------------------------------------------------------------- /libs/simsiam.py: -------------------------------------------------------------------------------- 1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config 2 | from libs.net import init_enc 3 | 4 | import torchvision.transforms.functional as TF 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | import matplotlib.pyplot as plt 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | import torchvision 13 | import itertools 14 | import torch 15 | import time 16 | import os 17 | 18 | def save_ckpt(enc, proj, pred, opt, iteration, runtime, ckpt_path): 19 | 20 | checkpoint = { 21 | 'iteration' : iteration, 22 | 'runtime' : runtime, 23 | 'opt' : opt.state_dict() 24 | } 25 | 26 | if isinstance(enc, nn.DataParallel): 27 | checkpoint['enc_state_dict'] = enc.module.state_dict() 28 | else: 29 | checkpoint['enc_state_dict'] = enc.state_dict() 30 | 31 | if isinstance(proj, nn.DataParallel): 32 | checkpoint['proj_state_dict'] = proj.module.state_dict() 33 | else: 34 | checkpoint['proj_state_dict'] = proj.state_dict() 35 | 36 | if isinstance(pred, nn.DataParallel): 37 | checkpoint['pred_state_dict'] = pred.module.state_dict() 38 | else: 39 | checkpoint['pred_state_dict'] = pred.state_dict() 40 | 41 | torch.save(checkpoint, ckpt_path) 42 | 43 | def load_ckpt(enc, proj, pred, opt, ckpt_path): 44 | ckpt = torch.load(ckpt_path) 45 | 46 | if isinstance(enc, nn.DataParallel): 47 | enc.module.load_state_dict(ckpt['enc_state_dict']) 48 | else: 49 | enc.load_state_dict(ckpt['enc_state_dict']) 50 | 51 | if isinstance(proj, nn.DataParallel): 52 | proj.module.load_state_dict(ckpt['proj_state_dict']) 53 | else: 54 | proj.load_state_dict(ckpt['proj_state_dict']) 55 | 56 | if isinstance(pred, nn.DataParallel): 57 | pred.module.load_state_dict(ckpt['pred_state_dict']) 58 | else: 59 | pred.load_state_dict(ckpt['pred_state_dict']) 60 | 61 | opt.load_state_dict(ckpt['opt']) 62 | it = ckpt['iteration'] 63 | rt = ckpt['runtime'] 64 | 65 | return it, rt 66 | 67 | def D(p, z): 68 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 69 | 70 | class Preprocess(nn.Module): 71 | def __init__(self, mean, std): 72 | super(Preprocess, self).__init__() 73 | self.mean = mean 74 | self.std = std 75 | 76 | def forward(self, tensors): 77 | return (tensors - self.mean) / self.std 78 | 79 | class projection_MLP(nn.Module): 80 | 81 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): 82 | super().__init__() 83 | 84 | self.layer1 = nn.Sequential( 85 | nn.Linear(in_dim, hidden_dim), 86 | nn.BatchNorm1d(hidden_dim), 87 | nn.ReLU(inplace=True) 88 | ) 89 | self.layer2 = nn.Sequential( 90 | nn.Linear(hidden_dim, hidden_dim), 91 | nn.BatchNorm1d(hidden_dim), 92 | nn.ReLU(inplace=True) 93 | ) 94 | self.layer3 = nn.Sequential( 95 | nn.Linear(hidden_dim, out_dim), 96 | nn.BatchNorm1d(hidden_dim) 97 | ) 98 | self.num_layers = 3 99 | 100 | def set_layers(self, num_layers): 101 | self.num_layers = num_layers 102 | 103 | def forward(self, x): 104 | if self.num_layers == 3: 105 | x = self.layer1(x) 106 | x = self.layer2(x) 107 | x = self.layer3(x) 108 | elif self.num_layers == 2: 109 | x = self.layer1(x) 110 | x = self.layer3(x) 111 | else: 112 | raise Exception 113 | return x 114 | 115 | class prediction_MLP(nn.Module): 116 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): 117 | super().__init__() 118 | 119 | self.layer1 = nn.Sequential( 120 | nn.Linear(in_dim, hidden_dim), 121 | nn.BatchNorm1d(hidden_dim), 122 | nn.ReLU(inplace=True) 123 | ) 124 | 125 | self.layer2 = nn.Linear(hidden_dim, out_dim) 126 | 127 | def forward(self, x): 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | return x 131 | 132 | def SimSiam(train_X, enc, proj, pred, config, log_dir, start_epoch): 133 | 134 | train_n = train_X.shape[0] 135 | size = train_X.shape[2] 136 | nc = train_X.shape[1] 137 | 138 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 139 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 140 | 141 | params = itertools.chain(enc.parameters(), proj.parameters(), pred.parameters()) 142 | net = nn.Sequential(Preprocess(mean, std), enc, proj) 143 | 144 | t = get_t(size, config['t']) 145 | lr_schedule = get_lr_schedule(config) 146 | opt = get_optimizer(config['optim'], params) 147 | 148 | if start_epoch == 0: 149 | it, rt = 0, 0 150 | else: 151 | it, rt = load_ckpt(enc, proj, pred, opt, log_dir + '/{}.pt'.format(start_epoch)) 152 | 153 | enc.train() 154 | proj.train() 155 | pred.train() 156 | 157 | epoch_n_iter = int(np.ceil(train_n / config['bs'])) 158 | while True: 159 | if it % epoch_n_iter == 0: 160 | train_X = train_X[torch.randperm(train_n)] 161 | 162 | i = it % epoch_n_iter 163 | it += 1 164 | 165 | s = time.time() 166 | 167 | X = train_X[i * config['bs']:(i + 1) * config['bs']] 168 | 169 | X_v1 = t(X).cuda() 170 | X_v2 = t(X).cuda() 171 | 172 | Z_v1, Z_v2 = net(X_v1), net(X_v2) 173 | P_v1, P_v2 = pred(Z_v1), pred(Z_v2) 174 | 175 | loss = D(P_v1, Z_v2) / 2 + D(P_v2, Z_v1) / 2 176 | 177 | # Update 178 | opt.param_groups[0]['lr'] = lr_schedule(it) 179 | opt.zero_grad() 180 | loss.backward() 181 | opt.step() 182 | 183 | e = time.time() 184 | rt += (e - s) 185 | 186 | if it % config['p_iter'] == 0: 187 | save_ckpt(enc, proj, pred, opt, it, rt, log_dir + '/curr.pt') 188 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt)) 189 | 190 | if it % (epoch_n_iter * config['s_epoch']) == 0: 191 | save_ckpt(enc, proj, pred, opt, it, rt, log_dir + '/{}.pt'.format(it // epoch_n_iter)) 192 | 193 | if it >= config['its']: 194 | break 195 | 196 | def run(log_dir, config_dir, start_epoch, device): 197 | 198 | if not os.path.isdir(log_dir): 199 | os.makedirs(log_dir) 200 | 201 | config = load_config(config_dir) 202 | 203 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data']) 204 | 205 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs'])) 206 | 207 | if 'epochs' in config: 208 | config['its'] = epoch_n_iter * config['epochs'] 209 | 210 | if 'p_epoch' in config: 211 | config['p_iter'] = epoch_n_iter * config['p_epoch'] 212 | 213 | enc, feature_dim = init_enc(config['net'], device) 214 | proj = projection_MLP(feature_dim).cuda(device) 215 | pred = prediction_MLP().cuda(device) 216 | 217 | print('Running SimSiam from epoch {}'.format(start_epoch)) 218 | 219 | with torch.cuda.device(device): 220 | SimSiam(train_X, enc, proj, pred, config, log_dir, start_epoch) 221 | 222 | print('Finished!') -------------------------------------------------------------------------------- /libs/moco.py: -------------------------------------------------------------------------------- 1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config 2 | from libs.net import init_enc_proj 3 | 4 | import torchvision.transforms.functional as TF 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | import matplotlib.pyplot as plt 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | import torchvision 13 | import itertools 14 | import torch 15 | import time 16 | import os 17 | 18 | def save_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, iteration, runtime, save_all, ckpt_path): 19 | 20 | checkpoint = { 21 | 'iteration' : iteration, 22 | 'runtime' : runtime, 23 | 'opt' : opt.state_dict() 24 | } 25 | 26 | if save_all: 27 | checkpoint['queue'] = queue 28 | 29 | if isinstance(enc_k, nn.DataParallel): 30 | checkpoint['enc_k_state_dict'] = enc_k.module.state_dict() 31 | else: 32 | checkpoint['enc_k_state_dict'] = enc_k.state_dict() 33 | 34 | if isinstance(proj_k, nn.DataParallel): 35 | checkpoint['proj_k_state_dict'] = proj_k.module.state_dict() 36 | else: 37 | checkpoint['proj_k_state_dict'] = proj_k.state_dict() 38 | 39 | if isinstance(enc_q, nn.DataParallel): 40 | checkpoint['enc_state_dict'] = enc_q.module.state_dict() 41 | else: 42 | checkpoint['enc_state_dict'] = enc_q.state_dict() 43 | 44 | if isinstance(proj_q, nn.DataParallel): 45 | checkpoint['proj_state_dict'] = proj_q.module.state_dict() 46 | else: 47 | checkpoint['proj_state_dict'] = proj_q.state_dict() 48 | 49 | torch.save(checkpoint, ckpt_path) 50 | 51 | def load_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, ckpt_path): 52 | ckpt = torch.load(ckpt_path) 53 | 54 | if isinstance(enc_q, nn.DataParallel): 55 | enc_q.module.load_state_dict(ckpt['enc_state_dict']) 56 | else: 57 | enc_q.load_state_dict(ckpt['enc_state_dict']) 58 | 59 | if isinstance(proj_q, nn.DataParallel): 60 | proj_q.module.load_state_dict(ckpt['proj_state_dict']) 61 | else: 62 | proj_q.load_state_dict(ckpt['proj_state_dict']) 63 | 64 | if isinstance(enc_k, nn.DataParallel): 65 | enc_k.module.load_state_dict(ckpt['enc_k_state_dict']) 66 | else: 67 | enc_k.load_state_dict(ckpt['enc_k_state_dict']) 68 | 69 | if isinstance(proj_k, nn.DataParallel): 70 | proj_k.module.load_state_dict(ckpt['proj_k_state_dict']) 71 | else: 72 | proj_k.load_state_dict(ckpt['proj_k_state_dict']) 73 | 74 | queue.data = ckpt['queue'].data 75 | 76 | opt.load_state_dict(ckpt['opt']) 77 | it = ckpt['iteration'] 78 | rt = ckpt['runtime'] 79 | 80 | return it, rt 81 | 82 | def shuffled_idx(batch_size): 83 | shuffled_idxs = torch.randperm(batch_size).long().cuda() 84 | reverse_idxs = torch.zeros(batch_size).long().cuda() 85 | value = torch.arange(batch_size).long().long().cuda() 86 | reverse_idxs.index_copy_(0, shuffled_idxs, value) 87 | return shuffled_idxs, reverse_idxs 88 | 89 | def MoCo(train_X, enc_q, proj_q, enc_k, proj_k, config, log_dir, start_epoch): 90 | 91 | train_n = train_X.shape[0] 92 | size = train_X.shape[2] 93 | nc = train_X.shape[1] 94 | 95 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 96 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 97 | 98 | criterion = nn.CrossEntropyLoss().cuda() 99 | 100 | queue = torch.randn(config['queue_size'], config['net']['proj_dim']).cuda() 101 | 102 | params_q = itertools.chain(enc_q.parameters(), proj_q.parameters()) 103 | net_q = lambda x : proj_q(enc_q((x - mean) / std)) 104 | 105 | params_k = itertools.chain(enc_k.parameters(), proj_k.parameters()) 106 | net_k = lambda x : proj_k(enc_k((x - mean) / std)) 107 | 108 | t = get_t(size, config['t']) 109 | lr_schedule = get_lr_schedule(config) 110 | opt = get_optimizer(config['optim'], params_q) 111 | 112 | if start_epoch == 0: 113 | it, rt = 0, 0 114 | else: 115 | it, rt = load_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, log_dir + '/{}.pt'.format(start_epoch)) 116 | 117 | enc_q.train() 118 | proj_q.train() 119 | 120 | enc_k.train() 121 | proj_k.train() 122 | 123 | epoch_n_iter = int(np.ceil(train_n / config['bs'])) 124 | while True: 125 | if it % epoch_n_iter == 0: 126 | train_X = train_X[torch.randperm(train_n)] 127 | 128 | i = it % epoch_n_iter 129 | it += 1 130 | 131 | s = time.time() 132 | 133 | X = train_X[i * config['bs']:(i + 1) * config['bs']] 134 | 135 | X_v1 = t(X).cuda() 136 | X_v2 = t(X).cuda() 137 | 138 | Z_v1 = F.normalize(net_q(X_v1), dim=1) 139 | 140 | with torch.no_grad(): 141 | 142 | for p_q, p_k in zip(enc_q.parameters(), enc_k.parameters()): 143 | p_k.data = p_k.data * config['momentum'] + p_q.detach().data * (1 - config['momentum']) 144 | 145 | for p_q, p_k in zip(proj_q.parameters(), proj_k.parameters()): 146 | p_k.data = p_k.data * config['momentum'] + p_q.detach().data * (1 - config['momentum']) 147 | 148 | shuffled_idxs, reverse_idxs = shuffled_idx(X.shape[0]) 149 | X_v2 = X_v2[shuffled_idxs] 150 | Z_v2 = F.normalize(net_k(X_v2), dim=1) 151 | Z_v2 = Z_v2[reverse_idxs] 152 | 153 | Z_mem = F.normalize(queue.clone().detach(), dim=1) 154 | 155 | pos = torch.bmm(Z_v1.view(Z_v1.shape[0],1,-1), Z_v2.view(Z_v2.shape[0],-1,1)).squeeze(-1) 156 | neg = torch.mm(Z_v1, Z_mem.transpose(1,0)) 157 | 158 | logits = torch.cat((pos, neg), dim=1) / config['temperature'] 159 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 160 | loss = criterion(logits, labels) 161 | 162 | # Update 163 | opt.param_groups[0]['lr'] = lr_schedule(it) 164 | opt.zero_grad() 165 | loss.backward() 166 | opt.step() 167 | 168 | # Update queue 169 | queue = queue[Z_v2.shape[0]:] 170 | queue = torch.cat((queue, Z_v2.detach().clone()), dim=0) 171 | 172 | e = time.time() 173 | rt += (e - s) 174 | 175 | if it % config['p_iter'] == 0: 176 | save_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, it, rt, False, log_dir + '/curr.pt') 177 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt)) 178 | 179 | if it % (epoch_n_iter * config['s_epoch']) == 0: 180 | save_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, it, rt, True, log_dir + '/{}.pt'.format(it // epoch_n_iter)) 181 | 182 | if it >= config['its']: 183 | break 184 | 185 | def run(log_dir, config_dir, start_epoch, device): 186 | 187 | if not os.path.isdir(log_dir): 188 | os.makedirs(log_dir) 189 | 190 | config = load_config(config_dir) 191 | 192 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data']) 193 | 194 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs'])) 195 | 196 | if 'epochs' in config: 197 | config['its'] = epoch_n_iter * config['epochs'] 198 | 199 | if 'p_epoch' in config: 200 | config['p_iter'] = epoch_n_iter * config['p_epoch'] 201 | 202 | enc_q, proj_q, _ = init_enc_proj(config['net'], device) 203 | enc_k, proj_k, _ = init_enc_proj(config['net'], device) 204 | 205 | for p_q, p_k in zip(enc_q.parameters(), enc_k.parameters()): 206 | p_k.data.copy_(p_q.data) 207 | p_k.data.requires_grad = False 208 | 209 | for p_q, p_k in zip(proj_q.parameters(), proj_k.parameters()): 210 | p_k.data.copy_(p_q.data) 211 | p_k.data.requires_grad = False 212 | 213 | print('Running MoCo from epoch {}'.format(start_epoch)) 214 | 215 | with torch.cuda.device(device): 216 | MoCo(train_X, enc_q, proj_q, enc_k, proj_k, config, log_dir, start_epoch) 217 | 218 | print('Finished!') -------------------------------------------------------------------------------- /libs/byol.py: -------------------------------------------------------------------------------- 1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config 2 | from libs.net import init_enc 3 | 4 | import torchvision.transforms.functional as TF 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | import matplotlib.pyplot as plt 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | import torchvision 13 | import itertools 14 | import torch 15 | import time 16 | import os 17 | 18 | def save_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, iteration, runtime, save_all, ckpt_path): 19 | 20 | checkpoint = { 21 | 'iteration' : iteration, 22 | 'runtime' : runtime, 23 | 'opt' : opt.state_dict() 24 | } 25 | 26 | if save_all: 27 | 28 | if isinstance(enc_t, nn.DataParallel): 29 | checkpoint['enc_t_state_dict'] = enc_t.module.state_dict() 30 | else: 31 | checkpoint['enc_t_state_dict'] = enc_t.state_dict() 32 | 33 | if isinstance(proj_t, nn.DataParallel): 34 | checkpoint['proj_t_state_dict'] = proj_t.module.state_dict() 35 | else: 36 | checkpoint['proj_t_state_dict'] = proj_t.state_dict() 37 | 38 | if isinstance(enc_o, nn.DataParallel): 39 | checkpoint['enc_state_dict'] = enc_o.module.state_dict() 40 | else: 41 | checkpoint['enc_state_dict'] = enc_o.state_dict() 42 | 43 | if isinstance(proj_o, nn.DataParallel): 44 | checkpoint['proj_state_dict'] = proj_o.module.state_dict() 45 | else: 46 | checkpoint['proj_state_dict'] = proj_o.state_dict() 47 | 48 | if isinstance(pred_o, nn.DataParallel): 49 | checkpoint['pred_state_dict'] = pred_o.module.state_dict() 50 | else: 51 | checkpoint['pred_state_dict'] = pred_o.state_dict() 52 | 53 | torch.save(checkpoint, ckpt_path) 54 | 55 | def load_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, ckpt_path): 56 | ckpt = torch.load(ckpt_path) 57 | 58 | if isinstance(enc_o, nn.DataParallel): 59 | enc_o.module.load_state_dict(ckpt['enc_state_dict']) 60 | else: 61 | enc_o.load_state_dict(ckpt['enc_state_dict']) 62 | 63 | if isinstance(proj_o, nn.DataParallel): 64 | proj_o.module.load_state_dict(ckpt['proj_state_dict']) 65 | else: 66 | proj_o.load_state_dict(ckpt['proj_state_dict']) 67 | 68 | if isinstance(pred_o, nn.DataParallel): 69 | pred_o.module.load_state_dict(ckpt['pred_state_dict']) 70 | else: 71 | pred_o.load_state_dict(ckpt['pred_state_dict']) 72 | 73 | if isinstance(enc_t, nn.DataParallel): 74 | enc_t.module.load_state_dict(ckpt['enc_t_state_dict']) 75 | else: 76 | enc_t.load_state_dict(ckpt['enc_t_state_dict']) 77 | 78 | if isinstance(proj_t, nn.DataParallel): 79 | proj_t.module.load_state_dict(ckpt['proj_t_state_dict']) 80 | else: 81 | proj_t.load_state_dict(ckpt['proj_t_state_dict']) 82 | 83 | opt.load_state_dict(ckpt['opt']) 84 | it = ckpt['iteration'] 85 | rt = ckpt['runtime'] 86 | 87 | return it, rt 88 | 89 | def D(x, y): 90 | return (2 - 2 * F.cosine_similarity(x, y.detach(), dim=-1)).mean() 91 | 92 | class Preprocess(nn.Module): 93 | def __init__(self, mean, std): 94 | super(Preprocess, self).__init__() 95 | self.mean = mean 96 | self.std = std 97 | 98 | def forward(self, tensors): 99 | return (tensors - self.mean) / self.std 100 | 101 | class MLP(nn.Module): 102 | def __init__(self, in_dim, hidden_dim=4096, out_dim=256): 103 | super().__init__() 104 | 105 | self.layer1 = nn.Sequential( 106 | nn.Linear(in_dim, hidden_dim), 107 | nn.BatchNorm1d(hidden_dim), 108 | nn.ReLU(inplace=True) 109 | ) 110 | 111 | self.layer2 = nn.Linear(hidden_dim, out_dim) 112 | 113 | def forward(self, x): 114 | x = self.layer1(x) 115 | x = self.layer2(x) 116 | return x 117 | 118 | def BYOL(train_X, enc_o, proj_o, pred_o, enc_t, proj_t, config, log_dir, start_epoch): 119 | 120 | train_n = train_X.shape[0] 121 | size = train_X.shape[2] 122 | nc = train_X.shape[1] 123 | 124 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 125 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 126 | 127 | params = itertools.chain(enc_o.parameters(), proj_o.parameters(), pred_o.parameters()) 128 | net_o = nn.Sequential(Preprocess(mean, std), enc_o, proj_o) 129 | net_t = nn.Sequential(Preprocess(mean, std), enc_t, proj_t) 130 | 131 | t = get_t(size, config['t']) 132 | lr_schedule = get_lr_schedule(config) 133 | opt = get_optimizer(config['optim'], params) 134 | 135 | if start_epoch == 0: 136 | it, rt = 0, 0 137 | else: 138 | it, rt = load_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, log_dir + '/{}.pt'.format(start_epoch)) 139 | 140 | enc_o.train() 141 | proj_o.train() 142 | pred_o.train() 143 | 144 | enc_t.train() 145 | proj_t.train() 146 | 147 | epoch_n_iter = int(np.ceil(train_n / config['bs'])) 148 | while True: 149 | if it % epoch_n_iter == 0: 150 | train_X = train_X[torch.randperm(train_n)] 151 | 152 | i = it % epoch_n_iter 153 | it += 1 154 | 155 | s = time.time() 156 | 157 | X = train_X[i * config['bs']:(i + 1) * config['bs']] 158 | 159 | X_v1 = t(X).cuda() 160 | X_v2 = t(X).cuda() 161 | 162 | Z_v1_o, Z_v2_o = net_o(X_v1), net_o(X_v2) 163 | P_v1_o, P_v2_o = pred_o(Z_v1_o), pred_o(Z_v2_o) 164 | 165 | with torch.no_grad(): 166 | Z_v1_t, Z_v2_t = net_t(X_v1), net_t(X_v2) 167 | 168 | loss = D(P_v1_o, Z_v2_t) / 2 + D(P_v2_o, Z_v1_t) / 2 169 | 170 | # Update 171 | opt.param_groups[0]['lr'] = lr_schedule(it) 172 | opt.zero_grad() 173 | loss.backward() 174 | opt.step() 175 | 176 | # Online network update 177 | with torch.no_grad(): 178 | 179 | for p_o, p_t in zip(enc_o.parameters(), enc_t.parameters()): 180 | p_t.data = p_t.data * config['momentum'] + p_o.detach().data * (1 - config['momentum']) 181 | 182 | for p_o, p_t in zip(proj_o.parameters(), proj_t.parameters()): 183 | p_t.data = p_t.data * config['momentum'] + p_o.detach().data * (1 - config['momentum']) 184 | 185 | e = time.time() 186 | rt += (e - s) 187 | 188 | if it % config['p_iter'] == 0: 189 | save_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, it, rt, False, log_dir + '/curr.pt') 190 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt)) 191 | 192 | if it % (epoch_n_iter * config['s_epoch']) == 0: 193 | save_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, it, rt, True, log_dir + '/{}.pt'.format(it // epoch_n_iter)) 194 | 195 | if it >= config['its']: 196 | break 197 | 198 | def run(log_dir, config_dir, start_epoch, device): 199 | 200 | if not os.path.isdir(log_dir): 201 | os.makedirs(log_dir) 202 | 203 | config = load_config(config_dir) 204 | 205 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data']) 206 | 207 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs'])) 208 | 209 | if 'epochs' in config: 210 | config['its'] = epoch_n_iter * config['epochs'] 211 | 212 | if 'p_epoch' in config: 213 | config['p_iter'] = epoch_n_iter * config['p_epoch'] 214 | 215 | enc_o, feature_dim = init_enc(config['net'], device) 216 | proj_o = MLP(feature_dim).cuda(device) 217 | pred_o = MLP(256).cuda(device) 218 | 219 | enc_t, feature_dim = init_enc(config['net'], device) 220 | proj_t = MLP(feature_dim).cuda(device) 221 | 222 | for p_o, p_t in zip(enc_o.parameters(), enc_t.parameters()): 223 | p_t.data.copy_(p_t.data) 224 | p_t.data.requires_grad = False 225 | 226 | for p_o, p_t in zip(proj_o.parameters(), proj_t.parameters()): 227 | p_t.data.copy_(p_t.data) 228 | p_t.data.requires_grad = False 229 | 230 | print('Running BYOL from epoch {}'.format(start_epoch)) 231 | 232 | with torch.cuda.device(device): 233 | BYOL(train_X, enc_o, proj_o, pred_o, enc_t, proj_t, config, log_dir, start_epoch) 234 | 235 | print('Finished!') -------------------------------------------------------------------------------- /libs/evaluation.py: -------------------------------------------------------------------------------- 1 | from libs.utils import log_softmax, load_config, load_data 2 | from libs.net import init_enc_proj 3 | from libs.models import LIN 4 | 5 | from tqdm import tqdm 6 | 7 | import matplotlib.pyplot as plt 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import numpy as np 12 | 13 | import itertools 14 | import torch 15 | 16 | # Use this function to get linear evaluation accuracy 17 | def eval_acc(ckpt_dir, config_dir, eval_config, data_config, device): 18 | 19 | config = load_config(config_dir) 20 | 21 | train_X, train_y, test_X, test_y, n_classes = load_data(data_config) 22 | 23 | if eval_config['standardize']: 24 | nc = train_X.shape[1] 25 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 26 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 27 | else: 28 | mean = 0.5 29 | std = 0.5 30 | 31 | enc, _, _ = init_enc_proj(config['net'], device) 32 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict']) 33 | net = nn.Sequential(Preprocess(mean, std), enc) 34 | net.eval() 35 | 36 | lin = LIN(512, n_classes).cuda(device) 37 | 38 | with torch.cuda.device(device): 39 | train(lin, net, train_X, train_y, test_X, test_y, n_classes, eval_config) 40 | 41 | # Use this function to get linear transfer accuracy 42 | def transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device): 43 | 44 | config = load_config(config_dir) 45 | 46 | if eval_config['standardize']: 47 | train_X, train_y, test_X, test_y, n_classes = load_data(source_data_config) 48 | nc = train_X.shape[1] 49 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 50 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 51 | else: 52 | mean = 0.5 53 | std = 0.5 54 | 55 | train_X, train_y, test_X, test_y, n_classes = load_data(target_data_config) 56 | 57 | enc, _, _ = init_enc_proj(config['net'], device) 58 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict']) 59 | net = nn.Sequential(Preprocess(mean, std), enc) 60 | net.eval() 61 | 62 | lin = LIN(512, n_classes).cuda(device) 63 | 64 | with torch.cuda.device(device): 65 | train(lin, net, train_X, train_y, test_X, test_y, n_classes, eval_config) 66 | 67 | # Use this function to get knn evaluation accuracy 68 | def eval_knn_acc(ckpt_dir, config_dir, eval_config, data_config, device): 69 | config = load_config(config_dir) 70 | train_X, train_y, test_X, test_y, n_classes = load_data(data_config) 71 | 72 | if eval_config['standardize']: 73 | nc = train_X.shape[1] 74 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 75 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 76 | else: 77 | mean = 0.5 78 | std = 0.5 79 | 80 | enc, _, _ = init_enc_proj(config['net'], device) 81 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict']) 82 | net = nn.Sequential(Preprocess(mean, std), enc) 83 | net.eval() 84 | 85 | with torch.cuda.device(device): 86 | test_n = test_X.shape[0] 87 | train_Z, test_Z = embed_dataset(net, train_X, test_X) 88 | 89 | acc = 0 90 | N = int(np.ceil(test_n / eval_config['bs'])) 91 | for i in tqdm(range(N)): 92 | Z = test_Z[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda() 93 | y = test_y[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda() 94 | 95 | D = (Z[:,None] - train_Z[None].cuda()).norm(dim=2).cpu() 96 | w, inds = D.topk(eval_config['K'], dim=1, largest=False) 97 | 98 | v = train_y[inds] 99 | a = v.reshape(-1) 100 | a = F.one_hot(a, num_classes=n_classes) 101 | a = a.reshape(v.shape[0],v.shape[1],n_classes) 102 | weight_pred = a / w[...,None] 103 | weight_pred = weight_pred.sum(dim=1) 104 | wp = weight_pred.argmax(dim=1) 105 | acc += (wp.cuda() == y).sum() 106 | 107 | print(acc / ((i + 1) * eval_config['bs'])) 108 | 109 | # Use this function to get knn transfer accuracy 110 | def eval_knn_transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device): 111 | config = load_config(config_dir) 112 | 113 | if eval_config['standardize']: 114 | train_X, train_y, test_X, test_y, n_classes = load_data(source_data_config) 115 | nc = train_X.shape[1] 116 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 117 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device) 118 | else: 119 | mean = 0.5 120 | std = 0.5 121 | 122 | train_X, train_y, test_X, test_y, n_classes = load_data(target_data_config) 123 | 124 | enc, _, _ = init_enc_proj(config['net'], device) 125 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict']) 126 | net = nn.Sequential(Preprocess(mean, std), enc) 127 | net.eval() 128 | 129 | with torch.cuda.device(device): 130 | test_n = test_X.shape[0] 131 | train_Z, test_Z = embed_dataset(net, train_X, test_X) 132 | 133 | acc = 0 134 | N = int(np.ceil(test_n / eval_config['bs'])) 135 | for i in tqdm(range(N)): 136 | Z = test_Z[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda() 137 | y = test_y[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda() 138 | 139 | D = (Z[:,None] - train_Z[None].cuda()).norm(dim=2).cpu() 140 | w, inds = D.topk(eval_config['K'], dim=1, largest=False) 141 | 142 | v = train_y[inds] 143 | a = v.reshape(-1) 144 | a = F.one_hot(a, num_classes=n_classes) 145 | a = a.reshape(v.shape[0],v.shape[1],n_classes) 146 | weight_pred = a / w[...,None] 147 | weight_pred = weight_pred.sum(dim=1) 148 | wp = weight_pred.argmax(dim=1) 149 | acc += (wp.cuda() == y).sum() 150 | 151 | print(acc / ((i + 1) * eval_config['bs'])) 152 | 153 | def train(lin, net, train_X, train_y, test_X, test_y, n_classes, eval_config): 154 | 155 | train_n = train_X.shape[0] 156 | train_Z, test_Z = embed_dataset(net, train_X, test_X) 157 | opt = optim.Adam(lin.parameters(), lr=eval_config['lr']) 158 | 159 | for epoch in range(eval_config['epochs']): 160 | for i in range(int(np.ceil(train_n / eval_config['bs']))): 161 | 162 | Z = train_Z[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda() 163 | y = train_y[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda() 164 | 165 | log_p = lin(Z) 166 | 167 | loss = -(F.one_hot(y, num_classes=n_classes) * log_softmax(log_p)).mean() 168 | 169 | opt.zero_grad() 170 | loss.backward() 171 | opt.step() 172 | 173 | if (epoch + 1) % eval_config['p_epoch'] == 0: 174 | print('Epoch {} | Accuracy : {:.3f}'.format(epoch + 1, accuracy(lin, test_Z, test_y))) 175 | 176 | class Preprocess(nn.Module): 177 | def __init__(self, mean, std): 178 | super(Preprocess, self).__init__() 179 | self.mean = mean 180 | self.std = std 181 | 182 | def forward(self, tensors): 183 | return (tensors - self.mean) / self.std 184 | 185 | def embed_dataset(net, train_X, test_X, bs=64): 186 | 187 | train_n = train_X.shape[0] 188 | test_n = test_X.shape[0] 189 | nc = train_X.shape[1] 190 | 191 | train_Z = [] 192 | for i in tqdm(range(int(np.ceil(train_n / bs)))): 193 | batch = train_X[i * bs:(i + 1) * bs].cuda() 194 | train_Z.append(net(batch).detach().cpu()) 195 | train_Z = torch.cat(train_Z, dim=0) 196 | 197 | test_Z = [] 198 | for i in tqdm(range(int(np.ceil(test_n / bs)))): 199 | batch = test_X[i * bs:(i + 1) * bs].cuda() 200 | test_Z.append(net(batch).detach().cpu()) 201 | test_Z = torch.cat(test_Z, dim=0) 202 | 203 | return train_Z, test_Z 204 | 205 | def accuracy(clf, test_Z, test_y, bs=500): 206 | 207 | test_n = test_Z.shape[0] 208 | 209 | acc = 0 210 | for i in range(int(np.ceil(test_n / bs))): 211 | Z = test_Z[i * bs:(i + 1) * bs].cuda() 212 | y = test_y[i * bs:(i + 1) * bs].cuda() 213 | log_p = clf(Z) 214 | acc += (torch.argmax(log_p, dim=1) == y).sum() / test_n 215 | 216 | return acc.item() -------------------------------------------------------------------------------- /libs/ebclr.py: -------------------------------------------------------------------------------- 1 | from libs.utils import MSGLD, dist, logsumexp, log_softmax, load_data, get_t, get_optimizer, get_lr_schedule, load_config 2 | from libs.net import init_enc_proj 3 | 4 | import torchvision.transforms.functional as TF 5 | import torchvision.transforms as transforms 6 | import torch.nn.functional as F 7 | import matplotlib.pyplot as plt 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import numpy as np 11 | 12 | import torchvision 13 | import itertools 14 | import torch 15 | import time 16 | import os 17 | 18 | def save_ckpt(buffer, enc, proj, opt, iteration, runtime, save_buffer, ckpt_path): 19 | 20 | checkpoint = { 21 | 'iteration' : iteration, 22 | 'runtime' : runtime, 23 | 'opt' : opt.state_dict() 24 | } 25 | 26 | if save_buffer: 27 | checkpoint['buffer'] = buffer.buffer 28 | checkpoint['counter'] = buffer.counter 29 | 30 | if isinstance(enc, nn.DataParallel): 31 | checkpoint['enc_state_dict'] = enc.module.state_dict() 32 | else: 33 | checkpoint['enc_state_dict'] = enc.state_dict() 34 | 35 | if isinstance(proj, nn.DataParallel): 36 | checkpoint['proj_state_dict'] = proj.module.state_dict() 37 | else: 38 | checkpoint['proj_state_dict'] = proj.state_dict() 39 | 40 | torch.save(checkpoint, ckpt_path) 41 | 42 | def load_ckpt(enc, proj, buffer, opt, ckpt_path): 43 | ckpt = torch.load(ckpt_path) 44 | 45 | if isinstance(enc, nn.DataParallel): 46 | enc.module.load_state_dict(ckpt['enc_state_dict']) 47 | else: 48 | enc.load_state_dict(ckpt['enc_state_dict']) 49 | 50 | if isinstance(proj, nn.DataParallel): 51 | proj.module.load_state_dict(ckpt['proj_state_dict']) 52 | else: 53 | proj.load_state_dict(ckpt['proj_state_dict']) 54 | 55 | buffer.buffer = ckpt['buffer'] 56 | buffer.counter = ckpt['counter'] 57 | 58 | opt.load_state_dict(ckpt['opt']) 59 | 60 | return ckpt['iteration'], ckpt['runtime'] 61 | 62 | def transform_data(x, t, bs): 63 | n_iter = int(np.ceil(x.shape[0] / bs)) 64 | return torch.cat([t(x[j * bs:(j + 1) * bs]) for j in range(n_iter)], dim=0) 65 | 66 | def similarity_f(x, y, normalize): 67 | 68 | if normalize: 69 | x = x / (x.norm(dim=1, keepdim=True) + 1e-10) 70 | y = y / (y.norm(dim=1, keepdim=True) + 1e-10) 71 | 72 | return -(x[:,None] - y[None]).square().sum(dim=2) 73 | 74 | class SimCLR_Loss(nn.Module): 75 | 76 | def __init__(self, normalize, temperature): 77 | super(SimCLR_Loss, self).__init__() 78 | self.normalize = normalize 79 | self.temperature = temperature 80 | self.criterion = nn.CrossEntropyLoss(reduction="sum") 81 | 82 | def mask_correlated_samples(self, batch_size): 83 | N = 2 * batch_size 84 | mask = torch.ones((N, N), dtype=bool) 85 | mask = mask.fill_diagonal_(0) 86 | 87 | for i in range(batch_size): 88 | mask[i, batch_size + i] = 0 89 | mask[batch_size + i, i] = 0 90 | return mask 91 | 92 | def forward(self, z_i, z_j): 93 | 94 | batch_size = z_i.shape[0] 95 | 96 | N = 2 * batch_size 97 | 98 | z = torch.cat((z_i, z_j), dim=0) 99 | 100 | sim = similarity_f(z, z, self.normalize) / self.temperature 101 | 102 | sim_i_j = torch.diag(sim, batch_size) 103 | sim_j_i = torch.diag(sim, -batch_size) 104 | 105 | mask = self.mask_correlated_samples(batch_size) 106 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) 107 | negative_samples = sim[mask].reshape(N, -1) 108 | 109 | labels = torch.from_numpy(np.array([0] * N)).reshape(-1).to(positive_samples.device).long() 110 | 111 | logits = torch.cat((positive_samples, negative_samples), dim=1) 112 | loss = self.criterion(logits, labels) 113 | loss /= N 114 | 115 | return loss 116 | 117 | class Buffer: 118 | 119 | def __init__(self, train_X, t, buffer_config): 120 | self.train_X = train_X 121 | self.config = buffer_config 122 | self.t = t 123 | 124 | self.train_n = train_X.shape[0] 125 | self.size = train_X.shape[2] 126 | self.nc = train_X.shape[1] 127 | 128 | self.idx = 0 129 | 130 | self.counter = torch.ones(size=[self.config['size'],1,1,1]) * 5.0 131 | self.init_buffer() 132 | 133 | def __get_sample__(self, n_samples): 134 | r_idx = torch.randint(self.train_X.shape[0], size=[n_samples]) 135 | samples = self.train_X[r_idx] 136 | samples = transform_data(samples, self.t, self.config['bs']) 137 | return samples 138 | 139 | def __get_rand__(self, n_samples): 140 | if 'CD_ratio' in self.config: 141 | samples = self.__get_sample__(n_samples) 142 | return (self.config['CD_ratio'] * samples + (1.0 - self.config['CD_ratio']) * torch.rand_like(samples)) * 2.0 - 1.0 143 | else: 144 | return torch.rand(size=[n_samples, self.nc, self.size, self.size]) * 2.0 - 1.0 145 | 146 | def init_buffer(self): 147 | self.buffer = self.__get_rand__(self.config['size']) 148 | 149 | def sample(self, n_samples): 150 | self.idx = torch.randint(self.config['size'], size=[n_samples]) 151 | sample = self.buffer[self.idx] 152 | count = self.counter[self.idx].clone() 153 | 154 | r_idx = torch.randint(n_samples, size=[int(n_samples * self.config['rho'])]) 155 | sample[r_idx] = self.__get_rand__(n_samples)[r_idx] 156 | count[r_idx] = 0.0 157 | 158 | self.counter[self.idx] = (count + 1.0) 159 | 160 | return sample.cuda(), count.cuda() 161 | 162 | def update(self, samples): 163 | self.buffer[self.idx] = samples.detach().clone().cpu() 164 | 165 | def shuffled_idx(batch_size): 166 | shuffled_idxs = torch.randperm(batch_size).long().cuda() 167 | reverse_idxs = torch.zeros(batch_size).long().cuda() 168 | value = torch.arange(batch_size).long().long().cuda() 169 | reverse_idxs.index_copy_(0, shuffled_idxs, value) 170 | return shuffled_idxs, reverse_idxs 171 | 172 | def EBCLR(train_X, enc, proj, config, log_dir, start_epoch): 173 | 174 | train_n = train_X.shape[0] 175 | size = train_X.shape[2] 176 | nc = train_X.shape[1] 177 | 178 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 179 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda() 180 | 181 | simclr_loss = SimCLR_Loss(config['normalize'], config['temperature']) 182 | logits = lambda x, y : similarity_f(x, y, config['normalize']) / config['temperature'] 183 | log_pdf = lambda x, y : logsumexp(logits(x, y)) 184 | 185 | params = itertools.chain(enc.parameters(), proj.parameters()) 186 | net = lambda x : proj(enc(x)) 187 | 188 | t = get_t(size, config['t']) 189 | buffer = Buffer(train_X, t, config['buffer']) 190 | lr_schedule = get_lr_schedule(config) 191 | opt = get_optimizer(config['optim'], params) 192 | sgld = MSGLD(config['sgld']) 193 | 194 | if start_epoch == 0: 195 | it, rt = 0, 0 196 | else: 197 | it, rt = load_ckpt(enc, proj, buffer, opt, log_dir + '/{}.pt'.format(start_epoch)) 198 | 199 | enc.train() 200 | proj.train() 201 | 202 | epoch_n_iter = int(np.ceil(train_n / config['bs'])) 203 | while True: 204 | if it % epoch_n_iter == 0: 205 | idx = torch.randperm(train_n) 206 | V1 = transform_data(train_X, t, config['bs'])[idx] 207 | V2 = transform_data(train_X, t, config['bs'])[idx] 208 | 209 | i = it % epoch_n_iter 210 | it += 1 211 | 212 | s = time.time() 213 | 214 | X_v1 = V1[i * config['bs']:(i + 1) * config['bs']].cuda() * 2.0 - 1.0 215 | X_v2 = V2[i * config['bs']:(i + 1) * config['bs']].cuda() * 2.0 - 1.0 216 | 217 | if config['net']['use_bn']: 218 | shuffled_idxs, reverse_idxs = shuffled_idx(X_v2.shape[0]) 219 | Z_v1, Z_v2 = net(X_v1), net(X_v2[shuffled_idxs])[reverse_idxs] 220 | else: 221 | Z_v1, Z_v2 = net(X_v1), net(X_v2) 222 | 223 | if 'neg_bs' in config: 224 | X_init, count = buffer.sample(config['neg_bs']) 225 | else: 226 | X_init, count = buffer.sample(X_v1.shape[0]) 227 | 228 | X_n = sgld(lambda x : log_pdf(net(x), Z_v2.detach().clone()), X_init, count) 229 | buffer.update(X_n) 230 | Z_n = net(X_n) 231 | 232 | log_pdf_d = log_pdf(Z_v1, Z_v2) 233 | log_pdf_n = log_pdf(Z_n, Z_v2) 234 | gen_loss = -log_pdf_d.mean() + log_pdf_n.mean() + config['lmda1'] * (Z_v1.square().sum() + Z_n.square().sum()) 235 | 236 | loss = config['lmda2'] * gen_loss + simclr_loss(Z_v1, Z_v2) 237 | 238 | # Update 239 | opt.param_groups[0]['lr'] = lr_schedule(it) 240 | opt.zero_grad() 241 | loss.backward() 242 | 243 | if config['optim']['optimizer'] == 'adam': 244 | for i, param in enumerate(params): 245 | std = opt.state_dict()['state'][i]['exp_avg_sq'].sqrt() 246 | param.grad = clamp(param.grad, - 3.0 * std, 3.0 * std) 247 | 248 | opt.step() 249 | 250 | e = time.time() 251 | rt += (e - s) 252 | 253 | if it % config['p_iter'] == 0: 254 | save_ckpt(buffer, enc, proj, opt, it, rt, False, log_dir + '/curr.pt') 255 | 256 | pdf_d = torch.exp(log_pdf_d).mean().item() 257 | pdf_n = torch.exp(log_pdf_n).mean().item() 258 | 259 | print('Epoch : {:.3f} | Data PDF : {:.3e} | Noise PDF : {:.3e} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, pdf_d, pdf_n, lr_schedule(it), rt)) 260 | 261 | if it % (epoch_n_iter * config['s_epoch']) == 0: 262 | save_ckpt(buffer, enc, proj, opt, it, rt, config['save_buffer'], log_dir + '/{}.pt'.format(it // epoch_n_iter)) 263 | 264 | if it >= config['its']: 265 | break 266 | 267 | def run(log_dir, config_dir, start_epoch, device): 268 | 269 | if not os.path.isdir(log_dir): 270 | os.makedirs(log_dir) 271 | 272 | config = load_config(config_dir) 273 | 274 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data']) 275 | 276 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs'])) 277 | 278 | if 'epochs' in config: 279 | config['its'] = epoch_n_iter * config['epochs'] 280 | 281 | if 'p_epoch' in config: 282 | config['p_iter'] = epoch_n_iter * config['p_epoch'] 283 | 284 | enc, proj, _ = init_enc_proj(config['net'], device) 285 | 286 | print('Starting EBCLR from epoch {}'.format(start_epoch)) 287 | 288 | with torch.cuda.device(device): 289 | flag = EBCLR(train_X, enc, proj, config, log_dir, start_epoch) 290 | 291 | print('Finished!') --------------------------------------------------------------------------------