├── README.md
├── Results.ipynb
├── __init__.py
├── discriminative
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── coreset.cpython-36.pyc
│ ├── multihead_models.cpython-36.pyc
│ ├── utils.cpython-36.pyc
│ └── vcl.cpython-36.pyc
├── data
│ └── mnist.pkl.gz
├── final_results
│ ├── VCL-split.npy
│ ├── VCL.npy
│ ├── VGR-split.npy
│ ├── kcen-VCL-200.npy
│ ├── kcen-coreset-only-split.npy
│ ├── kcen-coreset-only200.npy
│ ├── kcenVCL-split.npy
│ ├── only-coreset-1000.npy
│ ├── only-coreset-200.npy
│ ├── only-coreset-2500.npy
│ ├── only-coreset-400.npy
│ ├── only-coreset-5000.npy
│ ├── pca-kcen-coreset-only-split.npy
│ ├── pca-kcenVCL-split.npy
│ ├── rand-VCL-1000.npy
│ ├── rand-VCL-200.npy
│ ├── rand-VCL-2500.npy
│ ├── rand-VCL-400.npy
│ ├── rand-VCL-5000.npy
│ ├── rand-coreset-only-split.npy
│ └── randVCL-split.npy
├── misc
│ ├── KL.pdf
│ ├── Loss.pdf
│ ├── VGR.png
│ ├── gan_pics.png
│ ├── images
│ │ ├── 0.png
│ │ ├── 1.png
│ │ ├── 2.png
│ │ ├── 3.png
│ │ └── 4.png
│ ├── permuted_mnist_coreset_sizes.png
│ ├── permuted_mnist_main.png
│ ├── split_mnist_main_part1.png
│ ├── split_mnist_main_part2.png
│ ├── split_mnist_pca_part1.png
│ └── split_mnist_pca_part2.png
├── results
│ ├── VCL-split.npy
│ ├── VCL.npy
│ ├── VGR-split.npy
│ ├── kcen-VCL-200.npy
│ ├── kcen-coreset-only-split.npy
│ ├── kcen-coreset-only200.npy
│ ├── kcenVCL-split.npy
│ ├── only-coreset-1000.npy
│ ├── only-coreset-200.npy
│ ├── only-coreset-2500.npy
│ ├── only-coreset-400.npy
│ ├── only-coreset-5000.npy
│ ├── pca-kcen-coreset-only-split.npy
│ ├── pca-kcenVCL-split.npy
│ ├── rand-VCL-1000.npy
│ ├── rand-VCL-200.npy
│ ├── rand-VCL-2500.npy
│ ├── rand-VCL-400.npy
│ ├── rand-VCL-5000.npy
│ ├── rand-coreset-only-split.npy
│ └── randVCL-split.npy
├── run_permuted.py
├── run_permuted_just_coreset.py
├── run_split.py
├── run_split_just_coreset.py
└── utils
│ ├── DataGenerator.py
│ ├── GAN.py
│ ├── __init__.py
│ ├── __pycache__
│ ├── DataGenerator.cpython-36.pyc
│ ├── __init__.cpython-36.pyc
│ ├── coreset.cpython-36.pyc
│ ├── multihead_models.cpython-36.pyc
│ ├── test.cpython-36.pyc
│ └── vcl.cpython-36.pyc
│ ├── coreset.py
│ ├── multihead_models.py
│ ├── test.py
│ └── vcl.py
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Continual Learning (VCL)
2 | An implementation of the Variational Continual Learning (VCL) algorithms proposed by Nguyen, Li, Bui, and Turner (ICLR 2018).
3 |
4 | ```
5 | @inproceedings{nguyen2018variational,
6 | title = {Variational Continual Learning},
7 | author = {Nguyen, Cuong V. and Li, Yingzhen and Bui, Thang D. and Turner, Richard E.},
8 | booktitle = {International Conference on Learning Representations},
9 | year = {2018}
10 | }
11 | ```
12 | **To run the Permuted MNIST experiment:**
13 |
14 | python run_permuted.py
15 |
16 | **To run the Split MNIST experiment:**
17 |
18 | python run_split.py
19 |
20 | **Requirements:**
21 |
22 | - Torch 1.0
23 | - Python 3.6
24 |
25 |
26 | ## Results
27 | ### VCL in Deep discriminative models
28 |
29 |
30 | Permuted MNIST
31 | 
32 | 
33 |
34 |
35 | Split MNIST
36 | 
37 | 
38 | ... with Variational Generative Replay (VGR):
39 | 
40 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/__init__.py
--------------------------------------------------------------------------------
/discriminative/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__init__.py
--------------------------------------------------------------------------------
/discriminative/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/__pycache__/coreset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/coreset.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/__pycache__/multihead_models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/multihead_models.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/__pycache__/vcl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/vcl.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/data/mnist.pkl.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/data/mnist.pkl.gz
--------------------------------------------------------------------------------
/discriminative/final_results/VCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/VCL-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/VCL.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/VCL.npy
--------------------------------------------------------------------------------
/discriminative/final_results/VGR-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/VGR-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/kcen-VCL-200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcen-VCL-200.npy
--------------------------------------------------------------------------------
/discriminative/final_results/kcen-coreset-only-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcen-coreset-only-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/kcen-coreset-only200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcen-coreset-only200.npy
--------------------------------------------------------------------------------
/discriminative/final_results/kcenVCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcenVCL-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/only-coreset-1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-1000.npy
--------------------------------------------------------------------------------
/discriminative/final_results/only-coreset-200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-200.npy
--------------------------------------------------------------------------------
/discriminative/final_results/only-coreset-2500.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-2500.npy
--------------------------------------------------------------------------------
/discriminative/final_results/only-coreset-400.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-400.npy
--------------------------------------------------------------------------------
/discriminative/final_results/only-coreset-5000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-5000.npy
--------------------------------------------------------------------------------
/discriminative/final_results/pca-kcen-coreset-only-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/pca-kcen-coreset-only-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/pca-kcenVCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/pca-kcenVCL-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/rand-VCL-1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-1000.npy
--------------------------------------------------------------------------------
/discriminative/final_results/rand-VCL-200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-200.npy
--------------------------------------------------------------------------------
/discriminative/final_results/rand-VCL-2500.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-2500.npy
--------------------------------------------------------------------------------
/discriminative/final_results/rand-VCL-400.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-400.npy
--------------------------------------------------------------------------------
/discriminative/final_results/rand-VCL-5000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-5000.npy
--------------------------------------------------------------------------------
/discriminative/final_results/rand-coreset-only-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-coreset-only-split.npy
--------------------------------------------------------------------------------
/discriminative/final_results/randVCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/randVCL-split.npy
--------------------------------------------------------------------------------
/discriminative/misc/KL.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/KL.pdf
--------------------------------------------------------------------------------
/discriminative/misc/Loss.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/Loss.pdf
--------------------------------------------------------------------------------
/discriminative/misc/VGR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/VGR.png
--------------------------------------------------------------------------------
/discriminative/misc/gan_pics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/gan_pics.png
--------------------------------------------------------------------------------
/discriminative/misc/images/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/0.png
--------------------------------------------------------------------------------
/discriminative/misc/images/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/1.png
--------------------------------------------------------------------------------
/discriminative/misc/images/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/2.png
--------------------------------------------------------------------------------
/discriminative/misc/images/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/3.png
--------------------------------------------------------------------------------
/discriminative/misc/images/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/4.png
--------------------------------------------------------------------------------
/discriminative/misc/permuted_mnist_coreset_sizes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/permuted_mnist_coreset_sizes.png
--------------------------------------------------------------------------------
/discriminative/misc/permuted_mnist_main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/permuted_mnist_main.png
--------------------------------------------------------------------------------
/discriminative/misc/split_mnist_main_part1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_main_part1.png
--------------------------------------------------------------------------------
/discriminative/misc/split_mnist_main_part2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_main_part2.png
--------------------------------------------------------------------------------
/discriminative/misc/split_mnist_pca_part1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_pca_part1.png
--------------------------------------------------------------------------------
/discriminative/misc/split_mnist_pca_part2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_pca_part2.png
--------------------------------------------------------------------------------
/discriminative/results/VCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/VCL-split.npy
--------------------------------------------------------------------------------
/discriminative/results/VCL.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/VCL.npy
--------------------------------------------------------------------------------
/discriminative/results/VGR-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/VGR-split.npy
--------------------------------------------------------------------------------
/discriminative/results/kcen-VCL-200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcen-VCL-200.npy
--------------------------------------------------------------------------------
/discriminative/results/kcen-coreset-only-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcen-coreset-only-split.npy
--------------------------------------------------------------------------------
/discriminative/results/kcen-coreset-only200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcen-coreset-only200.npy
--------------------------------------------------------------------------------
/discriminative/results/kcenVCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcenVCL-split.npy
--------------------------------------------------------------------------------
/discriminative/results/only-coreset-1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-1000.npy
--------------------------------------------------------------------------------
/discriminative/results/only-coreset-200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-200.npy
--------------------------------------------------------------------------------
/discriminative/results/only-coreset-2500.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-2500.npy
--------------------------------------------------------------------------------
/discriminative/results/only-coreset-400.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-400.npy
--------------------------------------------------------------------------------
/discriminative/results/only-coreset-5000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-5000.npy
--------------------------------------------------------------------------------
/discriminative/results/pca-kcen-coreset-only-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/pca-kcen-coreset-only-split.npy
--------------------------------------------------------------------------------
/discriminative/results/pca-kcenVCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/pca-kcenVCL-split.npy
--------------------------------------------------------------------------------
/discriminative/results/rand-VCL-1000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-1000.npy
--------------------------------------------------------------------------------
/discriminative/results/rand-VCL-200.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-200.npy
--------------------------------------------------------------------------------
/discriminative/results/rand-VCL-2500.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-2500.npy
--------------------------------------------------------------------------------
/discriminative/results/rand-VCL-400.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-400.npy
--------------------------------------------------------------------------------
/discriminative/results/rand-VCL-5000.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-5000.npy
--------------------------------------------------------------------------------
/discriminative/results/rand-coreset-only-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-coreset-only-split.npy
--------------------------------------------------------------------------------
/discriminative/results/randVCL-split.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/randVCL-split.npy
--------------------------------------------------------------------------------
/discriminative/run_permuted.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import discriminative.utils.vcl as vcl
3 | import discriminative.utils.coreset as coreset
4 | from discriminative.utils.DataGenerator import PermutedMnistGenerator
5 |
6 | hidden_size = [100, 100]
7 | batch_size = 256
8 | no_epochs = 100
9 | single_head = True
10 | num_tasks = 10
11 |
12 | np.random.seed(1)
13 | #Just VCL
14 | coreset_size = 0
15 | data_gen = PermutedMnistGenerator(num_tasks)
16 | vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
17 | coreset.rand_from_batch, coreset_size, batch_size, single_head)
18 | np.save("./results/VCL{}".format(""), vcl_result)
19 | print(vcl_result)
20 |
21 | #VCL + Random Coreset
22 | np.random.seed(1)
23 |
24 | for coreset_size in [200,400,1000,2500,5000]:
25 | data_gen = PermutedMnistGenerator(num_tasks)
26 | rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
27 | coreset.rand_from_batch, coreset_size, batch_size, single_head, gan_bol=True)
28 | np.save("./results/rand-VCL-{}".format(coreset_size), rand_vcl_result)
29 | print(rand_vcl_result)
30 |
31 | #VCL + k-center coreset
32 | np.random.seed(1)
33 | coreset_size = 200
34 | data_gen = PermutedMnistGenerator(num_tasks)
35 | kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
36 | coreset.k_center, coreset_size, batch_size, single_head)
37 | print(kcen_vcl_result)
38 | np.save("./results/kcen-VCL{}".format(coreset_size), kcen_vcl_result)
39 |
40 |
--------------------------------------------------------------------------------
/discriminative/run_permuted_just_coreset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import discriminative.utils.vcl as vcl
3 | import discriminative.utils.coreset as coreset
4 | from discriminative.utils.DataGenerator import PermutedMnistGenerator
5 |
6 | hidden_size = [100, 100]
7 | batch_size = 256
8 | no_epochs = 100
9 | single_head = True
10 | num_tasks = 10
11 |
12 | np.random.seed(0)
13 | #for coreset_size in [400,1000,2500,5000]:
14 | # data_gen = PermutedMnistGenerator(num_tasks)
15 | # vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen,
16 | # coreset.rand_from_batch, coreset_size, batch_size, single_head)
17 | # np.save("./results/only-coreset-{}".format(coreset_size), vcl_result)
18 | # print(vcl_result)
19 |
20 | np.random.seed(0)
21 | coreset_size = 200
22 | data_gen = PermutedMnistGenerator(num_tasks)
23 | kcen_vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen,
24 | coreset.k_center, coreset_size, batch_size, single_head)
25 | print(kcen_vcl_result)
26 | np.save("./results/kcen-coreset-only{}".format(coreset_size), kcen_vcl_result)
27 |
28 |
29 |
--------------------------------------------------------------------------------
/discriminative/run_split.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import discriminative.utils.vcl as vcl
3 | import discriminative.utils.coreset as coreset
4 | from discriminative.utils.DataGenerator import SplitMnistGenerator
5 |
6 | hidden_size = [256, 256]
7 | batch_size = None
8 | no_epochs = 120
9 | single_head = False
10 | run_coreset_only = False
11 | np.random.seed(0)
12 |
13 | #Just VCL
14 | coreset_size = 0
15 | data_gen = SplitMnistGenerator()
16 | vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
17 | coreset.rand_from_batch, coreset_size, batch_size, single_head)
18 | np.save("./results/VCL-split{}".format(""), vcl_result)
19 |
20 | #VCL + Random Coreset
21 | coreset_size = 40
22 | data_gen = SplitMnistGenerator()
23 | rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
24 | coreset.rand_from_batch, coreset_size, batch_size, single_head,gan_bol= True)
25 | print(rand_vcl_result)
26 | np.save("./results/VGR-all-split{}".format(""), rand_vcl_result)
27 |
28 | #VCL + k-center coreset
29 | coreset_size = 40
30 | data_gen = SplitMnistGenerator()
31 | kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen,
32 | coreset.k_center, coreset_size, batch_size, single_head)
33 | print(kcen_vcl_result)
34 | np.save("./results/kcenVCL-split{}".format(""), kcen_vcl_result)
35 |
--------------------------------------------------------------------------------
/discriminative/run_split_just_coreset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import discriminative.utils.vcl as vcl
3 | import discriminative.utils.coreset as coreset
4 | from discriminative.utils.DataGenerator import SplitMnistGenerator
5 |
6 |
7 | hidden_size = [256, 256]
8 | batch_size = None
9 | no_epochs = 120
10 | single_head = False
11 | coreset_size = 40
12 |
13 | np.random.seed(0)
14 | data_gen = SplitMnistGenerator()
15 | rand_vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen,
16 | coreset.rand_from_batch, coreset_size, batch_size, single_head)
17 | print(rand_vcl_result)
18 | np.save("./results/rand-coreset-only-split{}".format(""), rand_vcl_result)
19 |
20 | np.random.seed(0)
21 | data_gen = SplitMnistGenerator()
22 | kcen_vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen,
23 | coreset.k_center, coreset_size, batch_size, single_head)
24 | print(kcen_vcl_result)
25 | np.save("./results/kcen-coreset-only-split{}".format(""), kcen_vcl_result)
26 |
--------------------------------------------------------------------------------
/discriminative/utils/DataGenerator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gzip
3 | import pickle as cp
4 | from copy import deepcopy
5 |
6 |
7 | class PermutedMnistGenerator():
8 | def __init__(self, max_iter=10):
9 |
10 | with gzip.open('data/mnist.pkl.gz', 'rb') as file:
11 | u = cp._Unpickler(file)
12 | u.encoding = 'latin1'
13 | p = u.load()
14 | train_set, valid_set, test_set = p
15 |
16 |
17 | self.X_train = np.vstack((train_set[0], valid_set[0]))
18 | self.Y_train = np.hstack((train_set[1], valid_set[1]))
19 | self.X_test = test_set[0]
20 | self.Y_test = test_set[1]
21 | self.max_iter = max_iter
22 | self.cur_iter = 0
23 |
24 | def get_dims(self):
25 | # Get data input and output dimensions
26 | return self.X_train.shape[1], 10
27 |
28 | def next_task(self):
29 | if self.cur_iter >= self.max_iter:
30 | raise Exception('Number of tasks exceeded!')
31 | else:
32 | np.random.seed(self.cur_iter)
33 | perm_inds = np.arange(self.X_train.shape[1])
34 | np.random.shuffle(perm_inds)
35 |
36 | # Retrieve train data
37 | next_x_train = deepcopy(self.X_train)
38 | next_x_train = next_x_train[:,perm_inds]
39 | next_y_train = self.Y_train
40 |
41 | # Retrieve test data
42 | next_x_test = deepcopy(self.X_test)
43 | next_x_test = next_x_test[:,perm_inds]
44 | next_y_test = self.Y_test
45 |
46 | self.cur_iter += 1
47 |
48 | return next_x_train, next_y_train, next_x_test, next_y_test
49 |
50 |
51 | class SplitMnistGenerator():
52 | def __init__(self):
53 | with gzip.open('data/mnist.pkl.gz', 'rb') as file:
54 | u = cp._Unpickler(file)
55 | u.encoding = 'latin1'
56 | p = u.load()
57 | train_set, valid_set, test_set = p
58 |
59 | self.X_train = np.vstack((train_set[0], valid_set[0]))
60 | self.X_test = test_set[0]
61 | self.train_label = np.hstack((train_set[1], valid_set[1]))
62 | self.test_label = test_set[1]
63 |
64 | self.sets_0 = [0, 2, 4, 6, 8]
65 | self.sets_1 = [1, 3, 5, 7, 9]
66 | self.max_iter = len(self.sets_0)
67 | self.cur_iter = 0
68 |
69 | def get_dims(self):
70 | # Get data input and output dimensions
71 | return self.X_train.shape[1], 2
72 |
73 | def next_task(self):
74 | if self.cur_iter >= self.max_iter:
75 | raise Exception('Number of tasks exceeded!')
76 | else:
77 | # Retrieve train data
78 | train_0_id = np.where(self.train_label == self.sets_0[self.cur_iter])[0]
79 | train_1_id = np.where(self.train_label == self.sets_1[self.cur_iter])[0]
80 | next_x_train = np.vstack((self.X_train[train_0_id], self.X_train[train_1_id]))
81 |
82 | next_y_train = np.vstack((np.ones((train_0_id.shape[0],1 )), np.zeros((train_1_id.shape[0],1 )))).squeeze(-1)
83 |
84 | # Retrieve test data
85 | test_0_id = np.where(self.test_label == self.sets_0[self.cur_iter])[0]
86 | test_1_id = np.where(self.test_label == self.sets_1[self.cur_iter])[0]
87 | next_x_test = np.vstack((self.X_test[test_0_id], self.X_test[test_1_id]))
88 |
89 | next_y_test = np.vstack((np.ones((test_0_id.shape[0],1 )), np.zeros((test_1_id.shape[0], 1)))).squeeze(-1)
90 |
91 | self.cur_iter += 1
92 |
93 | return next_x_train, next_y_train, next_x_test, next_y_test
--------------------------------------------------------------------------------
/discriminative/utils/GAN.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torchvision.transforms as transforms
5 | from torchvision.utils import save_image
6 | from torch.utils.data import DataLoader
7 | from torchvision import datasets
8 | from torch.autograd import Variable
9 | import torch.nn as nn
10 | import torch
11 |
12 | n_epochs = 50
13 | batch_size = 64
14 | lr = 0.0002
15 | b1 = 0.5
16 | b2 = 0.999
17 | n_cpu = 8
18 | latent_dim = 100
19 | num_classes = 2
20 | img_size = 28
21 | channels = 1
22 | sample_interval = 400
23 | threshold = 0.99
24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25 | cuda = True if torch.cuda.is_available() else False
26 | FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
27 | LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
28 |
29 |
30 |
31 | class VGR():
32 | def __init__(self, task_id):
33 | self.task_id = task_id
34 | # Loss functions
35 | self.adversarial_loss = torch.nn.BCELoss()
36 | self.auxiliary_loss = torch.nn.CrossEntropyLoss()
37 | # Initialize generator and discriminator
38 | self.generator = Generator()
39 | self.discriminator = Discriminator()
40 | if cuda:
41 | self.generator.cuda()
42 | self.discriminator.cuda()
43 | self.adversarial_loss.cuda()
44 | self.auxiliary_loss.cuda()
45 |
46 | # Initialize weights
47 | self.generator.apply(weights_init_normal)
48 | self.discriminator.apply(weights_init_normal)
49 |
50 | # Optimizers
51 | self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
52 | self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
53 |
54 |
55 | def train(self, x_train, y_train):
56 | N = x_train.shape[0]
57 | x_train -= 0.5
58 | x_train /= 0.5
59 | for epoch in range(n_epochs):
60 |
61 | total_batch = int(np.ceil(N * 1.0 / batch_size))
62 | perm_inds = np.arange(x_train.shape[0])
63 | np.random.shuffle(perm_inds)
64 | cur_x_train = x_train[perm_inds]
65 | cur_y_train = y_train[perm_inds]
66 | # Loop over all batches
67 | for i in range(total_batch):
68 | start_ind = i*batch_size
69 | end_ind = np.min([(i+1)*batch_size, N])
70 | batch_x = torch.Tensor(cur_x_train[start_ind:end_ind, :]).to(device = device)
71 | batch_y = torch.Tensor(cur_y_train[start_ind:end_ind]).to(device = device)
72 | batch_x = batch_x.reshape(-1,img_size,img_size).unsqueeze(1)
73 | bsize = batch_x.shape[0]
74 |
75 | # Adversarial ground truths
76 | valid = Variable(FloatTensor(bsize, 1).fill_(1.0), requires_grad=False)
77 | fake = Variable(FloatTensor(bsize, 1).fill_(0.0), requires_grad=False)
78 |
79 | # Configure input
80 | real_imgs = Variable(batch_x.type(FloatTensor))
81 | labels = Variable(batch_y.type(LongTensor))
82 |
83 | # -----------------
84 | # Train Generator
85 | # -----------------
86 |
87 | self.optimizer_G.zero_grad()
88 |
89 | # Sample noise and labels as generator input
90 | z = Variable(FloatTensor(np.random.normal(0, 1, (bsize, latent_dim))))
91 |
92 | # Generate a batch of images
93 | gen_imgs = self.generator(z)
94 |
95 | # Loss measures generator's ability to fool the discriminator
96 | validity, _ = self.discriminator(gen_imgs)
97 | g_loss = self.adversarial_loss(validity, valid)
98 |
99 | g_loss.backward()
100 | self.optimizer_G.step()
101 |
102 | # ---------------------
103 | # Train Discriminator
104 | # ---------------------
105 |
106 | self.optimizer_D.zero_grad()
107 |
108 | # Loss for real images
109 | real_pred, real_aux = self.discriminator(real_imgs)
110 | d_real_loss = (self.adversarial_loss(real_pred, valid) + self.auxiliary_loss(real_aux, labels)) / 2
111 |
112 | # Loss for fake images
113 | fake_pred, fake_aux = self.discriminator(gen_imgs.detach())
114 | d_fake_loss = self.adversarial_loss(fake_pred, fake)
115 |
116 | # Total discriminator loss
117 | d_loss = (d_real_loss + d_fake_loss) / 2
118 |
119 | # Calculate discriminator accuracy
120 | pred = real_aux.data.cpu().numpy()
121 | gt = labels.data.cpu().numpy()
122 | d_acc = np.mean(np.argmax(pred, axis=1) == gt)
123 |
124 | d_loss.backward()
125 | self.optimizer_D.step()
126 |
127 | print ("Epoch {}/{}, Discriminator loss: {}, acc: {}%, Generator loss: {}".format(epoch, n_epochs, d_loss.item(), 100 * d_acc, g_loss.item()))
128 |
129 |
130 | self.generator.cpu()
131 | self.discriminator.cpu()
132 | self.adversarial_loss.cpu()
133 | self.auxiliary_loss.cpu()
134 |
135 |
136 |
137 | def generate_samples(self, no_samples, task_id, current_nb = 0):
138 | # Sample noise and labels as generator input
139 | z = Variable(torch.FloatTensor(np.random.normal(0, 1, (no_samples, latent_dim))))
140 | # Generate a batch of images
141 | gen_imgs = self.generator(z)
142 | _,labels = self.discriminator(gen_imgs)
143 | kept_indices = torch.nonzero(torch.where(torch.max(labels,1)[0] < threshold, torch.zeros(labels.shape[0]), torch.ones(labels.shape[0]))).squeeze(1)
144 | print(kept_indices.shape[0])
145 | labels = labels.index_select(0,kept_indices)
146 | gen_imgs = gen_imgs.index_select(0,kept_indices)
147 | if current_nb == 0:
148 | save_image(gen_imgs.data[:25], 'images/%d.png' % task_id, nrow=5, normalize=True)
149 | gen_imgs = gen_imgs.squeeze(1).reshape(kept_indices.shape[0],784)
150 | labels = labels.argmax(1).type(torch.FloatTensor)
151 | gen_imgs = gen_imgs.data.cpu()
152 | labels = labels.data.cpu()
153 |
154 |
155 | new_current_nb = current_nb + kept_indices.shape[0]
156 | if(new_current_nb < no_samples):
157 | new_gen_imgs, new_labels = self.generate_samples(no_samples,task_id, new_current_nb)
158 | gen_imgs = np.vstack((gen_imgs,new_gen_imgs))
159 | labels = np.hstack((labels,new_labels))
160 |
161 | return gen_imgs*0.5+0.5, labels
162 |
163 | def weights_init_normal(m):
164 | classname = m.__class__.__name__
165 | if classname.find('Conv') != -1:
166 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
167 | elif classname.find('BatchNorm') != -1:
168 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
169 | torch.nn.init.constant_(m.bias.data, 0.0)
170 |
171 | class Generator(nn.Module):
172 | def __init__(self):
173 | super(Generator, self).__init__()
174 | self.init_size = img_size // 4 # Initial size before upsampling
175 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2))
176 | self.conv_blocks = nn.Sequential(
177 | nn.BatchNorm2d(128),
178 | nn.Upsample(scale_factor=2),
179 | nn.Conv2d(128, 128, 5, stride=1, padding=2),
180 | nn.BatchNorm2d(128, 0.8),
181 | nn.LeakyReLU(0.2, inplace=True),
182 | nn.Upsample(scale_factor=2),
183 | nn.Conv2d(128, 64, 5, stride=1, padding=2),
184 | nn.BatchNorm2d(64, 0.8),
185 | nn.LeakyReLU(0.2, inplace=True),
186 | nn.Conv2d(64, channels, 5, stride=1, padding=2),
187 | nn.Tanh()
188 | )
189 |
190 | def forward(self, noise):
191 | out = self.l1(noise)
192 | out = out.view(out.shape[0], 128, self.init_size, self.init_size)
193 | img = self.conv_blocks(out)
194 | return img
195 |
196 | class Discriminator(nn.Module):
197 | def __init__(self):
198 | super(Discriminator, self).__init__()
199 |
200 | def discriminator_block(in_filters, out_filters, bn=True):
201 | """Returns layers of each discriminator block"""
202 | block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1),
203 | nn.LeakyReLU(0.2, inplace=True),
204 | nn.Dropout2d(0.25)]
205 | if bn:
206 | block.append(nn.BatchNorm2d(out_filters, 0.8))
207 | return block
208 |
209 | self.conv_blocks = nn.Sequential(
210 | *discriminator_block(channels, 16, bn=False),
211 | *discriminator_block(16, 32),
212 | *discriminator_block(32, 64),
213 | *discriminator_block(64, 128),
214 | )
215 |
216 | # Output layers
217 | self.adv_layer = nn.Sequential( nn.Linear(512, 1),
218 | nn.Sigmoid())
219 | self.aux_layer = nn.Sequential( nn.Linear(512, num_classes),
220 | nn.Softmax())
221 |
222 | def forward(self, img):
223 | out = self.conv_blocks(img)
224 | out = out.view(out.shape[0], -1)
225 | validity = self.adv_layer(out)
226 | label = self.aux_layer(out)
227 |
228 | return validity, label
229 |
230 |
231 |
--------------------------------------------------------------------------------
/discriminative/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__init__.py
--------------------------------------------------------------------------------
/discriminative/utils/__pycache__/DataGenerator.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/DataGenerator.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/utils/__pycache__/coreset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/coreset.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/utils/__pycache__/multihead_models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/multihead_models.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/utils/__pycache__/test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/test.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/utils/__pycache__/vcl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/vcl.cpython-36.pyc
--------------------------------------------------------------------------------
/discriminative/utils/coreset.py:
--------------------------------------------------------------------------------
1 | ####Code by nvcuong ###
2 | #
3 | #
4 | # Nothing to change here to make it work on pytorch
5 |
6 | import sklearn.decomposition as decomp
7 | import numpy as np
8 |
9 | """ Random coreset selection """
10 | def rand_from_batch(x_coreset, y_coreset, x_train, y_train, coreset_size):
11 | # Randomly select from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)
12 | idx = np.random.choice(x_train.shape[0], coreset_size, False)
13 | x_coreset.append(x_train[idx,:])
14 | y_coreset.append(y_train[idx])
15 | x_train = np.delete(x_train, idx, axis=0)
16 | y_train = np.delete(y_train, idx, axis=0)
17 | return x_coreset, y_coreset, x_train, y_train
18 |
19 | """ K-center coreset selection """
20 | def k_center(x_coreset, y_coreset, x_train, y_train, coreset_size):
21 | # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)
22 | dists = np.full(x_train.shape[0], np.inf)
23 | current_id = 0
24 | dists = update_distance(dists, x_train, current_id)
25 | idx = [ current_id ]
26 |
27 | for i in range(1, coreset_size):
28 | current_id = np.argmax(dists)
29 | dists = update_distance(dists, x_train, current_id)
30 | idx.append(current_id)
31 |
32 | x_coreset.append(x_train[idx,:])
33 | y_coreset.append(y_train[idx])
34 | x_train = np.delete(x_train, idx, axis=0)
35 | y_train = np.delete(y_train, idx, axis=0)
36 |
37 | return x_coreset, y_coreset, x_train, y_train
38 |
39 |
40 | """ K-center performed on reduced data by pca coreset selection """
41 | def pca_k_center(x_coreset, y_coreset, x_train, y_train, coreset_size):
42 | # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset)
43 | pca = decomp.PCA(20)
44 | pca.fit(x_train)
45 | x_train_reduced = pca.transform(x_train)
46 | dists = np.full(x_train_reduced.shape[0], np.inf)
47 | current_id = 0
48 | dists = update_distance(dists, x_train_reduced, current_id)
49 | idx = [ current_id ]
50 |
51 | for i in range(1, coreset_size):
52 | current_id = np.argmax(dists)
53 | dists = update_distance(dists, x_train_reduced, current_id)
54 | idx.append(current_id)
55 |
56 | x_coreset.append(x_train[idx,:])
57 | y_coreset.append(y_train[idx])
58 | x_train = np.delete(x_train, idx, axis=0)
59 | y_train = np.delete(y_train, idx, axis=0)
60 |
61 | return x_coreset, y_coreset, x_train, y_train
62 |
63 |
64 |
65 | def attention_like_coreset(x_coreset, y_coreset, x_train, y_train, coreset_size):
66 | """TODO: learning a subnetwork that chooses the coreset (attention-like) """
67 | return x_coreset, y_coreset, x_train, y_train
68 |
69 |
70 | def uncertainty_like_coreset(x_coreset, y_coreset, x_train, y_train, coreset_size):
71 | """TODO: Keeping only the instances that were not classified w.h. certainty """
72 | return x_coreset, y_coreset, x_train, y_train
73 |
74 |
75 | def update_distance(dists, x_train, current_id):
76 | for i in range(x_train.shape[0]):
77 | current_dist = np.linalg.norm(x_train[i,:]-x_train[current_id,:])
78 | dists[i] = np.minimum(current_dist, dists[i])
79 | return dists
80 |
--------------------------------------------------------------------------------
/discriminative/utils/multihead_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | from scipy.stats import truncnorm
6 | from copy import deepcopy
7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8 |
9 | np.random.seed(0)
10 |
11 | # variable initialization functions
12 | def truncated_normal(size, stddev=1, variable = False, mean=0):
13 | mu, sigma = mean, stddev
14 | lower, upper= -2 * sigma, 2 * sigma
15 | X = truncnorm(
16 | (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
17 | X_tensor = torch.Tensor(data = X.rvs(size)).to(device = device)
18 | X_tensor.requires_grad = variable
19 | return X_tensor
20 |
21 | def init_tensor(value, dout, din = 1, variable = False):
22 | if din != 1:
23 | x = value * torch.ones([din, dout]).to(device = device)
24 | else:
25 | x = value * torch.ones([dout]).to(device = device)
26 | x.requires_grad=variable
27 |
28 | return x
29 |
30 | class Cla_NN(object):
31 | def __init__(self, input_size, hidden_size, output_size, training_size):
32 | return
33 |
34 |
35 | def train(self, x_train, y_train, task_idx, no_epochs=1000, batch_size=100, display_epoch=5):
36 | N = x_train.shape[0]
37 | self.training_size = N
38 | if batch_size > N:
39 | batch_size = N
40 |
41 | costs = []
42 | # Training cycle
43 | for epoch in range(no_epochs):
44 | perm_inds = np.arange(x_train.shape[0])
45 | np.random.shuffle(perm_inds)
46 | cur_x_train = x_train[perm_inds]
47 | cur_y_train = y_train[perm_inds]
48 |
49 | avg_cost = 0.
50 | total_batch = int(np.ceil(N * 1.0 / batch_size))
51 | # Loop over all batches
52 | for i in range(total_batch):
53 | start_ind = i*batch_size
54 | end_ind = np.min([(i+1)*batch_size, N])
55 | batch_x = torch.Tensor(cur_x_train[start_ind:end_ind, :]).to(device = device)
56 | batch_y = torch.Tensor(cur_y_train[start_ind:end_ind]).to(device = device)
57 |
58 | ##TODO: check if we need to lock the gradient somewhere
59 | self.optimizer.zero_grad()
60 | cost = self.get_loss(batch_x, batch_y, task_idx)
61 | cost.backward()
62 | self.optimizer.step()
63 |
64 | # Compute average loss
65 | avg_cost += cost / total_batch
66 | # Display logs per epoch step
67 | if epoch % display_epoch == 0:
68 | print("Epoch:", '%04d' % (epoch+1), "cost=", \
69 | "{:.9f}".format(avg_cost))
70 | costs.append(avg_cost)
71 | print("Optimization Finished!")
72 | return costs
73 |
74 | def prediction_prob(self, x_test, task_idx):
75 | prob = F.softmax(self._prediction(x_test, task_idx, self.no_pred_samples), dim=-1)
76 | return prob
77 |
78 |
79 |
80 |
81 |
82 | """ Neural Network Model """
83 | class Vanilla_NN(Cla_NN):
84 | def __init__(self, input_size, hidden_size, output_size, training_size, learning_rate=0.001):
85 | #
86 | super(Vanilla_NN, self).__init__(input_size, hidden_size, output_size, training_size)
87 | # # init weights and biases
88 | self.W, self.b, self.W_last, self.b_last, self.size = self.create_weights(
89 | input_size, hidden_size, output_size)
90 | self.no_layers = len(hidden_size) + 1
91 | self.weights = self.W + self.b + self.W_last + self.b_last
92 | self.training_size = training_size
93 | self.optimizer = optim.Adam(self.weights, lr=learning_rate)
94 |
95 | def _prediction(self, inputs, task_idx):
96 | act = inputs
97 | for i in range(self.no_layers-1):
98 | pre = torch.add(torch.matmul(act, self.W[i]), self.b[i])
99 | act = F.relu(pre)
100 | pre = torch.add(torch.matmul(act, self.W_last[task_idx]), self.b_last[task_idx])
101 | return pre
102 |
103 | def _logpred(self, inputs, targets, task_idx):
104 |
105 | loss = torch.nn.CrossEntropyLoss()
106 | pred = self._prediction(inputs, task_idx)
107 | log_lik = - loss(pred, targets.type(torch.long))
108 | return log_lik
109 |
110 | def prediction_prob(self, x_test, task_idx):
111 | prob = F.softmax(self._prediction(x_test, task_idx), dim=-1)
112 | return prob
113 |
114 | def get_loss(self, batch_x, batch_y, task_idx):
115 | return -self._logpred(batch_x, batch_y, task_idx)
116 |
117 | def create_weights(self, in_dim, hidden_size, out_dim):
118 | hidden_size = deepcopy(hidden_size)
119 | hidden_size.append(out_dim)
120 | hidden_size.insert(0, in_dim)
121 |
122 | no_layers = len(hidden_size) - 1
123 | W = []
124 | b = []
125 | W_last = []
126 | b_last = []
127 | for i in range(no_layers-1):
128 | din = hidden_size[i]
129 | dout = hidden_size[i+1]
130 |
131 | #Initializiation values of means
132 | Wi_m = truncated_normal([din, dout], stddev=0.1, variable = True)
133 | bi_m = truncated_normal([dout], stddev=0.1, variable = True)
134 |
135 | #Append to list weights
136 | W.append(Wi_m)
137 | b.append(bi_m)
138 |
139 | Wi = truncated_normal([hidden_size[-2], out_dim], stddev=0.1, variable = True)
140 | bi = truncated_normal([out_dim], stddev=0.1, variable = True)
141 | W_last.append(Wi)
142 | b_last.append(bi)
143 | return W, b, W_last, b_last, hidden_size
144 |
145 | def get_weights(self):
146 | weights = [self.weights[:self.no_layers-1], self.weights[self.no_layers-1:2*(self.no_layers-1)], [self.weights[-2]], [self.weights[-1]]]
147 | return weights
148 |
149 | """ Bayesian Neural Network with Mean field VI approximation """
150 | class MFVI_NN(Cla_NN):
151 | def __init__(self, input_size, hidden_size, output_size, training_size,
152 | no_train_samples=10, no_pred_samples=100, single_head = False, prev_means=None, learning_rate=0.001):
153 | ##TODO: handle single head
154 | super(MFVI_NN, self).__init__(input_size, hidden_size, output_size, training_size)
155 |
156 | m1, v1, hidden_size = self.create_weights(
157 | input_size, hidden_size, output_size, prev_means)
158 |
159 | self.input_size = input_size
160 | self.out_size = output_size
161 | self.size = hidden_size
162 | self.single_head = single_head
163 |
164 | self.W_m, self.b_m = m1[0], m1[1]
165 | self.W_v, self.b_v = v1[0], v1[1]
166 |
167 | self.W_last_m, self.b_last_m = [], []
168 | self.W_last_v, self.b_last_v = [], []
169 |
170 |
171 | m2, v2 = self.create_prior(input_size, self.size, output_size)
172 |
173 | self.prior_W_m, self.prior_b_m, = m2[0], m2[1]
174 | self.prior_W_v, self.prior_b_v = v2[0], v2[1]
175 |
176 | self.prior_W_last_m, self.prior_b_last_m = [], []
177 | self.prior_W_last_v, self.prior_b_last_v = [], []
178 |
179 | self.W_m_copy, self.W_v_copy, self.b_m_copy, self.b_v_copy = None, None, None, None
180 | self.W_last_m_copy, self.W_last_v_copy, self.b_last_m_copy, self.b_last_v_copy = None, None, None, None
181 | self.prior_W_m_copy, self.prior_W_v_copy, self.prior_b_m_copy, self.prior_b_v_copy = None, None, None, None
182 | self.prior_W_last_m_copy, self.prior_W_last_v_copy, self.prior_b_last_m_copy, self.prior_b_last_v_copy = None, None, None, None
183 |
184 |
185 |
186 | self.no_layers = len(self.size) - 1
187 | self.no_train_samples = no_train_samples
188 | self.no_pred_samples = no_pred_samples
189 | self.training_size = training_size
190 | self.learning_rate = learning_rate
191 |
192 | if prev_means is not None:
193 | self.init_first_head(prev_means)
194 | else:
195 | self.create_head()
196 |
197 |
198 | m1.append(self.W_last_m)
199 | m1.append(self.b_last_m)
200 | v1.append(self.W_last_v)
201 | v1.append(self.b_last_v)
202 |
203 | r1 = m1 + v1
204 | self.weights = [item for sublist in r1 for item in sublist]
205 |
206 |
207 | self.optimizer = optim.Adam(self.weights, lr=learning_rate)
208 |
209 | def get_loss(self, batch_x, batch_y, task_idx):
210 | return torch.div(self._KL_term(), self.training_size) - self._logpred(batch_x, batch_y, task_idx)
211 |
212 | def _prediction(self, inputs, task_idx, no_samples):
213 | K = no_samples
214 | size = self.size
215 |
216 | act = torch.unsqueeze(inputs, 0).repeat([K, 1, 1])
217 | for i in range(self.no_layers-1):
218 | din = self.size[i]
219 | dout = self.size[i+1]
220 | eps_w = torch.normal(torch.zeros((K, din, dout)), torch.ones((K, din, dout))).to(device = device)
221 | eps_b = torch.normal(torch.zeros((K, 1, dout)), torch.ones((K, 1, dout))).to(device = device)
222 | weights = torch.add(eps_w * torch.exp(0.5*self.W_v[i]), self.W_m[i])
223 | biases = torch.add(eps_b * torch.exp(0.5*self.b_v[i]), self.b_m[i])
224 | pre = torch.add(torch.einsum('mni,mio->mno', act, weights), biases)
225 | act = F.relu(pre)
226 |
227 | din = self.size[-2]
228 | dout = self.size[-1]
229 |
230 | eps_w = torch.normal(torch.zeros((K, din, dout)), torch.ones((K, din, dout))).to(device = device)
231 | eps_b = torch.normal(torch.zeros((K, 1, dout)), torch.ones((K, 1, dout))).to(device = device)
232 | Wtask_m = self.W_last_m[task_idx]
233 | Wtask_v = self.W_last_v[task_idx]
234 | btask_m = self.b_last_m[task_idx]
235 | btask_v = self.b_last_v[task_idx]
236 |
237 | weights = torch.add(eps_w * torch.exp(0.5*Wtask_v),Wtask_m)
238 | biases = torch.add(eps_b * torch.exp(0.5*btask_v), btask_m)
239 | act = torch.unsqueeze(act, 3)
240 | weights = torch.unsqueeze(weights, 1)
241 | pre = torch.add(torch.sum(act * weights, dim = 2), biases)
242 | return pre
243 |
244 | def _logpred(self, inputs, targets, task_idx):
245 | loss = torch.nn.CrossEntropyLoss()
246 | pred = self._prediction(inputs, task_idx, self.no_train_samples).view(-1,self.out_size)
247 | targets = targets.repeat([self.no_train_samples, 1]).view(-1)
248 | log_liks = -loss(pred, targets.type(torch.long))
249 | log_lik = log_liks.mean()
250 | return log_lik
251 |
252 |
253 | def _KL_term(self):
254 | kl = 0
255 | for i in range(self.no_layers-1):
256 | din = self.size[i]
257 | dout = self.size[i+1]
258 | m, v = self.W_m[i], self.W_v[i]
259 | m0, v0 = self.prior_W_m[i], self.prior_W_v[i]
260 |
261 | const_term = -0.5 * dout * din
262 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v)
263 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0)
264 | kl += const_term + log_std_diff + mu_diff_term
265 |
266 | m, v = self.b_m[i], self.b_v[i]
267 | m0, v0 = self.prior_b_m[i], self.prior_b_v[i]
268 |
269 | const_term = -0.5 * dout
270 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v)
271 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0)
272 | kl += log_std_diff + mu_diff_term + const_term
273 |
274 | no_tasks = len(self.W_last_m)
275 | din = self.size[-2]
276 | dout = self.size[-1]
277 |
278 | for i in range(no_tasks):
279 | m, v = self.W_last_m[i], self.W_last_v[i]
280 | m0, v0 = self.prior_W_last_m[i], self.prior_W_last_v[i]
281 |
282 | const_term = - 0.5 * dout * din
283 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v)
284 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0)
285 | kl += const_term + log_std_diff + mu_diff_term
286 |
287 | m, v = self.b_last_m[i], self.b_last_v[i]
288 | m0, v0 = self.prior_b_last_m[i], self.prior_b_last_v[i]
289 |
290 | const_term = -0.5 * dout
291 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v)
292 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0)
293 | kl += const_term + log_std_diff + mu_diff_term
294 | return kl
295 |
296 | def save_weights(self):
297 | ''' Save weights before training on the coreset before getting the test accuracy '''
298 |
299 | print("Saving weights before core set training")
300 | self.W_m_copy = [self.W_m[i].clone().detach().data for i in range(len(self.W_m))]
301 | self.W_v_copy = [self.W_v[i].clone().detach().data for i in range(len(self.W_v))]
302 | self.b_m_copy = [self.b_m[i].clone().detach().data for i in range(len(self.b_m))]
303 | self.b_v_copy = [self.b_v[i].clone().detach().data for i in range(len(self.b_v))]
304 |
305 | self.W_last_m_copy = [self.W_last_m[i].clone().detach().data for i in range(len(self.W_last_m))]
306 | self.W_last_v_copy = [self.W_last_v[i].clone().detach().data for i in range(len(self.W_last_v))]
307 | self.b_last_m_copy = [self.b_last_m[i].clone().detach().data for i in range(len(self.b_last_m))]
308 | self.b_last_v_copy = [self.b_last_v[i].clone().detach().data for i in range(len(self.b_last_v))]
309 |
310 | self.prior_W_m_copy = [self.prior_W_m[i].data for i in range(len(self.prior_W_m))]
311 | self.prior_W_v_copy = [self.prior_W_v[i].data for i in range(len(self.prior_W_v))]
312 | self.prior_b_m_copy = [self.prior_b_m[i].data for i in range(len(self.prior_b_m))]
313 | self.prior_b_v_copy = [self.prior_b_v[i].data for i in range(len(self.prior_b_v))]
314 |
315 | self.prior_W_last_m_copy = [self.prior_W_last_m[i].data for i in range(len(self.prior_W_last_m))]
316 | self.prior_W_last_v_copy = [self.prior_W_last_v[i].data for i in range(len(self.prior_W_last_v))]
317 | self.prior_b_last_m_copy = [self.prior_b_last_m[i].data for i in range(len(self.prior_b_last_m))]
318 | self.prior_b_last_v_copy = [self.prior_b_last_v[i].data for i in range(len(self.prior_b_last_v))]
319 |
320 | return
321 |
322 | def load_weights(self):
323 | ''' Re-load weights after getting the test accuracy '''
324 |
325 | print("Reloading previous weights after core set training")
326 | self.weights = []
327 | self.W_m = [self.W_m_copy[i].clone().detach().data for i in range(len(self.W_m))]
328 | self.W_v = [self.W_v_copy[i].clone().detach().data for i in range(len(self.W_v))]
329 | self.b_m = [self.b_m_copy[i].clone().detach().data for i in range(len(self.b_m))]
330 | self.b_v = [self.b_v_copy[i].clone().detach().data for i in range(len(self.b_v))]
331 |
332 | for i in range(len(self.W_m)):
333 | self.W_m[i].requires_grad = True
334 | self.W_v[i].requires_grad = True
335 | self.b_m[i].requires_grad = True
336 | self.b_v[i].requires_grad = True
337 |
338 | self.weights += self.W_m
339 | self.weights += self.W_v
340 | self.weights += self.b_m
341 | self.weights += self.b_v
342 |
343 |
344 | self.W_last_m = [self.W_last_m_copy[i].clone().detach().data for i in range(len(self.W_last_m))]
345 | self.W_last_v = [self.W_last_v_copy[i].clone().detach().data for i in range(len(self.W_last_v))]
346 | self.b_last_m = [self.b_last_m_copy[i].clone().detach().data for i in range(len(self.b_last_m))]
347 | self.b_last_v = [self.b_last_v_copy[i].clone().detach().data for i in range(len(self.b_last_v))]
348 |
349 | for i in range(len(self.W_last_m)):
350 | self.W_last_m[i].requires_grad = True
351 | self.W_last_v[i].requires_grad = True
352 | self.b_last_m[i].requires_grad = True
353 | self.b_last_v[i].requires_grad = True
354 |
355 | self.weights += self.W_last_m
356 | self.weights += self.W_last_v
357 | self.weights += self.b_last_m
358 | self.weights += self.b_last_v
359 |
360 | self.optimizer = optim.Adam(self.weights, lr=self.learning_rate)
361 | self.prior_W_m = [self.prior_W_m_copy[i].data for i in range(len(self.prior_W_m))]
362 | self.prior_W_v = [self.prior_W_v_copy[i].data for i in range(len(self.prior_W_v))]
363 | self.prior_b_m = [self.prior_b_m_copy[i].data for i in range(len(self.prior_b_m))]
364 | self.prior_b_v = [self.prior_b_v_copy[i].data for i in range(len(self.prior_b_v))]
365 |
366 | self.prior_W_last_m = [self.prior_W_last_m_copy[i].data for i in range(len(self.prior_W_last_m))]
367 | self.prior_W_last_v = [self.prior_W_last_v_copy[i].data for i in range(len(self.prior_W_last_v))]
368 | self.prior_b_last_m = [self.prior_b_last_m_copy[i].data for i in range(len(self.prior_b_last_m))]
369 | self.prior_b_last_v = [self.prior_b_last_v_copy[i].data for i in range(len(self.prior_b_last_v))]
370 |
371 | return
372 |
373 | def clean_copy_weights(self):
374 | self.W_m_copy, self.W_v_copy, self.b_m_copy, self.b_v_copy = None, None, None, None
375 | self.W_last_m_copy, self.W_last_v_copy, self.b_last_m_copy, self.b_last_v_copy = None, None, None, None
376 | self.prior_W_m_copy, self.prior_W_v_copy, self.prior_b_m_copy, self.prior_b_v_copy = None, None, None, None
377 | self.prior_W_last_m_copy, self.prior_W_last_v_copy, self.prior_b_last_m_copy, self.prior_b_last_v_copy = None, None, None, None
378 |
379 | def create_head(self):
380 | ''''Create new head when a new task is detected'''
381 | print("creating a new head")
382 | din = self.size[-2]
383 | dout = self.size[-1]
384 |
385 | W_m= truncated_normal([din, dout], stddev=0.1, variable=True)
386 | b_m= truncated_normal([dout], stddev=0.1, variable=True)
387 | W_v = init_tensor(-6.0, dout = dout, din = din, variable= True)
388 | b_v = init_tensor(-6.0, dout = dout, variable= True)
389 |
390 | self.W_last_m.append(W_m)
391 | self.W_last_v.append(W_v)
392 | self.b_last_m.append(b_m)
393 | self.b_last_v.append(b_v)
394 |
395 |
396 | W_m_p = torch.zeros([din, dout]).to(device = device)
397 | b_m_p = torch.zeros([dout]).to(device = device)
398 | W_v_p = init_tensor(1, dout = dout, din = din)
399 | b_v_p = init_tensor(1, dout = dout)
400 |
401 | self.prior_W_last_m.append(W_m_p)
402 | self.prior_W_last_v.append(W_v_p)
403 | self.prior_b_last_m.append(b_m_p)
404 | self.prior_b_last_v.append(b_v_p)
405 | self.weights = []
406 | self.weights += self.W_m
407 | self.weights += self.W_v
408 | self.weights += self.b_m
409 | self.weights += self.b_v
410 | self.weights += self.W_last_m
411 | self.weights += self.W_last_v
412 | self.weights += self.b_last_m
413 | self.weights += self.b_last_v
414 | self.optimizer = optim.Adam(self.weights, lr=self.learning_rate)
415 |
416 | return
417 |
418 |
419 | def init_first_head(self, prev_means):
420 | ''''When the MFVI_NN is instanciated, we initialize weights with those of the Vanilla NN'''
421 | print("initializing first head")
422 | din = self.size[-2]
423 | dout = self.size[-1]
424 | self.prior_W_last_m = [torch.zeros([din, dout]).to(device = device)]
425 | self.prior_b_last_m = [torch.zeros([dout]).to(device = device)]
426 | self.prior_W_last_v = [init_tensor(1, dout = dout, din = din)]
427 | self.prior_b_last_v = [init_tensor(1, dout = dout)]
428 |
429 | W_last_m = prev_means[2][0].detach().data
430 | W_last_m.requires_grad = True
431 | self.W_last_m = [W_last_m]
432 | self.W_last_v = [init_tensor(-6.0, dout = dout, din = din, variable= True)]
433 |
434 |
435 | b_last_m = prev_means[3][0].detach().data
436 | b_last_m.requires_grad = True
437 | self.b_last_m = [b_last_m]
438 | self.b_last_v = [init_tensor(-6.0, dout = dout, variable= True)]
439 |
440 | return
441 |
442 | def create_weights(self, in_dim, hidden_size, out_dim, prev_means):
443 | hidden_size = deepcopy(hidden_size)
444 | hidden_size.append(out_dim)
445 | hidden_size.insert(0, in_dim)
446 |
447 | no_layers = len(hidden_size) - 1
448 | W_m = []
449 | b_m = []
450 | W_v = []
451 | b_v = []
452 |
453 | for i in range(no_layers-1):
454 | din = hidden_size[i]
455 | dout = hidden_size[i+1]
456 | if prev_means is not None:
457 | W_m_i = prev_means[0][i].detach().data
458 | W_m_i.requires_grad = True
459 | bi_m_i = prev_means[1][i].detach().data
460 | bi_m_i.requires_grad = True
461 | else:
462 | #Initializiation values of means
463 | W_m_i= truncated_normal([din, dout], stddev=0.1, variable=True)
464 | bi_m_i= truncated_normal([dout], stddev=0.1, variable=True)
465 | #Initializiation values of variances
466 | W_v_i = init_tensor(-6.0, dout = dout, din = din, variable = True)
467 | bi_v_i = init_tensor(-6.0, dout = dout, variable = True)
468 |
469 | #Append to list weights
470 | W_m.append(W_m_i)
471 | b_m.append(bi_m_i)
472 | W_v.append(W_v_i)
473 | b_v.append(bi_v_i)
474 |
475 | return [W_m, b_m], [W_v, b_v], hidden_size
476 |
477 | def create_prior(self, in_dim, hidden_size, out_dim, initial_mean = 0, initial_variance = 1):
478 |
479 | no_layers = len(hidden_size) - 1
480 | W_m = []
481 | b_m = []
482 |
483 | W_v = []
484 | b_v = []
485 |
486 | for i in range(no_layers - 1):
487 | din = hidden_size[i]
488 | dout = hidden_size[i + 1]
489 |
490 | # Initializiation values of means
491 | W_m_val = initial_mean * torch.zeros([din, dout]).to(device = device)
492 | bi_m_val = initial_mean * torch.zeros([dout]).to(device = device)
493 |
494 | # Initializiation values of variances
495 | W_v_val = initial_variance * init_tensor(1, dout = dout, din = din )
496 | bi_v_val = initial_variance * init_tensor(1, dout = dout)
497 |
498 | # Append to list weights
499 | W_m.append(W_m_val)
500 | b_m.append(bi_m_val)
501 | W_v.append(W_v_val)
502 | b_v.append(bi_v_val)
503 |
504 | return [W_m, b_m], [W_v, b_v]
505 |
506 |
507 |
508 |
509 | def update_prior(self):
510 | print("updating prior...")
511 | for i in range(len(self.W_m)):
512 | self.prior_W_m[i].data.copy_(self.W_m[i].clone().detach().data)
513 | self.prior_b_m[i].data.copy_(self.b_m[i].clone().detach().data)
514 | self.prior_W_v[i].data.copy_(torch.exp(self.W_v[i].clone().detach().data))
515 | self.prior_b_v[i].data.copy_(torch.exp(self.b_v[i].clone().detach().data))
516 |
517 | length = len(self.W_last_m)
518 |
519 | for i in range(length):
520 | self.prior_W_last_m[i].data.copy_(self.W_last_m[i].clone().detach().data)
521 | self.prior_b_last_m[i].data.copy_(self.b_last_m[i].clone().detach().data)
522 | self.prior_W_last_v[i].data.copy_(torch.exp(self.W_last_v[i].clone().detach().data))
523 | self.prior_b_last_v[i].data.copy_(torch.exp(self.b_last_v[i].clone().detach().data))
524 |
525 | return
--------------------------------------------------------------------------------
/discriminative/utils/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('agg')
4 | import torch
5 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | def merge_coresets(x_coresets, y_coresets):
9 | merged_x, merged_y = x_coresets[0], y_coresets[0]
10 | for i in range(1, len(x_coresets)):
11 | merged_x = np.vstack((merged_x, x_coresets[i]))
12 | merged_y = np.hstack((merged_y, y_coresets[i]))
13 | return merged_x, merged_y
14 |
15 |
16 | def get_coreset(x_coresets, y_coresets, single_head, coreset_size = 5000, gans = None, task_id=0):
17 | if gans is not None:
18 | if single_head:
19 | merged_x, merged_y = gans[0].generate_samples(coreset_size, task_id)
20 | for i in range(1, len(gans)):
21 | new_x, new_y = gans[i].generate_samples(coreset_size, task_id)
22 | merged_x = np.vstack((merged_x,new_x))
23 | merged_y = np.hstack((merged_y,new_y))
24 | return merged_x, merged_y
25 | else:
26 | return gans.generate_samples(coreset_size, task_id)[:coreset_size]
27 | else:
28 | if single_head:
29 | return merge_coresets(x_coresets, y_coresets)
30 | else:
31 | return x_coresets, y_coresets
32 |
33 |
34 | def get_scores(model, x_testsets, y_testsets, no_epochs, single_head, x_coresets, y_coresets, batch_size=None, just_vanilla = False, gans = None):
35 |
36 | acc = []
37 | if single_head:
38 | if len(x_coresets) > 0 or gans is not None:
39 | x_train, y_train = get_coreset(x_coresets, y_coresets, single_head, coreset_size = 6000, gans = gans, task_id=0)
40 |
41 | bsize = x_train.shape[0] if (batch_size is None) else batch_size
42 | x_train = torch.Tensor(x_train)
43 | y_train = torch.Tensor(y_train)
44 | model.train(x_train, y_train, 0, no_epochs, bsize)
45 |
46 | for i in range(len(x_testsets)):
47 | if not single_head:
48 | if len(x_coresets)>0 or gans is not None:
49 | model.load_weights()
50 | gan_i = None
51 | if gans is not None:
52 | gan_i = gans[i]
53 | x_train, y_train = get_coreset(None, None, single_head, coreset_size = 6000, gans= gan_i, task_id=i)
54 | else:
55 | x_train, y_train = get_coreset(x_coresets[i], y_coresets[i], single_head, coreset_size = 6000, gans= None, task_id=i)
56 | bsize = x_train.shape[0] if (batch_size is None) else batch_size
57 | x_train = torch.Tensor(x_train)
58 | y_train = torch.Tensor(y_train)
59 | model.train(x_train, y_train, i, no_epochs, bsize)
60 |
61 | head = 0 if single_head else i
62 | x_test, y_test = x_testsets[i], y_testsets[i]
63 | N = x_test.shape[0]
64 | bsize = N if (batch_size is None) else batch_size
65 | cur_acc = 0
66 | total_batch = int(np.ceil(N * 1.0 / bsize))
67 | # Loop over all batches
68 | for i in range(total_batch):
69 | start_ind = i*bsize
70 | end_ind = np.min([(i+1)*bsize, N])
71 | batch_x_test = torch.Tensor(x_test[start_ind:end_ind, :]).to(device = device)
72 | batch_y_test = torch.Tensor(y_test[start_ind:end_ind]).type(torch.LongTensor).to(device = device)
73 | pred = model.prediction_prob(batch_x_test, head)
74 | if not just_vanilla:
75 | pred_mean = pred.mean(0)
76 | else:
77 | pred_mean = pred
78 | pred_y = torch.argmax(pred_mean, dim=1)
79 | cur_acc += end_ind - start_ind-(pred_y - batch_y_test).nonzero().shape[0]
80 |
81 | cur_acc = float(cur_acc)
82 | cur_acc /= N
83 | acc.append(cur_acc)
84 | print("Accuracy is {}".format(cur_acc))
85 | return acc
86 |
87 | def concatenate_results(score, all_score):
88 | if all_score.size == 0:
89 | all_score = np.reshape(score, (1,-1))
90 | else:
91 | new_arr = np.empty((all_score.shape[0], all_score.shape[1]+1))
92 | new_arr[:] = np.nan
93 | new_arr[:,:-1] = all_score
94 | all_score = np.vstack((new_arr, score))
95 | return all_score
--------------------------------------------------------------------------------
/discriminative/utils/vcl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import discriminative.utils.test as test
3 | from discriminative.utils.multihead_models import Vanilla_NN, MFVI_NN
4 | import torch
5 | import discriminative.utils.GAN as GAN
6 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
7 | try:
8 | from torchviz import make_dot, make_dot_from_trace
9 | except ImportError:
10 | print("Torchviz was not found.")
11 |
12 | def run_vcl(hidden_size, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True, gan_bol = False):
13 | in_dim, out_dim = data_gen.get_dims()
14 | x_coresets, y_coresets = [], []
15 | x_testsets, y_testsets = [], []
16 | gans = []
17 | all_acc = np.array([])
18 |
19 | for task_id in range(data_gen.max_iter):
20 | x_train, y_train, x_test, y_test = data_gen.next_task()
21 | x_testsets.append(x_test)
22 | y_testsets.append(y_test)
23 |
24 | # Set the readout head to train
25 | head = 0 if single_head else task_id
26 | bsize = x_train.shape[0] if (batch_size is None) else batch_size
27 |
28 | # Train network with maximum likelihood to initialize first model
29 | if task_id == 0:
30 | print_graph_bol = False #set to True if you want to see the graph
31 | ml_model = Vanilla_NN(in_dim, hidden_size, out_dim, x_train.shape[0])
32 | ml_model.train(x_train, y_train, task_id, no_epochs, bsize)
33 | mf_weights = ml_model.get_weights()
34 | mf_model = MFVI_NN(in_dim, hidden_size, out_dim, x_train.shape[0], single_head = single_head, prev_means=mf_weights)
35 |
36 | if not gan_bol:
37 | if coreset_size > 0:
38 | x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size)
39 | gans = None
40 | if print_graph_bol:
41 | #Just if you want to see the computational graph
42 | output_tensor = mf_model._KL_term() #mf_model.get_loss(torch.Tensor(x_train).to(device), torch.Tensor(y_train).to(device), task_id), params=params)
43 | print_graph(mf_model, output_tensor)
44 | print_graph_bol = False
45 |
46 | if gan_bol:
47 | gan_i = GAN.VGR(task_id)
48 | gan_i.train(x_train, y_train)
49 | gans.append(gan_i)
50 | mf_model.train(x_train, y_train, head, no_epochs, bsize)
51 |
52 | mf_model.update_prior()
53 | # Save weights before test (and last-minute training on coreset
54 | mf_model.save_weights()
55 |
56 | acc = test.get_scores(mf_model, x_testsets, y_testsets, no_epochs, single_head, x_coresets, y_coresets, batch_size, False,gans)
57 | all_acc = test.concatenate_results(acc, all_acc)
58 |
59 | mf_model.load_weights()
60 | mf_model.clean_copy_weights()
61 |
62 |
63 | if not single_head:
64 | mf_model.create_head()
65 |
66 | return all_acc
67 |
68 | def run_coreset_only(hidden_size, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True):
69 | in_dim, out_dim = data_gen.get_dims()
70 | x_coresets, y_coresets = [], []
71 | x_testsets, y_testsets = [], []
72 | all_acc = np.array([])
73 |
74 | for task_id in range(data_gen.max_iter):
75 | x_train, y_train, x_test, y_test = data_gen.next_task()
76 | x_testsets.append(x_test)
77 | y_testsets.append(y_test)
78 |
79 | head = 0 if single_head else task_id
80 | bsize = x_train.shape[0] if (batch_size is None) else batch_size
81 |
82 | if task_id == 0:
83 | mf_model = MFVI_NN(in_dim, hidden_size, out_dim, x_train.shape[0], single_head = single_head, prev_means=None)
84 |
85 | if coreset_size > 0:
86 | x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size)
87 |
88 |
89 | mf_model.save_weights()
90 |
91 | acc = test.get_scores(mf_model, x_testsets, y_testsets, no_epochs, single_head, x_coresets, y_coresets, batch_size, just_vanilla =False)
92 |
93 | all_acc = test.concatenate_results(acc, all_acc)
94 |
95 | mf_model.load_weights()
96 | mf_model.clean_copy_weights()
97 |
98 | if not single_head:
99 | mf_model.create_head()
100 |
101 | return all_acc
102 |
103 | def print_graph(model, output):
104 | params = dict()
105 | for i in range(len(model.W_m)):
106 | params["W_m{}".format(i)] = model.W_m[i]
107 | params["W_v{}".format(i)] = model.W_v[i]
108 | params["b_m{}".format(i)] = model.b_m[i]
109 | params["b_v{}".format(i)] = model.b_v[i]
110 | params["prior_W_m".format(i)] = model.prior_W_m[i]
111 | params["prior_W_v".format(i)] = model.prior_W_v[i]
112 | params["prior_b_m".format(i)] = model.prior_b_m[i]
113 | params["prior_b_v".format(i)] = model.prior_b_v[i]
114 |
115 | for i in range(len(model.W_last_m)):
116 | params["W_last_m".format(i)] = model.W_last_m[i]
117 | params["W_last_v".format(i)] = model.W_last_v[i]
118 | params["b_last_m".format(i)] = model.b_last_m[i]
119 | params["b_last_v".format(i)] = model.b_last_v[i]
120 | params["prior_W_last_m".format(i)] = model.prior_W_last_m[i]
121 | params["prior_W_last_v".format(i)] = model.prior_W_last_v[i]
122 | params["prior_b_last_m".format(i)] = model.prior_b_last_m[i]
123 | params["prior_b_last_v".format(i)] = model.prior_b_last_v[i]
124 | dot = make_dot(output, params=params)
125 | dot.view()
126 |
127 | return
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 | tensorflow==1.4.0
4 | matplotlib==1.5.3
5 |
6 |
--------------------------------------------------------------------------------