├── README.md ├── Results.ipynb ├── __init__.py ├── discriminative ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── coreset.cpython-36.pyc │ ├── multihead_models.cpython-36.pyc │ ├── utils.cpython-36.pyc │ └── vcl.cpython-36.pyc ├── data │ └── mnist.pkl.gz ├── final_results │ ├── VCL-split.npy │ ├── VCL.npy │ ├── VGR-split.npy │ ├── kcen-VCL-200.npy │ ├── kcen-coreset-only-split.npy │ ├── kcen-coreset-only200.npy │ ├── kcenVCL-split.npy │ ├── only-coreset-1000.npy │ ├── only-coreset-200.npy │ ├── only-coreset-2500.npy │ ├── only-coreset-400.npy │ ├── only-coreset-5000.npy │ ├── pca-kcen-coreset-only-split.npy │ ├── pca-kcenVCL-split.npy │ ├── rand-VCL-1000.npy │ ├── rand-VCL-200.npy │ ├── rand-VCL-2500.npy │ ├── rand-VCL-400.npy │ ├── rand-VCL-5000.npy │ ├── rand-coreset-only-split.npy │ └── randVCL-split.npy ├── misc │ ├── KL.pdf │ ├── Loss.pdf │ ├── VGR.png │ ├── gan_pics.png │ ├── images │ │ ├── 0.png │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ └── 4.png │ ├── permuted_mnist_coreset_sizes.png │ ├── permuted_mnist_main.png │ ├── split_mnist_main_part1.png │ ├── split_mnist_main_part2.png │ ├── split_mnist_pca_part1.png │ └── split_mnist_pca_part2.png ├── results │ ├── VCL-split.npy │ ├── VCL.npy │ ├── VGR-split.npy │ ├── kcen-VCL-200.npy │ ├── kcen-coreset-only-split.npy │ ├── kcen-coreset-only200.npy │ ├── kcenVCL-split.npy │ ├── only-coreset-1000.npy │ ├── only-coreset-200.npy │ ├── only-coreset-2500.npy │ ├── only-coreset-400.npy │ ├── only-coreset-5000.npy │ ├── pca-kcen-coreset-only-split.npy │ ├── pca-kcenVCL-split.npy │ ├── rand-VCL-1000.npy │ ├── rand-VCL-200.npy │ ├── rand-VCL-2500.npy │ ├── rand-VCL-400.npy │ ├── rand-VCL-5000.npy │ ├── rand-coreset-only-split.npy │ └── randVCL-split.npy ├── run_permuted.py ├── run_permuted_just_coreset.py ├── run_split.py ├── run_split_just_coreset.py └── utils │ ├── DataGenerator.py │ ├── GAN.py │ ├── __init__.py │ ├── __pycache__ │ ├── DataGenerator.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── coreset.cpython-36.pyc │ ├── multihead_models.cpython-36.pyc │ ├── test.cpython-36.pyc │ └── vcl.cpython-36.pyc │ ├── coreset.py │ ├── multihead_models.py │ ├── test.py │ └── vcl.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Variational Continual Learning (VCL) 2 | An implementation of the Variational Continual Learning (VCL) algorithms proposed by Nguyen, Li, Bui, and Turner (ICLR 2018). 3 | 4 | ``` 5 | @inproceedings{nguyen2018variational, 6 | title = {Variational Continual Learning}, 7 | author = {Nguyen, Cuong V. and Li, Yingzhen and Bui, Thang D. and Turner, Richard E.}, 8 | booktitle = {International Conference on Learning Representations}, 9 | year = {2018} 10 | } 11 | ``` 12 | **To run the Permuted MNIST experiment:** 13 | 14 | python run_permuted.py 15 | 16 | **To run the Split MNIST experiment:** 17 | 18 | python run_split.py 19 | 20 | **Requirements:** 21 | 25 | 26 | ## Results 27 | ### VCL in Deep discriminative models 28 | 29 | 30 | Permuted MNIST 31 | ![](/discriminative/misc/permuted_mnist_main.png) 32 | ![](/discriminative/misc/permuted_mnist_coreset_sizes.png) 33 | 34 | 35 | Split MNIST 36 | ![](/discriminative/misc/split_mnist_main_part1.png) 37 | ![](/discriminative/misc/split_mnist_main_part2.png) 38 | ... with Variational Generative Replay (VGR):
39 | ![](/discriminative/misc/VGR.png) 40 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/__init__.py -------------------------------------------------------------------------------- /discriminative/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__init__.py -------------------------------------------------------------------------------- /discriminative/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/__pycache__/coreset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/coreset.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/__pycache__/multihead_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/multihead_models.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/__pycache__/vcl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/__pycache__/vcl.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/data/mnist.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/data/mnist.pkl.gz -------------------------------------------------------------------------------- /discriminative/final_results/VCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/VCL-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/VCL.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/VCL.npy -------------------------------------------------------------------------------- /discriminative/final_results/VGR-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/VGR-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/kcen-VCL-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcen-VCL-200.npy -------------------------------------------------------------------------------- /discriminative/final_results/kcen-coreset-only-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcen-coreset-only-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/kcen-coreset-only200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcen-coreset-only200.npy -------------------------------------------------------------------------------- /discriminative/final_results/kcenVCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/kcenVCL-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/only-coreset-1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-1000.npy -------------------------------------------------------------------------------- /discriminative/final_results/only-coreset-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-200.npy -------------------------------------------------------------------------------- /discriminative/final_results/only-coreset-2500.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-2500.npy -------------------------------------------------------------------------------- /discriminative/final_results/only-coreset-400.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-400.npy -------------------------------------------------------------------------------- /discriminative/final_results/only-coreset-5000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/only-coreset-5000.npy -------------------------------------------------------------------------------- /discriminative/final_results/pca-kcen-coreset-only-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/pca-kcen-coreset-only-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/pca-kcenVCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/pca-kcenVCL-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/rand-VCL-1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-1000.npy -------------------------------------------------------------------------------- /discriminative/final_results/rand-VCL-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-200.npy -------------------------------------------------------------------------------- /discriminative/final_results/rand-VCL-2500.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-2500.npy -------------------------------------------------------------------------------- /discriminative/final_results/rand-VCL-400.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-400.npy -------------------------------------------------------------------------------- /discriminative/final_results/rand-VCL-5000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-VCL-5000.npy -------------------------------------------------------------------------------- /discriminative/final_results/rand-coreset-only-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/rand-coreset-only-split.npy -------------------------------------------------------------------------------- /discriminative/final_results/randVCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/final_results/randVCL-split.npy -------------------------------------------------------------------------------- /discriminative/misc/KL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/KL.pdf -------------------------------------------------------------------------------- /discriminative/misc/Loss.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/Loss.pdf -------------------------------------------------------------------------------- /discriminative/misc/VGR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/VGR.png -------------------------------------------------------------------------------- /discriminative/misc/gan_pics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/gan_pics.png -------------------------------------------------------------------------------- /discriminative/misc/images/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/0.png -------------------------------------------------------------------------------- /discriminative/misc/images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/1.png -------------------------------------------------------------------------------- /discriminative/misc/images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/2.png -------------------------------------------------------------------------------- /discriminative/misc/images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/3.png -------------------------------------------------------------------------------- /discriminative/misc/images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/images/4.png -------------------------------------------------------------------------------- /discriminative/misc/permuted_mnist_coreset_sizes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/permuted_mnist_coreset_sizes.png -------------------------------------------------------------------------------- /discriminative/misc/permuted_mnist_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/permuted_mnist_main.png -------------------------------------------------------------------------------- /discriminative/misc/split_mnist_main_part1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_main_part1.png -------------------------------------------------------------------------------- /discriminative/misc/split_mnist_main_part2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_main_part2.png -------------------------------------------------------------------------------- /discriminative/misc/split_mnist_pca_part1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_pca_part1.png -------------------------------------------------------------------------------- /discriminative/misc/split_mnist_pca_part2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/misc/split_mnist_pca_part2.png -------------------------------------------------------------------------------- /discriminative/results/VCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/VCL-split.npy -------------------------------------------------------------------------------- /discriminative/results/VCL.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/VCL.npy -------------------------------------------------------------------------------- /discriminative/results/VGR-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/VGR-split.npy -------------------------------------------------------------------------------- /discriminative/results/kcen-VCL-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcen-VCL-200.npy -------------------------------------------------------------------------------- /discriminative/results/kcen-coreset-only-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcen-coreset-only-split.npy -------------------------------------------------------------------------------- /discriminative/results/kcen-coreset-only200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcen-coreset-only200.npy -------------------------------------------------------------------------------- /discriminative/results/kcenVCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/kcenVCL-split.npy -------------------------------------------------------------------------------- /discriminative/results/only-coreset-1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-1000.npy -------------------------------------------------------------------------------- /discriminative/results/only-coreset-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-200.npy -------------------------------------------------------------------------------- /discriminative/results/only-coreset-2500.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-2500.npy -------------------------------------------------------------------------------- /discriminative/results/only-coreset-400.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-400.npy -------------------------------------------------------------------------------- /discriminative/results/only-coreset-5000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/only-coreset-5000.npy -------------------------------------------------------------------------------- /discriminative/results/pca-kcen-coreset-only-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/pca-kcen-coreset-only-split.npy -------------------------------------------------------------------------------- /discriminative/results/pca-kcenVCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/pca-kcenVCL-split.npy -------------------------------------------------------------------------------- /discriminative/results/rand-VCL-1000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-1000.npy -------------------------------------------------------------------------------- /discriminative/results/rand-VCL-200.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-200.npy -------------------------------------------------------------------------------- /discriminative/results/rand-VCL-2500.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-2500.npy -------------------------------------------------------------------------------- /discriminative/results/rand-VCL-400.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-400.npy -------------------------------------------------------------------------------- /discriminative/results/rand-VCL-5000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-VCL-5000.npy -------------------------------------------------------------------------------- /discriminative/results/rand-coreset-only-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/rand-coreset-only-split.npy -------------------------------------------------------------------------------- /discriminative/results/randVCL-split.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/results/randVCL-split.npy -------------------------------------------------------------------------------- /discriminative/run_permuted.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import discriminative.utils.vcl as vcl 3 | import discriminative.utils.coreset as coreset 4 | from discriminative.utils.DataGenerator import PermutedMnistGenerator 5 | 6 | hidden_size = [100, 100] 7 | batch_size = 256 8 | no_epochs = 100 9 | single_head = True 10 | num_tasks = 10 11 | 12 | np.random.seed(1) 13 | #Just VCL 14 | coreset_size = 0 15 | data_gen = PermutedMnistGenerator(num_tasks) 16 | vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 17 | coreset.rand_from_batch, coreset_size, batch_size, single_head) 18 | np.save("./results/VCL{}".format(""), vcl_result) 19 | print(vcl_result) 20 | 21 | #VCL + Random Coreset 22 | np.random.seed(1) 23 | 24 | for coreset_size in [200,400,1000,2500,5000]: 25 | data_gen = PermutedMnistGenerator(num_tasks) 26 | rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 27 | coreset.rand_from_batch, coreset_size, batch_size, single_head, gan_bol=True) 28 | np.save("./results/rand-VCL-{}".format(coreset_size), rand_vcl_result) 29 | print(rand_vcl_result) 30 | 31 | #VCL + k-center coreset 32 | np.random.seed(1) 33 | coreset_size = 200 34 | data_gen = PermutedMnistGenerator(num_tasks) 35 | kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 36 | coreset.k_center, coreset_size, batch_size, single_head) 37 | print(kcen_vcl_result) 38 | np.save("./results/kcen-VCL{}".format(coreset_size), kcen_vcl_result) 39 | 40 | -------------------------------------------------------------------------------- /discriminative/run_permuted_just_coreset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import discriminative.utils.vcl as vcl 3 | import discriminative.utils.coreset as coreset 4 | from discriminative.utils.DataGenerator import PermutedMnistGenerator 5 | 6 | hidden_size = [100, 100] 7 | batch_size = 256 8 | no_epochs = 100 9 | single_head = True 10 | num_tasks = 10 11 | 12 | np.random.seed(0) 13 | #for coreset_size in [400,1000,2500,5000]: 14 | # data_gen = PermutedMnistGenerator(num_tasks) 15 | # vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen, 16 | # coreset.rand_from_batch, coreset_size, batch_size, single_head) 17 | # np.save("./results/only-coreset-{}".format(coreset_size), vcl_result) 18 | # print(vcl_result) 19 | 20 | np.random.seed(0) 21 | coreset_size = 200 22 | data_gen = PermutedMnistGenerator(num_tasks) 23 | kcen_vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen, 24 | coreset.k_center, coreset_size, batch_size, single_head) 25 | print(kcen_vcl_result) 26 | np.save("./results/kcen-coreset-only{}".format(coreset_size), kcen_vcl_result) 27 | 28 | 29 | -------------------------------------------------------------------------------- /discriminative/run_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import discriminative.utils.vcl as vcl 3 | import discriminative.utils.coreset as coreset 4 | from discriminative.utils.DataGenerator import SplitMnistGenerator 5 | 6 | hidden_size = [256, 256] 7 | batch_size = None 8 | no_epochs = 120 9 | single_head = False 10 | run_coreset_only = False 11 | np.random.seed(0) 12 | 13 | #Just VCL 14 | coreset_size = 0 15 | data_gen = SplitMnistGenerator() 16 | vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 17 | coreset.rand_from_batch, coreset_size, batch_size, single_head) 18 | np.save("./results/VCL-split{}".format(""), vcl_result) 19 | 20 | #VCL + Random Coreset 21 | coreset_size = 40 22 | data_gen = SplitMnistGenerator() 23 | rand_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 24 | coreset.rand_from_batch, coreset_size, batch_size, single_head,gan_bol= True) 25 | print(rand_vcl_result) 26 | np.save("./results/VGR-all-split{}".format(""), rand_vcl_result) 27 | 28 | #VCL + k-center coreset 29 | coreset_size = 40 30 | data_gen = SplitMnistGenerator() 31 | kcen_vcl_result = vcl.run_vcl(hidden_size, no_epochs, data_gen, 32 | coreset.k_center, coreset_size, batch_size, single_head) 33 | print(kcen_vcl_result) 34 | np.save("./results/kcenVCL-split{}".format(""), kcen_vcl_result) 35 | -------------------------------------------------------------------------------- /discriminative/run_split_just_coreset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import discriminative.utils.vcl as vcl 3 | import discriminative.utils.coreset as coreset 4 | from discriminative.utils.DataGenerator import SplitMnistGenerator 5 | 6 | 7 | hidden_size = [256, 256] 8 | batch_size = None 9 | no_epochs = 120 10 | single_head = False 11 | coreset_size = 40 12 | 13 | np.random.seed(0) 14 | data_gen = SplitMnistGenerator() 15 | rand_vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen, 16 | coreset.rand_from_batch, coreset_size, batch_size, single_head) 17 | print(rand_vcl_result) 18 | np.save("./results/rand-coreset-only-split{}".format(""), rand_vcl_result) 19 | 20 | np.random.seed(0) 21 | data_gen = SplitMnistGenerator() 22 | kcen_vcl_result = vcl.run_coreset_only(hidden_size, no_epochs, data_gen, 23 | coreset.k_center, coreset_size, batch_size, single_head) 24 | print(kcen_vcl_result) 25 | np.save("./results/kcen-coreset-only-split{}".format(""), kcen_vcl_result) 26 | -------------------------------------------------------------------------------- /discriminative/utils/DataGenerator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gzip 3 | import pickle as cp 4 | from copy import deepcopy 5 | 6 | 7 | class PermutedMnistGenerator(): 8 | def __init__(self, max_iter=10): 9 | 10 | with gzip.open('data/mnist.pkl.gz', 'rb') as file: 11 | u = cp._Unpickler(file) 12 | u.encoding = 'latin1' 13 | p = u.load() 14 | train_set, valid_set, test_set = p 15 | 16 | 17 | self.X_train = np.vstack((train_set[0], valid_set[0])) 18 | self.Y_train = np.hstack((train_set[1], valid_set[1])) 19 | self.X_test = test_set[0] 20 | self.Y_test = test_set[1] 21 | self.max_iter = max_iter 22 | self.cur_iter = 0 23 | 24 | def get_dims(self): 25 | # Get data input and output dimensions 26 | return self.X_train.shape[1], 10 27 | 28 | def next_task(self): 29 | if self.cur_iter >= self.max_iter: 30 | raise Exception('Number of tasks exceeded!') 31 | else: 32 | np.random.seed(self.cur_iter) 33 | perm_inds = np.arange(self.X_train.shape[1]) 34 | np.random.shuffle(perm_inds) 35 | 36 | # Retrieve train data 37 | next_x_train = deepcopy(self.X_train) 38 | next_x_train = next_x_train[:,perm_inds] 39 | next_y_train = self.Y_train 40 | 41 | # Retrieve test data 42 | next_x_test = deepcopy(self.X_test) 43 | next_x_test = next_x_test[:,perm_inds] 44 | next_y_test = self.Y_test 45 | 46 | self.cur_iter += 1 47 | 48 | return next_x_train, next_y_train, next_x_test, next_y_test 49 | 50 | 51 | class SplitMnistGenerator(): 52 | def __init__(self): 53 | with gzip.open('data/mnist.pkl.gz', 'rb') as file: 54 | u = cp._Unpickler(file) 55 | u.encoding = 'latin1' 56 | p = u.load() 57 | train_set, valid_set, test_set = p 58 | 59 | self.X_train = np.vstack((train_set[0], valid_set[0])) 60 | self.X_test = test_set[0] 61 | self.train_label = np.hstack((train_set[1], valid_set[1])) 62 | self.test_label = test_set[1] 63 | 64 | self.sets_0 = [0, 2, 4, 6, 8] 65 | self.sets_1 = [1, 3, 5, 7, 9] 66 | self.max_iter = len(self.sets_0) 67 | self.cur_iter = 0 68 | 69 | def get_dims(self): 70 | # Get data input and output dimensions 71 | return self.X_train.shape[1], 2 72 | 73 | def next_task(self): 74 | if self.cur_iter >= self.max_iter: 75 | raise Exception('Number of tasks exceeded!') 76 | else: 77 | # Retrieve train data 78 | train_0_id = np.where(self.train_label == self.sets_0[self.cur_iter])[0] 79 | train_1_id = np.where(self.train_label == self.sets_1[self.cur_iter])[0] 80 | next_x_train = np.vstack((self.X_train[train_0_id], self.X_train[train_1_id])) 81 | 82 | next_y_train = np.vstack((np.ones((train_0_id.shape[0],1 )), np.zeros((train_1_id.shape[0],1 )))).squeeze(-1) 83 | 84 | # Retrieve test data 85 | test_0_id = np.where(self.test_label == self.sets_0[self.cur_iter])[0] 86 | test_1_id = np.where(self.test_label == self.sets_1[self.cur_iter])[0] 87 | next_x_test = np.vstack((self.X_test[test_0_id], self.X_test[test_1_id])) 88 | 89 | next_y_test = np.vstack((np.ones((test_0_id.shape[0],1 )), np.zeros((test_1_id.shape[0], 1)))).squeeze(-1) 90 | 91 | self.cur_iter += 1 92 | 93 | return next_x_train, next_y_train, next_x_test, next_y_test -------------------------------------------------------------------------------- /discriminative/utils/GAN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | from torchvision.utils import save_image 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets 8 | from torch.autograd import Variable 9 | import torch.nn as nn 10 | import torch 11 | 12 | n_epochs = 50 13 | batch_size = 64 14 | lr = 0.0002 15 | b1 = 0.5 16 | b2 = 0.999 17 | n_cpu = 8 18 | latent_dim = 100 19 | num_classes = 2 20 | img_size = 28 21 | channels = 1 22 | sample_interval = 400 23 | threshold = 0.99 24 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | cuda = True if torch.cuda.is_available() else False 26 | FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 27 | LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor 28 | 29 | 30 | 31 | class VGR(): 32 | def __init__(self, task_id): 33 | self.task_id = task_id 34 | # Loss functions 35 | self.adversarial_loss = torch.nn.BCELoss() 36 | self.auxiliary_loss = torch.nn.CrossEntropyLoss() 37 | # Initialize generator and discriminator 38 | self.generator = Generator() 39 | self.discriminator = Discriminator() 40 | if cuda: 41 | self.generator.cuda() 42 | self.discriminator.cuda() 43 | self.adversarial_loss.cuda() 44 | self.auxiliary_loss.cuda() 45 | 46 | # Initialize weights 47 | self.generator.apply(weights_init_normal) 48 | self.discriminator.apply(weights_init_normal) 49 | 50 | # Optimizers 51 | self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) 52 | self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) 53 | 54 | 55 | def train(self, x_train, y_train): 56 | N = x_train.shape[0] 57 | x_train -= 0.5 58 | x_train /= 0.5 59 | for epoch in range(n_epochs): 60 | 61 | total_batch = int(np.ceil(N * 1.0 / batch_size)) 62 | perm_inds = np.arange(x_train.shape[0]) 63 | np.random.shuffle(perm_inds) 64 | cur_x_train = x_train[perm_inds] 65 | cur_y_train = y_train[perm_inds] 66 | # Loop over all batches 67 | for i in range(total_batch): 68 | start_ind = i*batch_size 69 | end_ind = np.min([(i+1)*batch_size, N]) 70 | batch_x = torch.Tensor(cur_x_train[start_ind:end_ind, :]).to(device = device) 71 | batch_y = torch.Tensor(cur_y_train[start_ind:end_ind]).to(device = device) 72 | batch_x = batch_x.reshape(-1,img_size,img_size).unsqueeze(1) 73 | bsize = batch_x.shape[0] 74 | 75 | # Adversarial ground truths 76 | valid = Variable(FloatTensor(bsize, 1).fill_(1.0), requires_grad=False) 77 | fake = Variable(FloatTensor(bsize, 1).fill_(0.0), requires_grad=False) 78 | 79 | # Configure input 80 | real_imgs = Variable(batch_x.type(FloatTensor)) 81 | labels = Variable(batch_y.type(LongTensor)) 82 | 83 | # ----------------- 84 | # Train Generator 85 | # ----------------- 86 | 87 | self.optimizer_G.zero_grad() 88 | 89 | # Sample noise and labels as generator input 90 | z = Variable(FloatTensor(np.random.normal(0, 1, (bsize, latent_dim)))) 91 | 92 | # Generate a batch of images 93 | gen_imgs = self.generator(z) 94 | 95 | # Loss measures generator's ability to fool the discriminator 96 | validity, _ = self.discriminator(gen_imgs) 97 | g_loss = self.adversarial_loss(validity, valid) 98 | 99 | g_loss.backward() 100 | self.optimizer_G.step() 101 | 102 | # --------------------- 103 | # Train Discriminator 104 | # --------------------- 105 | 106 | self.optimizer_D.zero_grad() 107 | 108 | # Loss for real images 109 | real_pred, real_aux = self.discriminator(real_imgs) 110 | d_real_loss = (self.adversarial_loss(real_pred, valid) + self.auxiliary_loss(real_aux, labels)) / 2 111 | 112 | # Loss for fake images 113 | fake_pred, fake_aux = self.discriminator(gen_imgs.detach()) 114 | d_fake_loss = self.adversarial_loss(fake_pred, fake) 115 | 116 | # Total discriminator loss 117 | d_loss = (d_real_loss + d_fake_loss) / 2 118 | 119 | # Calculate discriminator accuracy 120 | pred = real_aux.data.cpu().numpy() 121 | gt = labels.data.cpu().numpy() 122 | d_acc = np.mean(np.argmax(pred, axis=1) == gt) 123 | 124 | d_loss.backward() 125 | self.optimizer_D.step() 126 | 127 | print ("Epoch {}/{}, Discriminator loss: {}, acc: {}%, Generator loss: {}".format(epoch, n_epochs, d_loss.item(), 100 * d_acc, g_loss.item())) 128 | 129 | 130 | self.generator.cpu() 131 | self.discriminator.cpu() 132 | self.adversarial_loss.cpu() 133 | self.auxiliary_loss.cpu() 134 | 135 | 136 | 137 | def generate_samples(self, no_samples, task_id, current_nb = 0): 138 | # Sample noise and labels as generator input 139 | z = Variable(torch.FloatTensor(np.random.normal(0, 1, (no_samples, latent_dim)))) 140 | # Generate a batch of images 141 | gen_imgs = self.generator(z) 142 | _,labels = self.discriminator(gen_imgs) 143 | kept_indices = torch.nonzero(torch.where(torch.max(labels,1)[0] < threshold, torch.zeros(labels.shape[0]), torch.ones(labels.shape[0]))).squeeze(1) 144 | print(kept_indices.shape[0]) 145 | labels = labels.index_select(0,kept_indices) 146 | gen_imgs = gen_imgs.index_select(0,kept_indices) 147 | if current_nb == 0: 148 | save_image(gen_imgs.data[:25], 'images/%d.png' % task_id, nrow=5, normalize=True) 149 | gen_imgs = gen_imgs.squeeze(1).reshape(kept_indices.shape[0],784) 150 | labels = labels.argmax(1).type(torch.FloatTensor) 151 | gen_imgs = gen_imgs.data.cpu() 152 | labels = labels.data.cpu() 153 | 154 | 155 | new_current_nb = current_nb + kept_indices.shape[0] 156 | if(new_current_nb < no_samples): 157 | new_gen_imgs, new_labels = self.generate_samples(no_samples,task_id, new_current_nb) 158 | gen_imgs = np.vstack((gen_imgs,new_gen_imgs)) 159 | labels = np.hstack((labels,new_labels)) 160 | 161 | return gen_imgs*0.5+0.5, labels 162 | 163 | def weights_init_normal(m): 164 | classname = m.__class__.__name__ 165 | if classname.find('Conv') != -1: 166 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 167 | elif classname.find('BatchNorm') != -1: 168 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 169 | torch.nn.init.constant_(m.bias.data, 0.0) 170 | 171 | class Generator(nn.Module): 172 | def __init__(self): 173 | super(Generator, self).__init__() 174 | self.init_size = img_size // 4 # Initial size before upsampling 175 | self.l1 = nn.Sequential(nn.Linear(latent_dim, 128*self.init_size**2)) 176 | self.conv_blocks = nn.Sequential( 177 | nn.BatchNorm2d(128), 178 | nn.Upsample(scale_factor=2), 179 | nn.Conv2d(128, 128, 5, stride=1, padding=2), 180 | nn.BatchNorm2d(128, 0.8), 181 | nn.LeakyReLU(0.2, inplace=True), 182 | nn.Upsample(scale_factor=2), 183 | nn.Conv2d(128, 64, 5, stride=1, padding=2), 184 | nn.BatchNorm2d(64, 0.8), 185 | nn.LeakyReLU(0.2, inplace=True), 186 | nn.Conv2d(64, channels, 5, stride=1, padding=2), 187 | nn.Tanh() 188 | ) 189 | 190 | def forward(self, noise): 191 | out = self.l1(noise) 192 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 193 | img = self.conv_blocks(out) 194 | return img 195 | 196 | class Discriminator(nn.Module): 197 | def __init__(self): 198 | super(Discriminator, self).__init__() 199 | 200 | def discriminator_block(in_filters, out_filters, bn=True): 201 | """Returns layers of each discriminator block""" 202 | block = [ nn.Conv2d(in_filters, out_filters, 3, 2, 1), 203 | nn.LeakyReLU(0.2, inplace=True), 204 | nn.Dropout2d(0.25)] 205 | if bn: 206 | block.append(nn.BatchNorm2d(out_filters, 0.8)) 207 | return block 208 | 209 | self.conv_blocks = nn.Sequential( 210 | *discriminator_block(channels, 16, bn=False), 211 | *discriminator_block(16, 32), 212 | *discriminator_block(32, 64), 213 | *discriminator_block(64, 128), 214 | ) 215 | 216 | # Output layers 217 | self.adv_layer = nn.Sequential( nn.Linear(512, 1), 218 | nn.Sigmoid()) 219 | self.aux_layer = nn.Sequential( nn.Linear(512, num_classes), 220 | nn.Softmax()) 221 | 222 | def forward(self, img): 223 | out = self.conv_blocks(img) 224 | out = out.view(out.shape[0], -1) 225 | validity = self.adv_layer(out) 226 | label = self.aux_layer(out) 227 | 228 | return validity, label 229 | 230 | 231 | -------------------------------------------------------------------------------- /discriminative/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__init__.py -------------------------------------------------------------------------------- /discriminative/utils/__pycache__/DataGenerator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/DataGenerator.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/utils/__pycache__/coreset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/coreset.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/utils/__pycache__/multihead_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/multihead_models.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/utils/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/utils/__pycache__/vcl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pa-dqn/VariationalContinualLearning/c1ca089ad8bcd3388d1a7f8dae2ede13d70cd52b/discriminative/utils/__pycache__/vcl.cpython-36.pyc -------------------------------------------------------------------------------- /discriminative/utils/coreset.py: -------------------------------------------------------------------------------- 1 | ####Code by nvcuong ### 2 | # 3 | # 4 | # Nothing to change here to make it work on pytorch 5 | 6 | import sklearn.decomposition as decomp 7 | import numpy as np 8 | 9 | """ Random coreset selection """ 10 | def rand_from_batch(x_coreset, y_coreset, x_train, y_train, coreset_size): 11 | # Randomly select from (x_train, y_train) and add to current coreset (x_coreset, y_coreset) 12 | idx = np.random.choice(x_train.shape[0], coreset_size, False) 13 | x_coreset.append(x_train[idx,:]) 14 | y_coreset.append(y_train[idx]) 15 | x_train = np.delete(x_train, idx, axis=0) 16 | y_train = np.delete(y_train, idx, axis=0) 17 | return x_coreset, y_coreset, x_train, y_train 18 | 19 | """ K-center coreset selection """ 20 | def k_center(x_coreset, y_coreset, x_train, y_train, coreset_size): 21 | # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset) 22 | dists = np.full(x_train.shape[0], np.inf) 23 | current_id = 0 24 | dists = update_distance(dists, x_train, current_id) 25 | idx = [ current_id ] 26 | 27 | for i in range(1, coreset_size): 28 | current_id = np.argmax(dists) 29 | dists = update_distance(dists, x_train, current_id) 30 | idx.append(current_id) 31 | 32 | x_coreset.append(x_train[idx,:]) 33 | y_coreset.append(y_train[idx]) 34 | x_train = np.delete(x_train, idx, axis=0) 35 | y_train = np.delete(y_train, idx, axis=0) 36 | 37 | return x_coreset, y_coreset, x_train, y_train 38 | 39 | 40 | """ K-center performed on reduced data by pca coreset selection """ 41 | def pca_k_center(x_coreset, y_coreset, x_train, y_train, coreset_size): 42 | # Select K centers from (x_train, y_train) and add to current coreset (x_coreset, y_coreset) 43 | pca = decomp.PCA(20) 44 | pca.fit(x_train) 45 | x_train_reduced = pca.transform(x_train) 46 | dists = np.full(x_train_reduced.shape[0], np.inf) 47 | current_id = 0 48 | dists = update_distance(dists, x_train_reduced, current_id) 49 | idx = [ current_id ] 50 | 51 | for i in range(1, coreset_size): 52 | current_id = np.argmax(dists) 53 | dists = update_distance(dists, x_train_reduced, current_id) 54 | idx.append(current_id) 55 | 56 | x_coreset.append(x_train[idx,:]) 57 | y_coreset.append(y_train[idx]) 58 | x_train = np.delete(x_train, idx, axis=0) 59 | y_train = np.delete(y_train, idx, axis=0) 60 | 61 | return x_coreset, y_coreset, x_train, y_train 62 | 63 | 64 | 65 | def attention_like_coreset(x_coreset, y_coreset, x_train, y_train, coreset_size): 66 | """TODO: learning a subnetwork that chooses the coreset (attention-like) """ 67 | return x_coreset, y_coreset, x_train, y_train 68 | 69 | 70 | def uncertainty_like_coreset(x_coreset, y_coreset, x_train, y_train, coreset_size): 71 | """TODO: Keeping only the instances that were not classified w.h. certainty """ 72 | return x_coreset, y_coreset, x_train, y_train 73 | 74 | 75 | def update_distance(dists, x_train, current_id): 76 | for i in range(x_train.shape[0]): 77 | current_dist = np.linalg.norm(x_train[i,:]-x_train[current_id,:]) 78 | dists[i] = np.minimum(current_dist, dists[i]) 79 | return dists 80 | -------------------------------------------------------------------------------- /discriminative/utils/multihead_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from scipy.stats import truncnorm 6 | from copy import deepcopy 7 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 8 | 9 | np.random.seed(0) 10 | 11 | # variable initialization functions 12 | def truncated_normal(size, stddev=1, variable = False, mean=0): 13 | mu, sigma = mean, stddev 14 | lower, upper= -2 * sigma, 2 * sigma 15 | X = truncnorm( 16 | (lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma) 17 | X_tensor = torch.Tensor(data = X.rvs(size)).to(device = device) 18 | X_tensor.requires_grad = variable 19 | return X_tensor 20 | 21 | def init_tensor(value, dout, din = 1, variable = False): 22 | if din != 1: 23 | x = value * torch.ones([din, dout]).to(device = device) 24 | else: 25 | x = value * torch.ones([dout]).to(device = device) 26 | x.requires_grad=variable 27 | 28 | return x 29 | 30 | class Cla_NN(object): 31 | def __init__(self, input_size, hidden_size, output_size, training_size): 32 | return 33 | 34 | 35 | def train(self, x_train, y_train, task_idx, no_epochs=1000, batch_size=100, display_epoch=5): 36 | N = x_train.shape[0] 37 | self.training_size = N 38 | if batch_size > N: 39 | batch_size = N 40 | 41 | costs = [] 42 | # Training cycle 43 | for epoch in range(no_epochs): 44 | perm_inds = np.arange(x_train.shape[0]) 45 | np.random.shuffle(perm_inds) 46 | cur_x_train = x_train[perm_inds] 47 | cur_y_train = y_train[perm_inds] 48 | 49 | avg_cost = 0. 50 | total_batch = int(np.ceil(N * 1.0 / batch_size)) 51 | # Loop over all batches 52 | for i in range(total_batch): 53 | start_ind = i*batch_size 54 | end_ind = np.min([(i+1)*batch_size, N]) 55 | batch_x = torch.Tensor(cur_x_train[start_ind:end_ind, :]).to(device = device) 56 | batch_y = torch.Tensor(cur_y_train[start_ind:end_ind]).to(device = device) 57 | 58 | ##TODO: check if we need to lock the gradient somewhere 59 | self.optimizer.zero_grad() 60 | cost = self.get_loss(batch_x, batch_y, task_idx) 61 | cost.backward() 62 | self.optimizer.step() 63 | 64 | # Compute average loss 65 | avg_cost += cost / total_batch 66 | # Display logs per epoch step 67 | if epoch % display_epoch == 0: 68 | print("Epoch:", '%04d' % (epoch+1), "cost=", \ 69 | "{:.9f}".format(avg_cost)) 70 | costs.append(avg_cost) 71 | print("Optimization Finished!") 72 | return costs 73 | 74 | def prediction_prob(self, x_test, task_idx): 75 | prob = F.softmax(self._prediction(x_test, task_idx, self.no_pred_samples), dim=-1) 76 | return prob 77 | 78 | 79 | 80 | 81 | 82 | """ Neural Network Model """ 83 | class Vanilla_NN(Cla_NN): 84 | def __init__(self, input_size, hidden_size, output_size, training_size, learning_rate=0.001): 85 | # 86 | super(Vanilla_NN, self).__init__(input_size, hidden_size, output_size, training_size) 87 | # # init weights and biases 88 | self.W, self.b, self.W_last, self.b_last, self.size = self.create_weights( 89 | input_size, hidden_size, output_size) 90 | self.no_layers = len(hidden_size) + 1 91 | self.weights = self.W + self.b + self.W_last + self.b_last 92 | self.training_size = training_size 93 | self.optimizer = optim.Adam(self.weights, lr=learning_rate) 94 | 95 | def _prediction(self, inputs, task_idx): 96 | act = inputs 97 | for i in range(self.no_layers-1): 98 | pre = torch.add(torch.matmul(act, self.W[i]), self.b[i]) 99 | act = F.relu(pre) 100 | pre = torch.add(torch.matmul(act, self.W_last[task_idx]), self.b_last[task_idx]) 101 | return pre 102 | 103 | def _logpred(self, inputs, targets, task_idx): 104 | 105 | loss = torch.nn.CrossEntropyLoss() 106 | pred = self._prediction(inputs, task_idx) 107 | log_lik = - loss(pred, targets.type(torch.long)) 108 | return log_lik 109 | 110 | def prediction_prob(self, x_test, task_idx): 111 | prob = F.softmax(self._prediction(x_test, task_idx), dim=-1) 112 | return prob 113 | 114 | def get_loss(self, batch_x, batch_y, task_idx): 115 | return -self._logpred(batch_x, batch_y, task_idx) 116 | 117 | def create_weights(self, in_dim, hidden_size, out_dim): 118 | hidden_size = deepcopy(hidden_size) 119 | hidden_size.append(out_dim) 120 | hidden_size.insert(0, in_dim) 121 | 122 | no_layers = len(hidden_size) - 1 123 | W = [] 124 | b = [] 125 | W_last = [] 126 | b_last = [] 127 | for i in range(no_layers-1): 128 | din = hidden_size[i] 129 | dout = hidden_size[i+1] 130 | 131 | #Initializiation values of means 132 | Wi_m = truncated_normal([din, dout], stddev=0.1, variable = True) 133 | bi_m = truncated_normal([dout], stddev=0.1, variable = True) 134 | 135 | #Append to list weights 136 | W.append(Wi_m) 137 | b.append(bi_m) 138 | 139 | Wi = truncated_normal([hidden_size[-2], out_dim], stddev=0.1, variable = True) 140 | bi = truncated_normal([out_dim], stddev=0.1, variable = True) 141 | W_last.append(Wi) 142 | b_last.append(bi) 143 | return W, b, W_last, b_last, hidden_size 144 | 145 | def get_weights(self): 146 | weights = [self.weights[:self.no_layers-1], self.weights[self.no_layers-1:2*(self.no_layers-1)], [self.weights[-2]], [self.weights[-1]]] 147 | return weights 148 | 149 | """ Bayesian Neural Network with Mean field VI approximation """ 150 | class MFVI_NN(Cla_NN): 151 | def __init__(self, input_size, hidden_size, output_size, training_size, 152 | no_train_samples=10, no_pred_samples=100, single_head = False, prev_means=None, learning_rate=0.001): 153 | ##TODO: handle single head 154 | super(MFVI_NN, self).__init__(input_size, hidden_size, output_size, training_size) 155 | 156 | m1, v1, hidden_size = self.create_weights( 157 | input_size, hidden_size, output_size, prev_means) 158 | 159 | self.input_size = input_size 160 | self.out_size = output_size 161 | self.size = hidden_size 162 | self.single_head = single_head 163 | 164 | self.W_m, self.b_m = m1[0], m1[1] 165 | self.W_v, self.b_v = v1[0], v1[1] 166 | 167 | self.W_last_m, self.b_last_m = [], [] 168 | self.W_last_v, self.b_last_v = [], [] 169 | 170 | 171 | m2, v2 = self.create_prior(input_size, self.size, output_size) 172 | 173 | self.prior_W_m, self.prior_b_m, = m2[0], m2[1] 174 | self.prior_W_v, self.prior_b_v = v2[0], v2[1] 175 | 176 | self.prior_W_last_m, self.prior_b_last_m = [], [] 177 | self.prior_W_last_v, self.prior_b_last_v = [], [] 178 | 179 | self.W_m_copy, self.W_v_copy, self.b_m_copy, self.b_v_copy = None, None, None, None 180 | self.W_last_m_copy, self.W_last_v_copy, self.b_last_m_copy, self.b_last_v_copy = None, None, None, None 181 | self.prior_W_m_copy, self.prior_W_v_copy, self.prior_b_m_copy, self.prior_b_v_copy = None, None, None, None 182 | self.prior_W_last_m_copy, self.prior_W_last_v_copy, self.prior_b_last_m_copy, self.prior_b_last_v_copy = None, None, None, None 183 | 184 | 185 | 186 | self.no_layers = len(self.size) - 1 187 | self.no_train_samples = no_train_samples 188 | self.no_pred_samples = no_pred_samples 189 | self.training_size = training_size 190 | self.learning_rate = learning_rate 191 | 192 | if prev_means is not None: 193 | self.init_first_head(prev_means) 194 | else: 195 | self.create_head() 196 | 197 | 198 | m1.append(self.W_last_m) 199 | m1.append(self.b_last_m) 200 | v1.append(self.W_last_v) 201 | v1.append(self.b_last_v) 202 | 203 | r1 = m1 + v1 204 | self.weights = [item for sublist in r1 for item in sublist] 205 | 206 | 207 | self.optimizer = optim.Adam(self.weights, lr=learning_rate) 208 | 209 | def get_loss(self, batch_x, batch_y, task_idx): 210 | return torch.div(self._KL_term(), self.training_size) - self._logpred(batch_x, batch_y, task_idx) 211 | 212 | def _prediction(self, inputs, task_idx, no_samples): 213 | K = no_samples 214 | size = self.size 215 | 216 | act = torch.unsqueeze(inputs, 0).repeat([K, 1, 1]) 217 | for i in range(self.no_layers-1): 218 | din = self.size[i] 219 | dout = self.size[i+1] 220 | eps_w = torch.normal(torch.zeros((K, din, dout)), torch.ones((K, din, dout))).to(device = device) 221 | eps_b = torch.normal(torch.zeros((K, 1, dout)), torch.ones((K, 1, dout))).to(device = device) 222 | weights = torch.add(eps_w * torch.exp(0.5*self.W_v[i]), self.W_m[i]) 223 | biases = torch.add(eps_b * torch.exp(0.5*self.b_v[i]), self.b_m[i]) 224 | pre = torch.add(torch.einsum('mni,mio->mno', act, weights), biases) 225 | act = F.relu(pre) 226 | 227 | din = self.size[-2] 228 | dout = self.size[-1] 229 | 230 | eps_w = torch.normal(torch.zeros((K, din, dout)), torch.ones((K, din, dout))).to(device = device) 231 | eps_b = torch.normal(torch.zeros((K, 1, dout)), torch.ones((K, 1, dout))).to(device = device) 232 | Wtask_m = self.W_last_m[task_idx] 233 | Wtask_v = self.W_last_v[task_idx] 234 | btask_m = self.b_last_m[task_idx] 235 | btask_v = self.b_last_v[task_idx] 236 | 237 | weights = torch.add(eps_w * torch.exp(0.5*Wtask_v),Wtask_m) 238 | biases = torch.add(eps_b * torch.exp(0.5*btask_v), btask_m) 239 | act = torch.unsqueeze(act, 3) 240 | weights = torch.unsqueeze(weights, 1) 241 | pre = torch.add(torch.sum(act * weights, dim = 2), biases) 242 | return pre 243 | 244 | def _logpred(self, inputs, targets, task_idx): 245 | loss = torch.nn.CrossEntropyLoss() 246 | pred = self._prediction(inputs, task_idx, self.no_train_samples).view(-1,self.out_size) 247 | targets = targets.repeat([self.no_train_samples, 1]).view(-1) 248 | log_liks = -loss(pred, targets.type(torch.long)) 249 | log_lik = log_liks.mean() 250 | return log_lik 251 | 252 | 253 | def _KL_term(self): 254 | kl = 0 255 | for i in range(self.no_layers-1): 256 | din = self.size[i] 257 | dout = self.size[i+1] 258 | m, v = self.W_m[i], self.W_v[i] 259 | m0, v0 = self.prior_W_m[i], self.prior_W_v[i] 260 | 261 | const_term = -0.5 * dout * din 262 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v) 263 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0) 264 | kl += const_term + log_std_diff + mu_diff_term 265 | 266 | m, v = self.b_m[i], self.b_v[i] 267 | m0, v0 = self.prior_b_m[i], self.prior_b_v[i] 268 | 269 | const_term = -0.5 * dout 270 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v) 271 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0) 272 | kl += log_std_diff + mu_diff_term + const_term 273 | 274 | no_tasks = len(self.W_last_m) 275 | din = self.size[-2] 276 | dout = self.size[-1] 277 | 278 | for i in range(no_tasks): 279 | m, v = self.W_last_m[i], self.W_last_v[i] 280 | m0, v0 = self.prior_W_last_m[i], self.prior_W_last_v[i] 281 | 282 | const_term = - 0.5 * dout * din 283 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v) 284 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0) 285 | kl += const_term + log_std_diff + mu_diff_term 286 | 287 | m, v = self.b_last_m[i], self.b_last_v[i] 288 | m0, v0 = self.prior_b_last_m[i], self.prior_b_last_v[i] 289 | 290 | const_term = -0.5 * dout 291 | log_std_diff = 0.5 * torch.sum(torch.log(v0) - v) 292 | mu_diff_term = 0.5 * torch.sum((torch.exp(v) + (m0 - m)**2) / v0) 293 | kl += const_term + log_std_diff + mu_diff_term 294 | return kl 295 | 296 | def save_weights(self): 297 | ''' Save weights before training on the coreset before getting the test accuracy ''' 298 | 299 | print("Saving weights before core set training") 300 | self.W_m_copy = [self.W_m[i].clone().detach().data for i in range(len(self.W_m))] 301 | self.W_v_copy = [self.W_v[i].clone().detach().data for i in range(len(self.W_v))] 302 | self.b_m_copy = [self.b_m[i].clone().detach().data for i in range(len(self.b_m))] 303 | self.b_v_copy = [self.b_v[i].clone().detach().data for i in range(len(self.b_v))] 304 | 305 | self.W_last_m_copy = [self.W_last_m[i].clone().detach().data for i in range(len(self.W_last_m))] 306 | self.W_last_v_copy = [self.W_last_v[i].clone().detach().data for i in range(len(self.W_last_v))] 307 | self.b_last_m_copy = [self.b_last_m[i].clone().detach().data for i in range(len(self.b_last_m))] 308 | self.b_last_v_copy = [self.b_last_v[i].clone().detach().data for i in range(len(self.b_last_v))] 309 | 310 | self.prior_W_m_copy = [self.prior_W_m[i].data for i in range(len(self.prior_W_m))] 311 | self.prior_W_v_copy = [self.prior_W_v[i].data for i in range(len(self.prior_W_v))] 312 | self.prior_b_m_copy = [self.prior_b_m[i].data for i in range(len(self.prior_b_m))] 313 | self.prior_b_v_copy = [self.prior_b_v[i].data for i in range(len(self.prior_b_v))] 314 | 315 | self.prior_W_last_m_copy = [self.prior_W_last_m[i].data for i in range(len(self.prior_W_last_m))] 316 | self.prior_W_last_v_copy = [self.prior_W_last_v[i].data for i in range(len(self.prior_W_last_v))] 317 | self.prior_b_last_m_copy = [self.prior_b_last_m[i].data for i in range(len(self.prior_b_last_m))] 318 | self.prior_b_last_v_copy = [self.prior_b_last_v[i].data for i in range(len(self.prior_b_last_v))] 319 | 320 | return 321 | 322 | def load_weights(self): 323 | ''' Re-load weights after getting the test accuracy ''' 324 | 325 | print("Reloading previous weights after core set training") 326 | self.weights = [] 327 | self.W_m = [self.W_m_copy[i].clone().detach().data for i in range(len(self.W_m))] 328 | self.W_v = [self.W_v_copy[i].clone().detach().data for i in range(len(self.W_v))] 329 | self.b_m = [self.b_m_copy[i].clone().detach().data for i in range(len(self.b_m))] 330 | self.b_v = [self.b_v_copy[i].clone().detach().data for i in range(len(self.b_v))] 331 | 332 | for i in range(len(self.W_m)): 333 | self.W_m[i].requires_grad = True 334 | self.W_v[i].requires_grad = True 335 | self.b_m[i].requires_grad = True 336 | self.b_v[i].requires_grad = True 337 | 338 | self.weights += self.W_m 339 | self.weights += self.W_v 340 | self.weights += self.b_m 341 | self.weights += self.b_v 342 | 343 | 344 | self.W_last_m = [self.W_last_m_copy[i].clone().detach().data for i in range(len(self.W_last_m))] 345 | self.W_last_v = [self.W_last_v_copy[i].clone().detach().data for i in range(len(self.W_last_v))] 346 | self.b_last_m = [self.b_last_m_copy[i].clone().detach().data for i in range(len(self.b_last_m))] 347 | self.b_last_v = [self.b_last_v_copy[i].clone().detach().data for i in range(len(self.b_last_v))] 348 | 349 | for i in range(len(self.W_last_m)): 350 | self.W_last_m[i].requires_grad = True 351 | self.W_last_v[i].requires_grad = True 352 | self.b_last_m[i].requires_grad = True 353 | self.b_last_v[i].requires_grad = True 354 | 355 | self.weights += self.W_last_m 356 | self.weights += self.W_last_v 357 | self.weights += self.b_last_m 358 | self.weights += self.b_last_v 359 | 360 | self.optimizer = optim.Adam(self.weights, lr=self.learning_rate) 361 | self.prior_W_m = [self.prior_W_m_copy[i].data for i in range(len(self.prior_W_m))] 362 | self.prior_W_v = [self.prior_W_v_copy[i].data for i in range(len(self.prior_W_v))] 363 | self.prior_b_m = [self.prior_b_m_copy[i].data for i in range(len(self.prior_b_m))] 364 | self.prior_b_v = [self.prior_b_v_copy[i].data for i in range(len(self.prior_b_v))] 365 | 366 | self.prior_W_last_m = [self.prior_W_last_m_copy[i].data for i in range(len(self.prior_W_last_m))] 367 | self.prior_W_last_v = [self.prior_W_last_v_copy[i].data for i in range(len(self.prior_W_last_v))] 368 | self.prior_b_last_m = [self.prior_b_last_m_copy[i].data for i in range(len(self.prior_b_last_m))] 369 | self.prior_b_last_v = [self.prior_b_last_v_copy[i].data for i in range(len(self.prior_b_last_v))] 370 | 371 | return 372 | 373 | def clean_copy_weights(self): 374 | self.W_m_copy, self.W_v_copy, self.b_m_copy, self.b_v_copy = None, None, None, None 375 | self.W_last_m_copy, self.W_last_v_copy, self.b_last_m_copy, self.b_last_v_copy = None, None, None, None 376 | self.prior_W_m_copy, self.prior_W_v_copy, self.prior_b_m_copy, self.prior_b_v_copy = None, None, None, None 377 | self.prior_W_last_m_copy, self.prior_W_last_v_copy, self.prior_b_last_m_copy, self.prior_b_last_v_copy = None, None, None, None 378 | 379 | def create_head(self): 380 | ''''Create new head when a new task is detected''' 381 | print("creating a new head") 382 | din = self.size[-2] 383 | dout = self.size[-1] 384 | 385 | W_m= truncated_normal([din, dout], stddev=0.1, variable=True) 386 | b_m= truncated_normal([dout], stddev=0.1, variable=True) 387 | W_v = init_tensor(-6.0, dout = dout, din = din, variable= True) 388 | b_v = init_tensor(-6.0, dout = dout, variable= True) 389 | 390 | self.W_last_m.append(W_m) 391 | self.W_last_v.append(W_v) 392 | self.b_last_m.append(b_m) 393 | self.b_last_v.append(b_v) 394 | 395 | 396 | W_m_p = torch.zeros([din, dout]).to(device = device) 397 | b_m_p = torch.zeros([dout]).to(device = device) 398 | W_v_p = init_tensor(1, dout = dout, din = din) 399 | b_v_p = init_tensor(1, dout = dout) 400 | 401 | self.prior_W_last_m.append(W_m_p) 402 | self.prior_W_last_v.append(W_v_p) 403 | self.prior_b_last_m.append(b_m_p) 404 | self.prior_b_last_v.append(b_v_p) 405 | self.weights = [] 406 | self.weights += self.W_m 407 | self.weights += self.W_v 408 | self.weights += self.b_m 409 | self.weights += self.b_v 410 | self.weights += self.W_last_m 411 | self.weights += self.W_last_v 412 | self.weights += self.b_last_m 413 | self.weights += self.b_last_v 414 | self.optimizer = optim.Adam(self.weights, lr=self.learning_rate) 415 | 416 | return 417 | 418 | 419 | def init_first_head(self, prev_means): 420 | ''''When the MFVI_NN is instanciated, we initialize weights with those of the Vanilla NN''' 421 | print("initializing first head") 422 | din = self.size[-2] 423 | dout = self.size[-1] 424 | self.prior_W_last_m = [torch.zeros([din, dout]).to(device = device)] 425 | self.prior_b_last_m = [torch.zeros([dout]).to(device = device)] 426 | self.prior_W_last_v = [init_tensor(1, dout = dout, din = din)] 427 | self.prior_b_last_v = [init_tensor(1, dout = dout)] 428 | 429 | W_last_m = prev_means[2][0].detach().data 430 | W_last_m.requires_grad = True 431 | self.W_last_m = [W_last_m] 432 | self.W_last_v = [init_tensor(-6.0, dout = dout, din = din, variable= True)] 433 | 434 | 435 | b_last_m = prev_means[3][0].detach().data 436 | b_last_m.requires_grad = True 437 | self.b_last_m = [b_last_m] 438 | self.b_last_v = [init_tensor(-6.0, dout = dout, variable= True)] 439 | 440 | return 441 | 442 | def create_weights(self, in_dim, hidden_size, out_dim, prev_means): 443 | hidden_size = deepcopy(hidden_size) 444 | hidden_size.append(out_dim) 445 | hidden_size.insert(0, in_dim) 446 | 447 | no_layers = len(hidden_size) - 1 448 | W_m = [] 449 | b_m = [] 450 | W_v = [] 451 | b_v = [] 452 | 453 | for i in range(no_layers-1): 454 | din = hidden_size[i] 455 | dout = hidden_size[i+1] 456 | if prev_means is not None: 457 | W_m_i = prev_means[0][i].detach().data 458 | W_m_i.requires_grad = True 459 | bi_m_i = prev_means[1][i].detach().data 460 | bi_m_i.requires_grad = True 461 | else: 462 | #Initializiation values of means 463 | W_m_i= truncated_normal([din, dout], stddev=0.1, variable=True) 464 | bi_m_i= truncated_normal([dout], stddev=0.1, variable=True) 465 | #Initializiation values of variances 466 | W_v_i = init_tensor(-6.0, dout = dout, din = din, variable = True) 467 | bi_v_i = init_tensor(-6.0, dout = dout, variable = True) 468 | 469 | #Append to list weights 470 | W_m.append(W_m_i) 471 | b_m.append(bi_m_i) 472 | W_v.append(W_v_i) 473 | b_v.append(bi_v_i) 474 | 475 | return [W_m, b_m], [W_v, b_v], hidden_size 476 | 477 | def create_prior(self, in_dim, hidden_size, out_dim, initial_mean = 0, initial_variance = 1): 478 | 479 | no_layers = len(hidden_size) - 1 480 | W_m = [] 481 | b_m = [] 482 | 483 | W_v = [] 484 | b_v = [] 485 | 486 | for i in range(no_layers - 1): 487 | din = hidden_size[i] 488 | dout = hidden_size[i + 1] 489 | 490 | # Initializiation values of means 491 | W_m_val = initial_mean * torch.zeros([din, dout]).to(device = device) 492 | bi_m_val = initial_mean * torch.zeros([dout]).to(device = device) 493 | 494 | # Initializiation values of variances 495 | W_v_val = initial_variance * init_tensor(1, dout = dout, din = din ) 496 | bi_v_val = initial_variance * init_tensor(1, dout = dout) 497 | 498 | # Append to list weights 499 | W_m.append(W_m_val) 500 | b_m.append(bi_m_val) 501 | W_v.append(W_v_val) 502 | b_v.append(bi_v_val) 503 | 504 | return [W_m, b_m], [W_v, b_v] 505 | 506 | 507 | 508 | 509 | def update_prior(self): 510 | print("updating prior...") 511 | for i in range(len(self.W_m)): 512 | self.prior_W_m[i].data.copy_(self.W_m[i].clone().detach().data) 513 | self.prior_b_m[i].data.copy_(self.b_m[i].clone().detach().data) 514 | self.prior_W_v[i].data.copy_(torch.exp(self.W_v[i].clone().detach().data)) 515 | self.prior_b_v[i].data.copy_(torch.exp(self.b_v[i].clone().detach().data)) 516 | 517 | length = len(self.W_last_m) 518 | 519 | for i in range(length): 520 | self.prior_W_last_m[i].data.copy_(self.W_last_m[i].clone().detach().data) 521 | self.prior_b_last_m[i].data.copy_(self.b_last_m[i].clone().detach().data) 522 | self.prior_W_last_v[i].data.copy_(torch.exp(self.W_last_v[i].clone().detach().data)) 523 | self.prior_b_last_v[i].data.copy_(torch.exp(self.b_last_v[i].clone().detach().data)) 524 | 525 | return -------------------------------------------------------------------------------- /discriminative/utils/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import torch 5 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | def merge_coresets(x_coresets, y_coresets): 9 | merged_x, merged_y = x_coresets[0], y_coresets[0] 10 | for i in range(1, len(x_coresets)): 11 | merged_x = np.vstack((merged_x, x_coresets[i])) 12 | merged_y = np.hstack((merged_y, y_coresets[i])) 13 | return merged_x, merged_y 14 | 15 | 16 | def get_coreset(x_coresets, y_coresets, single_head, coreset_size = 5000, gans = None, task_id=0): 17 | if gans is not None: 18 | if single_head: 19 | merged_x, merged_y = gans[0].generate_samples(coreset_size, task_id) 20 | for i in range(1, len(gans)): 21 | new_x, new_y = gans[i].generate_samples(coreset_size, task_id) 22 | merged_x = np.vstack((merged_x,new_x)) 23 | merged_y = np.hstack((merged_y,new_y)) 24 | return merged_x, merged_y 25 | else: 26 | return gans.generate_samples(coreset_size, task_id)[:coreset_size] 27 | else: 28 | if single_head: 29 | return merge_coresets(x_coresets, y_coresets) 30 | else: 31 | return x_coresets, y_coresets 32 | 33 | 34 | def get_scores(model, x_testsets, y_testsets, no_epochs, single_head, x_coresets, y_coresets, batch_size=None, just_vanilla = False, gans = None): 35 | 36 | acc = [] 37 | if single_head: 38 | if len(x_coresets) > 0 or gans is not None: 39 | x_train, y_train = get_coreset(x_coresets, y_coresets, single_head, coreset_size = 6000, gans = gans, task_id=0) 40 | 41 | bsize = x_train.shape[0] if (batch_size is None) else batch_size 42 | x_train = torch.Tensor(x_train) 43 | y_train = torch.Tensor(y_train) 44 | model.train(x_train, y_train, 0, no_epochs, bsize) 45 | 46 | for i in range(len(x_testsets)): 47 | if not single_head: 48 | if len(x_coresets)>0 or gans is not None: 49 | model.load_weights() 50 | gan_i = None 51 | if gans is not None: 52 | gan_i = gans[i] 53 | x_train, y_train = get_coreset(None, None, single_head, coreset_size = 6000, gans= gan_i, task_id=i) 54 | else: 55 | x_train, y_train = get_coreset(x_coresets[i], y_coresets[i], single_head, coreset_size = 6000, gans= None, task_id=i) 56 | bsize = x_train.shape[0] if (batch_size is None) else batch_size 57 | x_train = torch.Tensor(x_train) 58 | y_train = torch.Tensor(y_train) 59 | model.train(x_train, y_train, i, no_epochs, bsize) 60 | 61 | head = 0 if single_head else i 62 | x_test, y_test = x_testsets[i], y_testsets[i] 63 | N = x_test.shape[0] 64 | bsize = N if (batch_size is None) else batch_size 65 | cur_acc = 0 66 | total_batch = int(np.ceil(N * 1.0 / bsize)) 67 | # Loop over all batches 68 | for i in range(total_batch): 69 | start_ind = i*bsize 70 | end_ind = np.min([(i+1)*bsize, N]) 71 | batch_x_test = torch.Tensor(x_test[start_ind:end_ind, :]).to(device = device) 72 | batch_y_test = torch.Tensor(y_test[start_ind:end_ind]).type(torch.LongTensor).to(device = device) 73 | pred = model.prediction_prob(batch_x_test, head) 74 | if not just_vanilla: 75 | pred_mean = pred.mean(0) 76 | else: 77 | pred_mean = pred 78 | pred_y = torch.argmax(pred_mean, dim=1) 79 | cur_acc += end_ind - start_ind-(pred_y - batch_y_test).nonzero().shape[0] 80 | 81 | cur_acc = float(cur_acc) 82 | cur_acc /= N 83 | acc.append(cur_acc) 84 | print("Accuracy is {}".format(cur_acc)) 85 | return acc 86 | 87 | def concatenate_results(score, all_score): 88 | if all_score.size == 0: 89 | all_score = np.reshape(score, (1,-1)) 90 | else: 91 | new_arr = np.empty((all_score.shape[0], all_score.shape[1]+1)) 92 | new_arr[:] = np.nan 93 | new_arr[:,:-1] = all_score 94 | all_score = np.vstack((new_arr, score)) 95 | return all_score -------------------------------------------------------------------------------- /discriminative/utils/vcl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import discriminative.utils.test as test 3 | from discriminative.utils.multihead_models import Vanilla_NN, MFVI_NN 4 | import torch 5 | import discriminative.utils.GAN as GAN 6 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 7 | try: 8 | from torchviz import make_dot, make_dot_from_trace 9 | except ImportError: 10 | print("Torchviz was not found.") 11 | 12 | def run_vcl(hidden_size, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True, gan_bol = False): 13 | in_dim, out_dim = data_gen.get_dims() 14 | x_coresets, y_coresets = [], [] 15 | x_testsets, y_testsets = [], [] 16 | gans = [] 17 | all_acc = np.array([]) 18 | 19 | for task_id in range(data_gen.max_iter): 20 | x_train, y_train, x_test, y_test = data_gen.next_task() 21 | x_testsets.append(x_test) 22 | y_testsets.append(y_test) 23 | 24 | # Set the readout head to train 25 | head = 0 if single_head else task_id 26 | bsize = x_train.shape[0] if (batch_size is None) else batch_size 27 | 28 | # Train network with maximum likelihood to initialize first model 29 | if task_id == 0: 30 | print_graph_bol = False #set to True if you want to see the graph 31 | ml_model = Vanilla_NN(in_dim, hidden_size, out_dim, x_train.shape[0]) 32 | ml_model.train(x_train, y_train, task_id, no_epochs, bsize) 33 | mf_weights = ml_model.get_weights() 34 | mf_model = MFVI_NN(in_dim, hidden_size, out_dim, x_train.shape[0], single_head = single_head, prev_means=mf_weights) 35 | 36 | if not gan_bol: 37 | if coreset_size > 0: 38 | x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size) 39 | gans = None 40 | if print_graph_bol: 41 | #Just if you want to see the computational graph 42 | output_tensor = mf_model._KL_term() #mf_model.get_loss(torch.Tensor(x_train).to(device), torch.Tensor(y_train).to(device), task_id), params=params) 43 | print_graph(mf_model, output_tensor) 44 | print_graph_bol = False 45 | 46 | if gan_bol: 47 | gan_i = GAN.VGR(task_id) 48 | gan_i.train(x_train, y_train) 49 | gans.append(gan_i) 50 | mf_model.train(x_train, y_train, head, no_epochs, bsize) 51 | 52 | mf_model.update_prior() 53 | # Save weights before test (and last-minute training on coreset 54 | mf_model.save_weights() 55 | 56 | acc = test.get_scores(mf_model, x_testsets, y_testsets, no_epochs, single_head, x_coresets, y_coresets, batch_size, False,gans) 57 | all_acc = test.concatenate_results(acc, all_acc) 58 | 59 | mf_model.load_weights() 60 | mf_model.clean_copy_weights() 61 | 62 | 63 | if not single_head: 64 | mf_model.create_head() 65 | 66 | return all_acc 67 | 68 | def run_coreset_only(hidden_size, no_epochs, data_gen, coreset_method, coreset_size=0, batch_size=None, single_head=True): 69 | in_dim, out_dim = data_gen.get_dims() 70 | x_coresets, y_coresets = [], [] 71 | x_testsets, y_testsets = [], [] 72 | all_acc = np.array([]) 73 | 74 | for task_id in range(data_gen.max_iter): 75 | x_train, y_train, x_test, y_test = data_gen.next_task() 76 | x_testsets.append(x_test) 77 | y_testsets.append(y_test) 78 | 79 | head = 0 if single_head else task_id 80 | bsize = x_train.shape[0] if (batch_size is None) else batch_size 81 | 82 | if task_id == 0: 83 | mf_model = MFVI_NN(in_dim, hidden_size, out_dim, x_train.shape[0], single_head = single_head, prev_means=None) 84 | 85 | if coreset_size > 0: 86 | x_coresets, y_coresets, x_train, y_train = coreset_method(x_coresets, y_coresets, x_train, y_train, coreset_size) 87 | 88 | 89 | mf_model.save_weights() 90 | 91 | acc = test.get_scores(mf_model, x_testsets, y_testsets, no_epochs, single_head, x_coresets, y_coresets, batch_size, just_vanilla =False) 92 | 93 | all_acc = test.concatenate_results(acc, all_acc) 94 | 95 | mf_model.load_weights() 96 | mf_model.clean_copy_weights() 97 | 98 | if not single_head: 99 | mf_model.create_head() 100 | 101 | return all_acc 102 | 103 | def print_graph(model, output): 104 | params = dict() 105 | for i in range(len(model.W_m)): 106 | params["W_m{}".format(i)] = model.W_m[i] 107 | params["W_v{}".format(i)] = model.W_v[i] 108 | params["b_m{}".format(i)] = model.b_m[i] 109 | params["b_v{}".format(i)] = model.b_v[i] 110 | params["prior_W_m".format(i)] = model.prior_W_m[i] 111 | params["prior_W_v".format(i)] = model.prior_W_v[i] 112 | params["prior_b_m".format(i)] = model.prior_b_m[i] 113 | params["prior_b_v".format(i)] = model.prior_b_v[i] 114 | 115 | for i in range(len(model.W_last_m)): 116 | params["W_last_m".format(i)] = model.W_last_m[i] 117 | params["W_last_v".format(i)] = model.W_last_v[i] 118 | params["b_last_m".format(i)] = model.b_last_m[i] 119 | params["b_last_v".format(i)] = model.b_last_v[i] 120 | params["prior_W_last_m".format(i)] = model.prior_W_last_m[i] 121 | params["prior_W_last_v".format(i)] = model.prior_W_last_v[i] 122 | params["prior_b_last_m".format(i)] = model.prior_b_last_m[i] 123 | params["prior_b_last_v".format(i)] = model.prior_b_last_v[i] 124 | dot = make_dot(output, params=params) 125 | dot.view() 126 | 127 | return -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | tensorflow==1.4.0 4 | matplotlib==1.5.3 5 | 6 | --------------------------------------------------------------------------------