├── assets
└── main.png
├── libs
├── __pycache__
│ ├── ebclr.cpython-39.pyc
│ ├── net.cpython-39.pyc
│ ├── utils.cpython-39.pyc
│ ├── models.cpython-39.pyc
│ ├── resnets.cpython-39.pyc
│ ├── evaluation.cpython-39.pyc
│ └── gaussian_blur.cpython-39.pyc
├── gaussian_blur.py
├── models.py
├── net.py
├── simclr.py
├── resnets.py
├── utils.py
├── simsiam.py
├── moco.py
├── byol.py
├── evaluation.py
└── ebclr.py
├── configs_simsiam
├── fmnist_b256.yaml
├── mnist_b256.yaml
├── cifar100_b256.yaml
└── cifar10_b256.yaml
├── configs_byol
├── fmnist_b256.yaml
├── mnist_b256.yaml
├── cifar100_b256.yaml
└── cifar10_b256.yaml
├── configs_simclr
├── fmnist_b256.yaml
├── fmnist_b64.yaml
├── mnist_b128.yaml
├── mnist_b16.yaml
├── mnist_b256.yaml
├── mnist_b64.yaml
├── fmnist_b128.yaml
├── fmnist_b16.yaml
├── cifar100_b128.yaml
├── cifar100_b256.yaml
├── cifar100_b64.yaml
├── cifar10_b128.yaml
├── cifar10_b16.yaml
├── cifar10_b256.yaml
├── cifar10_b64.yaml
└── cifar100_b16.yaml
├── configs_moco
├── fmnist_b256.yaml
├── mnist_b256.yaml
├── cifar100_b256.yaml
└── cifar10_b256.yaml
├── configs_ebclr
├── fmnist_b16.yaml
├── fmnist_b64.yaml
├── mnist_b128.yaml
├── mnist_b16.yaml
├── mnist_b64.yaml
├── fmnist_b128.yaml
├── cifar100_b128.yaml
├── cifar100_b16.yaml
├── cifar100_b64.yaml
├── cifar10_b128.yaml
├── cifar10_b16.yaml
└── cifar10_b64.yaml
├── Train.ipynb
├── README.md
└── Evaluation.ipynb
/assets/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/assets/main.png
--------------------------------------------------------------------------------
/libs/__pycache__/ebclr.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/ebclr.cpython-39.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/net.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/net.cpython-39.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/models.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/models.cpython-39.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/resnets.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/resnets.cpython-39.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/evaluation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/evaluation.cpython-39.pyc
--------------------------------------------------------------------------------
/libs/__pycache__/gaussian_blur.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/1202kbs/EBCLR/HEAD/libs/__pycache__/gaussian_blur.cpython-39.pyc
--------------------------------------------------------------------------------
/configs_simsiam/fmnist_b256.yaml:
--------------------------------------------------------------------------------
1 | epochs : 200
2 | bs : 256
3 | p_epoch : 5
4 | s_epoch : 5
5 |
6 | data:
7 | data_dir : "./data"
8 | dataset : "fmnist"
9 | data_ratio : 1.0
10 |
11 | optim:
12 | optimizer: "sgd"
13 | init_lr: 0.03
14 | lr_schedule: "const"
15 | weight_decay: 0.0001
16 |
17 | net:
18 | encoder : "resnet18"
19 | use_bn : true
20 | use_sn : false
21 | use_wn : false
22 | act : "relu"
23 | proj_dim : 2048
24 | nc : 1
25 |
26 | t:
27 | crop_scale:
28 | min: 0.08
29 | max: 1.0
--------------------------------------------------------------------------------
/configs_simsiam/mnist_b256.yaml:
--------------------------------------------------------------------------------
1 | epochs : 200
2 | bs : 256
3 | p_epoch : 5
4 | s_epoch : 5
5 |
6 | data:
7 | data_dir : "./data"
8 | dataset : "kmnist"
9 | data_ratio : 1.0
10 |
11 | optim:
12 | optimizer: "sgd"
13 | init_lr: 0.03
14 | lr_schedule: "const"
15 | weight_decay: 0.0001
16 |
17 | net:
18 | encoder : "resnet18"
19 | use_bn : true
20 | use_sn : false
21 | use_wn : false
22 | act : "relu"
23 | proj_dim : 2048
24 | nc : 1
25 |
26 | t:
27 | crop_scale:
28 | min: 0.08
29 | max: 1.0
--------------------------------------------------------------------------------
/configs_byol/fmnist_b256.yaml:
--------------------------------------------------------------------------------
1 | momentum : 0.9
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "fmnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 256
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_byol/mnist_b256.yaml:
--------------------------------------------------------------------------------
1 | momentum : 0.9
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "mnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 256
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/fmnist_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "fmnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/fmnist_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 64
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "fmnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.0075
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/mnist_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 128
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "mnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.015
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/mnist_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 16
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "mnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.001875
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/mnist_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "mnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/mnist_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 64
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "mnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.0075
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/fmnist_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 128
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "fmnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.015
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_simclr/fmnist_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 16
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "fmnist"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.001875
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 1
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
--------------------------------------------------------------------------------
/configs_moco/fmnist_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.07
2 | queue_size : 65536
3 | momentum : 0.999
4 | epochs : 200
5 | bs : 256
6 | p_epoch : 5
7 | s_epoch : 5
8 |
9 | data:
10 | data_dir : "./data"
11 | dataset : "fmnist"
12 | data_ratio : 1.0
13 |
14 | optim:
15 | optimizer: "sgd"
16 | init_lr: 0.03
17 | lr_schedule: "const"
18 | weight_decay: 0.0001
19 |
20 | net:
21 | encoder : "resnet18"
22 | use_bn : true
23 | use_sn : false
24 | use_wn : false
25 | act : "relu"
26 | proj_dim : 128
27 | nc : 1
28 |
29 | t:
30 | crop_scale:
31 | min: 0.08
32 | max: 1.0
--------------------------------------------------------------------------------
/configs_moco/mnist_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.07
2 | queue_size : 65536
3 | momentum : 0.999
4 | epochs : 200
5 | bs : 256
6 | p_epoch : 5
7 | s_epoch : 5
8 |
9 | data:
10 | data_dir : "./data"
11 | dataset : "mnist"
12 | data_ratio : 1.0
13 |
14 | optim:
15 | optimizer: "sgd"
16 | init_lr: 0.03
17 | lr_schedule: "const"
18 | weight_decay: 0.0001
19 |
20 | net:
21 | encoder : "resnet18"
22 | use_bn : true
23 | use_sn : false
24 | use_wn : false
25 | act : "relu"
26 | proj_dim : 128
27 | nc : 1
28 |
29 | t:
30 | crop_scale:
31 | min: 0.08
32 | max: 1.0
--------------------------------------------------------------------------------
/configs_simsiam/cifar100_b256.yaml:
--------------------------------------------------------------------------------
1 | epochs : 200
2 | bs : 256
3 | p_epoch : 5
4 | s_epoch : 5
5 |
6 | data:
7 | data_dir : "./data"
8 | dataset : "cifar100"
9 | data_ratio : 1.0
10 |
11 | optim:
12 | optimizer: "sgd"
13 | init_lr: 0.03
14 | lr_schedule: "const"
15 | weight_decay: 0.0001
16 |
17 | net:
18 | encoder : "resnet18"
19 | use_bn : true
20 | use_sn : false
21 | use_wn : false
22 | act : "relu"
23 | proj_dim : 2048
24 | nc : 3
25 |
26 | t:
27 | crop_scale:
28 | min: 0.08
29 | max: 1.0
30 | flip_p : 0.5
31 | jitter:
32 | b: 0.8
33 | c: 0.8
34 | s: 0.8
35 | h: 0.2
36 | jitter_p : 0.8
37 | gray_p : 0.2
38 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simsiam/cifar10_b256.yaml:
--------------------------------------------------------------------------------
1 | epochs : 200
2 | bs : 256
3 | p_epoch : 5
4 | s_epoch : 5
5 |
6 | data:
7 | data_dir : "./data"
8 | dataset : "cifar10"
9 | data_ratio : 1.0
10 |
11 | optim:
12 | optimizer: "sgd"
13 | init_lr: 0.03
14 | lr_schedule: "const"
15 | weight_decay: 0.0001
16 |
17 | net:
18 | encoder : "resnet18"
19 | use_bn : true
20 | use_sn : false
21 | use_wn : false
22 | act : "relu"
23 | proj_dim : 2048
24 | nc : 3
25 |
26 | t:
27 | crop_scale:
28 | min: 0.08
29 | max: 1.0
30 | flip_p : 0.5
31 | jitter:
32 | b: 0.8
33 | c: 0.8
34 | s: 0.8
35 | h: 0.2
36 | jitter_p : 0.8
37 | gray_p : 0.2
38 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_byol/cifar100_b256.yaml:
--------------------------------------------------------------------------------
1 | momentum : 0.9
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar10"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 256
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_byol/cifar10_b256.yaml:
--------------------------------------------------------------------------------
1 | momentum : 0.9
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar10"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 256
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar100_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 128
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar100"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.015
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar100_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar100"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar100_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 64
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar100"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.0075
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar10_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 128
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar10"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.015
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar10_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 16
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar10"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.001875
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar10_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 256
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar10"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.03
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar10_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 64
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar10"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.0075
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_simclr/cifar100_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.1
2 | epochs : 200
3 | bs : 16
4 | p_epoch : 5
5 | s_epoch : 5
6 |
7 | data:
8 | data_dir : "./data"
9 | dataset : "cifar100"
10 | data_ratio : 1.0
11 |
12 | optim:
13 | optimizer: "sgd"
14 | init_lr: 0.001875
15 | lr_schedule: "const"
16 | weight_decay: 0.0001
17 |
18 | net:
19 | encoder : "resnet18"
20 | use_bn : true
21 | use_sn : false
22 | use_wn : false
23 | act : "relu"
24 | proj_dim : 128
25 | nc : 3
26 |
27 | t:
28 | crop_scale:
29 | min: 0.08
30 | max: 1.0
31 | flip_p : 0.5
32 | jitter:
33 | b: 0.8
34 | c: 0.8
35 | s: 0.8
36 | h: 0.2
37 | jitter_p : 0.8
38 | gray_p : 0.2
39 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_moco/cifar100_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.07
2 | queue_size : 65536
3 | momentum : 0.999
4 | epochs : 200
5 | bs : 256
6 | p_epoch : 5
7 | s_epoch : 5
8 |
9 | data:
10 | data_dir : "./data"
11 | dataset : "cifar100"
12 | data_ratio : 1.0
13 |
14 | optim:
15 | optimizer: "sgd"
16 | init_lr: 0.03
17 | lr_schedule: "const"
18 | weight_decay: 0.0001
19 |
20 | net:
21 | encoder : "resnet18"
22 | use_bn : true
23 | use_sn : false
24 | use_wn : false
25 | act : "relu"
26 | proj_dim : 128
27 | nc : 3
28 |
29 | t:
30 | crop_scale:
31 | min: 0.08
32 | max: 1.0
33 | flip_p : 0.5
34 | jitter:
35 | b: 0.8
36 | c: 0.8
37 | s: 0.8
38 | h: 0.2
39 | jitter_p : 0.8
40 | gray_p : 0.2
41 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_moco/cifar10_b256.yaml:
--------------------------------------------------------------------------------
1 | temperature : 0.07
2 | queue_size : 65536
3 | momentum : 0.999
4 | epochs : 200
5 | bs : 256
6 | p_epoch : 5
7 | s_epoch : 5
8 |
9 | data:
10 | data_dir : "./data"
11 | dataset : "cifar10"
12 | data_ratio : 1.0
13 |
14 | optim:
15 | optimizer: "sgd"
16 | init_lr: 0.03
17 | lr_schedule: "const"
18 | weight_decay: 0.0001
19 |
20 | net:
21 | encoder : "resnet18"
22 | use_bn : true
23 | use_sn : false
24 | use_wn : false
25 | act : "relu"
26 | proj_dim : 128
27 | nc : 3
28 |
29 | t:
30 | crop_scale:
31 | min: 0.08
32 | max: 1.0
33 | flip_p : 0.5
34 | jitter:
35 | b: 0.8
36 | c: 0.8
37 | s: 0.8
38 | h: 0.2
39 | jitter_p : 0.8
40 | gray_p : 0.2
41 | blur_scale : 0.1
--------------------------------------------------------------------------------
/configs_ebclr/fmnist_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 16
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 960
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "fmnist"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.6
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 1
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/fmnist_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 64
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 480
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "fmnist"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.6
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 1
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/mnist_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 128
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 240
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "mnist"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.4
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0002
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 1
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/mnist_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 16
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 960
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "mnist"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.4
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 1
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/mnist_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 64
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 480
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "mnist"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.4
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 1
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/fmnist_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 128
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 240
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "fmnist"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.6
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0002
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 1
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/cifar100_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 128
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 320
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "cifar100"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.2
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0002
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 3
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | flip_p : 0.5
49 | jitter:
50 | b: 0.8
51 | c: 0.8
52 | s: 0.8
53 | h: 0.2
54 | jitter_p : 0.8
55 | gray_p : 0.2
56 | blur_scale : 0.1
57 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/cifar100_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 16
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 1280
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "cifar100"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.2
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 3
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | flip_p : 0.5
49 | jitter:
50 | b: 0.8
51 | c: 0.8
52 | s: 0.8
53 | h: 0.2
54 | jitter_p : 0.8
55 | gray_p : 0.2
56 | blur_scale : 0.1
57 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/cifar100_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 64
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 640
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "cifar100"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.2
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 3
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | flip_p : 0.5
49 | jitter:
50 | b: 0.8
51 | c: 0.8
52 | s: 0.8
53 | h: 0.2
54 | jitter_p : 0.8
55 | gray_p : 0.2
56 | blur_scale : 0.1
57 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/cifar10_b128.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 128
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 320
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "cifar10"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.2
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0002
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 3
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | flip_p : 0.5
49 | jitter:
50 | b: 0.8
51 | c: 0.8
52 | s: 0.8
53 | h: 0.2
54 | jitter_p : 0.8
55 | gray_p : 0.2
56 | blur_scale : 0.1
57 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/cifar10_b16.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 16
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 1280
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "cifar10"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.2
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 3
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | flip_p : 0.5
49 | jitter:
50 | b: 0.8
51 | c: 0.8
52 | s: 0.8
53 | h: 0.2
54 | jitter_p : 0.8
55 | gray_p : 0.2
56 | blur_scale : 0.1
57 | noise_std : 0.03
--------------------------------------------------------------------------------
/configs_ebclr/cifar10_b64.yaml:
--------------------------------------------------------------------------------
1 | temperature: 0.1
2 | normalize: true
3 | bs: 64
4 | lmda1: 0.001
5 | lmda2: 0.1
6 | epochs: 200
7 | p_iter: 640
8 | s_epoch: 5
9 | save_buffer: true
10 |
11 | data:
12 | data_dir : "./data"
13 | dataset : "cifar10"
14 | data_ratio : 1.0
15 |
16 | sgld:
17 | iter: 10
18 | lr: 0.05
19 | min_std: 0.01
20 | max_std: 0.05
21 | threshold: 3.0
22 | tau: 1.0
23 |
24 | buffer:
25 | CD_ratio: 1.0
26 | size: 50000
27 | rho: 0.2
28 | bs: 256
29 |
30 | optim:
31 | optimizer: "adam"
32 | init_lr: 0.0001
33 | lr_schedule: "const"
34 |
35 | net:
36 | encoder: "resnet18"
37 | use_bn: false
38 | use_sn: false
39 | use_wn: false
40 | act: "lrelu"
41 | proj_dim: 128
42 | nc: 3
43 |
44 | t:
45 | crop_scale:
46 | min: 0.08
47 | max: 1.0
48 | flip_p : 0.5
49 | jitter:
50 | b: 0.8
51 | c: 0.8
52 | s: 0.8
53 | h: 0.2
54 | jitter_p : 0.8
55 | gray_p : 0.2
56 | blur_scale : 0.1
57 | noise_std : 0.03
--------------------------------------------------------------------------------
/Train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "87eead4f",
7 | "metadata": {
8 | "tags": []
9 | },
10 | "outputs": [],
11 | "source": [
12 | "from libs import ebclr\n",
13 | "\n",
14 | "ebclr.run(log_dir='./logs', config_dir='./configs_ebclr/cifar10_b16.yaml', start_epoch=0, device=0)"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "id": "9a4cc447",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": []
24 | }
25 | ],
26 | "metadata": {
27 | "kernelspec": {
28 | "display_name": "Python 3 (ipykernel)",
29 | "language": "python",
30 | "name": "python3"
31 | },
32 | "language_info": {
33 | "codemirror_mode": {
34 | "name": "ipython",
35 | "version": 3
36 | },
37 | "file_extension": ".py",
38 | "mimetype": "text/x-python",
39 | "name": "python",
40 | "nbconvert_exporter": "python",
41 | "pygments_lexer": "ipython3",
42 | "version": "3.9.7"
43 | }
44 | },
45 | "nbformat": 4,
46 | "nbformat_minor": 5
47 | }
48 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EBCLR
2 | Official PyTorch implementation of [Energy-Based Contrastive Learning of Visual Representations](https://arxiv.org/abs/2202.04933), NeurIPS 2022.
3 |
4 | We propose a visual representation learning framework, Energy-Based Contrastive Learning (EBCLR), that combines Energy-Based Models (EBMs) with contrastive learning. EBCLR associates distance on the projection space with the density of positive pairs to learn useful visual representations. The figure below illustrates the general idea of EBCLR. We find EBCLR shows accelerated convergence and robustness to small number of negative pairs per positive pair.
5 |
6 |
7 |
8 |
9 |
10 | ## How to Run This Code
11 |
12 | Example codes for training and linear / KNN evaluation can be found in `Train.ipynb` and `Evaluation.ipynb`, respectively.
13 |
14 | ## References
15 |
16 | If you find the code useful for your research, please consider citing
17 | ```bib
18 | @inproceedings{
19 | kim2022ebclr,
20 | title={Energy-Based Contrastive Learning of Visual Representations},
21 | author={Beomsu Kim and Jong Chul Ye},
22 | booktitle={NeurIPS},
23 | year={2022}
24 | }
25 | ```
26 |
--------------------------------------------------------------------------------
/libs/gaussian_blur.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 |
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch
6 |
7 | class GaussianBlur(object):
8 | """blur a single image on CPU"""
9 |
10 | def __init__(self, kernel_size):
11 | radius = kernel_size // 2
12 | kernel_size = radius * 2 + 1
13 | self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), stride=1, padding=0, bias=False, groups=3)
14 | self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size), stride=1, padding=0, bias=False, groups=3)
15 | self.k = kernel_size
16 | self.r = radius
17 |
18 | self.blur = nn.Sequential(
19 | nn.ReflectionPad2d(radius),
20 | self.blur_h,
21 | self.blur_v
22 | )
23 |
24 | def __call__(self, img):
25 |
26 | # img = img[None]
27 |
28 | sigma = np.random.uniform(0.1, 2.0)
29 | x = np.arange(-self.r, self.r + 1)
30 | x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
31 | x = x / x.sum()
32 | x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
33 |
34 | self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
35 | self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
36 |
37 | with torch.no_grad():
38 | img = self.blur(img)
39 | # img = img.squeeze()
40 |
41 | return img
--------------------------------------------------------------------------------
/libs/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | import torch
6 |
7 | def Linear(in_features, out_features, bias=True, use_bn=False, use_sn=False, use_wn=False):
8 |
9 | layer = nn.Linear(in_features, out_features, bias=bias)
10 |
11 | if use_bn:
12 | layer = nn.Sequential(layer, nn.BatchNorm1d(out_features))
13 |
14 | if use_sn:
15 | layer = nn.utils.spectral_norm(layer)
16 |
17 | if use_wn:
18 | layer = nn.utils.weight_norm(layer)
19 |
20 | return layer
21 |
22 | class LIN(nn.Module):
23 |
24 | def __init__(self, in_dim, out_dim, use_sn=False):
25 | super(LIN, self).__init__()
26 |
27 | self.fc1 = Linear(in_dim, out_dim, use_sn=use_sn)
28 |
29 | def forward(self, x):
30 | x = self.fc1(x)
31 | return x
32 |
33 | class MLP(nn.Module):
34 |
35 | def __init__(self, in_dim, out_dim, n_layers=2, use_bn=False, use_sn=False, use_wn=False, act=nn.LeakyReLU(0.2)):
36 | super(MLP, self).__init__()
37 | self.act = act
38 |
39 | self.fc1 = Linear(in_dim, in_dim, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn)
40 |
41 | # fcs = []
42 | # for _ in range(n_layers - 1):
43 | # fcs.append(Linear(in_dim, in_dim, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn))
44 | # fcs.append(self.act)
45 | # self.fcs = nn.Sequential(*fcs)
46 |
47 | self.fc2 = Linear(in_dim, out_dim, use_sn=use_sn, use_wn=use_wn)
48 |
49 | def forward(self, x):
50 | x = self.act(self.fc1(x))
51 | x = self.fc2(x)
52 | return x
--------------------------------------------------------------------------------
/libs/net.py:
--------------------------------------------------------------------------------
1 | from libs.resnets import ResNet18, ResNet34, ResNet50
2 | from libs.models import MLP
3 |
4 | import torch.nn as nn
5 |
6 | def get_act(net_config):
7 |
8 | if net_config['act'] == 'relu':
9 | return nn.ReLU()
10 | elif net_config['act'] == 'lrelu':
11 | return nn.LeakyReLU(0.2)
12 | else:
13 | raise NotImplementedError
14 |
15 | def init_enc(net_config, device):
16 |
17 | if net_config['encoder'] == 'resnet18':
18 | enc = ResNet18(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device)
19 | feature_dim = 512
20 | elif net_config['encoder'] == 'resnet34':
21 | enc = ResNet34(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device)
22 | feature_dim = 512
23 | elif net_config['encoder'] == 'resnet50':
24 | enc = ResNet50(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device)
25 | feature_dim = 2048
26 | else:
27 | raise NotImplementedError
28 |
29 | return enc, feature_dim
30 |
31 | def init_enc_proj(net_config, device):
32 |
33 | if net_config['encoder'] == 'resnet18':
34 | enc = ResNet18(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device)
35 | feature_dim = 512
36 | elif net_config['encoder'] == 'resnet34':
37 | enc = ResNet34(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device)
38 | feature_dim = 512
39 | elif net_config['encoder'] == 'resnet50':
40 | enc = ResNet50(nc=net_config['nc'], use_bn=net_config['use_bn'], use_sn=net_config['use_sn'], use_wn=net_config['use_wn'], act=get_act(net_config)).cuda(device)
41 | feature_dim = 2048
42 | else:
43 | raise NotImplementedError
44 |
45 | if 'proj_layers' in net_config:
46 | proj = MLP(feature_dim, net_config['proj_dim'], n_layers=net_config['proj_layers'], act=get_act(net_config)).cuda(device)
47 | else:
48 | proj = MLP(feature_dim, net_config['proj_dim'], n_layers=2, act=get_act(net_config)).cuda(device)
49 |
50 | return enc, proj, feature_dim
--------------------------------------------------------------------------------
/Evaluation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "3d0be4ea",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from libs.evaluation import transfer_acc, eval_acc, eval_knn_transfer_acc, eval_knn_acc"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "id": "a5973077",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "for epoch in ['curr']:\n",
21 | "\n",
22 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n",
23 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n",
24 | "\n",
25 | " eval_config = {\n",
26 | " 'K' : 20,\n",
27 | " 'bs' : 50,\n",
28 | " 'standardize' : True\n",
29 | " }\n",
30 | "\n",
31 | " data_config = {\n",
32 | " 'data_dir' : './data',\n",
33 | " 'dataset' : 'cifar10',\n",
34 | " 'data_ratio' : 1.0\n",
35 | " }\n",
36 | " \n",
37 | " device = 0\n",
38 | " \n",
39 | " print('Epoch {}'.format(epoch))\n",
40 | " \n",
41 | " eval_knn_acc(ckpt_dir, config_dir, eval_config, data_config, device)"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "995cb890",
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "for epoch in ['curr']:\n",
52 | "\n",
53 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n",
54 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n",
55 | "\n",
56 | " eval_config = {\n",
57 | " 'K' : 20,\n",
58 | " 'bs' : 50,\n",
59 | " 'standardize' : False\n",
60 | " }\n",
61 | " \n",
62 | " source_data_config = {\n",
63 | " 'data_dir' : './data',\n",
64 | " 'dataset' : 'cifar100',\n",
65 | " 'data_ratio' : 1.0\n",
66 | " }\n",
67 | "\n",
68 | " target_data_config = {\n",
69 | " 'data_dir' : './data',\n",
70 | " 'dataset' : 'cifar10',\n",
71 | " 'data_ratio' : 1.0\n",
72 | " }\n",
73 | " \n",
74 | " device = 0\n",
75 | " \n",
76 | " print('Epoch {}'.format(epoch))\n",
77 | " \n",
78 | " eval_knn_transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "id": "b43feb97",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "for epoch in ['curr']:\n",
89 | "\n",
90 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n",
91 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n",
92 | "\n",
93 | " eval_config = {\n",
94 | " 'epochs' : 200,\n",
95 | " 'bs' : 512,\n",
96 | " 'lr' : 1e-3,\n",
97 | " 'p_epoch' : 10,\n",
98 | " 'standardize' : False\n",
99 | " }\n",
100 | "\n",
101 | " data_config = {\n",
102 | " 'data_dir' : './data',\n",
103 | " 'dataset' : 'cifar100',\n",
104 | " 'data_ratio' : 1.0\n",
105 | " }\n",
106 | " \n",
107 | " device = 0\n",
108 | " \n",
109 | " print('Epoch {}'.format(epoch))\n",
110 | "\n",
111 | " eval_acc(ckpt_dir, config_dir, eval_config, data_config, device)"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "id": "10124fbf",
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "for epoch in ['curr']:\n",
122 | "\n",
123 | " ckpt_dir='./logs/{}.pt'.format(epoch)\n",
124 | " config_dir='./configs_ebclr/cifar10_b16.yaml'\n",
125 | "\n",
126 | " eval_config = {\n",
127 | " 'epochs' : 200,\n",
128 | " 'bs' : 512,\n",
129 | " 'lr' : 8e-4,\n",
130 | " 'p_epoch' : 10,\n",
131 | " 'standardize' : False\n",
132 | " }\n",
133 | " \n",
134 | " source_data_config = {\n",
135 | " 'data_dir' : './data',\n",
136 | " 'dataset' : 'cifar100',\n",
137 | " 'data_ratio' : 1.0\n",
138 | " }\n",
139 | "\n",
140 | " target_data_config = {\n",
141 | " 'data_dir' : './data',\n",
142 | " 'dataset' : 'cifar10',\n",
143 | " 'data_ratio' : 1.0\n",
144 | " }\n",
145 | " \n",
146 | " device = 0\n",
147 | " \n",
148 | " print('Epoch {}'.format(epoch))\n",
149 | " \n",
150 | " transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device)"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "1cfe6e8c",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": []
160 | }
161 | ],
162 | "metadata": {
163 | "kernelspec": {
164 | "display_name": "Python 3 (ipykernel)",
165 | "language": "python",
166 | "name": "python3"
167 | },
168 | "language_info": {
169 | "codemirror_mode": {
170 | "name": "ipython",
171 | "version": 3
172 | },
173 | "file_extension": ".py",
174 | "mimetype": "text/x-python",
175 | "name": "python",
176 | "nbconvert_exporter": "python",
177 | "pygments_lexer": "ipython3",
178 | "version": "3.9.7"
179 | }
180 | },
181 | "nbformat": 4,
182 | "nbformat_minor": 5
183 | }
184 |
--------------------------------------------------------------------------------
/libs/simclr.py:
--------------------------------------------------------------------------------
1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config
2 | from libs.net import init_enc_proj
3 |
4 | import torchvision.transforms.functional as TF
5 | import torchvision.transforms as transforms
6 | import torch.nn.functional as F
7 | import matplotlib.pyplot as plt
8 | import torch.optim as optim
9 | import torch.nn as nn
10 | import numpy as np
11 |
12 | import torchvision
13 | import itertools
14 | import torch
15 | import time
16 | import os
17 |
18 | def save_ckpt(enc, proj, opt, iteration, runtime, ckpt_path):
19 |
20 | checkpoint = {
21 | 'iteration' : iteration,
22 | 'runtime' : runtime,
23 | 'opt' : opt.state_dict()
24 | }
25 |
26 | if isinstance(enc, nn.DataParallel):
27 | checkpoint['enc_state_dict'] = enc.module.state_dict()
28 | else:
29 | checkpoint['enc_state_dict'] = enc.state_dict()
30 |
31 | if isinstance(proj, nn.DataParallel):
32 | checkpoint['proj_state_dict'] = proj.module.state_dict()
33 | else:
34 | checkpoint['proj_state_dict'] = proj.state_dict()
35 |
36 | torch.save(checkpoint, ckpt_path)
37 |
38 | def load_ckpt(enc, proj, opt, ckpt_path):
39 | ckpt = torch.load(ckpt_path)
40 |
41 | if isinstance(enc, nn.DataParallel):
42 | enc.module.load_state_dict(ckpt['enc_state_dict'])
43 | else:
44 | enc.load_state_dict(ckpt['enc_state_dict'])
45 |
46 | if isinstance(proj, nn.DataParallel):
47 | proj.module.load_state_dict(ckpt['proj_state_dict'])
48 | else:
49 | proj.load_state_dict(ckpt['proj_state_dict'])
50 |
51 | opt.load_state_dict(ckpt['opt'])
52 | it = ckpt['iteration']
53 | rt = ckpt['runtime']
54 |
55 | return it, rt
56 |
57 | class SimCLR_Loss(nn.Module):
58 |
59 | def __init__(self, temperature):
60 | super(SimCLR_Loss, self).__init__()
61 | self.temperature = temperature
62 |
63 | self.criterion = nn.CrossEntropyLoss(reduction="sum")
64 | self.similarity_f = nn.CosineSimilarity(dim=2)
65 |
66 | def mask_correlated_samples(self, batch_size):
67 | N = 2 * batch_size
68 | mask = torch.ones((N, N), dtype=bool)
69 | mask = mask.fill_diagonal_(0)
70 |
71 | for i in range(batch_size):
72 | mask[i, batch_size + i] = 0
73 | mask[batch_size + i, i] = 0
74 | return mask
75 |
76 | def forward(self, z_i, z_j):
77 |
78 | batch_size = z_i.shape[0]
79 |
80 | N = 2 * batch_size
81 |
82 | z = torch.cat((z_i, z_j), dim=0)
83 |
84 | sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature
85 |
86 | sim_i_j = torch.diag(sim, batch_size)
87 | sim_j_i = torch.diag(sim, -batch_size)
88 |
89 | mask = self.mask_correlated_samples(batch_size)
90 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
91 | negative_samples = sim[mask].reshape(N, -1)
92 |
93 | labels = torch.from_numpy(np.array([0] * N)).reshape(-1).to(positive_samples.device).long()
94 |
95 | logits = torch.cat((positive_samples, negative_samples), dim=1)
96 | loss = self.criterion(logits, labels)
97 | loss /= N
98 |
99 | return loss
100 |
101 | def SimCLR(train_X, enc, proj, config, log_dir, start_epoch):
102 |
103 | train_n = train_X.shape[0]
104 | size = train_X.shape[2]
105 | nc = train_X.shape[1]
106 |
107 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
108 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
109 |
110 | simclr_loss = SimCLR_Loss(config['temperature'])
111 |
112 | params = itertools.chain(enc.parameters(), proj.parameters())
113 | net = lambda x : proj(enc((x - mean) / std))
114 |
115 | t = get_t(size, config['t'])
116 | lr_schedule = get_lr_schedule(config)
117 | opt = get_optimizer(config['optim'], params)
118 |
119 | if start_epoch == 0:
120 | it, rt = 0, 0
121 | else:
122 | it, rt = load_ckpt(enc, proj, opt, log_dir + '/{}.pt'.format(start_epoch))
123 |
124 | enc.train()
125 | proj.train()
126 |
127 | epoch_n_iter = int(np.ceil(train_n / config['bs']))
128 | while True:
129 | if it % epoch_n_iter == 0:
130 | train_X = train_X[torch.randperm(train_n)]
131 |
132 | i = it % epoch_n_iter
133 | it += 1
134 |
135 | s = time.time()
136 |
137 | X = train_X[i * config['bs']:(i + 1) * config['bs']]
138 |
139 | X_v1 = t(X).cuda()
140 | X_v2 = t(X).cuda()
141 |
142 | loss = simclr_loss(net(X_v1), net(X_v2))
143 |
144 | # Update
145 | opt.param_groups[0]['lr'] = lr_schedule(it)
146 | opt.zero_grad()
147 | loss.backward()
148 | opt.step()
149 |
150 | e = time.time()
151 | rt += (e - s)
152 |
153 | if it % config['p_iter'] == 0:
154 | save_ckpt(enc, proj, opt, it, rt, log_dir + '/curr.pt')
155 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt))
156 |
157 | if it % (epoch_n_iter * config['s_epoch']) == 0:
158 | save_ckpt(enc, proj, opt, it, rt, log_dir + '/{}.pt'.format(it // epoch_n_iter))
159 |
160 | if it >= config['its']:
161 | break
162 |
163 | def run(log_dir, config_dir, start_epoch, device):
164 |
165 | if not os.path.isdir(log_dir):
166 | os.makedirs(log_dir)
167 |
168 | config = load_config(config_dir)
169 |
170 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data'])
171 |
172 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs']))
173 |
174 | if 'epochs' in config:
175 | config['its'] = epoch_n_iter * config['epochs']
176 |
177 | if 'p_epoch' in config:
178 | config['p_iter'] = epoch_n_iter * config['p_epoch']
179 |
180 | enc, proj, _ = init_enc_proj(config['net'], device)
181 |
182 | print('Running SimCLR from epoch {}'.format(start_epoch))
183 |
184 | with torch.cuda.device(device):
185 | SimCLR(train_X, enc, proj, config, log_dir, start_epoch)
186 |
187 | print('Finished!')
--------------------------------------------------------------------------------
/libs/resnets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, use_bn=False, use_sn=False, use_wn=False, avg_pool=False):
6 |
7 | if avg_pool:
8 | layer = nn.Conv2d(in_channels, out_channels, kernel_size, 1, padding, bias=False if use_bn else bias)
9 | else:
10 | layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False if use_bn else bias)
11 |
12 | if use_bn:
13 | layer = nn.Sequential(layer, nn.BatchNorm2d(out_channels))
14 |
15 | if use_sn:
16 | layer = nn.utils.spectral_norm(layer)
17 |
18 | if use_wn:
19 | layer = nn.utils.weight_norm(layer)
20 |
21 | if avg_pool and stride > 1:
22 | layer = nn.Sequential(layer, nn.AvgPool2d(stride, stride))
23 |
24 | return layer
25 |
26 | def Linear(in_features, out_features, bias=True, use_bn=False, use_sn=False, use_wn=False):
27 |
28 | layer = nn.Linear(in_features, out_features, bias=bias)
29 |
30 | if use_bn:
31 | layer = nn.Sequential(layer, nn.BatchNorm1d(out_features))
32 |
33 | if use_sn:
34 | layer = nn.utils.spectral_norm(layer)
35 |
36 | if use_wn:
37 | layer = nn.utils.weight_norm(layer)
38 |
39 | return layer
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, in_planes, planes, stride=1, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
45 | super(BasicBlock, self).__init__()
46 | self.conv1 = Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
47 | self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
48 | self.act = act
49 |
50 | self.shortcut = nn.Sequential()
51 | if stride != 1 or in_planes != self.expansion*planes:
52 | self.shortcut = nn.Sequential(
53 | Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
54 | )
55 |
56 | def forward(self, x):
57 | out = self.act(self.conv1(x))
58 | out = self.conv2(out)
59 | out += self.shortcut(x)
60 | out = self.act(out)
61 | return out
62 |
63 |
64 | class Bottleneck(nn.Module):
65 | expansion = 4
66 |
67 | def __init__(self, in_planes, planes, stride=1, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
68 | super(Bottleneck, self).__init__()
69 | self.conv1 = Conv2d(in_planes, planes, kernel_size=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
70 | self.conv2 = Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
71 | self.conv3 = Conv2d(planes, self.expansion * planes, kernel_size=1, bias=True, use_sn=use_sn, use_bn=use_bn, use_wn=use_wn, avg_pool=avg_pool)
72 | self.act = act
73 |
74 | self.shortcut = nn.Sequential()
75 | if stride != 1 or in_planes != self.expansion*planes:
76 | self.shortcut = nn.Sequential(
77 | Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
78 | )
79 |
80 | def forward(self, x):
81 | out = self.act(self.conv1(x))
82 | out = self.act(self.conv2(out))
83 | out = self.conv3(out)
84 | out += self.shortcut(x)
85 | out = self.act(out)
86 | return out
87 |
88 |
89 | class ResNet(nn.Module):
90 | def __init__(self, nc, block, num_blocks, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
91 | super(ResNet, self).__init__()
92 | self.in_planes = 64
93 |
94 | self.conv1 = Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=True, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool)
95 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act)
96 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act)
97 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act)
98 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, use_bn=use_bn, use_sn=use_sn, use_wn=use_wn, avg_pool=avg_pool, act=act)
99 | self.act = act
100 |
101 | def _make_layer(self, block, planes, num_blocks, stride, use_bn, use_sn, use_wn, avg_pool, act):
102 | strides = [stride] + [1]*(num_blocks-1)
103 | layers = []
104 | for stride in strides:
105 | layers.append(block(self.in_planes, planes, stride, use_bn, use_sn, use_wn, avg_pool, act))
106 | self.in_planes = planes * block.expansion
107 | return nn.Sequential(*layers)
108 |
109 | def forward(self, x):
110 | out = self.act(self.conv1(x))
111 | out = self.layer1(out)
112 | out = self.layer2(out)
113 | out = self.layer3(out)
114 | out = self.layer4(out)
115 | out = F.adaptive_avg_pool2d(out, 1)
116 | out = out.view(out.size(0), -1)
117 | return out
118 |
119 |
120 | def ResNet18(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
121 | return ResNet(nc, BasicBlock, [2, 2, 2, 2], use_bn, use_sn, use_wn, avg_pool, act)
122 |
123 |
124 | def ResNet34(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
125 | return ResNet(nc, BasicBlock, [3, 4, 6, 3], use_bn, use_sn, use_wn, avg_pool, act)
126 |
127 |
128 | def ResNet50(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
129 | return ResNet(nc, Bottleneck, [3, 4, 6, 3], use_bn, use_sn, use_wn, avg_pool, act)
130 |
131 |
132 | def ResNet101(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
133 | return ResNetnc, (Bottleneck, [3, 4, 23, 3], use_bn, use_sn, use_wn, avg_pool, act)
134 |
135 |
136 | def ResNet152(nc=3, use_bn=False, use_sn=False, use_wn=False, avg_pool=False, act=nn.LeakyReLU(0.2)):
137 | return ResNet(nc, Bottleneck, [3, 8, 36, 3], use_bn, use_sn, use_wn, avg_pool, act)
--------------------------------------------------------------------------------
/libs/utils.py:
--------------------------------------------------------------------------------
1 | from libs.gaussian_blur import GaussianBlur
2 |
3 | import torchvision.transforms.functional as TF
4 | import torchvision.transforms as transforms
5 | import torch.nn.functional as F
6 | import matplotlib.pyplot as plt
7 | import torch.optim as optim
8 | import torch.nn as nn
9 | import numpy as np
10 |
11 | import torchvision
12 | import torch
13 | import yaml
14 |
15 | class MSGLD:
16 |
17 | def __init__(self, config):
18 | self.config = config
19 |
20 | def __get_std__(self, count):
21 | return self.config['min_std'] + (self.config['max_std'] - self.config['min_std']) * torch.maximum(1.0 - count / self.config['threshold'], torch.zeros_like(count).cuda())
22 |
23 | def __call__(self, log_pdf, init, count):
24 | out = init.detach().clone().requires_grad_(True)
25 | for i in range(self.config['iter']):
26 | lp = log_pdf(out).sum()
27 | lp.backward()
28 | out.data = out + self.config['lr'] * torch.clamp(out.grad, -self.config['tau'], self.config['tau']) + self.__get_std__(count) * torch.randn_like(out)
29 | out.grad.zero_()
30 | return out.detach().clone()
31 |
32 | def logsumexp(log_p):
33 | m, _ = torch.max(log_p, dim=1, keepdim=True)
34 | return torch.log(torch.exp(log_p - m).sum(dim=1, keepdim=True)) + m
35 |
36 | def log_softmax(log_p):
37 | return log_p - logsumexp(log_p)
38 |
39 | def softmax(log_p):
40 | m, _ = torch.max(log_p, dim=1, keepdim=True)
41 | f = torch.exp(log_p - m)
42 | return f / f.sum(dim=1, keepdim=True)
43 |
44 | def get_t(img_size, t_config):
45 |
46 | t = []
47 |
48 | if 'crop_scale' in t_config:
49 | t.append(transforms.RandomResizedCrop(size=img_size, scale=tuple(t_config['crop_scale'].values())))
50 |
51 | if 'flip_p' in t_config:
52 | t.append(transforms.RandomHorizontalFlip(p=t_config['flip_p']))
53 |
54 | if 'jitter' in t_config:
55 | jitter = transforms.ColorJitter(t_config['jitter']['b'], t_config['jitter']['c'], t_config['jitter']['s'], t_config['jitter']['h'])
56 | t.append(transforms.RandomApply([jitter], p=t_config['jitter_p']))
57 |
58 | if 'gray_p' in t_config:
59 | t.append(transforms.RandomGrayscale(p=t_config['gray_p']))
60 |
61 | if 'blur_scale' in t_config:
62 | blur = GaussianBlur(kernel_size=int(t_config['blur_scale'] * img_size))
63 | t.append(transforms.RandomApply([blur], p=1.0))
64 |
65 | if 'noise_std' in t_config:
66 | t.append(lambda x : x + torch.randn_like(x) * t_config['noise_std'])
67 |
68 | return transforms.Compose(t)
69 |
70 | def get_optimizer(optim_config, params):
71 | if optim_config['optimizer'] == 'sgd':
72 | return optim.SGD(params, lr=optim_config['init_lr'], momentum=0.9, weight_decay=optim_config['weight_decay'] if ('weight_decay' in optim_config) else 0.0)
73 | elif optim_config['optimizer'] == 'adam':
74 | return optim.Adam(params, lr=optim_config['init_lr'], weight_decay=optim_config['weight_decay'] if ('weight_decay' in optim_config) else 0.0)
75 | else:
76 | raise NotImplementedError
77 |
78 | def get_lr_schedule(config):
79 |
80 | if config['optim']['lr_schedule'] == 'const':
81 | return lambda it : config['optim']['init_lr']
82 | elif config['optim']['lr_schedule'] == 'cosine':
83 | return lambda it : config['optim']['init_lr'] * np.cos(0.5 * np.pi * it / config['its'])
84 | else:
85 | raise NotImplementedError
86 |
87 | def inner(x, y, normalize=False):
88 | x = x.flatten(start_dim=1)
89 | y = y.flatten(start_dim=1)
90 |
91 | if normalize:
92 | x = x / x.norm(dim=1, keepdim=True)
93 | y = y / y.norm(dim=1, keepdim=True)
94 |
95 | d = (x * y).sum(dim=1, keepdim=True)
96 | return d
97 |
98 | def dist(x, y, normalize=False):
99 | x = x.flatten(start_dim=1)
100 | y = y.flatten(start_dim=1)
101 |
102 | if normalize:
103 | x = x / (x.norm(dim=1, keepdim=True) + 1e-10)
104 | y = y / (y.norm(dim=1, keepdim=True) + 1e-10)
105 |
106 | d = (x - y).square().sum(dim=1, keepdim=True)
107 | return d
108 |
109 | def generate_views(x, t, n_samples):
110 | views = []
111 | shape = list(x.shape[1:])
112 | for _ in range(n_samples):
113 | views.append(t(x)[:,None])
114 | views = torch.cat(views, dim=1).reshape([-1] + shape)
115 | return views
116 |
117 | def load_data(data_config, download=False):
118 |
119 | root = data_config['data_dir']
120 | dataset = data_config['dataset']
121 |
122 | if dataset == 'mnist':
123 | trainset = torchvision.datasets.MNIST(root=root, train=True, download=download, transform=[transforms.ToTensor()])
124 | testset = torchvision.datasets.MNIST(root=root, train=False, download=download, transform=[transforms.ToTensor()])
125 |
126 | train_X, train_y = trainset.data[:,None] / 255, trainset.targets
127 | test_X, test_y = testset.data[:,None] / 255, testset.targets
128 | n_classes = 10
129 | elif dataset == 'fmnist':
130 | trainset = torchvision.datasets.FashionMNIST(root=root, train=True, download=download, transform=[transforms.ToTensor()])
131 | testset = torchvision.datasets.FashionMNIST(root=root, train=False, download=download, transform=[transforms.ToTensor()])
132 |
133 | train_X, train_y = trainset.data[:,None] / 255, trainset.targets
134 | test_X, test_y = testset.data[:,None] / 255, testset.targets
135 | n_classes = 10
136 | elif dataset == 'cifar10':
137 | trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=download, transform=[transforms.ToTensor()])
138 | testset = torchvision.datasets.CIFAR10(root=root, train=False, download=download, transform=[transforms.ToTensor()])
139 |
140 | train_X = torch.tensor(trainset.data).permute(0,3,1,2) / 255
141 | train_y = torch.tensor(trainset.targets, dtype=int)
142 |
143 | test_X = torch.tensor(testset.data).permute(0,3,1,2) / 255
144 | test_y = torch.tensor(testset.targets, dtype=int)
145 | n_classes = 10
146 | elif dataset == 'cifar100':
147 | trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=download, transform=[transforms.ToTensor()])
148 | testset = torchvision.datasets.CIFAR100(root=root, train=False, download=download, transform=[transforms.ToTensor()])
149 |
150 | train_X = torch.tensor(trainset.data).permute(0,3,1,2) / 255
151 | train_y = torch.tensor(trainset.targets, dtype=int)
152 |
153 | test_X = torch.tensor(testset.data).permute(0,3,1,2) / 255
154 | test_y = torch.tensor(testset.targets, dtype=int)
155 | n_classes = 100
156 |
157 | train_n = train_X.shape[0]
158 | idx = torch.randperm(train_n)
159 | train_X = train_X[idx][:int(train_n * data_config['data_ratio'])]
160 | train_y = train_y[idx][:int(train_n * data_config['data_ratio'])]
161 |
162 | return train_X, train_y, test_X, test_y, n_classes
163 |
164 | def save_config(config, config_path):
165 | with open(config_path, 'w') as f:
166 | yaml.dump(config, f)
167 |
168 | def load_config(config_path):
169 | with open(config_path, 'r') as f:
170 | config = yaml.load(f, Loader=yaml.FullLoader)
171 | return config
--------------------------------------------------------------------------------
/libs/simsiam.py:
--------------------------------------------------------------------------------
1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config
2 | from libs.net import init_enc
3 |
4 | import torchvision.transforms.functional as TF
5 | import torchvision.transforms as transforms
6 | import torch.nn.functional as F
7 | import matplotlib.pyplot as plt
8 | import torch.optim as optim
9 | import torch.nn as nn
10 | import numpy as np
11 |
12 | import torchvision
13 | import itertools
14 | import torch
15 | import time
16 | import os
17 |
18 | def save_ckpt(enc, proj, pred, opt, iteration, runtime, ckpt_path):
19 |
20 | checkpoint = {
21 | 'iteration' : iteration,
22 | 'runtime' : runtime,
23 | 'opt' : opt.state_dict()
24 | }
25 |
26 | if isinstance(enc, nn.DataParallel):
27 | checkpoint['enc_state_dict'] = enc.module.state_dict()
28 | else:
29 | checkpoint['enc_state_dict'] = enc.state_dict()
30 |
31 | if isinstance(proj, nn.DataParallel):
32 | checkpoint['proj_state_dict'] = proj.module.state_dict()
33 | else:
34 | checkpoint['proj_state_dict'] = proj.state_dict()
35 |
36 | if isinstance(pred, nn.DataParallel):
37 | checkpoint['pred_state_dict'] = pred.module.state_dict()
38 | else:
39 | checkpoint['pred_state_dict'] = pred.state_dict()
40 |
41 | torch.save(checkpoint, ckpt_path)
42 |
43 | def load_ckpt(enc, proj, pred, opt, ckpt_path):
44 | ckpt = torch.load(ckpt_path)
45 |
46 | if isinstance(enc, nn.DataParallel):
47 | enc.module.load_state_dict(ckpt['enc_state_dict'])
48 | else:
49 | enc.load_state_dict(ckpt['enc_state_dict'])
50 |
51 | if isinstance(proj, nn.DataParallel):
52 | proj.module.load_state_dict(ckpt['proj_state_dict'])
53 | else:
54 | proj.load_state_dict(ckpt['proj_state_dict'])
55 |
56 | if isinstance(pred, nn.DataParallel):
57 | pred.module.load_state_dict(ckpt['pred_state_dict'])
58 | else:
59 | pred.load_state_dict(ckpt['pred_state_dict'])
60 |
61 | opt.load_state_dict(ckpt['opt'])
62 | it = ckpt['iteration']
63 | rt = ckpt['runtime']
64 |
65 | return it, rt
66 |
67 | def D(p, z):
68 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean()
69 |
70 | class Preprocess(nn.Module):
71 | def __init__(self, mean, std):
72 | super(Preprocess, self).__init__()
73 | self.mean = mean
74 | self.std = std
75 |
76 | def forward(self, tensors):
77 | return (tensors - self.mean) / self.std
78 |
79 | class projection_MLP(nn.Module):
80 |
81 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
82 | super().__init__()
83 |
84 | self.layer1 = nn.Sequential(
85 | nn.Linear(in_dim, hidden_dim),
86 | nn.BatchNorm1d(hidden_dim),
87 | nn.ReLU(inplace=True)
88 | )
89 | self.layer2 = nn.Sequential(
90 | nn.Linear(hidden_dim, hidden_dim),
91 | nn.BatchNorm1d(hidden_dim),
92 | nn.ReLU(inplace=True)
93 | )
94 | self.layer3 = nn.Sequential(
95 | nn.Linear(hidden_dim, out_dim),
96 | nn.BatchNorm1d(hidden_dim)
97 | )
98 | self.num_layers = 3
99 |
100 | def set_layers(self, num_layers):
101 | self.num_layers = num_layers
102 |
103 | def forward(self, x):
104 | if self.num_layers == 3:
105 | x = self.layer1(x)
106 | x = self.layer2(x)
107 | x = self.layer3(x)
108 | elif self.num_layers == 2:
109 | x = self.layer1(x)
110 | x = self.layer3(x)
111 | else:
112 | raise Exception
113 | return x
114 |
115 | class prediction_MLP(nn.Module):
116 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048):
117 | super().__init__()
118 |
119 | self.layer1 = nn.Sequential(
120 | nn.Linear(in_dim, hidden_dim),
121 | nn.BatchNorm1d(hidden_dim),
122 | nn.ReLU(inplace=True)
123 | )
124 |
125 | self.layer2 = nn.Linear(hidden_dim, out_dim)
126 |
127 | def forward(self, x):
128 | x = self.layer1(x)
129 | x = self.layer2(x)
130 | return x
131 |
132 | def SimSiam(train_X, enc, proj, pred, config, log_dir, start_epoch):
133 |
134 | train_n = train_X.shape[0]
135 | size = train_X.shape[2]
136 | nc = train_X.shape[1]
137 |
138 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
139 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
140 |
141 | params = itertools.chain(enc.parameters(), proj.parameters(), pred.parameters())
142 | net = nn.Sequential(Preprocess(mean, std), enc, proj)
143 |
144 | t = get_t(size, config['t'])
145 | lr_schedule = get_lr_schedule(config)
146 | opt = get_optimizer(config['optim'], params)
147 |
148 | if start_epoch == 0:
149 | it, rt = 0, 0
150 | else:
151 | it, rt = load_ckpt(enc, proj, pred, opt, log_dir + '/{}.pt'.format(start_epoch))
152 |
153 | enc.train()
154 | proj.train()
155 | pred.train()
156 |
157 | epoch_n_iter = int(np.ceil(train_n / config['bs']))
158 | while True:
159 | if it % epoch_n_iter == 0:
160 | train_X = train_X[torch.randperm(train_n)]
161 |
162 | i = it % epoch_n_iter
163 | it += 1
164 |
165 | s = time.time()
166 |
167 | X = train_X[i * config['bs']:(i + 1) * config['bs']]
168 |
169 | X_v1 = t(X).cuda()
170 | X_v2 = t(X).cuda()
171 |
172 | Z_v1, Z_v2 = net(X_v1), net(X_v2)
173 | P_v1, P_v2 = pred(Z_v1), pred(Z_v2)
174 |
175 | loss = D(P_v1, Z_v2) / 2 + D(P_v2, Z_v1) / 2
176 |
177 | # Update
178 | opt.param_groups[0]['lr'] = lr_schedule(it)
179 | opt.zero_grad()
180 | loss.backward()
181 | opt.step()
182 |
183 | e = time.time()
184 | rt += (e - s)
185 |
186 | if it % config['p_iter'] == 0:
187 | save_ckpt(enc, proj, pred, opt, it, rt, log_dir + '/curr.pt')
188 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt))
189 |
190 | if it % (epoch_n_iter * config['s_epoch']) == 0:
191 | save_ckpt(enc, proj, pred, opt, it, rt, log_dir + '/{}.pt'.format(it // epoch_n_iter))
192 |
193 | if it >= config['its']:
194 | break
195 |
196 | def run(log_dir, config_dir, start_epoch, device):
197 |
198 | if not os.path.isdir(log_dir):
199 | os.makedirs(log_dir)
200 |
201 | config = load_config(config_dir)
202 |
203 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data'])
204 |
205 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs']))
206 |
207 | if 'epochs' in config:
208 | config['its'] = epoch_n_iter * config['epochs']
209 |
210 | if 'p_epoch' in config:
211 | config['p_iter'] = epoch_n_iter * config['p_epoch']
212 |
213 | enc, feature_dim = init_enc(config['net'], device)
214 | proj = projection_MLP(feature_dim).cuda(device)
215 | pred = prediction_MLP().cuda(device)
216 |
217 | print('Running SimSiam from epoch {}'.format(start_epoch))
218 |
219 | with torch.cuda.device(device):
220 | SimSiam(train_X, enc, proj, pred, config, log_dir, start_epoch)
221 |
222 | print('Finished!')
--------------------------------------------------------------------------------
/libs/moco.py:
--------------------------------------------------------------------------------
1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config
2 | from libs.net import init_enc_proj
3 |
4 | import torchvision.transforms.functional as TF
5 | import torchvision.transforms as transforms
6 | import torch.nn.functional as F
7 | import matplotlib.pyplot as plt
8 | import torch.optim as optim
9 | import torch.nn as nn
10 | import numpy as np
11 |
12 | import torchvision
13 | import itertools
14 | import torch
15 | import time
16 | import os
17 |
18 | def save_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, iteration, runtime, save_all, ckpt_path):
19 |
20 | checkpoint = {
21 | 'iteration' : iteration,
22 | 'runtime' : runtime,
23 | 'opt' : opt.state_dict()
24 | }
25 |
26 | if save_all:
27 | checkpoint['queue'] = queue
28 |
29 | if isinstance(enc_k, nn.DataParallel):
30 | checkpoint['enc_k_state_dict'] = enc_k.module.state_dict()
31 | else:
32 | checkpoint['enc_k_state_dict'] = enc_k.state_dict()
33 |
34 | if isinstance(proj_k, nn.DataParallel):
35 | checkpoint['proj_k_state_dict'] = proj_k.module.state_dict()
36 | else:
37 | checkpoint['proj_k_state_dict'] = proj_k.state_dict()
38 |
39 | if isinstance(enc_q, nn.DataParallel):
40 | checkpoint['enc_state_dict'] = enc_q.module.state_dict()
41 | else:
42 | checkpoint['enc_state_dict'] = enc_q.state_dict()
43 |
44 | if isinstance(proj_q, nn.DataParallel):
45 | checkpoint['proj_state_dict'] = proj_q.module.state_dict()
46 | else:
47 | checkpoint['proj_state_dict'] = proj_q.state_dict()
48 |
49 | torch.save(checkpoint, ckpt_path)
50 |
51 | def load_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, ckpt_path):
52 | ckpt = torch.load(ckpt_path)
53 |
54 | if isinstance(enc_q, nn.DataParallel):
55 | enc_q.module.load_state_dict(ckpt['enc_state_dict'])
56 | else:
57 | enc_q.load_state_dict(ckpt['enc_state_dict'])
58 |
59 | if isinstance(proj_q, nn.DataParallel):
60 | proj_q.module.load_state_dict(ckpt['proj_state_dict'])
61 | else:
62 | proj_q.load_state_dict(ckpt['proj_state_dict'])
63 |
64 | if isinstance(enc_k, nn.DataParallel):
65 | enc_k.module.load_state_dict(ckpt['enc_k_state_dict'])
66 | else:
67 | enc_k.load_state_dict(ckpt['enc_k_state_dict'])
68 |
69 | if isinstance(proj_k, nn.DataParallel):
70 | proj_k.module.load_state_dict(ckpt['proj_k_state_dict'])
71 | else:
72 | proj_k.load_state_dict(ckpt['proj_k_state_dict'])
73 |
74 | queue.data = ckpt['queue'].data
75 |
76 | opt.load_state_dict(ckpt['opt'])
77 | it = ckpt['iteration']
78 | rt = ckpt['runtime']
79 |
80 | return it, rt
81 |
82 | def shuffled_idx(batch_size):
83 | shuffled_idxs = torch.randperm(batch_size).long().cuda()
84 | reverse_idxs = torch.zeros(batch_size).long().cuda()
85 | value = torch.arange(batch_size).long().long().cuda()
86 | reverse_idxs.index_copy_(0, shuffled_idxs, value)
87 | return shuffled_idxs, reverse_idxs
88 |
89 | def MoCo(train_X, enc_q, proj_q, enc_k, proj_k, config, log_dir, start_epoch):
90 |
91 | train_n = train_X.shape[0]
92 | size = train_X.shape[2]
93 | nc = train_X.shape[1]
94 |
95 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
96 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
97 |
98 | criterion = nn.CrossEntropyLoss().cuda()
99 |
100 | queue = torch.randn(config['queue_size'], config['net']['proj_dim']).cuda()
101 |
102 | params_q = itertools.chain(enc_q.parameters(), proj_q.parameters())
103 | net_q = lambda x : proj_q(enc_q((x - mean) / std))
104 |
105 | params_k = itertools.chain(enc_k.parameters(), proj_k.parameters())
106 | net_k = lambda x : proj_k(enc_k((x - mean) / std))
107 |
108 | t = get_t(size, config['t'])
109 | lr_schedule = get_lr_schedule(config)
110 | opt = get_optimizer(config['optim'], params_q)
111 |
112 | if start_epoch == 0:
113 | it, rt = 0, 0
114 | else:
115 | it, rt = load_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, log_dir + '/{}.pt'.format(start_epoch))
116 |
117 | enc_q.train()
118 | proj_q.train()
119 |
120 | enc_k.train()
121 | proj_k.train()
122 |
123 | epoch_n_iter = int(np.ceil(train_n / config['bs']))
124 | while True:
125 | if it % epoch_n_iter == 0:
126 | train_X = train_X[torch.randperm(train_n)]
127 |
128 | i = it % epoch_n_iter
129 | it += 1
130 |
131 | s = time.time()
132 |
133 | X = train_X[i * config['bs']:(i + 1) * config['bs']]
134 |
135 | X_v1 = t(X).cuda()
136 | X_v2 = t(X).cuda()
137 |
138 | Z_v1 = F.normalize(net_q(X_v1), dim=1)
139 |
140 | with torch.no_grad():
141 |
142 | for p_q, p_k in zip(enc_q.parameters(), enc_k.parameters()):
143 | p_k.data = p_k.data * config['momentum'] + p_q.detach().data * (1 - config['momentum'])
144 |
145 | for p_q, p_k in zip(proj_q.parameters(), proj_k.parameters()):
146 | p_k.data = p_k.data * config['momentum'] + p_q.detach().data * (1 - config['momentum'])
147 |
148 | shuffled_idxs, reverse_idxs = shuffled_idx(X.shape[0])
149 | X_v2 = X_v2[shuffled_idxs]
150 | Z_v2 = F.normalize(net_k(X_v2), dim=1)
151 | Z_v2 = Z_v2[reverse_idxs]
152 |
153 | Z_mem = F.normalize(queue.clone().detach(), dim=1)
154 |
155 | pos = torch.bmm(Z_v1.view(Z_v1.shape[0],1,-1), Z_v2.view(Z_v2.shape[0],-1,1)).squeeze(-1)
156 | neg = torch.mm(Z_v1, Z_mem.transpose(1,0))
157 |
158 | logits = torch.cat((pos, neg), dim=1) / config['temperature']
159 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
160 | loss = criterion(logits, labels)
161 |
162 | # Update
163 | opt.param_groups[0]['lr'] = lr_schedule(it)
164 | opt.zero_grad()
165 | loss.backward()
166 | opt.step()
167 |
168 | # Update queue
169 | queue = queue[Z_v2.shape[0]:]
170 | queue = torch.cat((queue, Z_v2.detach().clone()), dim=0)
171 |
172 | e = time.time()
173 | rt += (e - s)
174 |
175 | if it % config['p_iter'] == 0:
176 | save_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, it, rt, False, log_dir + '/curr.pt')
177 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt))
178 |
179 | if it % (epoch_n_iter * config['s_epoch']) == 0:
180 | save_ckpt(enc_q, proj_q, enc_k, proj_k, queue, opt, it, rt, True, log_dir + '/{}.pt'.format(it // epoch_n_iter))
181 |
182 | if it >= config['its']:
183 | break
184 |
185 | def run(log_dir, config_dir, start_epoch, device):
186 |
187 | if not os.path.isdir(log_dir):
188 | os.makedirs(log_dir)
189 |
190 | config = load_config(config_dir)
191 |
192 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data'])
193 |
194 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs']))
195 |
196 | if 'epochs' in config:
197 | config['its'] = epoch_n_iter * config['epochs']
198 |
199 | if 'p_epoch' in config:
200 | config['p_iter'] = epoch_n_iter * config['p_epoch']
201 |
202 | enc_q, proj_q, _ = init_enc_proj(config['net'], device)
203 | enc_k, proj_k, _ = init_enc_proj(config['net'], device)
204 |
205 | for p_q, p_k in zip(enc_q.parameters(), enc_k.parameters()):
206 | p_k.data.copy_(p_q.data)
207 | p_k.data.requires_grad = False
208 |
209 | for p_q, p_k in zip(proj_q.parameters(), proj_k.parameters()):
210 | p_k.data.copy_(p_q.data)
211 | p_k.data.requires_grad = False
212 |
213 | print('Running MoCo from epoch {}'.format(start_epoch))
214 |
215 | with torch.cuda.device(device):
216 | MoCo(train_X, enc_q, proj_q, enc_k, proj_k, config, log_dir, start_epoch)
217 |
218 | print('Finished!')
--------------------------------------------------------------------------------
/libs/byol.py:
--------------------------------------------------------------------------------
1 | from libs.utils import log_softmax, load_data, get_t, get_lr_schedule, get_optimizer, load_config
2 | from libs.net import init_enc
3 |
4 | import torchvision.transforms.functional as TF
5 | import torchvision.transforms as transforms
6 | import torch.nn.functional as F
7 | import matplotlib.pyplot as plt
8 | import torch.optim as optim
9 | import torch.nn as nn
10 | import numpy as np
11 |
12 | import torchvision
13 | import itertools
14 | import torch
15 | import time
16 | import os
17 |
18 | def save_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, iteration, runtime, save_all, ckpt_path):
19 |
20 | checkpoint = {
21 | 'iteration' : iteration,
22 | 'runtime' : runtime,
23 | 'opt' : opt.state_dict()
24 | }
25 |
26 | if save_all:
27 |
28 | if isinstance(enc_t, nn.DataParallel):
29 | checkpoint['enc_t_state_dict'] = enc_t.module.state_dict()
30 | else:
31 | checkpoint['enc_t_state_dict'] = enc_t.state_dict()
32 |
33 | if isinstance(proj_t, nn.DataParallel):
34 | checkpoint['proj_t_state_dict'] = proj_t.module.state_dict()
35 | else:
36 | checkpoint['proj_t_state_dict'] = proj_t.state_dict()
37 |
38 | if isinstance(enc_o, nn.DataParallel):
39 | checkpoint['enc_state_dict'] = enc_o.module.state_dict()
40 | else:
41 | checkpoint['enc_state_dict'] = enc_o.state_dict()
42 |
43 | if isinstance(proj_o, nn.DataParallel):
44 | checkpoint['proj_state_dict'] = proj_o.module.state_dict()
45 | else:
46 | checkpoint['proj_state_dict'] = proj_o.state_dict()
47 |
48 | if isinstance(pred_o, nn.DataParallel):
49 | checkpoint['pred_state_dict'] = pred_o.module.state_dict()
50 | else:
51 | checkpoint['pred_state_dict'] = pred_o.state_dict()
52 |
53 | torch.save(checkpoint, ckpt_path)
54 |
55 | def load_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, ckpt_path):
56 | ckpt = torch.load(ckpt_path)
57 |
58 | if isinstance(enc_o, nn.DataParallel):
59 | enc_o.module.load_state_dict(ckpt['enc_state_dict'])
60 | else:
61 | enc_o.load_state_dict(ckpt['enc_state_dict'])
62 |
63 | if isinstance(proj_o, nn.DataParallel):
64 | proj_o.module.load_state_dict(ckpt['proj_state_dict'])
65 | else:
66 | proj_o.load_state_dict(ckpt['proj_state_dict'])
67 |
68 | if isinstance(pred_o, nn.DataParallel):
69 | pred_o.module.load_state_dict(ckpt['pred_state_dict'])
70 | else:
71 | pred_o.load_state_dict(ckpt['pred_state_dict'])
72 |
73 | if isinstance(enc_t, nn.DataParallel):
74 | enc_t.module.load_state_dict(ckpt['enc_t_state_dict'])
75 | else:
76 | enc_t.load_state_dict(ckpt['enc_t_state_dict'])
77 |
78 | if isinstance(proj_t, nn.DataParallel):
79 | proj_t.module.load_state_dict(ckpt['proj_t_state_dict'])
80 | else:
81 | proj_t.load_state_dict(ckpt['proj_t_state_dict'])
82 |
83 | opt.load_state_dict(ckpt['opt'])
84 | it = ckpt['iteration']
85 | rt = ckpt['runtime']
86 |
87 | return it, rt
88 |
89 | def D(x, y):
90 | return (2 - 2 * F.cosine_similarity(x, y.detach(), dim=-1)).mean()
91 |
92 | class Preprocess(nn.Module):
93 | def __init__(self, mean, std):
94 | super(Preprocess, self).__init__()
95 | self.mean = mean
96 | self.std = std
97 |
98 | def forward(self, tensors):
99 | return (tensors - self.mean) / self.std
100 |
101 | class MLP(nn.Module):
102 | def __init__(self, in_dim, hidden_dim=4096, out_dim=256):
103 | super().__init__()
104 |
105 | self.layer1 = nn.Sequential(
106 | nn.Linear(in_dim, hidden_dim),
107 | nn.BatchNorm1d(hidden_dim),
108 | nn.ReLU(inplace=True)
109 | )
110 |
111 | self.layer2 = nn.Linear(hidden_dim, out_dim)
112 |
113 | def forward(self, x):
114 | x = self.layer1(x)
115 | x = self.layer2(x)
116 | return x
117 |
118 | def BYOL(train_X, enc_o, proj_o, pred_o, enc_t, proj_t, config, log_dir, start_epoch):
119 |
120 | train_n = train_X.shape[0]
121 | size = train_X.shape[2]
122 | nc = train_X.shape[1]
123 |
124 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
125 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
126 |
127 | params = itertools.chain(enc_o.parameters(), proj_o.parameters(), pred_o.parameters())
128 | net_o = nn.Sequential(Preprocess(mean, std), enc_o, proj_o)
129 | net_t = nn.Sequential(Preprocess(mean, std), enc_t, proj_t)
130 |
131 | t = get_t(size, config['t'])
132 | lr_schedule = get_lr_schedule(config)
133 | opt = get_optimizer(config['optim'], params)
134 |
135 | if start_epoch == 0:
136 | it, rt = 0, 0
137 | else:
138 | it, rt = load_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, log_dir + '/{}.pt'.format(start_epoch))
139 |
140 | enc_o.train()
141 | proj_o.train()
142 | pred_o.train()
143 |
144 | enc_t.train()
145 | proj_t.train()
146 |
147 | epoch_n_iter = int(np.ceil(train_n / config['bs']))
148 | while True:
149 | if it % epoch_n_iter == 0:
150 | train_X = train_X[torch.randperm(train_n)]
151 |
152 | i = it % epoch_n_iter
153 | it += 1
154 |
155 | s = time.time()
156 |
157 | X = train_X[i * config['bs']:(i + 1) * config['bs']]
158 |
159 | X_v1 = t(X).cuda()
160 | X_v2 = t(X).cuda()
161 |
162 | Z_v1_o, Z_v2_o = net_o(X_v1), net_o(X_v2)
163 | P_v1_o, P_v2_o = pred_o(Z_v1_o), pred_o(Z_v2_o)
164 |
165 | with torch.no_grad():
166 | Z_v1_t, Z_v2_t = net_t(X_v1), net_t(X_v2)
167 |
168 | loss = D(P_v1_o, Z_v2_t) / 2 + D(P_v2_o, Z_v1_t) / 2
169 |
170 | # Update
171 | opt.param_groups[0]['lr'] = lr_schedule(it)
172 | opt.zero_grad()
173 | loss.backward()
174 | opt.step()
175 |
176 | # Online network update
177 | with torch.no_grad():
178 |
179 | for p_o, p_t in zip(enc_o.parameters(), enc_t.parameters()):
180 | p_t.data = p_t.data * config['momentum'] + p_o.detach().data * (1 - config['momentum'])
181 |
182 | for p_o, p_t in zip(proj_o.parameters(), proj_t.parameters()):
183 | p_t.data = p_t.data * config['momentum'] + p_o.detach().data * (1 - config['momentum'])
184 |
185 | e = time.time()
186 | rt += (e - s)
187 |
188 | if it % config['p_iter'] == 0:
189 | save_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, it, rt, False, log_dir + '/curr.pt')
190 | print('Epoch : {:.3f} | Loss : {:.3f} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, loss.item(), lr_schedule(it), rt))
191 |
192 | if it % (epoch_n_iter * config['s_epoch']) == 0:
193 | save_ckpt(enc_o, proj_o, pred_o, enc_t, proj_t, opt, it, rt, True, log_dir + '/{}.pt'.format(it // epoch_n_iter))
194 |
195 | if it >= config['its']:
196 | break
197 |
198 | def run(log_dir, config_dir, start_epoch, device):
199 |
200 | if not os.path.isdir(log_dir):
201 | os.makedirs(log_dir)
202 |
203 | config = load_config(config_dir)
204 |
205 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data'])
206 |
207 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs']))
208 |
209 | if 'epochs' in config:
210 | config['its'] = epoch_n_iter * config['epochs']
211 |
212 | if 'p_epoch' in config:
213 | config['p_iter'] = epoch_n_iter * config['p_epoch']
214 |
215 | enc_o, feature_dim = init_enc(config['net'], device)
216 | proj_o = MLP(feature_dim).cuda(device)
217 | pred_o = MLP(256).cuda(device)
218 |
219 | enc_t, feature_dim = init_enc(config['net'], device)
220 | proj_t = MLP(feature_dim).cuda(device)
221 |
222 | for p_o, p_t in zip(enc_o.parameters(), enc_t.parameters()):
223 | p_t.data.copy_(p_t.data)
224 | p_t.data.requires_grad = False
225 |
226 | for p_o, p_t in zip(proj_o.parameters(), proj_t.parameters()):
227 | p_t.data.copy_(p_t.data)
228 | p_t.data.requires_grad = False
229 |
230 | print('Running BYOL from epoch {}'.format(start_epoch))
231 |
232 | with torch.cuda.device(device):
233 | BYOL(train_X, enc_o, proj_o, pred_o, enc_t, proj_t, config, log_dir, start_epoch)
234 |
235 | print('Finished!')
--------------------------------------------------------------------------------
/libs/evaluation.py:
--------------------------------------------------------------------------------
1 | from libs.utils import log_softmax, load_config, load_data
2 | from libs.net import init_enc_proj
3 | from libs.models import LIN
4 |
5 | from tqdm import tqdm
6 |
7 | import matplotlib.pyplot as plt
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | import torch.nn as nn
11 | import numpy as np
12 |
13 | import itertools
14 | import torch
15 |
16 | # Use this function to get linear evaluation accuracy
17 | def eval_acc(ckpt_dir, config_dir, eval_config, data_config, device):
18 |
19 | config = load_config(config_dir)
20 |
21 | train_X, train_y, test_X, test_y, n_classes = load_data(data_config)
22 |
23 | if eval_config['standardize']:
24 | nc = train_X.shape[1]
25 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
26 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
27 | else:
28 | mean = 0.5
29 | std = 0.5
30 |
31 | enc, _, _ = init_enc_proj(config['net'], device)
32 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict'])
33 | net = nn.Sequential(Preprocess(mean, std), enc)
34 | net.eval()
35 |
36 | lin = LIN(512, n_classes).cuda(device)
37 |
38 | with torch.cuda.device(device):
39 | train(lin, net, train_X, train_y, test_X, test_y, n_classes, eval_config)
40 |
41 | # Use this function to get linear transfer accuracy
42 | def transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device):
43 |
44 | config = load_config(config_dir)
45 |
46 | if eval_config['standardize']:
47 | train_X, train_y, test_X, test_y, n_classes = load_data(source_data_config)
48 | nc = train_X.shape[1]
49 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
50 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
51 | else:
52 | mean = 0.5
53 | std = 0.5
54 |
55 | train_X, train_y, test_X, test_y, n_classes = load_data(target_data_config)
56 |
57 | enc, _, _ = init_enc_proj(config['net'], device)
58 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict'])
59 | net = nn.Sequential(Preprocess(mean, std), enc)
60 | net.eval()
61 |
62 | lin = LIN(512, n_classes).cuda(device)
63 |
64 | with torch.cuda.device(device):
65 | train(lin, net, train_X, train_y, test_X, test_y, n_classes, eval_config)
66 |
67 | # Use this function to get knn evaluation accuracy
68 | def eval_knn_acc(ckpt_dir, config_dir, eval_config, data_config, device):
69 | config = load_config(config_dir)
70 | train_X, train_y, test_X, test_y, n_classes = load_data(data_config)
71 |
72 | if eval_config['standardize']:
73 | nc = train_X.shape[1]
74 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
75 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
76 | else:
77 | mean = 0.5
78 | std = 0.5
79 |
80 | enc, _, _ = init_enc_proj(config['net'], device)
81 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict'])
82 | net = nn.Sequential(Preprocess(mean, std), enc)
83 | net.eval()
84 |
85 | with torch.cuda.device(device):
86 | test_n = test_X.shape[0]
87 | train_Z, test_Z = embed_dataset(net, train_X, test_X)
88 |
89 | acc = 0
90 | N = int(np.ceil(test_n / eval_config['bs']))
91 | for i in tqdm(range(N)):
92 | Z = test_Z[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda()
93 | y = test_y[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda()
94 |
95 | D = (Z[:,None] - train_Z[None].cuda()).norm(dim=2).cpu()
96 | w, inds = D.topk(eval_config['K'], dim=1, largest=False)
97 |
98 | v = train_y[inds]
99 | a = v.reshape(-1)
100 | a = F.one_hot(a, num_classes=n_classes)
101 | a = a.reshape(v.shape[0],v.shape[1],n_classes)
102 | weight_pred = a / w[...,None]
103 | weight_pred = weight_pred.sum(dim=1)
104 | wp = weight_pred.argmax(dim=1)
105 | acc += (wp.cuda() == y).sum()
106 |
107 | print(acc / ((i + 1) * eval_config['bs']))
108 |
109 | # Use this function to get knn transfer accuracy
110 | def eval_knn_transfer_acc(ckpt_dir, config_dir, eval_config, source_data_config, target_data_config, device):
111 | config = load_config(config_dir)
112 |
113 | if eval_config['standardize']:
114 | train_X, train_y, test_X, test_y, n_classes = load_data(source_data_config)
115 | nc = train_X.shape[1]
116 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
117 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda(device)
118 | else:
119 | mean = 0.5
120 | std = 0.5
121 |
122 | train_X, train_y, test_X, test_y, n_classes = load_data(target_data_config)
123 |
124 | enc, _, _ = init_enc_proj(config['net'], device)
125 | enc.load_state_dict(torch.load(ckpt_dir)['enc_state_dict'])
126 | net = nn.Sequential(Preprocess(mean, std), enc)
127 | net.eval()
128 |
129 | with torch.cuda.device(device):
130 | test_n = test_X.shape[0]
131 | train_Z, test_Z = embed_dataset(net, train_X, test_X)
132 |
133 | acc = 0
134 | N = int(np.ceil(test_n / eval_config['bs']))
135 | for i in tqdm(range(N)):
136 | Z = test_Z[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda()
137 | y = test_y[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda()
138 |
139 | D = (Z[:,None] - train_Z[None].cuda()).norm(dim=2).cpu()
140 | w, inds = D.topk(eval_config['K'], dim=1, largest=False)
141 |
142 | v = train_y[inds]
143 | a = v.reshape(-1)
144 | a = F.one_hot(a, num_classes=n_classes)
145 | a = a.reshape(v.shape[0],v.shape[1],n_classes)
146 | weight_pred = a / w[...,None]
147 | weight_pred = weight_pred.sum(dim=1)
148 | wp = weight_pred.argmax(dim=1)
149 | acc += (wp.cuda() == y).sum()
150 |
151 | print(acc / ((i + 1) * eval_config['bs']))
152 |
153 | def train(lin, net, train_X, train_y, test_X, test_y, n_classes, eval_config):
154 |
155 | train_n = train_X.shape[0]
156 | train_Z, test_Z = embed_dataset(net, train_X, test_X)
157 | opt = optim.Adam(lin.parameters(), lr=eval_config['lr'])
158 |
159 | for epoch in range(eval_config['epochs']):
160 | for i in range(int(np.ceil(train_n / eval_config['bs']))):
161 |
162 | Z = train_Z[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda()
163 | y = train_y[i * eval_config['bs']:(i + 1) * eval_config['bs']].cuda()
164 |
165 | log_p = lin(Z)
166 |
167 | loss = -(F.one_hot(y, num_classes=n_classes) * log_softmax(log_p)).mean()
168 |
169 | opt.zero_grad()
170 | loss.backward()
171 | opt.step()
172 |
173 | if (epoch + 1) % eval_config['p_epoch'] == 0:
174 | print('Epoch {} | Accuracy : {:.3f}'.format(epoch + 1, accuracy(lin, test_Z, test_y)))
175 |
176 | class Preprocess(nn.Module):
177 | def __init__(self, mean, std):
178 | super(Preprocess, self).__init__()
179 | self.mean = mean
180 | self.std = std
181 |
182 | def forward(self, tensors):
183 | return (tensors - self.mean) / self.std
184 |
185 | def embed_dataset(net, train_X, test_X, bs=64):
186 |
187 | train_n = train_X.shape[0]
188 | test_n = test_X.shape[0]
189 | nc = train_X.shape[1]
190 |
191 | train_Z = []
192 | for i in tqdm(range(int(np.ceil(train_n / bs)))):
193 | batch = train_X[i * bs:(i + 1) * bs].cuda()
194 | train_Z.append(net(batch).detach().cpu())
195 | train_Z = torch.cat(train_Z, dim=0)
196 |
197 | test_Z = []
198 | for i in tqdm(range(int(np.ceil(test_n / bs)))):
199 | batch = test_X[i * bs:(i + 1) * bs].cuda()
200 | test_Z.append(net(batch).detach().cpu())
201 | test_Z = torch.cat(test_Z, dim=0)
202 |
203 | return train_Z, test_Z
204 |
205 | def accuracy(clf, test_Z, test_y, bs=500):
206 |
207 | test_n = test_Z.shape[0]
208 |
209 | acc = 0
210 | for i in range(int(np.ceil(test_n / bs))):
211 | Z = test_Z[i * bs:(i + 1) * bs].cuda()
212 | y = test_y[i * bs:(i + 1) * bs].cuda()
213 | log_p = clf(Z)
214 | acc += (torch.argmax(log_p, dim=1) == y).sum() / test_n
215 |
216 | return acc.item()
--------------------------------------------------------------------------------
/libs/ebclr.py:
--------------------------------------------------------------------------------
1 | from libs.utils import MSGLD, dist, logsumexp, log_softmax, load_data, get_t, get_optimizer, get_lr_schedule, load_config
2 | from libs.net import init_enc_proj
3 |
4 | import torchvision.transforms.functional as TF
5 | import torchvision.transforms as transforms
6 | import torch.nn.functional as F
7 | import matplotlib.pyplot as plt
8 | import torch.optim as optim
9 | import torch.nn as nn
10 | import numpy as np
11 |
12 | import torchvision
13 | import itertools
14 | import torch
15 | import time
16 | import os
17 |
18 | def save_ckpt(buffer, enc, proj, opt, iteration, runtime, save_buffer, ckpt_path):
19 |
20 | checkpoint = {
21 | 'iteration' : iteration,
22 | 'runtime' : runtime,
23 | 'opt' : opt.state_dict()
24 | }
25 |
26 | if save_buffer:
27 | checkpoint['buffer'] = buffer.buffer
28 | checkpoint['counter'] = buffer.counter
29 |
30 | if isinstance(enc, nn.DataParallel):
31 | checkpoint['enc_state_dict'] = enc.module.state_dict()
32 | else:
33 | checkpoint['enc_state_dict'] = enc.state_dict()
34 |
35 | if isinstance(proj, nn.DataParallel):
36 | checkpoint['proj_state_dict'] = proj.module.state_dict()
37 | else:
38 | checkpoint['proj_state_dict'] = proj.state_dict()
39 |
40 | torch.save(checkpoint, ckpt_path)
41 |
42 | def load_ckpt(enc, proj, buffer, opt, ckpt_path):
43 | ckpt = torch.load(ckpt_path)
44 |
45 | if isinstance(enc, nn.DataParallel):
46 | enc.module.load_state_dict(ckpt['enc_state_dict'])
47 | else:
48 | enc.load_state_dict(ckpt['enc_state_dict'])
49 |
50 | if isinstance(proj, nn.DataParallel):
51 | proj.module.load_state_dict(ckpt['proj_state_dict'])
52 | else:
53 | proj.load_state_dict(ckpt['proj_state_dict'])
54 |
55 | buffer.buffer = ckpt['buffer']
56 | buffer.counter = ckpt['counter']
57 |
58 | opt.load_state_dict(ckpt['opt'])
59 |
60 | return ckpt['iteration'], ckpt['runtime']
61 |
62 | def transform_data(x, t, bs):
63 | n_iter = int(np.ceil(x.shape[0] / bs))
64 | return torch.cat([t(x[j * bs:(j + 1) * bs]) for j in range(n_iter)], dim=0)
65 |
66 | def similarity_f(x, y, normalize):
67 |
68 | if normalize:
69 | x = x / (x.norm(dim=1, keepdim=True) + 1e-10)
70 | y = y / (y.norm(dim=1, keepdim=True) + 1e-10)
71 |
72 | return -(x[:,None] - y[None]).square().sum(dim=2)
73 |
74 | class SimCLR_Loss(nn.Module):
75 |
76 | def __init__(self, normalize, temperature):
77 | super(SimCLR_Loss, self).__init__()
78 | self.normalize = normalize
79 | self.temperature = temperature
80 | self.criterion = nn.CrossEntropyLoss(reduction="sum")
81 |
82 | def mask_correlated_samples(self, batch_size):
83 | N = 2 * batch_size
84 | mask = torch.ones((N, N), dtype=bool)
85 | mask = mask.fill_diagonal_(0)
86 |
87 | for i in range(batch_size):
88 | mask[i, batch_size + i] = 0
89 | mask[batch_size + i, i] = 0
90 | return mask
91 |
92 | def forward(self, z_i, z_j):
93 |
94 | batch_size = z_i.shape[0]
95 |
96 | N = 2 * batch_size
97 |
98 | z = torch.cat((z_i, z_j), dim=0)
99 |
100 | sim = similarity_f(z, z, self.normalize) / self.temperature
101 |
102 | sim_i_j = torch.diag(sim, batch_size)
103 | sim_j_i = torch.diag(sim, -batch_size)
104 |
105 | mask = self.mask_correlated_samples(batch_size)
106 | positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
107 | negative_samples = sim[mask].reshape(N, -1)
108 |
109 | labels = torch.from_numpy(np.array([0] * N)).reshape(-1).to(positive_samples.device).long()
110 |
111 | logits = torch.cat((positive_samples, negative_samples), dim=1)
112 | loss = self.criterion(logits, labels)
113 | loss /= N
114 |
115 | return loss
116 |
117 | class Buffer:
118 |
119 | def __init__(self, train_X, t, buffer_config):
120 | self.train_X = train_X
121 | self.config = buffer_config
122 | self.t = t
123 |
124 | self.train_n = train_X.shape[0]
125 | self.size = train_X.shape[2]
126 | self.nc = train_X.shape[1]
127 |
128 | self.idx = 0
129 |
130 | self.counter = torch.ones(size=[self.config['size'],1,1,1]) * 5.0
131 | self.init_buffer()
132 |
133 | def __get_sample__(self, n_samples):
134 | r_idx = torch.randint(self.train_X.shape[0], size=[n_samples])
135 | samples = self.train_X[r_idx]
136 | samples = transform_data(samples, self.t, self.config['bs'])
137 | return samples
138 |
139 | def __get_rand__(self, n_samples):
140 | if 'CD_ratio' in self.config:
141 | samples = self.__get_sample__(n_samples)
142 | return (self.config['CD_ratio'] * samples + (1.0 - self.config['CD_ratio']) * torch.rand_like(samples)) * 2.0 - 1.0
143 | else:
144 | return torch.rand(size=[n_samples, self.nc, self.size, self.size]) * 2.0 - 1.0
145 |
146 | def init_buffer(self):
147 | self.buffer = self.__get_rand__(self.config['size'])
148 |
149 | def sample(self, n_samples):
150 | self.idx = torch.randint(self.config['size'], size=[n_samples])
151 | sample = self.buffer[self.idx]
152 | count = self.counter[self.idx].clone()
153 |
154 | r_idx = torch.randint(n_samples, size=[int(n_samples * self.config['rho'])])
155 | sample[r_idx] = self.__get_rand__(n_samples)[r_idx]
156 | count[r_idx] = 0.0
157 |
158 | self.counter[self.idx] = (count + 1.0)
159 |
160 | return sample.cuda(), count.cuda()
161 |
162 | def update(self, samples):
163 | self.buffer[self.idx] = samples.detach().clone().cpu()
164 |
165 | def shuffled_idx(batch_size):
166 | shuffled_idxs = torch.randperm(batch_size).long().cuda()
167 | reverse_idxs = torch.zeros(batch_size).long().cuda()
168 | value = torch.arange(batch_size).long().long().cuda()
169 | reverse_idxs.index_copy_(0, shuffled_idxs, value)
170 | return shuffled_idxs, reverse_idxs
171 |
172 | def EBCLR(train_X, enc, proj, config, log_dir, start_epoch):
173 |
174 | train_n = train_X.shape[0]
175 | size = train_X.shape[2]
176 | nc = train_X.shape[1]
177 |
178 | mean = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
179 | std = train_X.transpose(1,0).flatten(start_dim=1).mean(dim=1).reshape(1,nc,1,1).cuda()
180 |
181 | simclr_loss = SimCLR_Loss(config['normalize'], config['temperature'])
182 | logits = lambda x, y : similarity_f(x, y, config['normalize']) / config['temperature']
183 | log_pdf = lambda x, y : logsumexp(logits(x, y))
184 |
185 | params = itertools.chain(enc.parameters(), proj.parameters())
186 | net = lambda x : proj(enc(x))
187 |
188 | t = get_t(size, config['t'])
189 | buffer = Buffer(train_X, t, config['buffer'])
190 | lr_schedule = get_lr_schedule(config)
191 | opt = get_optimizer(config['optim'], params)
192 | sgld = MSGLD(config['sgld'])
193 |
194 | if start_epoch == 0:
195 | it, rt = 0, 0
196 | else:
197 | it, rt = load_ckpt(enc, proj, buffer, opt, log_dir + '/{}.pt'.format(start_epoch))
198 |
199 | enc.train()
200 | proj.train()
201 |
202 | epoch_n_iter = int(np.ceil(train_n / config['bs']))
203 | while True:
204 | if it % epoch_n_iter == 0:
205 | idx = torch.randperm(train_n)
206 | V1 = transform_data(train_X, t, config['bs'])[idx]
207 | V2 = transform_data(train_X, t, config['bs'])[idx]
208 |
209 | i = it % epoch_n_iter
210 | it += 1
211 |
212 | s = time.time()
213 |
214 | X_v1 = V1[i * config['bs']:(i + 1) * config['bs']].cuda() * 2.0 - 1.0
215 | X_v2 = V2[i * config['bs']:(i + 1) * config['bs']].cuda() * 2.0 - 1.0
216 |
217 | if config['net']['use_bn']:
218 | shuffled_idxs, reverse_idxs = shuffled_idx(X_v2.shape[0])
219 | Z_v1, Z_v2 = net(X_v1), net(X_v2[shuffled_idxs])[reverse_idxs]
220 | else:
221 | Z_v1, Z_v2 = net(X_v1), net(X_v2)
222 |
223 | if 'neg_bs' in config:
224 | X_init, count = buffer.sample(config['neg_bs'])
225 | else:
226 | X_init, count = buffer.sample(X_v1.shape[0])
227 |
228 | X_n = sgld(lambda x : log_pdf(net(x), Z_v2.detach().clone()), X_init, count)
229 | buffer.update(X_n)
230 | Z_n = net(X_n)
231 |
232 | log_pdf_d = log_pdf(Z_v1, Z_v2)
233 | log_pdf_n = log_pdf(Z_n, Z_v2)
234 | gen_loss = -log_pdf_d.mean() + log_pdf_n.mean() + config['lmda1'] * (Z_v1.square().sum() + Z_n.square().sum())
235 |
236 | loss = config['lmda2'] * gen_loss + simclr_loss(Z_v1, Z_v2)
237 |
238 | # Update
239 | opt.param_groups[0]['lr'] = lr_schedule(it)
240 | opt.zero_grad()
241 | loss.backward()
242 |
243 | if config['optim']['optimizer'] == 'adam':
244 | for i, param in enumerate(params):
245 | std = opt.state_dict()['state'][i]['exp_avg_sq'].sqrt()
246 | param.grad = clamp(param.grad, - 3.0 * std, 3.0 * std)
247 |
248 | opt.step()
249 |
250 | e = time.time()
251 | rt += (e - s)
252 |
253 | if it % config['p_iter'] == 0:
254 | save_ckpt(buffer, enc, proj, opt, it, rt, False, log_dir + '/curr.pt')
255 |
256 | pdf_d = torch.exp(log_pdf_d).mean().item()
257 | pdf_n = torch.exp(log_pdf_n).mean().item()
258 |
259 | print('Epoch : {:.3f} | Data PDF : {:.3e} | Noise PDF : {:.3e} | LR : {:.3e} | Time : {:.3f}'.format(it / epoch_n_iter, pdf_d, pdf_n, lr_schedule(it), rt))
260 |
261 | if it % (epoch_n_iter * config['s_epoch']) == 0:
262 | save_ckpt(buffer, enc, proj, opt, it, rt, config['save_buffer'], log_dir + '/{}.pt'.format(it // epoch_n_iter))
263 |
264 | if it >= config['its']:
265 | break
266 |
267 | def run(log_dir, config_dir, start_epoch, device):
268 |
269 | if not os.path.isdir(log_dir):
270 | os.makedirs(log_dir)
271 |
272 | config = load_config(config_dir)
273 |
274 | train_X, train_y, test_X, test_y, n_classes = load_data(config['data'])
275 |
276 | epoch_n_iter = int(np.ceil(train_X.shape[0] / config['bs']))
277 |
278 | if 'epochs' in config:
279 | config['its'] = epoch_n_iter * config['epochs']
280 |
281 | if 'p_epoch' in config:
282 | config['p_iter'] = epoch_n_iter * config['p_epoch']
283 |
284 | enc, proj, _ = init_enc_proj(config['net'], device)
285 |
286 | print('Starting EBCLR from epoch {}'.format(start_epoch))
287 |
288 | with torch.cuda.device(device):
289 | flag = EBCLR(train_X, enc, proj, config, log_dir, start_epoch)
290 |
291 | print('Finished!')
--------------------------------------------------------------------------------