├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── conf └── data_paths.yaml ├── datasets └── README.md ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── basic_model.cpython-36.pyc │ ├── basic_model.cpython-37.pyc │ ├── modified_linear.cpython-36.pyc │ └── modified_linear.cpython-37.pyc ├── basic_model.py └── modified_linear.py ├── requirements.txt ├── runner.py ├── runner.sh ├── train ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── architecture_update.cpython-36.pyc │ ├── architecture_update.cpython-37.pyc │ ├── compute_cosine_features.cpython-36.pyc │ ├── compute_cosine_features.cpython-37.pyc │ ├── customized_distill_trainer.cpython-36.pyc │ ├── customized_distill_trainer.cpython-37.pyc │ ├── customized_two.cpython-36.pyc │ ├── ewc.cpython-36.pyc │ ├── ewc.cpython-37.pyc │ ├── exemplar_selection.cpython-36.pyc │ ├── exemplar_selection.cpython-37.pyc │ ├── fisher_matrix.cpython-36.pyc │ ├── gem_solver.cpython-36.pyc │ ├── gem_solver.cpython-37.pyc │ ├── prediction_analyzer.cpython-36.pyc │ ├── prediction_analyzer.cpython-37.pyc │ ├── result_analyser.cpython-36.pyc │ ├── result_analyser.cpython-37.pyc │ ├── trainer.cpython-36.pyc │ └── trainer.cpython-37.pyc ├── architecture_update.py ├── compute_cosine_features.py ├── customized_distill_trainer.py ├── customized_two.py ├── ewc.py ├── exemplar_selection.py ├── exemplar_strategies │ ├── __pycache__ │ │ ├── class_boundary.cpython-36.pyc │ │ ├── class_boundary.cpython-37.pyc │ │ ├── fwsr.cpython-36.pyc │ │ ├── fwsr.cpython-37.pyc │ │ ├── herding.cpython-36.pyc │ │ ├── herding.cpython-37.pyc │ │ ├── kmeans.cpython-36.pyc │ │ └── kmeans.cpython-37.pyc │ ├── class_boundary.py │ ├── fwsr.py │ ├── herding.py │ └── kmeans.py ├── fisher_matrix.py ├── gem_solver.py ├── losses │ ├── BinBranchLoss.py │ ├── BinDevianceLoss.py │ ├── MultiClassCrossEntropy.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── BinBranchLoss.cpython-36.pyc │ │ ├── BinBranchLoss.cpython-37.pyc │ │ ├── BinDevianceLoss.cpython-36.pyc │ │ ├── BinDevianceLoss.cpython-37.pyc │ │ ├── MultiClassCrossEntropy.cpython-36.pyc │ │ ├── MultiClassCrossEntropy.cpython-37.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── amsoftmax.cpython-36.pyc │ │ ├── angular.cpython-36.pyc │ │ ├── angular.cpython-37.pyc │ │ ├── class_balanced_loss.cpython-36.pyc │ │ ├── class_balanced_loss.cpython-37.pyc │ │ ├── margin_ranking_loss.cpython-36.pyc │ │ ├── margin_ranking_loss.cpython-37.pyc │ │ ├── msloss.cpython-36.pyc │ │ ├── msloss.cpython-37.pyc │ │ ├── prob_diff_loss.cpython-36.pyc │ │ ├── prob_diff_loss.cpython-37.pyc │ │ └── triplet_loss.cpython-36.pyc │ ├── amsoftmax.py │ ├── angular.py │ ├── class_balanced_loss.py │ ├── margin_ranking_loss.py │ ├── msloss.py │ ├── prob_diff_loss.py │ └── triplet_loss.py ├── prediction_analyzer.py ├── record_test_scores.py ├── result_analyser.py ├── trainer.py └── visualisations │ ├── __pycache__ │ ├── exemplar_visualizer.cpython-36.pyc │ ├── exemplar_visualizer.cpython-37.pyc │ ├── stability_visualizer.cpython-36.pyc │ ├── stability_visualizer.cpython-37.pyc │ ├── training_visualizer.cpython-36.pyc │ ├── training_visualizer.cpython-37.pyc │ ├── vis_by_person.cpython-36.pyc │ └── vis_by_person.cpython-37.pyc │ ├── exemplar_visualizer.py │ ├── stability_visualizer.py │ ├── training_visualizer.py │ └── vis_by_person.py └── utils ├── __pycache__ ├── data_handler.cpython-36.pyc └── data_handler.cpython-37.pyc ├── data_handler.py ├── img └── incremental_learning.png └── join_pickles.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | sauravonn@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2018 Saurav Jha 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Continual Learning Benchmark [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | 3 | This repo contains the code for reproducing the results of the following papers (done as part of my Master's thesis at St Andrews): 4 | 5 | 1. [Benchmarking Continual Learning in Sensor-based Human Activity Recognition: an Empirical Analysis](http://arxiv.org/abs/2104.09396) [Accepted in the _Information Sciences_ (April 2021)] 6 | 2. [Continual Learning in Human Activity Recognition (HAR): An Emperical Analysis of Regularization](https://research-repository.st-andrews.ac.uk/handle/10023/20242) [ICML workshop on Continual Learning (July 2020)] 7 | 8 | ![Incremental learning](https://github.com/srvCodes/continual-learning-benchmark/blob/master/utils/img/incremental_learning.png) 9 | 10 | A sub-total of 11 recent continual learning techniques have been implemented on a component-wise basis: 11 | 12 | 1. Maintaining Discrimination and Fairness in Class Incremental Learning (WA-MDF) [[Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Zhao_Maintaining_Discrimination_and_Fairness_in_Class_Incremental_Learning_CVPR_2020_paper.pdf)] 13 | 2. Adjusting Decision Boundary for Class Imbalanced Learning (WA-ADB) [[Paper](https://ieeexplore.ieee.org/document/9081988)] 14 | 3. Large Scale Incremental Learning (BiC) [[Paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wu_Large_Scale_Incremental_Learning_CVPR_2019_paper.pdf)] 15 | 4. Learning a Unified Classifier Incrementally via Rebalancing (LUCIR) [[Paper](http://dahualin.org/publications/dhl19_increclass.pdf)] 16 | 5. Incremental Learning in Online Scenario (ILOS) [[Paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/He_Incremental_Learning_in_Online_Scenario_CVPR_2020_paper.pdf)] 17 | 6. Gradient Episodic Memory for Continual Learning (GEM) [[Paper](https://papers.nips.cc/paper/7225-gradient-episodic-memory-for-continual-learning.pdf)] 18 | 7. Efficient Lifelong Learning with A-GEM [[Paper](https://openreview.net/forum?id=Hkf2_sC5FX)] 19 | 8. Elastic Weight Consolidation (EWC) [[Paper](https://arxiv.org/pdf/1612.00796.pdf)] 20 | 9. Rotated Elastic Weight Consolidation (R-EWC) [[Paper](https://arxiv.org/abs/1802.02950)] 21 | 10. Learning without Forgetting (LwF) [[Paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8107520)] 22 | 11. Memory Aware Synapses (MAS) [[Paper](https://link.springer.com/chapter/10.1007/978-3-030-01219-9_9)] 23 | 24 | Additionally, the following six exemplar-selection techniques are available (for memory-rehearsal): 25 | 26 | 1. Herding from ICaRL [[Paper](https://openaccess.thecvf.com/content_cvpr_2017/papers/Rebuffi_iCaRL_Incremental_Classifier_CVPR_2017_paper.pdf)] 27 | 2. Frank-Wolfe Sparse Regression (FWSR) [[Paper](https://arxiv.org/abs/1811.02702)] 28 | 3. K-means sampling 29 | 4. DPP sampling 30 | 5. Boundary-based sampling [[Paper](https://ieeexplore.ieee.org/document/8986833)] 31 | 6. Sensitivity-based sampling [[Paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8949290)] 32 | 33 | ## Running the code 34 | 35 | For training, please execute the `runner.sh` script that creates all the directories required for logging the outputs. One can add further similar commands for running further experiments. 36 | 37 | For instance, training on *ARUBA* dataset with FWSR-styled exemplar selection: 38 | 39 | ```python 40 | >>> python runner.py --dataset 'aruba' --total_classes 11 --base_classes 2 --new_classes 2 --epochs 160 --method 'kd_kldiv_wa1' --exemplar 'fwsr' # e.g. for FWSR-styled exemplar selection 41 | 42 | ``` 43 | 44 | ## Proposed Forgetting Score 45 | 46 | The existing forgetting measure metric [1] suffers from self-relativeness, i.e., the forgetting score will remain low throughout the training if the model did not learn much information about the class at the beginning. Class-imbalance scenarios (as in our case) further amplify its ramifications [2]. Code for our correction to the forgetting score can be found [here](https://github.com/srvCodes/continual-learning-benchmark/blob/master/train/result_analyser.py#L211). 47 | 48 | ## Datasets 49 | 50 | The experiments were performed on 8 publicly available HAR datasets. These can downloaded from the drive link in `datasets/`. 51 | 52 | ## Experimental protocol 53 | 54 | The experiments for each dataset and for each train set / exemplar size were performed on 30 random sequences of tasks. The logs in `output_reports/[dataname]` (created after executing the bash script) contain the performances of each individual task sequence as the incremental learning progresses. The final accuracy is then reported as the average over the 30 runs (see instructions below for evaluation). 55 | 56 | ## Evaluating the logs 57 | 58 | For evaluation, please uncomment the lines per the instructions in `runner.py`. This can be used to measure forgetting scores [2], base-new-old accuracies, and average report by holdout sizes. 59 | 60 | ## Combination of techniques 61 | 62 | The component-wise implementation of techniques nevertheless helps in playing with two or more techniques. This can be done by tweaking the `--method` argument. The table below details some of these combinations: 63 | 64 | Technique | Argument for `--method` 65 | ------------ | ------------- 66 | Knowledge distillation with margin ranking loss (KD_MR) | kd_kldiv_mr 67 | KD_MR with WA-MDF | kd_kldiv_mr_wa1 68 | KD_MR with WA-ADB | kd_kldiv_mr_wa2 69 | KD_MR with less forget constraint loss (KD_LFC_MR) | kd_kldiv_lfc_mr 70 | KD_LFC_MR with WA-MDF | kd_kldiv_lfc_mr_wa1 71 | KD_LFC_MR with WA-ADB | kd_kldiv_lfc_mr_wa2 72 | Cosine normalisation with knowledge distillation | cn_kd_kldiv 73 | 74 | Furthermore, the logits replacement tweak of ILOS and weight initialisation from LUCIR can be used with either of the above methods by simply setting the following arguments: 75 | 76 | Technique | Argument 77 | ------------ | ---------------- 78 | ILOS (with either of above) | `--replace_new_logits = True` 79 | LUCIR-styled weight initialisation (with either of above) | `--wt_init = True` 80 | 81 | Please feel free to play around with these. We would be interested in knowing if the combinations deliver better results for you! 82 | 83 | ## Notes on incremental classes 84 | 85 | - All the experiments in our papers used number of base classes and incremental classes as 2. For replicating this, set `--base_classes = 2` and `--new_classes = 2`. 86 | 87 | - For offline learning (_i.e._, without incremental training), set `--base_classes` to the total number of classes in the dataset and `--new_classes = 0`. 88 | 89 | - For experiments with permuted datasets, set `--base_classes = --new_classes` where `--base_classes` = the total number of classes in the dataset. 90 | 91 | ## Verification 92 | 93 | The implementations have been verified through runs on Split-MNIST and Permumted-MNIST - also available for download in `datasets/`. 94 | 95 | 96 | ## Acknowledgement 97 | 98 | Special thanks to [sairin1202](https://github.com/sairin1202)'s implementation of [BiC](https://github.com/sairin1202/BIC) and [Electronic Tomato](https://github.com/ElectronicTomato)'s implementation of [GEM/AGEM/EWC/MAS](https://github.com/ElectronicTomato/continue_leanrning_agem/tree/master/agents). 99 | 100 | ## References 101 | 102 | [1] Chaudhry, A., Dokania, P.K., Ajanthan, T., & Torr, P.H. (2018). Riemannian Walk for Incremental Learning: Understanding Forgetting and Intransigence. ECCV. 103 | 104 | [2] Kim, C. D., Jeong, J., & Kim, G. (2020). Imbalanced continual learning with partitioning reservoir sampling. ECCV. 105 | 106 | 107 | 108 | ## Cite 109 | 110 | If you found this repo useful in your work, please feel free to cite us: 111 | 112 | ```bibtex 113 | @article{jha2021continual, 114 | title={Continual Learning in Sensor-based Human Activity Recognition: an Empirical Benchmark Analysis}, 115 | author={Jha, Saurav and Schiemer, Martin and Zambonelli, Franco and Ye, Juan}, 116 | journal={Information Sciences}, 117 | year={2021}, 118 | publisher={Elsevier} 119 | } 120 | ``` 121 | 122 | ```bibtex 123 | @article{jha2020continual, 124 | title={Continual learning in human activity recognition: an empirical analysis of regularization}, 125 | author={Jha, Saurav and Schiemer, Martin and Ye, Juan}, 126 | journal={Proceedings of Machine Learning Research}, 127 | year={2020} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /conf/data_paths.yaml: -------------------------------------------------------------------------------- 1 | large_data: 2 | milan: datasets/milan.xlsx 3 | twor: datasets/twor.xlsx 4 | aruba: datasets/aruba(1).xlsx 5 | medium_data: 6 | dsads: datasets/Dataset_PerCom18_STL/ 7 | pamap: datasets/Dataset_PerCom18_STL/ 8 | opp: datasets/Dataset_PerCom18_STL/ 9 | hapt: datasets/hapt_data/ 10 | small_data: 11 | hatn6: datasets/hatn6.csv 12 | ws: datasets/ws.csv 13 | cifar100: 14 | train: datasets/cifar-100-python/train 15 | test: datasets/cifar-100-python/test 16 | meta: datasets/cifar-100-python/meta 17 | mnist: 18 | train: datasets/MNIST/mnist_train.csv 19 | test: datasets/MNIST/mnist_test.csv 20 | 21 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | Please download the datasets from this [google drive link](https://drive.google.com/file/d/1e2f0DsZpf-brsjwrvY99TrutN-aOl2ym/view?usp=sharing) and place them in this directory. 2 | 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__pycache__/basic_model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/basic_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__pycache__/basic_model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/modified_linear.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__pycache__/modified_linear.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/modified_linear.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/models/__pycache__/modified_linear.cpython-37.pyc -------------------------------------------------------------------------------- /models/basic_model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import LongTensor 6 | from torch import from_numpy, ones, zeros 7 | from torch.utils import data 8 | from . import modified_linear 9 | 10 | PATH_TO_SAVE_WEIGHTS = 'saved_weights/' 11 | 12 | def get_layer_dims(dataname): 13 | res_ = [1,2,2,4] if dataname in ['dsads'] else [1,2,4] if dataname in ['opp'] else [0.5, 1, 2] \ 14 | if dataname in ['hapt', 'milan', 'pamap', 'aruba'] else [500, 500] if dataname in ['cifar100'] else [100, 100, 100] \ 15 | if dataname in ['mnist', 'permuted_mnist'] else [1,2,2] 16 | return res_ 17 | 18 | class Net(nn.Module): 19 | def __init__(self, input_dim, n_classes, dataname, lwf=False, cosine_liner=False): 20 | super(Net, self).__init__() 21 | self.dataname = dataname 22 | layer_nums = get_layer_dims(self.dataname) 23 | self.layer_sizes = layer_nums if self.dataname in ['cifar100', 'mnist'] else\ 24 | [int(input_dim / num) for num in layer_nums] 25 | self.fc0 = nn.Linear(input_dim, self.layer_sizes[0]) 26 | if len(self.layer_sizes) == 2: 27 | self.fc_penultimate = nn.Linear(self.layer_sizes[0], self.layer_sizes[1]) 28 | elif len(self.layer_sizes) == 3: 29 | self.fc1 = nn.Linear(self.layer_sizes[0], self.layer_sizes[1]) 30 | self.fc_penultimate = nn.Linear(self.layer_sizes[1], self.layer_sizes[2]) 31 | elif (len(self.layer_sizes) == 4): 32 | self.fc1 = nn.Linear(self.layer_sizes[0], self.layer_sizes[1]) 33 | self.fc2 = nn.Linear(self.layer_sizes[1], self.layer_sizes[2]) 34 | self.fc_penultimate = nn.Linear(self.layer_sizes[2], self.layer_sizes[3]) 35 | final_dim = self.fc_penultimate.out_features 36 | self.fc = modified_linear.CosineLinear(final_dim, n_classes) if cosine_liner \ 37 | else nn.Linear(final_dim, n_classes, bias=lwf==False) # no biases for LwF 38 | 39 | def forward(self, x): 40 | x = F.relu(self.fc0(x)) 41 | if len(self.layer_sizes) > 2: 42 | x = F.relu(self.fc1(x)) 43 | if len(self.layer_sizes) > 3: 44 | x = F.relu(self.fc2(x)) 45 | x = F.relu(self.fc_penultimate(x)) 46 | x = x.view(x.size(0), -1) 47 | x = self.fc(x) 48 | return x 49 | 50 | 51 | class Dataset(data.Dataset): 52 | def __init__(self, features, labels): 53 | self.labels = labels 54 | self.features = features 55 | 56 | def __len__(self): 57 | return len(self.features) 58 | 59 | def __getitem__(self, idx): 60 | X = from_numpy(self.features[idx]) 61 | y = self.labels[idx] 62 | y = LongTensor([y]) 63 | return X, y 64 | 65 | def get_sample(self, sample_size): 66 | return random.sample(self.features, sample_size) 67 | 68 | 69 | class BiasLayer(nn.Module): 70 | def __init__(self, device): 71 | super(BiasLayer, self).__init__() 72 | self.beta = nn.Parameter(ones(1, requires_grad=True, device=device)) 73 | self.gamma = nn.Parameter(zeros(1, requires_grad=True, device=device)) 74 | 75 | def forward(self, x): 76 | return self.beta * x + self.gamma 77 | 78 | def printParam(self, i): 79 | print(i, self.beta.item(), self.gamma.item()) 80 | 81 | def get_beta(self): 82 | return self.beta 83 | 84 | def get_gamma(self): 85 | return self.gamma 86 | 87 | def set_beta(self, new_beta): 88 | self.beta = new_beta 89 | 90 | def set_gamma(self, new_gamma): 91 | self.gamma = new_gamma 92 | 93 | def set_grad(self, bool_value): 94 | self.beta.requires_grad = bool_value 95 | self.gamma.requires_grad = bool_value 96 | -------------------------------------------------------------------------------- /models/modified_linear.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | from torch.nn import functional as F 6 | from torch.nn import Module 7 | 8 | class CosineLinear(Module): 9 | def __init__(self, in_features, out_features, sigma=True): 10 | super(CosineLinear, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 14 | if sigma: 15 | self.sigma = Parameter(torch.Tensor(1)) 16 | else: 17 | self.register_parameter('sigma', None) 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | stdv = 1. / math.sqrt(self.weight.size(1)) 22 | self.weight.data.uniform_(-stdv, stdv) 23 | if self.sigma is not None: 24 | self.sigma.data.fill_(1) #for initializaiton of sigma 25 | 26 | def forward(self, input): 27 | out = F.linear(F.normalize(input, p=2,dim=1), \ 28 | F.normalize(self.weight, p=2, dim=1)) # experiment with adding bias? 29 | if self.sigma is not None: 30 | out = self.sigma * out 31 | return out 32 | 33 | class SplitCosineLinear(Module): 34 | # consists of two fc layers and concatenate their outputs 35 | def __init__(self, in_features, out_features1, out_features2, sigma=True): 36 | super(SplitCosineLinear, self).__init__() 37 | self.in_features = in_features 38 | self.out_features = out_features1 + out_features2 39 | self.fc1 = CosineLinear(in_features, out_features1, False) 40 | self.fc2 = CosineLinear(in_features, out_features2, False) 41 | if sigma: 42 | self.sigma = Parameter(torch.Tensor(1)) 43 | self.sigma.data.fill_(1) 44 | else: 45 | self.register_parameter('sigma', None) 46 | 47 | def forward(self, x): 48 | out1 = self.fc1(x) 49 | out2 = self.fc2(x) 50 | out = torch.cat((out1, out2), dim=1) # concatenate along the channel 51 | if self.sigma is not None: 52 | out = self.sigma * out 53 | return out -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _tflow_select=2.1.0=gpu 6 | absl-py=0.7.1=pypi_0 7 | alabaster=0.7.12=pypi_0 8 | argparse=1.4.0=pypi_0 9 | asn1crypto=1.3.0=py36_0 10 | astor=0.8.0=pypi_0 11 | attrs=19.1.0=pypi_0 12 | babel=2.8.0=pypi_0 13 | backcall=0.1.0=pypi_0 14 | blas=1.0=mkl 15 | bleach=1.5.0=pypi_0 16 | boto=2.49.0=pypi_0 17 | boto3=1.12.11=pypi_0 18 | botocore=1.15.11=pypi_0 19 | bpemb=0.3.0=pypi_0 20 | bzip2=1.0.8=h7b6447c_0 21 | c-ares=1.15.0=h7b6447c_1001 22 | ca-certificates=2020.6.24=0 23 | cachetools=4.0.0=pypi_0 24 | cairo=1.14.12=h8948797_3 25 | certifi=2020.6.20=py36_0 26 | cffi=1.14.0=py36h2e261b9_0 27 | chardet=3.0.4=py36_1003 28 | click=7.0=pypi_0 29 | cloudpickle=1.3.0=py_0 30 | cryptography=2.8=py36h1ba5d50_0 31 | cuda101=1.0=h589bae5_0 32 | cudatoolkit=10.0.130=0 33 | cudnn=7.6.5=cuda10.0_0 34 | cupti=10.0.130=0 35 | cvxopt=1.2.1=pypi_0 36 | cycler=0.10.0=py36_0 37 | cymem=1.31.2=py36h6bb024c_0 38 | cython=0.29.13=pypi_0 39 | cytoolz=0.9.0.1=py36h14c3975_1 40 | dask-core=2.14.0=py_0 41 | dbus=1.13.12=h746ee38_0 42 | decorator=4.4.2=py_0 43 | defusedxml=0.6.0=pypi_0 44 | deprecated=1.2.7=pypi_0 45 | dialogflow=0.7.2=pypi_0 46 | dill=0.2.9=py36_0 47 | distro=1.4.0=pypi_0 48 | docutils=0.15.2=pypi_0 49 | dppy=0.3.0=pypi_0 50 | en-core-web-sm=2.0.0=pypi_0 51 | entrypoints=0.3=pypi_0 52 | et_xmlfile=1.0.1=py36_0 53 | expat=2.2.6=he6710b0_0 54 | ffmpeg=4.0=hcdf2ecd_0 55 | filelock=3.0.12=pypi_0 56 | flair=0.4.3=pypi_0 57 | fontconfig=2.13.0=h9420a91_0 58 | freeglut=3.0.0=hf484d3e_5 59 | freetype=2.9.1=h8a8886c_1 60 | future=0.18.2=pypi_0 61 | gast=0.2.2=pypi_0 62 | gensim=3.8.1=pypi_0 63 | glib=2.63.1=h5a9c865_0 64 | google-api-core=1.16.0=pypi_0 65 | google-auth=1.11.2=pypi_0 66 | google-pasta=0.1.7=pypi_0 67 | googleapis-common-protos=1.51.0=pypi_0 68 | graphite2=1.3.13=h23475e2_0 69 | grpcio=1.21.1=pypi_0 70 | gst-plugins-base=1.14.0=hbbd80ab_1 71 | gstreamer=1.14.0=hb453b48_1 72 | h5py=2.9.0=pypi_0 73 | harfbuzz=1.8.8=hffaf4a1_0 74 | hdf5=1.10.2=hba1933b_1 75 | html5lib=0.9999999=pypi_0 76 | hyperopt=0.2.3=pypi_0 77 | icu=58.2=h9c2bf20_1 78 | idna=2.8=py36_0 79 | imageio=2.8.0=py_0 80 | imagesize=1.2.0=pypi_0 81 | imbalanced-learn=0.6.2=py_0 82 | importlib-metadata=1.5.0=pypi_0 83 | intel-openmp=2019.4=243 84 | ipykernel=5.1.1=pypi_0 85 | ipython=7.6.1=pypi_0 86 | ipython-genutils=0.2.0=pypi_0 87 | ipywidgets=7.5.0=pypi_0 88 | jasper=2.0.14=h07fcdf6_1 89 | jdcal=1.4.1=py_0 90 | jedi=0.14.1=pypi_0 91 | jinja2=2.10.1=pypi_0 92 | jmespath=0.9.5=pypi_0 93 | joblib=0.14.1=py_0 94 | jpeg=9b=h024ee3a_2 95 | jsonschema=3.0.1=pypi_0 96 | jupyter=1.0.0=pypi_0 97 | jupyter-client=5.3.1=pypi_0 98 | jupyter-console=6.0.0=pypi_0 99 | jupyter-core=4.5.0=pypi_0 100 | keras=2.2.5=pypi_0 101 | keras-applications=1.0.8=py_0 102 | keras-preprocessing=1.1.0=py_1 103 | keyboard=0.13.4=pypi_0 104 | kiwisolver=1.1.0=py36he6710b0_0 105 | knn-cuda=0.2=pypi_0 106 | langdetect=1.0.7=pypi_0 107 | latexcodec=2.0.0=pypi_0 108 | libedit=3.1.20181209=hc058e9b_0 109 | libffi=3.2.1=hd88cf55_4 110 | libgcc-ng=9.1.0=hdf63c60_0 111 | libgfortran-ng=7.3.0=hdf63c60_0 112 | libglu=9.0.0=hf484d3e_1 113 | libopenblas=0.2.20=h9ac9557_7 114 | libopencv=3.4.2=hb342d67_1 115 | libopus=1.3.1=h7b6447c_0 116 | libpng=1.6.37=hbc83047_0 117 | libprotobuf=3.11.4=hd408876_0 118 | libstdcxx-ng=9.1.0=hdf63c60_0 119 | libtiff=4.1.0=h2733197_0 120 | libuuid=1.0.3=h1bed415_2 121 | libvpx=1.7.0=h439df22_0 122 | libxcb=1.13=h1bed415_1 123 | libxml2=2.9.9=hea5a465_1 124 | markdown=3.1.1=py36_0 125 | markupsafe=1.1.1=pypi_0 126 | matplotlib=3.1.0=py36h5429711_0 127 | mistune=0.8.4=pypi_0 128 | mkl=2019.4=243 129 | mkl-service=2.3.0=py36he904b0f_0 130 | mkl_fft=1.0.15=py36ha843d7b_0 131 | mkl_random=1.1.0=py36hd6b4f25_0 132 | mock=4.0.1=py_1 133 | more-itertools=8.2.0=pypi_0 134 | mpld3=0.3=pypi_0 135 | msgpack-numpy=0.4.4.3=py_0 136 | msgpack-python=0.5.6=py36h6bb024c_1 137 | murmurhash=0.28.0=py36hf484d3e_0 138 | nbconvert=5.5.0=pypi_0 139 | nbformat=4.4.0=pypi_0 140 | ncurses=6.2=he6710b0_0 141 | networkx=2.2=pypi_0 142 | ninja=1.9.0=py36hfd86e86_0 143 | nltk=3.4.5=pypi_0 144 | notebook=6.0.0=pypi_0 145 | numpy=1.17.0=pypi_0 146 | numpy-base=1.18.1=py36hde5b4d6_1 147 | olefile=0.46=py36_0 148 | openblas=0.2.20=4 149 | openblas-devel=0.2.20=7 150 | opencv=3.4.2=py36h6fd60c2_1 151 | openpyxl=3.0.1=py_0 152 | openssl=1.1.1g=h7b6447c_0 153 | opt-einsum=3.0.1=pypi_0 154 | oset=0.1.3=pypi_0 155 | packaging=20.1=pypi_0 156 | pandas=0.24.2=py36he6710b0_0 157 | pandocfilters=1.4.2=pypi_0 158 | parso=0.5.1=pypi_0 159 | pcre=8.43=he6710b0_0 160 | pexpect=4.7.0=pypi_0 161 | pickleshare=0.7.5=pypi_0 162 | pillow=7.0.0=py36hb39fc2d_0 163 | pip=19.1.1=py36_0 164 | pixman=0.38.0=h7b6447c_0 165 | plac=0.9.6=py36_0 166 | pluggy=0.13.1=pypi_0 167 | preshed=1.0.1=py36he6710b0_0 168 | prometheus-client=0.7.1=pypi_0 169 | prompt-toolkit=2.0.9=pypi_0 170 | protobuf=3.8.0=pypi_0 171 | ptyprocess=0.6.0=pypi_0 172 | py=1.8.1=pypi_0 173 | py-opencv=3.4.2=py36hb342d67_1 174 | pyasn1=0.4.8=pypi_0 175 | pyasn1-modules=0.2.8=pypi_0 176 | pybtex=0.22.2=pypi_0 177 | pybtex-docutils=0.2.2=pypi_0 178 | pycparser=2.20=py_0 179 | pydpp=0.2.1=pypi_0 180 | pygments=2.4.2=pypi_0 181 | pyopenssl=19.1.0=py36_0 182 | pyparsing=2.4.6=py_0 183 | pyqt=5.9.2=py36h05f1152_2 184 | pyrsistent=0.15.3=pypi_0 185 | pysocks=1.7.1=py36_0 186 | pytest=5.3.5=pypi_0 187 | python=3.6.8=h0371630_0 188 | python-dateutil=2.8.1=py_0 189 | pytorch=1.4.0=py3.6_cuda10.0.130_cudnn7.6.3_0 190 | pytorch-metric-learning=0.9.81=pypi_0 191 | pytorch-transformers=1.2.0=pypi_0 192 | pytz=2019.3=py_0 193 | pywavelets=1.1.1=py36h7b6447c_0 194 | pyyaml=5.1.1=pypi_0 195 | pyzmq=18.0.2=pypi_0 196 | qt=5.9.7=h5867ecd_1 197 | qtconsole=4.5.2=pypi_0 198 | quadprog=0.1.6=py36_0 199 | readline=7.0=h7b6447c_5 200 | regex=2018.07.11=py36h14c3975_0 201 | requests=2.22.0=py36_0 202 | rsa=4.0=pypi_0 203 | s3transfer=0.3.3=pypi_0 204 | sacremoses=0.0.38=pypi_0 205 | scikit-image=0.15.0=py36he6710b0_0 206 | scikit-learn=0.22.1=py36hd81dba3_0 207 | scipy=1.3.0=pypi_0 208 | seaborn=0.10.0=py_0 209 | segtok=1.5.7=pypi_0 210 | send2trash=1.5.0=pypi_0 211 | sentencepiece=0.1.85=pypi_0 212 | sentiment-classifier=0.7=pypi_0 213 | setuptools=41.2.0=pypi_0 214 | sip=4.19.8=py36hf484d3e_0 215 | six=1.12.0=pypi_0 216 | sklearn=0.0=pypi_0 217 | smart-open=1.9.0=pypi_0 218 | snowballstemmer=2.0.0=pypi_0 219 | spacy=2.0.12=py36h962f231_0 220 | sphinx=2.4.2=pypi_0 221 | sphinx-rtd-theme=0.4.3=pypi_0 222 | sphinxcontrib-applehelp=1.0.1=pypi_0 223 | sphinxcontrib-bibtex=1.0.0=pypi_0 224 | sphinxcontrib-devhelp=1.0.1=pypi_0 225 | sphinxcontrib-htmlhelp=1.0.2=pypi_0 226 | sphinxcontrib-jsmath=1.0.1=pypi_0 227 | sphinxcontrib-qthelp=1.0.2=pypi_0 228 | sphinxcontrib-serializinghtml=1.1.3=pypi_0 229 | sqlite=3.31.1=h7b6447c_0 230 | sqlitedict=1.6.0=pypi_0 231 | tabula-py=1.4.3=pypi_0 232 | tabulate=0.8.6=pypi_0 233 | tb-nightly=1.14.0a20190301=pypi_0 234 | tensorboard=1.8.0=pypi_0 235 | tensorflow=1.13.1=gpu_py36h3991807_0 236 | tensorflow-base=1.13.1=gpu_py36h8d69cac_0 237 | tensorflow-estimator=1.13.0=py_0 238 | tensorflow-gpu=1.13.1=h0d30ee6_0 239 | termcolor=1.1.0=pypi_0 240 | terminado=0.8.2=pypi_0 241 | testpath=0.4.2=pypi_0 242 | tf-estimator-nightly=1.14.0.dev2019080601=pypi_0 243 | thinc=6.10.3=py36h962f231_0 244 | tk=8.6.8=hbc83047_0 245 | tokenizers=0.5.2=pypi_0 246 | toolz=0.10.0=py_0 247 | torchvision=0.5.0=py36_cu100 248 | tornado=6.0.4=py36h7b6447c_1 249 | tqdm=4.33.0=py_0 250 | traitlets=4.3.2=pypi_0 251 | transformers=2.5.1=pypi_0 252 | tsnecuda=2.1.0=pypi_0 253 | ujson=2.0.3=py36he6710b0_0 254 | urllib3=1.24.3=py36_0 255 | wcwidth=0.1.7=pypi_0 256 | werkzeug=0.15.4=pypi_0 257 | wheel=0.33.4=pypi_0 258 | widgetsnbextension=3.5.0=pypi_0 259 | wrapt=1.11.2=pypi_0 260 | xlrd=1.2.0=py36_0 261 | xz=5.2.5=h7b6447c_0 262 | zipp=3.0.0=pypi_0 263 | zlib=1.2.11=h7b6447c_3 264 | zstd=1.3.7=h0b5b093_0 265 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Saurav Jha 3 | MSc, Advanced Systems Dependability 4 | University of St Andrews 5 | """ 6 | 7 | import argparse 8 | import time 9 | 10 | 11 | from train import trainer, result_analyser 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--dataset', default='dsads', type=str, 15 | help='Possible values: dsads, pamap, opp, hapt, ws, hatn6, milan,' 16 | ' aruba, and twor') 17 | parser.add_argument('--total_classes', default=19, type=int, 18 | help='12 for pamap, 12 for hapt, 19 for dsads, 15 for milan, ' 19 | '11 for aruba, 23 for twor, 9 for ws, 7 for hatn6, 6 for opp') 20 | parser.add_argument('--new_classes', default=2, type=int, help='number of new classes per incremental batch') 21 | parser.add_argument('--base_classes', default=2, type=int, help='number of classes in first batch') 22 | parser.add_argument('--epochs', default=200, type=int, help='number of training epochs: 200 for dsads/pamap') 23 | parser.add_argument('--T', default=2, type=float, help='temperature value for distillation loss') 24 | parser.add_argument('--average_over', default='holdout', type=str, 25 | help="whether to average over different holdout sizes: " 26 | " 'holdout', different train percents: 'tp'" 27 | "or a single run: 'na'") 28 | parser.add_argument('--tp', default=1.0, type=int, 29 | help='Fixed train percent to use if "average_over" = "holdout"') 30 | parser.add_argument('--exemp_size', default=6, type=int, 31 | help="Fixed holdout size to use if 'average_over' = 'tp'") 32 | parser.add_argument('--method', default='kd_kldiv', type=str, 33 | help="distillation method to use: 'ce' for only cross entropy" 34 | "'kd_kldiv' for base distillaiton loss with kl divergence " 35 | "'kd_kldiv_bic' for Large Scale Incremental Learning, " 36 | "'kd_kldiv_wa1' for Maintaining Discrimination and Fairness in Class Incremental Learning," 37 | "'kd_kldiv_wa2' for Adjusting Decision Boundary for Class Imbalanced Learning" 38 | " 'cn': cosine norm with basic distillation loss 'cn_lfc': " 39 | "cosine normaliztion with less forget constraint as distillation loss, " 40 | "'cn_lfc_mr' : cosine norm + less forget constraint + margin ranking loss," 41 | " 'ewc' for elastic weight consolidation with each task having their own importance matrix," 42 | "'online_ewc' for regularised ewc where there will be only one importance matrix across all tasks, " 43 | "'lwf': learning without forgetting, 'gem': gradient episodic memory," 44 | " 'agem': averaged gem, 'ce_holdout': cross entropy with memory replay," 45 | "'ce_ewc': EWC with memory replay," 46 | "'ce_lfc': Cross entropy (CE) with less forget constraint, 'ce_mr': CE with margin ranking loss," 47 | "'ce_replaced': CE with ILOS (--replace_new_logits should be set to True for this to work)") 48 | parser.add_argument('--exemplar', default='random', type=str, help="exemplar selection strategy: 'random', 'icarl', " 49 | "'kmeans', 'dpp', 'boundary', 'sensitivity' or 'fwsr'") 50 | parser.add_argument('--replace_new_logits', default=False, type=bool, help='if True, replace logits for new class (Incremental Learning in Online Scenario paper)') 51 | parser.add_argument('--wt_init', default=False, type=bool, 52 | help="whether to initialize the weights for old classes using " 53 | "data stats or not") 54 | parser.add_argument('--weighted', default=False, type=bool, 55 | help="whether to weight the new and old class samples or not") 56 | parser.add_argument('--rs_ratio', default=0.7, type=float, help='0 <= rescale ratio <= 1 to use if --weighted is True') 57 | parser.add_argument('--lwf_lamda', default=1.6, type=float, help="loss balance weight for LwF whose higher values favor" 58 | " old task performance.") 59 | parser.add_argument('--lamda_base', default=5.0, type=float, 60 | help='Base lamda for weighting less forget constraint loss.') 61 | parser.add_argument('--wa2_gamma', default=0.1, type=float, help='Rescaling factor for wa2 method.') 62 | parser.add_argument('--vis', default=False, type=bool, help='visualizing the raw dataset by persons') 63 | parser.add_argument('--tsne_vis', default=False, type=bool, 64 | help='tsne visualisations of the intermediate model features') 65 | parser.add_argument('--norm_vis', default=False, type=bool, help='visualising norms of final layer weights by classes ' 66 | 'and by epochs.') 67 | parser.add_argument('--acc_vis', default=False, type=bool, help='visualising accuracies of old and new classes.') 68 | parser.add_argument('-corr_vis', default=False, type=bool, help='correlation heatmaps of classes using raw data ' 69 | 'as well as predictions') 70 | parser.add_argument('--exemp_vis', default=False, type=bool, help='help visualising the space occupied by the selected ' 71 | 'exemplars within the class') 72 | parser.add_argument('--reg_coef', default=.2, type=float, help='Regularization coefficient for "online_ewc": a larger ' 73 | 'value means less plasticity') 74 | args = parser.parse_args() 75 | 76 | 77 | 78 | def main(): 79 | """ 80 | Main function to train and test. Also for analysing forgetting and accuracy scores 81 | :return: none 82 | """ 83 | start_time = time.time() 84 | model_trainer = trainer.Trainer(args) 85 | print(f"Total elpased time: {time.time() - start_time}") 86 | 87 | """Uncomment for analysing the saved results in the text files:""" 88 | # result_analyser.visualize_size_wise_sampling_scores('twor') 89 | # result_analyser.visualize_tp_wise_sampling_scores('pamap') 90 | # result_analyser.visualize_size_wise_scores('dsads', baseline=True) 91 | # result_analyser.visualize_base_old_all_scores('hatn6', baseline=True) 92 | # result_analyser.visualize_forgetting_measure(filename='hatn6', replay=True, baseline=True) 93 | 94 | """ Uncomment for analysing the forgetting scores, and reports by holdout sizes.""" 95 | filename = 'output_reports/mnist/kd_kldiv_wa1_random_1.0_200' 96 | # analyser = result_analyser.ResultAnalysis(filename, 30) 97 | # analyser.parse_text_results() 98 | # analyser.compute_avg_report_by_sizes() 99 | # analyser.compute_avg_detailed_accs() 100 | # analyser.plot_detailed_acc() 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /runner.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Author : Saurav Jha 3 | # Shell script for creating all result directories and guiding over running the commands 4 | mkdir -p output_reports/{aruba,dsads,hapt,hatn6,milan,opp,pamap,twor,ws,mnist,permuted_mnist} 5 | mkdir -p vis_outputs/{accuracy_vis/{acc_by_batches,acc_by_classes,detail_acc},corr_vis/{by_predictions,by_raw_features},exemp_vis,norm_vis/{by_batches,by_epochs},per_person,tsne_vis} 6 | 7 | python runner.py --dataset 'hapt' --total_classes 12 --new_classes 2 --base_classes 2 --epochs 200 --method 'kd_kldiv_wa1' # for incremental learning 8 | python runner.py --dataset 'hapt' --total_classes 12 --base_classes 12 --new_classes 0 --epochs 200 --method 'kd_kldiv_wa1' # for offline learning in single batch 9 | python runner.py --dataset 'dsads' --total_classes 19 --base_classes 2 --new_classes 2 --epochs 200 --method 'kd_kldiv_wa1' --exemplar 'icarl' # e.g. for ICaRL-styled exemplar selection 10 | 11 | python runner.py --dataset 'permuted_mnist' --total_classes 10 --new_classes 10 --base_classes 10 --epochs 5 --method 'agem' # for verification on permuted-mnist 12 | python runner.py --dataset 'mnist' --total_classes 10 --new_classes 2 --base_classes 2 --epochs 5 --method 'agem' # for verification on split-mnist 13 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__init__.py -------------------------------------------------------------------------------- /train/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__init__.pyc -------------------------------------------------------------------------------- /train/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/architecture_update.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/architecture_update.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/architecture_update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/architecture_update.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/compute_cosine_features.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/compute_cosine_features.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/compute_cosine_features.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/compute_cosine_features.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/customized_distill_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/customized_distill_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/customized_distill_trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/customized_distill_trainer.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/customized_two.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/customized_two.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/ewc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/ewc.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/ewc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/ewc.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/exemplar_selection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/exemplar_selection.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/exemplar_selection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/exemplar_selection.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/fisher_matrix.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/fisher_matrix.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/gem_solver.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/gem_solver.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/gem_solver.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/gem_solver.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/prediction_analyzer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/prediction_analyzer.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/prediction_analyzer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/prediction_analyzer.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/result_analyser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/result_analyser.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/result_analyser.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/result_analyser.cpython-37.pyc -------------------------------------------------------------------------------- /train/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /train/__pycache__/trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/__pycache__/trainer.cpython-37.pyc -------------------------------------------------------------------------------- /train/architecture_update.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Jan 12 12:51:47 2020 3 | Contains methods to extend pytorch neural network outputlayers 4 | @author: Martin Schiemer 5 | """ 6 | 7 | import random 8 | from copy import deepcopy 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | def create_proxy_outputs_from_onehot(model, task_pool, batch_y): 16 | current_output_neurons = list(model.children())[-1].in_features 17 | 18 | # inititalize array with 0s 19 | proxy_outputs = torch.zeros([batch_y.shape[0], current_output_neurons]) 20 | # convert one hot to class 21 | batch_classes = np.where(batch_y == 1)[1] 22 | 23 | # get index of the class in pool 24 | pool_ind = [task_pool.index(c) for c in batch_y] 25 | 26 | for i, j in enumerate(pool_ind): 27 | proxy_outputs[i, j] = 1 28 | 29 | return proxy_outputs 30 | 31 | 32 | def create_proxy_outputs(model, task_pool, batch_y): 33 | current_output_neurons = list(model.children())[-1].in_features 34 | 35 | # inititalize array with 0s 36 | proxy_outputs = torch.zeros(len(batch_y), dtype=torch.long) 37 | 38 | # get index of the class in pool 39 | pool_ind = [task_pool.index(c) for c in batch_y] 40 | 41 | for i, j in enumerate(pool_ind): 42 | proxy_outputs[i] = j 43 | 44 | return proxy_outputs 45 | 46 | 47 | def batch_transform(batch, model, task_pool, device): 48 | batch_x, batch_y = batch 49 | batch_y = create_proxy_outputs(model, task_pool, batch_y) 50 | 51 | # calculate model outputs and loss and backprop 52 | if model.type == "fc": 53 | shape_length = len(batch_x.shape) 54 | batch_x = batch_x.view(-1, batch_x.shape[shape_length - 2] * 55 | batch_x.shape[shape_length - 1] * 56 | batch_x.shape[shape_length - 3]) 57 | 58 | batch_x, batch_y = batch_x.to(device), batch_y.to(device) 59 | return batch_x, batch_y 60 | 61 | 62 | def pad_random_weights(vec, pad_width, *_, **__): 63 | vec[vec.size - pad_width[1]:] = random.uniform(-1, 1) 64 | 65 | 66 | def pad_normal_dist_weights(vec, pad_width, iaxis, kwargs, mean, var): 67 | vec[vec.size - pad_width[1]:] = random.uniform(-1, 1) 68 | 69 | 70 | def update_weights(net, amnt_new_classes): 71 | layer_key = list(net.state_dict().keys())[-2] 72 | weights = net.state_dict()[layer_key].cpu().detach().numpy() 73 | # weights = net.state_dict()[f"fc{output_layer_nr}.weight"].cpu().detach().numpy() 74 | w_mean = np.mean(weights, axis=0) 75 | w_std = np.std(weights, axis=0) 76 | new_weights = np.pad(weights, ((0, amnt_new_classes), (0, 0)), mode="constant", constant_values=0) 77 | for i in reversed(range(amnt_new_classes)): 78 | for j in range(new_weights.shape[1]): 79 | new_weights[new_weights.shape[0] - 1 - i][j] = np.random.normal(w_mean[j], w_std[j]) 80 | return new_weights 81 | 82 | 83 | def update_bias(net, amnt_new_classes): 84 | bias_key = list(net.state_dict().keys())[-1] 85 | bias = net.state_dict()[bias_key].cpu().detach().numpy() 86 | b_mean = np.mean(bias) 87 | b_std = np.std(bias) 88 | new_bias = np.zeros(len(bias) + amnt_new_classes, dtype="f") 89 | new_bias[:len(bias)] = bias 90 | for i in range(amnt_new_classes): 91 | new_bias[-1 - i] = np.random.normal(b_mean, b_std) - np.log(amnt_new_classes) 92 | return new_bias 93 | 94 | 95 | def transform_state_dic(state_dict, old_keys): 96 | new_dic = {key: data[1] for key, data in zip(old_keys, state_dict.items())} 97 | 98 | return new_dic 99 | 100 | 101 | def add_output_neurons(net, amnt_old_classes, amnt_new_classes): 102 | newmodel = torch.nn.Sequential(*(list(net.modules())[:-1]), 103 | nn.Linear(list(net.modules())[-1].in_features, 104 | amnt_old_classes + amnt_new_classes)) 105 | return newmodel 106 | 107 | 108 | def kaiming_normal_init(m): 109 | # Source: https://github.com/ngailapdi/LWF/blob/baa07ee322d4b2f93a28eba092ad37379f565aca/model.py#L28 110 | if isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 112 | elif isinstance(m, nn.Linear): 113 | nn.init.kaiming_normal_(m.weight, nonlinearity='sigmoid') 114 | 115 | 116 | def update_model(NET_FUNCTION, net, data_shape, dataname, amnt_new_classes, device, lwf=False): 117 | amnt_old_classes = list(net.children())[-1].out_features 118 | if amnt_new_classes >= 1: 119 | if lwf: 120 | # Source: https://github.com/ngailapdi/LWF/blob/baa07ee322d4b2f93a28eba092ad37379f565aca/model.py#L73 121 | in_features = list(net.children())[-1].in_features 122 | out_features = amnt_old_classes 123 | weights = net.fc.weight.data 124 | new_out_features = out_features + amnt_new_classes 125 | net.fc = torch.nn.Linear(in_features, 126 | new_out_features, bias=False) 127 | kaiming_normal_init(net.fc.weight) 128 | net.fc.weight.data[:out_features] = weights 129 | net = net.to(device) 130 | return net 131 | else: 132 | new_output_layer_w = update_weights(net, amnt_new_classes) 133 | new_output_layer_bias = update_bias(net, amnt_new_classes) 134 | new_model = NET_FUNCTION(data_shape, amnt_old_classes + amnt_new_classes, dataname) 135 | 136 | for i, l in enumerate(zip(new_model.children(), net.children())): 137 | if i == len(list(new_model.children())) - 1: 138 | l[0].weight = torch.nn.Parameter(torch.from_numpy(new_output_layer_w)) 139 | l[0].bias = torch.nn.Parameter(torch.from_numpy(new_output_layer_bias)) 140 | else: 141 | l[0].weight = l[1].weight 142 | l[0].bias = l[1].bias 143 | del net 144 | new_model = new_model.to(device) 145 | return new_model 146 | else: 147 | return net 148 | 149 | 150 | def check_model_integrity(old_model, new_model, verbose=False): 151 | for i in old_model.state_dict().keys(): 152 | if (np.array_equal(old_model.state_dict()[i].cpu().numpy(), new_model.state_dict()[i].cpu().numpy())): 153 | if verbose: 154 | print(f"key {i} is the same for both nets") 155 | else: 156 | if verbose: 157 | print("\n", i, "\n") 158 | for h in range(len(old_model.state_dict()[i])): 159 | try: 160 | if np.array_equal(old_model.state_dict()[i][h].numpy(), new_model.state_dict()[i][h].numpy()): 161 | if verbose: 162 | print(f"key {i} weights of neuron {h} are the same for both nets\n") 163 | else: 164 | 165 | print(f"key {i} weights of neuron {h} are different for both nets\n Differces at:") 166 | print(old_model.state_dict()[i][h].numpy() - new_model.state_dict()[i][h].numpy()) 167 | print("\n") 168 | return False 169 | except: 170 | pass 171 | return True 172 | -------------------------------------------------------------------------------- /train/compute_cosine_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # !/usr/bin/env python 4 | # coding=utf-8 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def compute_features(tg_feature_model, cls_idx, evalloader, num_samples, num_features, device=None): 10 | tg_feature_model.eval() 11 | 12 | features = np.zeros([num_samples, num_features]) 13 | start_idx = 0 14 | with torch.no_grad(): 15 | for inputs, targets in evalloader: 16 | outputs = tg_feature_model(inputs.float().cuda()) 17 | # print(num_samples, num_features, inputs.shape, start_idx, len(evalloader.dataset), targets, outputs.shape) 18 | features[start_idx:start_idx + inputs.shape[0], :] = np.squeeze(outputs.cpu()) 19 | start_idx = start_idx + inputs.shape[0] 20 | assert (start_idx == num_samples) 21 | return features 22 | -------------------------------------------------------------------------------- /train/customized_distill_trainer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scipy.spatial.distance import cosine 7 | from torch.autograd import Variable 8 | from tqdm import tqdm 9 | 10 | from train.losses import class_balanced_loss, MultiClassCrossEntropy 11 | from .losses import margin_ranking_loss 12 | 13 | cur_features = [] 14 | ref_features = [] 15 | old_scores = [] 16 | new_scores = [] 17 | lpl_features_student, lpl_features_teacher = [], [] 18 | 19 | 20 | def get_ref_features(self, inputs, outputs): 21 | global ref_features 22 | ref_features = inputs[0] 23 | 24 | 25 | def get_cur_features(self, inputs, outputs): 26 | global cur_features 27 | cur_features = inputs[0] 28 | 29 | 30 | def get_old_scores_before_scale(self, inputs, outputs): 31 | global old_scores 32 | old_scores = outputs 33 | 34 | 35 | def get_new_scores_before_scale(self, inputs, outputs): 36 | global new_scores 37 | new_scores = outputs 38 | 39 | 40 | def get_teacher_features(self, inputs, outputs): 41 | global lpl_features_teacher 42 | lpl_features_teacher = outputs 43 | 44 | 45 | def get_student_features(self, inputs, outputs): 46 | global lpl_features_student 47 | lpl_features_student = outputs 48 | 49 | 50 | loss_by_epoch = dict() 51 | 52 | 53 | class CustomizedTrainer(): 54 | def __init__(self, args, itera, seen_cls, train_loader, model, previous_model, lamda, 55 | bias_layers, virtual_map, classes_by_groups, device, visualizer=None): 56 | self.args = args 57 | self.itera = itera 58 | self.seen_class = seen_cls 59 | self.train_loader = train_loader 60 | self.device = device 61 | self.cur_lamda = lamda 62 | self.model, self.previous_model = model, previous_model 63 | self.data_visualizer = visualizer 64 | self.virtual_map = virtual_map 65 | self.handle_ref_features = self.previous_model.fc.register_forward_hook(get_ref_features) 66 | self.handle_cur_features = self.model.fc.register_forward_hook(get_cur_features) 67 | if 'cn' in self.args.method: 68 | self.handle_old_scores_bs = self.model.fc.fc1.register_forward_hook(get_old_scores_before_scale) 69 | self.handle_new_scores_bs = self.model.fc.fc2.register_forward_hook(get_new_scores_before_scale) 70 | if 'bic' in self.args.method: 71 | self.bias_layers = bias_layers 72 | self.classes_by_groups = classes_by_groups 73 | if 'lpl' in self.args.method: 74 | self.handle_student_features = self.model.fc_penultimate.register_forward_hook(get_student_features) 75 | self.handle_teacher_features = self.model.fc_penultimate.register_forward_hook(get_teacher_features) 76 | 77 | def align_weights(self): 78 | print("Aligning weights: ") 79 | if '_wa1' in self.args.method: 80 | with torch.no_grad(): 81 | # print("before clamp: ", self.model.fc4.weight.data) 82 | for p in self.model.fc.parameters(): 83 | p.data.clamp_(0) 84 | # print("after clamp: ", self.model.fc4.weight.data) 85 | elif '_wa2' in self.args.method: 86 | if 'cn' in self.args.method: 87 | # customize the normalization so as to fit for the fc1 and fc2 layers of SplitCosineLayer() 88 | with torch.no_grad(): 89 | w1, w2 = self.model.fc.fc1.weight.data, self.model.fc.fc2.weight.data 90 | w1_norm, w2_norm = [w / torch.norm(w, p=2, dim=1, keepdim=True) for w in (w1, w2)] 91 | self.model.fc.fc1.weight.data.copy_(w1_norm) 92 | self.model.fc.fc2.weight.data.copy_(w2_norm) 93 | else: 94 | # carry out regular normalization with fc layers 95 | with torch.no_grad(): 96 | # https://github.com/feidfoe/AdjustBnd4Imbalance/blob/master/models/cifar/resnet.py#L135 97 | w = self.model.fc.weight.data 98 | w_norm = w / torch.norm(w, p=2, dim=1, keepdim=True) 99 | self.model.fc.weight.data.copy_(w_norm) 100 | 101 | def distill_training(self, optimizer, num_new_classes, last_epoch=False, new_class_avg=None, permuted=False): 102 | print("Training with distillation losses... ") 103 | losses = [] 104 | lambda_ = (self.seen_class - num_new_classes) / self.seen_class 105 | dataloader = self.train_loader 106 | if self.args.tsne_vis and last_epoch: 107 | tsne_features, tsne_labels = np.empty(shape=[0, self.model.fc.in_features]), [] 108 | 109 | for i, (feature, label) in enumerate(tqdm(dataloader)): 110 | feature, label = Variable(feature), Variable(label) 111 | if self.args.dataset == 'cifar100': 112 | feature = feature.reshape(-1, 32 * 32) 113 | feature = feature.to(self.device) 114 | # label = label.type(torch.LongTensor) 115 | label = label.view(-1).to(self.device) 116 | optimizer.zero_grad() 117 | p = self.model(feature.float()) 118 | if '_bic' in self.args.method and self.itera > 1: 119 | p = self.bias_forward(p) 120 | 121 | logp = F.log_softmax(p[:, :self.seen_class - num_new_classes] / self.args.T, dim=1) 122 | 123 | with torch.no_grad(): 124 | if self.args.weighted: 125 | sample_weights = self.get_sample_weights(feature, label, new_class_avg, num_new_classes) 126 | sample_weights = torch.exp(torch.Tensor(sample_weights).to(self.device)) 127 | pre_p = self.previous_model(feature.float()) 128 | if 'bic' in self.args.method and self.itera > 1: 129 | pre_p = self.bias_forward(pre_p) 130 | if self.args.method != 'lwf': 131 | # apply distilling loss giving soft labels (T = wts. of small values to introduce rescaling) 132 | if permuted: 133 | assert pre_p.shape[1] == num_new_classes and self.seen_class % num_new_classes == 0 134 | else: 135 | assert pre_p.shape[1] == self.seen_class - num_new_classes, print("Shape mismatch between previous " 136 | "model and no. of old classes.") 137 | pre_p = F.softmax(pre_p / self.args.T, dim=1) 138 | 139 | if self.args.replace_new_logits: 140 | print("Adjusting old class logits..") 141 | p = self.modify_new_logits(p, pre_p, num_new_classes) 142 | 143 | loss_hard_target = torch.nn.CrossEntropyLoss()(p, label) 144 | if any([x in self.args.method for x in ['ce', 'cn']]): 145 | # for both base_ce and while using cosine normalisation, whole of the loss_hard_target is added 146 | if self.args.method == "cn": 147 | # if only cosine normalization, loss = ce_loss + kldiv_loss 148 | loss_soft_target = torch.nn.KLDivLoss()(logp, pre_p) * self.args.T * self.args.T * lambda_ 149 | loss = loss_hard_target + loss_soft_target 150 | loss_stats = f"CE loss: {loss_hard_target}, KD loss: {loss_soft_target}" 151 | else: 152 | loss = loss_hard_target 153 | loss_stats = f"CE loss: {loss_hard_target}" 154 | elif self.args.method == 'lwf': 155 | # learning without forgetting using multi class cross entropy loss 156 | logits_dist = p[:, :self.seen_class - num_new_classes] 157 | loss_soft_target = MultiClassCrossEntropy.MultiClassCrossEntropy(logits_dist, pre_p, self.args.T, device=self.device) 158 | loss = self.args.lwf_lamda * loss_soft_target + loss_hard_target 159 | loss_stats = f"\nCE loss: {loss_hard_target}, KD loss: {loss_soft_target}" 160 | elif 'kd' in self.args.method: 161 | # preserve previous knowledge by encouraging current predictions on old classes to match soft labels of previous model 162 | if '_prod' in self.args.method: 163 | prod = pre_p * logp 164 | # T * T factor multiplied following https://github.com/peterliht/knowledge-distillation-pytorch/issues/10 165 | loss_soft_target = -torch.mean(torch.sum(prod, dim=1)) * self.args.T * self.args.T * lambda_ # knowledge distillation loss 166 | elif '_kldiv' in self.args.method: 167 | loss_soft_target = torch.nn.KLDivLoss()(logp, pre_p) * self.args.T * self.args.T * lambda_ 168 | loss_hard_target = (1. - lambda_) * loss_hard_target 169 | loss = loss_hard_target + loss_soft_target # cross distillation loss 170 | loss_stats = f"\nCE loss: {loss_hard_target}, KD loss: {loss_soft_target}" 171 | if 'cb' in self.args.method: 172 | samples_per_cls = self.get_count_by_classes(label.detach().cpu().numpy(), self.seen_class) 173 | cb_loss = class_balanced_loss.CB_loss(label, p, samples_per_cls, 174 | no_of_classes=self.seen_class, loss_type='softmax', 175 | device=self.device) 176 | loss = loss + cb_loss 177 | loss_stats += f"CB loss: {cb_loss}" 178 | else: 179 | print("No valid distill method: 'ce' or 'kd' or 'lwf' or 'cn' found !!!") 180 | if 'lfc' in self.args.method: 181 | # less forget constraint loss 182 | cur_features_ = F.normalize(cur_features, p=2, dim=1) 183 | ref_features_ = F.normalize(ref_features.detach(), p=2, dim=1) 184 | less_forget_constraint = torch.nn.CosineEmbeddingLoss()(cur_features_, ref_features_, 185 | torch.ones(feature.shape[0]).to( 186 | self.device)) * self.cur_lamda 187 | loss += less_forget_constraint 188 | loss_stats += f" LFC loss: {less_forget_constraint}" 189 | if 'mr' in self.args.method: 190 | # compute margin ranking loss 191 | if 'cn' in self.args.method: 192 | output_bs = torch.cat((old_scores, new_scores), dim=1) 193 | else: 194 | output_bs = p 195 | output_bs = F.normalize(output_bs, p=2, dim=1) 196 | # mr_loss = triplet_loss.batch_hard_triplet_loss(label, output_bs, margin=1., device=self.device) 197 | mr_loss = margin_ranking_loss.compute_margin_ranking_loss(p, label, num_new_classes, self.seen_class, 198 | self.device, output_bs) 199 | loss += mr_loss 200 | loss_stats += f" MR loss: {mr_loss}" 201 | if 'lpl' in self.args.method: 202 | # locality preserving loss 203 | k, gamma = 5 if len(label) > 5 else math.ceil(len(label) / 2), 1.5 204 | lpl_loss = 0 205 | for i, data in enumerate(lpl_features_student): 206 | f_s_i = data 207 | for j, data_ in enumerate(lpl_features_student): 208 | if j != i: 209 | alpha_i_j= self.get_locality_preserving_alpha(i, j, k) 210 | if alpha_i_j > 0: 211 | temp_ = torch.norm(f_s_i - data_, dim=0, p=None).pow(2) 212 | lpl_loss += temp_.item() * alpha_i_j 213 | lpl_loss = gamma * lpl_loss / (label.shape[0] * k) # scale by factor: gamma / (k * batch_size) 214 | loss += lpl_loss 215 | loss_stats += f" LPL loss: {lpl_loss}" 216 | 217 | print(loss_stats) 218 | loss.backward(retain_graph=True) 219 | optimizer.step() 220 | if 'wa' in self.args.method: 221 | self.align_weights() 222 | # look into tsne visualisations 223 | if self.args.tsne_vis and last_epoch: 224 | tsne_features = np.vstack((tsne_features, cur_features.detach().cpu().numpy())) 225 | tsne_labels += label.cpu().tolist() 226 | losses.append(loss.item()) 227 | if self.args.tsne_vis and last_epoch: 228 | self.data_visualizer.plot_tsne(tsne_features, tsne_labels, itera=self.itera) 229 | 230 | return sum(losses) / len(dataloader.dataset) 231 | 232 | def get_locality_preserving_alpha(self, i, j, k=5): 233 | sigma = math.sqrt(2) # normalizing constant 234 | f_T_i = lpl_features_teacher[i] 235 | dist = torch.norm(lpl_features_teacher - f_T_i, dim=1, p=None) 236 | knn_indices = dist.topk(k+1, largest=False).indices[1:] # 0th index is always the element itself 237 | if j in knn_indices: 238 | alpha_i_j = - dist[j].float().pow(2) / sigma ** 2 239 | alpha_i_j = torch.exp(alpha_i_j).item() 240 | else: 241 | alpha_i_j = 0. 242 | return alpha_i_j 243 | 244 | def modify_new_logits(self, p, p_old, m): 245 | """ 246 | Adapted from https://arxiv.org/pdf/2003.13191.pdf 247 | :param p: output logits of new classifier (o_1...o_n, o_n+1...o_n+m) 248 | :param p_old: old classifier output logits (o_1...o_n) 249 | :param m: num of new classes 250 | :return: modified logits of new classifier 251 | """ 252 | beta = 0.5 # beta = 0.5 used in the paper 253 | p[:, :self.seen_class - m] = p[:, :self.seen_class - m] * beta + p_old * (1 - beta) 254 | return p 255 | 256 | def bias_forward(self, input): 257 | input_groups = [] 258 | for idx, classes in enumerate(self.classes_by_groups): 259 | temp_tensor = torch.Tensor().to(self.device) 260 | for each in classes: 261 | each = self.virtual_map[each] 262 | temp_tensor = torch.cat([temp_tensor, input[:, (int)(each):(int)(each + 1)]], dim=1) 263 | input_groups.append(temp_tensor) 264 | output_by_groups = [self.bias_layers[idx](item) for idx, item in enumerate(input_groups)] 265 | output_by_groups = torch.cat(output_by_groups, dim=1) 266 | return output_by_groups 267 | 268 | @staticmethod 269 | def get_count_by_classes(array_of_labels, seen_classes): 270 | classes_seen = [i for i in range(seen_classes)] 271 | counts = [] 272 | for label in classes_seen: 273 | counts.append(np.count_nonzero(array_of_labels == label)) 274 | assert len(counts) == seen_classes 275 | return counts 276 | 277 | def get_sample_weights(self, features, labels, label_averaged_dict, num_new_classes): 278 | # feature_extractor = basic_model.Net(self.model.fc1.in_features, self.args.total_classes) 279 | # feature_extractor = torch.nn.Sequential(*list(feature_extractor.children())[:-1]) 280 | past_classes = [i for i in range(self.seen_class - num_new_classes)] 281 | max_elem = 0 282 | batch_distance = [] 283 | for feature, label in zip(cur_features, labels): 284 | feature = feature.div(feature.norm(p=2, dim=0, keepdim=True).expand_as(feature)) 285 | if label in past_classes: 286 | cosine_dist = tuple((cosine(feature.detach().cpu(), averaged_vec) for _, averaged_vec in 287 | label_averaged_dict.items())) 288 | batch_distance.append(np.array([max(cosine_dist)])) 289 | else: 290 | batch_distance.append(np.array([0])) 291 | max_elem = max(batch_distance) 292 | assert (len(batch_distance) == len(features)) 293 | distances = list(map(lambda x: x - max_elem, batch_distance)) # x = x - max(x_i) for avoiding underflow 294 | return distances 295 | 296 | def remove_hooks(self): 297 | # remove the registered hook after model has been trained for the incremental batch 298 | self.handle_ref_features.remove() 299 | self.handle_cur_features.remove() 300 | if 'cn' in self.args.method: 301 | self.handle_old_scores_bs.remove() 302 | self.handle_new_scores_bs.remove() 303 | if 'lpl' in self.args.method: 304 | self.handle_student_features.remove() 305 | self.handle_teacher_features.remove() 306 | 307 | def get_model(self): 308 | return self.model -------------------------------------------------------------------------------- /train/customized_two.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from scipy.spatial.distance import cosine 7 | from torch.autograd import Variable 8 | from tqdm import tqdm 9 | 10 | from train.losses import class_balanced_loss 11 | from .ewc import EWC 12 | from .losses import margin_ranking_loss 13 | 14 | cur_features = [] 15 | ref_features = [] 16 | old_scores = [] 17 | new_scores = [] 18 | lpl_features_student, lpl_features_teacher = [], [] 19 | 20 | 21 | def get_ref_features(self, inputs, outputs): 22 | global ref_features 23 | ref_features = inputs[0] 24 | 25 | 26 | def get_cur_features(self, inputs, outputs): 27 | global cur_features 28 | cur_features = inputs[0] 29 | 30 | 31 | def get_old_scores_before_scale(self, inputs, outputs): 32 | global old_scores 33 | old_scores = outputs 34 | 35 | 36 | def get_new_scores_before_scale(self, inputs, outputs): 37 | global new_scores 38 | new_scores = outputs 39 | 40 | 41 | def get_teacher_features(self, inputs, outputs): 42 | global lpl_features_teacher 43 | lpl_features_teacher = outputs 44 | 45 | 46 | def get_student_features(self, inputs, outputs): 47 | global lpl_features_student 48 | lpl_features_student = outputs 49 | 50 | 51 | loss_by_epoch = dict() 52 | 53 | 54 | class CustomizedTrainer(): 55 | def __init__(self, args, itera, seen_cls, train_loader, model, previous_model, lamda, 56 | bias_layers, virtual_map, classes_by_groups, device, visualizer=None, old_tasks=None): 57 | self.args = args 58 | self.itera = itera 59 | self.seen_class = seen_cls 60 | self.train_loader = train_loader 61 | self.device = device 62 | self.cur_lamda = lamda 63 | self.model, self.previous_model = model, previous_model 64 | self.data_visualizer = visualizer 65 | self.virtual_map = virtual_map 66 | self.handle_ref_features = self.previous_model.fc.register_forward_hook(get_ref_features) 67 | self.handle_cur_features = self.model.fc.register_forward_hook(get_cur_features) 68 | if 'cn' in self.args.method: 69 | self.handle_old_scores_bs = self.model.fc.fc1.register_forward_hook(get_old_scores_before_scale) 70 | self.handle_new_scores_bs = self.model.fc.fc2.register_forward_hook(get_new_scores_before_scale) 71 | if 'bic' in self.args.method: 72 | self.bias_layers = bias_layers 73 | self.classes_by_groups = classes_by_groups 74 | if 'ewc' in self.args.method: 75 | self.old_tasks = old_tasks 76 | if 'lpl' in self.args.method: 77 | self.handle_student_features = self.model.fc_penultimate.register_forward_hook(get_student_features) 78 | self.handle_teacher_features = self.model.fc_penultimate.register_forward_hook(get_teacher_features) 79 | 80 | def align_weights(self): 81 | print("Aligning weights: ") 82 | if 'wa1' in self.args.method: 83 | with torch.no_grad(): 84 | for p in self.model.fc.parameters(): 85 | p.data.clamp_(0) 86 | elif 'wa2' in self.args.method: 87 | if 'cn' in self.args.method: 88 | # customize the normalization so as to fit for the fc1 and fc2 layers of SplitCosineLayer() 89 | with torch.no_grad(): 90 | w1, w2 = self.model.fc.fc1.weight.data, self.model.fc.fc2.weight.data 91 | w1_norm, w2_norm = [w / torch.norm(w, p=2, dim=1, keepdim=True) for w in (w1, w2)] 92 | self.model.fc.fc1.weight.data.copy_(w1_norm) 93 | self.model.fc.fc2.weight.data.copy_(w2_norm) 94 | else: 95 | # carry out regular normalization with fc layers 96 | with torch.no_grad(): 97 | # https://github.com/feidfoe/AdjustBnd4Imbalance/blob/master/models/cifar/resnet.py#L135 98 | w = self.model.fc.weight.data 99 | w_norm = w / torch.norm(w, p=2, dim=1, keepdim=True) 100 | self.model.fc.weight.data.copy_(w_norm) 101 | 102 | def distill_training(self, optimizer, num_new_classes, last_epoch=False, new_class_avg=None): 103 | print("Training with distillation losses... ") 104 | losses = [] 105 | lambda_ = (self.seen_class - num_new_classes) / self.seen_class 106 | dataloader = self.train_loader 107 | if self.args.tsne_vis and last_epoch: 108 | tsne_features, tsne_labels = np.empty(shape=[0, self.model.fc.in_features]), [] 109 | 110 | if 'ewc' in self.args.method: 111 | ewc = EWC(self.model, self.old_tasks, self.device) 112 | 113 | for i, (feature, label) in enumerate(tqdm(dataloader)): 114 | feature, label = Variable(feature), Variable(label) 115 | feature = feature.to(self.device) 116 | # label = label.type(torch.LongTensor) 117 | label = label.view(-1).to(self.device) 118 | optimizer.zero_grad() 119 | p = self.model(feature.float()) 120 | if 'bic' in self.args.method and self.itera > 1: 121 | p = self.bias_forward(p) 122 | 123 | logp = F.log_softmax(p[:, :self.seen_class - num_new_classes] / self.args.T, dim=1) 124 | 125 | with torch.no_grad(): 126 | if self.args.weighted: 127 | sample_weights = self.get_sample_weights(feature, label, new_class_avg, num_new_classes) 128 | sample_weights = torch.exp(torch.Tensor(sample_weights).to(self.device)) 129 | pre_p = self.previous_model(feature.float()) 130 | if 'bic' in self.args.method and self.itera > 1: 131 | pre_p = self.bias_forward(pre_p) 132 | # apply distilling loss giving soft labels (T = wts. of small values to introduce rescaling) 133 | pre_p = F.softmax(pre_p[:, :self.seen_class - num_new_classes] / self.args.T, dim=1) 134 | 135 | if self.args.extra_loss == 'modified_cd': 136 | p = self.modify_new_logits(p, pre_p, num_new_classes) 137 | 138 | loss_hard_target = torch.nn.CrossEntropyLoss()(p[:, :self.seen_class], 139 | label) 140 | if any([x in self.args.method for x in ['ce', 'cn']]): 141 | # for both base_ce and while using cosine normalisation, whole of the loss_hard_target is added 142 | loss = loss_hard_target 143 | loss_stats = f"CE loss: {loss_hard_target}" 144 | else: 145 | print("No valid distill method: 'ce' or 'kd' or 'cn' found !!!") 146 | if 'lfc' in self.args.method: 147 | # less forget constraint loss 148 | cur_features_ = F.normalize(cur_features, p=2, dim=1) 149 | ref_features_ = F.normalize(ref_features.detach(), p=2, dim=1) 150 | less_forget_constraint = torch.nn.CosineEmbeddingLoss()(cur_features_, ref_features_, 151 | torch.ones(feature.shape[0]).to( 152 | self.device)) * self.cur_lamda 153 | loss += less_forget_constraint 154 | loss_stats += f" LFC loss: {less_forget_constraint}" 155 | if 'mr' in self.args.method: 156 | # compute margin ranking loss 157 | if 'cn' in self.args.method: 158 | output_bs = torch.cat((old_scores, new_scores), dim=1) 159 | else: 160 | output_bs = p 161 | output_bs = F.normalize(output_bs, p=2, dim=1) 162 | mr_loss = margin_ranking_loss.compute_margin_ranking_loss(p, label, num_new_classes, self.seen_class, 163 | self.device, output_bs) 164 | loss += mr_loss 165 | loss_stats += f" MR loss: {mr_loss}" 166 | if 'lpl' in self.args.method: 167 | # locality preserving loss 168 | k, gamma = 5 if len(label) > 5 else math.ceil(len(label) / 2), 1.5 169 | lpl_loss = 0 170 | for i, data in enumerate(lpl_features_student): 171 | f_s_i = data 172 | for j, data_ in enumerate(lpl_features_student): 173 | if j != i: 174 | alpha_i_j= self.get_locality_preserving_alpha(i, j, k) 175 | if alpha_i_j > 0: 176 | temp_ = torch.norm(f_s_i - data_, dim=0, p=None).pow(2) 177 | lpl_loss += temp_.item() * alpha_i_j 178 | lpl_loss = gamma * lpl_loss / (label.shape[0] * k) # scale by factor: gamma / (k * batch_size) 179 | loss += lpl_loss 180 | loss_stats += f" LPL loss: {lpl_loss}" 181 | 182 | print(loss_stats) 183 | if 'wa' in self.args.method: 184 | self.align_weights() 185 | 186 | loss.backward(retain_graph=True) 187 | optimizer.step() 188 | # look into tsne visualisations 189 | if self.args.tsne_vis and last_epoch: 190 | tsne_features = np.vstack((tsne_features, cur_features.detach().cpu().numpy())) 191 | tsne_labels += label.cpu().numpy().tolist() 192 | self.data_visualizer.plot_tsne(tsne_features, tsne_labels, itera=self.itera) 193 | losses.append(loss.item()) 194 | 195 | return sum(losses) / len(dataloader.dataset) 196 | 197 | def get_locality_preserving_alpha(self, i, j, k=5): 198 | sigma = math.sqrt(2) # normalizing constant 199 | f_T_i = lpl_features_teacher[i] 200 | dist = torch.norm(lpl_features_teacher - f_T_i, dim=1, p=None) 201 | knn_indices = dist.topk(k+1, largest=False).indices[1:] # 0th index is always the element itself 202 | if j in knn_indices: 203 | alpha_i_j = - dist[j].float().pow(2) / sigma ** 2 204 | alpha_i_j = torch.exp(alpha_i_j).item() 205 | else: 206 | alpha_i_j = 0. 207 | return alpha_i_j 208 | 209 | def modify_new_logits(self, p, p_old, m): 210 | """ 211 | Adapted from https://arxiv.org/pdf/2003.13191.pdf 212 | :param p: output logits of new classifier (o_1...o_n, o_n+1...o_n+m) 213 | :param p_old: old classifier output logits (o_1...o_n) 214 | :param m: num of new classes 215 | :return: modified logits of new classifier 216 | """ 217 | beta = 0.8 # beta = 0.5 used in the paper 218 | p[:, :self.seen_class - m] = p[:, :self.seen_class - m] * beta + p_old * (1 - beta) 219 | return p 220 | 221 | def bias_forward(self, input): 222 | input_groups = [] 223 | for idx, classes in enumerate(self.classes_by_groups): 224 | temp_tensor = torch.Tensor().to(self.device) 225 | for each in classes: 226 | each = self.virtual_map[each] 227 | temp_tensor = torch.cat([temp_tensor, input[:, (int)(each):(int)(each + 1)]], dim=1) 228 | input_groups.append(temp_tensor) 229 | output_by_groups = [self.bias_layers[idx](item) for idx, item in enumerate(input_groups)] 230 | output_by_groups = torch.cat(output_by_groups, dim=1) 231 | return output_by_groups 232 | 233 | @staticmethod 234 | def get_count_by_classes(array_of_labels, seen_classes): 235 | classes_seen = [i for i in range(seen_classes)] 236 | counts = [] 237 | for label in classes_seen: 238 | counts.append(np.count_nonzero(array_of_labels == label)) 239 | assert len(counts) == seen_classes 240 | return counts 241 | 242 | def get_sample_weights(self, features, labels, label_averaged_dict, num_new_classes): 243 | # feature_extractor = basic_model.Net(self.model.fc1.in_features, self.args.total_classes) 244 | # feature_extractor = torch.nn.Sequential(*list(feature_extractor.children())[:-1]) 245 | past_classes = [i for i in range(self.seen_class - num_new_classes)] 246 | max_elem = 0 247 | batch_distance = [] 248 | for feature, label in zip(cur_features, labels): 249 | feature = feature.div(feature.norm(p=2, dim=0, keepdim=True).expand_as(feature)) 250 | if label in past_classes: 251 | cosine_dist = tuple((cosine(feature.detach().cpu(), averaged_vec) for _, averaged_vec in 252 | label_averaged_dict.items())) 253 | batch_distance.append(np.array([max(cosine_dist)])) 254 | else: 255 | batch_distance.append(np.array([0])) 256 | max_elem = max(batch_distance) 257 | assert (len(batch_distance) == len(features)) 258 | distances = list(map(lambda x: x - max_elem, batch_distance)) # x = x - max(x_i) for avoiding underflow 259 | return distances 260 | 261 | def remove_hooks(self): 262 | # remove the registered hook after model has been trained for the incremental batch 263 | self.handle_ref_features.remove() 264 | self.handle_cur_features.remove() 265 | if 'cn' in self.args.method: 266 | self.handle_old_scores_bs.remove() 267 | self.handle_new_scores_bs.remove() 268 | if 'lpl' in self.args.method: 269 | self.handle_student_features.remove() 270 | self.handle_teacher_features.remove() 271 | 272 | def get_model(self): 273 | return self.model -------------------------------------------------------------------------------- /train/ewc.py: -------------------------------------------------------------------------------- 1 | 2 | from copy import deepcopy 3 | 4 | from torch import nn, Tensor 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | 8 | 9 | class EWC(object): 10 | """ 11 | Class for defining the diagonal fisher matrix. 12 | Author: https://github.com/moskomule/ewc.pytorch/blob/master/utils.py 13 | """ 14 | 15 | def __init__(self, model: nn.Module, dataset: list, device): 16 | self.model = model 17 | self.dataset = dataset 18 | self.device = device 19 | self.params = {n: p for n, p in self.model.named_parameters() if p.requires_grad} 20 | self._means = {} 21 | self._precision_matrices = self._diag_fisher() 22 | 23 | for n, p in deepcopy(self.params).items(): 24 | self._means[n] = Variable(p.data.to(self.device)) 25 | 26 | def _diag_fisher(self): 27 | precision_matrices = {} 28 | for n, p in deepcopy(self.params).items(): 29 | p.data.zero_() # sets all data to zeros 30 | precision_matrices[n] = Variable(p.data.to(self.device)) # fisher matrix whose diagonal = precision of posterior p(theta | data) 31 | 32 | self.model.eval() 33 | for input, label in self.dataset: 34 | self.model.zero_grad() 35 | input, label = Variable(input), Variable(label) 36 | input = input.to(self.device) 37 | label = label.view(-1).to(self.device) 38 | output = self.model(input.float()) 39 | # label = output.max(1)[1].view(-1) 40 | loss = F.nll_loss(F.log_softmax(output, dim=1), label) 41 | loss.backward() 42 | 43 | for n, p in self.model.named_parameters(): 44 | precision_matrices[n].data += p.grad.data.pow(2) / len(self.dataset) # grad = first order derivatives; point (b) - EWC paper 45 | 46 | precision_matrices = {n: p for n, p in precision_matrices.items()} 47 | return precision_matrices 48 | 49 | def penalty(self, model: nn.Module): 50 | loss = 0 51 | for n, p in model.named_parameters(): 52 | _loss = self._precision_matrices[n] * (p - self._means[n]) ** 2 # F_i * (theta_i - theta'_i) ** 2 53 | loss += _loss.sum() 54 | return loss -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/class_boundary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/class_boundary.cpython-36.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/class_boundary.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/class_boundary.cpython-37.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/fwsr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/fwsr.cpython-36.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/fwsr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/fwsr.cpython-37.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/herding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/herding.cpython-36.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/herding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/herding.cpython-37.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/kmeans.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/kmeans.cpython-36.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/__pycache__/kmeans.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/exemplar_strategies/__pycache__/kmeans.cpython-37.pyc -------------------------------------------------------------------------------- /train/exemplar_strategies/class_boundary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors 3 | from .herding import herding_selection 4 | from .kmeans import kmeans_sample 5 | 6 | def get_overlap_region_exemplars(train_dict): 7 | all_features = np.vstack(list(train_dict.values())) 8 | k = int(8 * np.log10(len(all_features))) # math.ceil(exemp_size / 2) 9 | nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(all_features) 10 | 11 | length_by_classes = [len(each) for each in list(train_dict.values())] 12 | list_of_sums = [sum(length_by_classes[:idx + 1]) for idx in range(len(length_by_classes))] 13 | dict_of_sums = {sum_val: key for sum_val, key in zip(list_of_sums, train_dict.keys())} 14 | 15 | label_to_indices_dict = {key: [] for key in train_dict.keys()} 16 | for label, features in train_dict.items(): 17 | for idx, data_point in enumerate(features): 18 | count_of_neighbours = {} 19 | distances, indices = nbrs.kneighbors(data_point.reshape(1, -1)) 20 | for index in indices[0][1:]: # since first index always contains the element itself 21 | # get sum value that is closest to this index 22 | nearest_sum = min(list_of_sums, key = lambda x: (x-index) if x >= index else max(list_of_sums)) 23 | nearest_class = dict_of_sums[nearest_sum] 24 | # print(nearest_class, label) 25 | if nearest_class not in count_of_neighbours: 26 | count_of_neighbours[nearest_class] = 1 27 | else: 28 | count_of_neighbours[nearest_class] += 1 29 | dominant_neighbour = max(count_of_neighbours.items(), key=lambda x: x[1])[0] 30 | if dominant_neighbour == label: 31 | n_c = len(count_of_neighbours) 32 | if n_c > 1: 33 | lamda = 0.3 # suggested value [0.1, 0.3] 34 | del count_of_neighbours[dominant_neighbour] 35 | second_dominant_nbr = max(count_of_neighbours.items(), key=lambda x: x[1])[1] / k 36 | # print(second_dominant_nbr*2, 1/ n_c, 1 / n_c + lamda) 37 | if second_dominant_nbr >= 1 / n_c and second_dominant_nbr <= 1 / n_c + lamda: 38 | label_to_indices_dict[label].append(idx) 39 | 40 | return label_to_indices_dict 41 | 42 | def get_edge_region_exemplars(train_dict): 43 | all_features = np.vstack(list(train_dict.values())) 44 | k = int(8 * np.log10(len(all_features)))#math.ceil(exemp_size / 3) # int(5 * np.log10(len(all_features))) 45 | nbrs = NearestNeighbors(n_neighbors=k, algorithm='auto').fit(all_features) 46 | 47 | length_by_classes = [len(each) for each in list(train_dict.values())] 48 | list_of_sums = [sum(length_by_classes[:idx + 1]) for idx in range(len(length_by_classes))] 49 | dict_of_sums = {sum_val: key for sum_val, key in zip(list_of_sums, train_dict.keys())} 50 | 51 | label_to_indices_dict = {key: [] for key in train_dict.keys()} 52 | for label, features in train_dict.items(): 53 | if len(features) > 1: 54 | k_e = int(3 * np.log(len(features))) # math.ceil(exemp_size / 3) # int(5 * np.log10(len(all_features))) 55 | nbrs_l = NearestNeighbors(n_neighbors=k_e, algorithm='auto').fit(features) 56 | for idx, data_point in enumerate(features): 57 | count_of_neighbours = {} 58 | distances, indices = nbrs.kneighbors(data_point.reshape(1, -1)) 59 | for index in indices[0][1:]: # since first index always contains the element itself 60 | # get sum value that is closest to this index 61 | nearest_sum = min(list_of_sums, key = lambda x: (x-index) if x >= index else max(list_of_sums)) 62 | nearest_class = dict_of_sums[nearest_sum] 63 | # print(nearest_class, label) 64 | if nearest_class not in count_of_neighbours: 65 | count_of_neighbours[nearest_class] = 1 66 | else: 67 | count_of_neighbours[nearest_class] += 1 68 | dominant_neighbour = max(count_of_neighbours.items(), key=lambda x: x[1])[0] 69 | if dominant_neighbour == label: 70 | lamda, gamma = 0.1, 0.25 71 | n_c = len(count_of_neighbours) 72 | if n_c > 1: 73 | del count_of_neighbours[dominant_neighbour] 74 | second_dominant_nbr_score = max(count_of_neighbours.items(), key=lambda x: x[1])[1] / k 75 | if second_dominant_nbr_score*2 > (lamda + 1 / n_c): # eqn. (4) 76 | # check for k_e nearest neighbours that: 77 | distances_, indices_ = nbrs_l.kneighbors(data_point.reshape(1, -1)) 78 | normal_vecs, difference_vecs = [], [] 79 | for index in indices_[0][1:]: 80 | difference = features[index] - data_point 81 | difference_vecs.append(difference) 82 | normal_vecc = difference / np.linalg.norm(difference) 83 | normal_vecs.append(normal_vecc) 84 | sum_normal_vec = np.sum(np.array(normal_vecs), axis=0) 85 | I = 0 86 | for differ in difference_vecs: 87 | theta = differ @ sum_normal_vec 88 | if theta > 0: 89 | I += 1 90 | l_i = I / k_e 91 | if l_i >= (1 - gamma): 92 | label_to_indices_dict[label].append(idx) 93 | return label_to_indices_dict 94 | 95 | def get_interior_region_exemplars(train_dict, dict_of_means, exemp_size_per_class): 96 | label_to_indices_dict = {} 97 | for label, size in exemp_size_per_class.items(): 98 | if size > 0: 99 | mean_of_class = dict_of_means[label] 100 | # top_k_indices = herding_selection(train_dict[label], size, mean_=mean_of_class) 101 | top_k_indices = kmeans_sample(train_dict[label], size) 102 | label_to_indices_dict[label] = top_k_indices 103 | return label_to_indices_dict 104 | 105 | # def get_interior_region_exemplars(train_dict, dict_of_means, exemp_size_per_class): 106 | # label_to_indices_dict = {} 107 | # for label, size in exemp_size_per_class.items(): 108 | # if size > 0: 109 | # mean_of_class = dict_of_means[label] 110 | # cosine_sims = np.array([np.dot(x, mean_of_class) / (np.linalg.norm(x) * np.linalg.norm(mean_of_class)) for 111 | # x in train_dict[label]]) 112 | # try: 113 | # top_k_indices = np.argpartition(cosine_sims, -size)[-size:] # https://stackoverflow.com/questions/6910641/how-do-i-get-indices-of-n-maximum-values-in-a-numpy-array 114 | # label_to_indices_dict[label] = top_k_indices 115 | # except ValueError: 116 | # print(cosine_sims, size, label, len(train_dict[label])) 117 | # return label_to_indices_dict 118 | 119 | # def get_new_exemplars(dict_of_features, dict_of_means, exemp_size): 120 | # normalised_features_dict = {key: feature / np.linalg.norm(feature) for key, feature in dict_of_features.items()} 121 | # overlapping_exemplars_indices = get_overlap_region_exemplars(normalised_features_dict, exemp_size) 122 | # overlapping_exemplars = {label: np.array(features)[overlapping_exemplars_indices[label][:exemp_size]] for label, features in 123 | # dict_of_features.items()} 124 | # filtered_features = {label: np.delete(features, overlapping_exemplars_indices[label][:exemp_size], axis=0) for 125 | # label, features in dict_of_features.items() } 126 | # normalised_features_dict = {key: feature / np.linalg.norm(feature) for key, feature in filtered_features.items()} 127 | # edge_exemplar_indices = get_edge_region_exemplars(normalised_features_dict, exemp_size) 128 | # edge_exemplars = {label: np.array(features)[edge_exemplar_indices[label]] for label, features in 129 | # filtered_features.items()} 130 | # reqd_exemplars_per_class = {key: exemp_size - len(features) for key, features in overlapping_exemplars.items()} 131 | # # print(reqd_exemplars_per_class) 132 | # total_exemplars = {label: np.vstack((overlapping_exemplars[label], edge_exemplars[label][:reqd_exemplars_per_class[label]])) for label in 133 | # dict_of_features.keys()} 134 | # 135 | # filtered_features = {label: np.delete(features, edge_exemplar_indices[label][:exemp_size], axis=0) for 136 | # label, features in filtered_features.items()} 137 | # normalised_features_dict = {key: feature / np.linalg.norm(feature) for key, feature in filtered_features.items()} 138 | # reqd_exemplars_per_class = {key: exemp_size - len(features) for key, features in total_exemplars.items()} 139 | # # print(reqd_exemplars_per_class, exemp_size) 140 | # interior_exemplar_indices = get_interior_region_exemplars(normalised_features_dict, dict_of_means, reqd_exemplars_per_class) 141 | # interior_exemplars = {label: np.array(features)[interior_exemplar_indices[label]] for label, features in filtered_features.items() 142 | # if reqd_exemplars_per_class[label] > 0} 143 | # total_exemplars = {label: np.vstack((interior_exemplars[label], total_exemplars[label])) if 144 | # reqd_exemplars_per_class[label] > 0 else total_exemplars[label] for label in dict_of_features.keys()} 145 | # # print({key: len(features) for key, features in total_exemplars.items()}) 146 | # return total_exemplars 147 | 148 | def get_new_exemplars(dict_of_features, normalised_features_dict, dict_of_means, exemp_size_dict): 149 | overlapping_exemplars_indices = get_overlap_region_exemplars(normalised_features_dict) 150 | overlapping_exemplars = {label: np.array(features)[overlapping_exemplars_indices[label][:exemp_size_dict[label]]] for label, features in 151 | dict_of_features.items()} 152 | filtered_features = {label: np.delete(features, overlapping_exemplars_indices[label][:exemp_size_dict[label]], axis=0) for 153 | label, features in dict_of_features.items()} 154 | normalised_features_dict = {label: np.delete(features, overlapping_exemplars_indices[label][:exemp_size_dict[label]], axis=0) for 155 | label, features in normalised_features_dict.items()} 156 | 157 | edge_exemplar_indices = get_edge_region_exemplars(normalised_features_dict) 158 | edge_exemplars = {label: np.array(features)[edge_exemplar_indices[label]] for label, features in 159 | filtered_features.items()} 160 | reqd_exemplars_per_class = {key: exemp_size_dict[key] - len(features) for key, features in overlapping_exemplars.items()} 161 | # print(reqd_exemplars_per_class) 162 | total_exemplars = {label: np.vstack((overlapping_exemplars[label], edge_exemplars[label][:reqd_exemplars_per_class[label]])) for label in 163 | dict_of_features.keys()} 164 | 165 | filtered_features = {label: np.delete(features, edge_exemplar_indices[label][:exemp_size_dict[label]], axis=0) for 166 | label, features in filtered_features.items()} 167 | normalised_features_dict = {label: np.delete(features, edge_exemplar_indices[label][:exemp_size_dict[label]], axis=0) for 168 | label, features in normalised_features_dict.items()} 169 | reqd_exemplars_per_class = {key: exemp_size_dict[key] - len(features) for key, features in total_exemplars.items()} 170 | interior_exemplar_indices = get_interior_region_exemplars(normalised_features_dict, dict_of_means, reqd_exemplars_per_class) 171 | interior_exemplars = {label: np.array(features)[interior_exemplar_indices[label]] for label, features in filtered_features.items() 172 | if reqd_exemplars_per_class[label] > 0} 173 | total_exemplars = {label: np.vstack((interior_exemplars[label], total_exemplars[label])) if 174 | reqd_exemplars_per_class[label] > 0 else total_exemplars[label] for label in dict_of_features.keys()} 175 | return total_exemplars -------------------------------------------------------------------------------- /train/exemplar_strategies/fwsr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | 4 | 5 | def FWSR_identify_exemplars(beta, A, K, max_iterations, num_exemp, greedy=True, verbose=True, zeta=0, epsilon=0, 6 | order=2, positive=False, term_thres=1e-8): 7 | """ 8 | zeta, epsilon, order, beta correspond to the variables of the following problem 9 | minimize_X || A @ X - A ||_F**2 + ||zeta(X.T @ 1 - 1)||_2 ** 2 + epsilon ||X||_F^2 10 | s.t. ||X||_{1, order} = sum_{i} ||X^(i)||_order <= beta 11 | where X^(i) denotes the ith row of X 12 | order can be equal to 1, 2, or infinity (corresponding to the l1, l2, or l-infinity norm) 13 | K = A.T @ A 14 | if greedy == True, then the method will terminate if and when the number of non-zero rows of X is >= num_exemp 15 | otherwise, the method will run until either max_iterations is hit or termination condition is reached (dictated by term_thres) 16 | If postive == True, then method will optimize the obove objective with the added constraint: X >=0 elementwise 17 | see https://arxiv.org/abs/1811.02702 for more details about the algorithm itself 18 | """ 19 | max_iterations += 1 # this is useful for consistency in testing 20 | if zeta != 0: 21 | K += zeta ** 2 22 | 23 | trK = np.trace(K) 24 | n, m = A.shape 25 | X = np.zeros((m, m)) 26 | 27 | exemplar_index_lst = [] 28 | cost_lst = [] 29 | G_lst = [] 30 | 31 | prev_KX = 0 32 | S = None 33 | step_size = None 34 | max_index = None 35 | trXTK = 0 36 | trXTKX = 0 37 | 38 | D = 0 39 | trSTKS = 0 40 | trSTKX = 0 41 | trSTK = 0 42 | step_size = 0 43 | 44 | row_norm_X = np.linalg.norm(X, axis=1, ord=2) 45 | exemplar_index_lst = np.where(row_norm_X != 0)[0] 46 | len_of_exemplar_index_lst = [] 47 | 48 | pbar = tqdm(total=int(max_iterations), unit="iter", unit_scale=False, leave=False, disable=not verbose) 49 | 50 | for iteration in range(int(max_iterations)): 51 | if greedy and len(exemplar_index_lst) >= num_exemp: 52 | pbar.update(int(max_iterations) - iteration) 53 | break 54 | pbar.set_postfix(num_ex=len_of_exemplar_index_lst[-1] if len(len_of_exemplar_index_lst) > 0 else 0, 55 | tol=G_lst[-1] if len(G_lst) > 0 else np.inf, refresh=False) 56 | pbar.update(1) 57 | 58 | X += step_size * D 59 | row_norm_X = np.linalg.norm(X, axis=1, ord=2) 60 | exemplar_index_lst = np.where(row_norm_X != 0)[0] 61 | len_of_exemplar_index_lst.append(len(exemplar_index_lst)) 62 | 63 | if len(exemplar_index_lst) == 0: 64 | KX = np.zeros((m, m)) 65 | elif step_size is None or prev_KX is None or S is None or max_index is None: 66 | # print("SHOULD NEVER GET HERE") # this case is just defensive programming 67 | KX = K[:, exemplar_index_lst].dot(X[exemplar_index_lst]) 68 | elif max_index == -1: 69 | # print("max index is -1") 70 | KX = prev_KX * (1 - step_size) 71 | KX1 = K[:, exemplar_index_lst].dot(X[exemplar_index_lst]) 72 | assert np.all(np.isclose(KX, KX1, atol=0)) 73 | else: 74 | KX = prev_KX * (1 - step_size) + step_size * np.outer(K[:, max_index], S[max_index]) 75 | 76 | prev_KX = KX 77 | trXTK = step_size * trSTK + (1 - step_size) * trXTK 78 | trXTKX = (1 - step_size) ** 2 * trXTKX + step_size ** 2 * trSTKS + 2 * step_size * (1 - step_size) * trSTKX 79 | cost_lst.append(trXTKX - 2 * trXTK + trK) 80 | 81 | if epsilon == 0: 82 | gradient = KX - K # with respect to Z 83 | else: 84 | gradient = KX - K + epsilon * X # with respect to Z 85 | 86 | max_index = get_max_index(gradient=gradient, order=order, positive=positive) # next index to update 87 | if max_index == -1 and positive: 88 | S = np.zeros((m, m)) 89 | D = -X 90 | numerator = - trXTK + trXTKX 91 | denominator = trXTKX 92 | else: 93 | gradient_max_row = gradient[max_index].flatten() 94 | S = np.zeros((m, m)) 95 | S[max_index] = make_S_row(gradient_max_row=gradient_max_row, beta=beta, order=order, positive=positive) 96 | D = S - X 97 | 98 | trSTK = np.inner(S[max_index], K[max_index]) 99 | trSTKS = K[max_index, max_index] * np.inner(S[max_index], S[max_index]) 100 | trSTKX = np.inner(S[max_index], KX[max_index]) 101 | 102 | numerator = trSTK - trXTK - trSTKX + trXTKX 103 | denominator = trSTKS - 2 * trSTKX + trXTKX 104 | 105 | G = -2 * np.einsum("ij, ij ->", gradient, D) 106 | G_lst.append(G) 107 | 108 | if G < term_thres: 109 | # myprint("EARLY TERMINATION", verbose) 110 | break 111 | 112 | step_size = max(0, min(1, numerator / denominator)) 113 | 114 | pbar.close() 115 | if not greedy and num_exemp is None: 116 | exemplar_indices = exemplar_index_lst 117 | elif not greedy: 118 | exemplar_indices = make_exemplar_indices(X.T, num_exemp) 119 | else: 120 | exemplar_indices = exemplar_index_lst 121 | # if len(exemplar_indices) < num_exemp: 122 | # myprint("ALERT: less than num_exemp were selected: " + str(len(exemplar_indices))) 123 | 124 | return exemplar_indices 125 | 126 | 127 | def get_max_index(gradient, order, positive): 128 | if positive: 129 | if np.all(gradient >= 0): 130 | return -1 131 | gradient = np.where(gradient < 0, gradient, 0) 132 | 133 | if order == 2: 134 | return np.argmax(np.linalg.norm(gradient, axis=1, ord=2)) 135 | elif np.isinf(order): 136 | return np.argmax(np.linalg.norm(gradient, axis=1, ord=1)) 137 | elif order == 1: 138 | return np.argmax(np.linalg.norm(gradient, axis=1, ord=np.inf)) 139 | raise Exception("Improper ord arguement; ord = " + str(ord)) 140 | 141 | 142 | def make_S_row(gradient_max_row, beta, order, positive): 143 | if positive: 144 | return make_S_row_positive(gradient_max_row, beta, order) 145 | 146 | if order == 2: 147 | if np.linalg.norm(gradient_max_row, ord=2) == 0: 148 | val = np.zeros_like(gradient_max_row) 149 | val[0] = beta 150 | return val 151 | return -1 * gradient_max_row / np.linalg.norm(gradient_max_row, ord=2) * beta + 0. 152 | if np.isinf(order): 153 | sign_vec = np.sign(gradient_max_row) 154 | sign_vec[sign_vec == 0] = 1 # this is just to make sure a vertex of the ball is selected 155 | return -1 * sign_vec * beta + 0. 156 | if order == 1: 157 | max_index = np.argmax(np.abs(gradient_max_row)) 158 | max_sign = np.sign(gradient_max_row[max_index]) 159 | 160 | if max_sign == 0: 161 | max_sign = 1 # this is just to make sure a vertex of the ball is selected 162 | 163 | return_vec = np.zeros_like(gradient_max_row) 164 | return_vec[max_index] = -1 * max_sign * beta 165 | return return_vec + 0. 166 | 167 | 168 | def make_S_row_positive(gradient_max_row, beta, order): 169 | gradient_max_row = np.where(gradient_max_row < 0, gradient_max_row, 0.) 170 | if order == 2: 171 | return -1 * gradient_max_row / np.linalg.norm(gradient_max_row, ord=2) * beta + 0. 172 | if np.isinf(order): 173 | sign_vec = np.sign(gradient_max_row) 174 | return -1 * sign_vec * beta + 0. 175 | if order == 1: 176 | max_index = np.argmax(np.abs(gradient_max_row)) 177 | max_sign = np.sign(gradient_max_row[max_index]) 178 | 179 | return_vec = np.zeros_like(gradient_max_row) 180 | return_vec[max_index] = -1 * max_sign * beta 181 | return return_vec + 0. 182 | 183 | 184 | def compute_inner_product_of_S_max_row(m, beta, order): 185 | """ 186 | To compute the optimal step size, one of the trace terms need the inner product of s_max^T s_max 187 | This calculation depends on the order of the group lasso ball. 188 | """ 189 | if order == 2 or order == 1: 190 | return beta ** 2 191 | elif np.isinf(order): 192 | return (beta ** 2) * m 193 | 194 | 195 | def make_exemplar_indices(Z, num_exemp): 196 | """ 197 | horizontal_norms refers to the horizontal norms of ZT which are the vertical norms of Z 198 | """ 199 | horizontal_norms = np.linalg.norm(Z, ord=2, axis=0) 200 | total_norm_sum = np.sum(horizontal_norms) 201 | sorted_indices = np.flipud(np.argsort(horizontal_norms))[:num_exemp] 202 | 203 | m = Z.shape[0] 204 | 205 | # don't pick coefficients that aren't used at all 206 | last_index = num_exemp 207 | for idx in range(len(sorted_indices)): 208 | og_idx = sorted_indices[idx] 209 | if horizontal_norms[og_idx] == 0.0: 210 | last_index = idx 211 | myprint("ALERT: less than num_exemp were selected") 212 | break 213 | 214 | return sorted_indices[:last_index] 215 | 216 | 217 | def fw_objective(AX, X): 218 | return np.linalg.norm(AX - A, ord="fro") ** 2 219 | 220 | 221 | def myprint(s, to_print=True): 222 | if to_print: 223 | print(s) -------------------------------------------------------------------------------- /train/exemplar_strategies/herding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def herding_selection(x, m, mean_=None): 4 | """ 5 | Source: https://github.com/PatrickZH/End-to-End-Incremental-Learning/blob/39d6f4e594e805a713aa7a1deedbcb03d1f2c9cc/utils.py#L176 6 | Parameters 7 | ---------- 8 | x: the features, n * dimension 9 | m: the number of selected exemplars 10 | Returns 11 | ---------- 12 | pos_s: the position of selected exemplars 13 | """ 14 | 15 | pos_s = [] 16 | comb = 0 17 | mu = np.mean(x, axis=0, keepdims=False) if mean_ is None else mean_ 18 | for k in range(m): 19 | det = mu * (k + 1) - comb 20 | dist = np.zeros(shape=(np.shape(x)[0])) 21 | for i in range(np.shape(x)[0]): 22 | if i in pos_s: 23 | dist[i] = np.inf 24 | else: 25 | dist[i] = np.linalg.norm(det - x[i]) 26 | pos = np.argmin(dist) 27 | pos_s.append(pos) 28 | comb += x[pos] 29 | 30 | return pos_s -------------------------------------------------------------------------------- /train/exemplar_strategies/kmeans.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | from sklearn.metrics import pairwise_distances_argmin_min 3 | 4 | 5 | def kmeans_sample(X, k): 6 | kmeans = KMeans(n_clusters=k, random_state=10).fit(X) 7 | closest, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, X) 8 | return closest.tolist() 9 | -------------------------------------------------------------------------------- /train/fisher_matrix.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import train.losses as losses 3 | 4 | embeddings = [] 5 | def get_embeddings(self, inputs, outputs): 6 | global embeddings 7 | embeddings = inputs[0] 8 | 9 | def fisher_matrix_diag(model, loss_name, train_loader, device): 10 | criterion = losses.create(loss_name).to(device) 11 | 12 | model.fc.register_forward_hook(get_embeddings) # Init 13 | fisher = {} 14 | for n, p in model.named_parameters(): 15 | fisher[n] = 0*p.data 16 | 17 | model.train() 18 | count = 0 19 | for i, data in enumerate(train_loader, 0): 20 | count += 1 21 | inputs, labels = data 22 | # wrap them in Variable 23 | inputs = Variable(inputs).to(device) 24 | labels = Variable(labels).to(device) 25 | 26 | # Forward and backward 27 | model.zero_grad() 28 | outputs = model(inputs.float()) 29 | if loss_name == 'msloss': 30 | loss = criterion(embeddings, labels, device) 31 | else: 32 | loss, _, _, _ = criterion(embeddings, labels, device) 33 | loss.backward() 34 | 35 | for n, p in model.named_parameters(): 36 | if p.grad is not None: 37 | fisher[n] += p.grad.data.pow(2) 38 | 39 | for n, _ in model.named_parameters(): 40 | fisher[n] = fisher[n]/float(count) 41 | fisher[n] = Variable(fisher[n], requires_grad=False) 42 | return fisher 43 | -------------------------------------------------------------------------------- /train/gem_solver.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import quadprog 4 | import torch 5 | 6 | 7 | def project2cone2(gradient, memories, device, margin=0.5, eps=1e-3): 8 | """ 9 | Source: https://github.com/ElectronicTomato/continue_leanrning_agem/blob/master/agents/exp_replay.py#L317 10 | Solves the GEM dual QP described in the paper given a proposed 11 | gradient "gradient", and a memory of task gradients "memories". 12 | Overwrites "gradient" with the final projected update. 13 | input: gradient, p-vector 14 | input: memories, (t * p)-vector 15 | output: x, p-vector 16 | Modified from: https://github.com/facebookresearch/GradientEpisodicMemory/blob/master/model/gem.py#L70 17 | """ 18 | memories_np = memories.cpu().contiguous().double().numpy() 19 | gradient_np = gradient.cpu().contiguous().view(-1).double().numpy() 20 | t = memories_np.shape[0] 21 | # print(memories_np.shape, gradient_np.shape) 22 | P = np.dot(memories_np, memories_np.transpose()) 23 | P = 0.5 * (P + P.transpose()) 24 | q = np.dot(memories_np, gradient_np) * -1 25 | G = np.eye(t) 26 | P = P + G * eps 27 | h = np.zeros(t) + margin 28 | v = quadprog.solve_qp(P, q, G, h)[0] 29 | x = np.dot(v, memories_np) + gradient_np 30 | new_grad = torch.Tensor(x).view(-1) 31 | new_grad = new_grad.to(device) 32 | return new_grad -------------------------------------------------------------------------------- /train/losses/BinBranchLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from .BinDevianceLoss import BinDevianceLoss 5 | import numpy as np 6 | 7 | 8 | def similarity(inputs_): 9 | # Compute similarity mat of deep feature 10 | # n = inputs_.size(0) 11 | sim = torch.matmul(inputs_, inputs_.t()) 12 | return sim 13 | 14 | 15 | class BinBranchLoss(nn.Module): 16 | def __init__(self, margin=0.5, slice=[0, 170, 341, 512]): 17 | super(BinBranchLoss, self).__init__() 18 | self.s = slice 19 | self.margin = margin 20 | 21 | def forward(self, inputs, targets): 22 | inputs = [inputs[:, self.s[i]:self.s[i+1]] 23 | for i in range(len(self.s)-1)] 24 | loss_list, prec_list, pos_d_list, neg_d_list = [], [], [], [] 25 | 26 | for input in inputs: 27 | loss, prec, pos_d, neg_d = BinDevianceLoss(margin=self.margin)(input, targets) 28 | loss_list.append(loss) 29 | prec_list.append(prec) 30 | pos_d_list.append(pos_d) 31 | neg_d_list.append(neg_d) 32 | 33 | loss = torch.mean(torch.cat(loss_list)) 34 | prec = np.mean(prec_list) 35 | pos_d = np.mean((pos_d_list)) 36 | neg_d = np.mean((neg_d_list)) 37 | 38 | return loss, prec, pos_d, neg_d -------------------------------------------------------------------------------- /train/losses/BinDevianceLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | def similarity(inputs_): 8 | # Compute similarity mat of deep feature 9 | # n = inputs_.size(0) 10 | sim = torch.matmul(inputs_, inputs_.t()) 11 | return sim 12 | 13 | 14 | class BinDevianceLoss(nn.Module): 15 | def __init__(self, margin=0.5): 16 | super(BinDevianceLoss, self).__init__() 17 | self.margin = margin 18 | 19 | def forward(self, inputs, targets): 20 | n = inputs.size(0) 21 | # Compute similarity matrix 22 | sim_mat = similarity(inputs) 23 | # print(sim_mat) 24 | targets = targets.cuda() 25 | # split the positive and negative pairs 26 | eyes_ = Variable(torch.eye(n, n)).cuda() 27 | # eyes_ = Variable(torch.eye(n, n)) 28 | pos_mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 29 | neg_mask = eyes_.eq(eyes_) - pos_mask 30 | pos_mask = pos_mask - eyes_.eq(1) 31 | 32 | pos_sim = torch.masked_select(sim_mat, pos_mask) 33 | neg_sim = torch.masked_select(sim_mat, neg_mask) 34 | 35 | num_instances = len(pos_sim)//n + 1 36 | num_neg_instances = n - num_instances 37 | 38 | pos_sim = pos_sim.resize(len(pos_sim)//(num_instances-1), num_instances-1) 39 | neg_sim = neg_sim.resize( 40 | len(neg_sim) // num_neg_instances, num_neg_instances) 41 | 42 | # clear way to compute the loss first 43 | loss = list() 44 | c = 0 45 | 46 | for i, pos_pair in enumerate(pos_sim): 47 | # print(i) 48 | pos_pair = torch.sort(pos_pair)[0] 49 | neg_pair = torch.sort(neg_sim[i])[0] 50 | 51 | neg_pair = torch.masked_select(neg_pair, neg_pair > pos_pair[0] - 0.05) 52 | # pos_pair = pos_pair[1:] 53 | if len(neg_pair) < 1: 54 | c += 1 55 | continue 56 | 57 | neg_pair = torch.sort(neg_pair)[0] 58 | 59 | if i == 1 and np.random.randint(199) == 1: 60 | print('neg_pair is ---------', neg_pair) 61 | print('pos_pair is ---------', pos_pair.data) 62 | pos_loss = torch.mean(torch.log(1 + torch.exp(-2*(pos_pair - self.margin)))) 63 | neg_loss = 0.04*torch.mean(torch.log(1 + torch.exp(50*(neg_pair - self.margin)))) 64 | loss.append(pos_loss + neg_loss) 65 | 66 | loss = torch.sum(torch.cat(loss))/n 67 | 68 | prec = float(c)/n 69 | neg_d = torch.mean(neg_sim).data[0] 70 | pos_d = torch.mean(pos_sim).data[0] 71 | 72 | return loss, prec, pos_d, neg_d 73 | -------------------------------------------------------------------------------- /train/losses/MultiClassCrossEntropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | def MultiClassCrossEntropy(logits, labels, T, device): 6 | """ 7 | Source: https://github.com/ngailapdi/LWF/blob/baa07ee322d4b2f93a28eba092ad37379f565aca/model.py#L16 8 | :param logits: output logits of the model 9 | :param labels: ground truth labels 10 | :param T: temperature scaler 11 | :return: the loss value wrapped in torch.autograd.Variable 12 | """ 13 | labels = Variable(labels.data, requires_grad=False).to(device) 14 | outputs = torch.log_softmax(logits / T, dim=1) # compute the log of softmax values 15 | labels = torch.softmax(labels / T, dim=1) 16 | # print('outputs: ', outputs) 17 | # print('labels: ', labels.shape) 18 | outputs = torch.sum(outputs * labels, dim=1, keepdim=False) 19 | outputs = -torch.mean(outputs, dim=0, keepdim=False) 20 | # print('OUT: ', outputs) 21 | return Variable(outputs.data, requires_grad=True).to(device) 22 | -------------------------------------------------------------------------------- /train/losses/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: https://github.com/yulu0724/SDC-IL/blob/master/losses/__init__.py 3 | """ 4 | 5 | 6 | from .BinBranchLoss import BinBranchLoss 7 | from .BinDevianceLoss import BinDevianceLoss 8 | from .msloss import MultiSimilarityLoss 9 | from .angular import AngularLoss 10 | 11 | __factory = { 12 | 'binbranch': BinBranchLoss, 13 | 'bin': BinDevianceLoss, 14 | 'msloss': MultiSimilarityLoss, 15 | 'angular': AngularLoss 16 | } 17 | 18 | def names(): 19 | return sorted(__factory.keys()) 20 | 21 | 22 | def create(name, *args, **kwargs): 23 | """ 24 | Create a loss instance. 25 | Parameters 26 | ---------- 27 | name : str 28 | the name of loss function 29 | """ 30 | if name not in __factory: 31 | raise KeyError("Unknown loss:", name) 32 | return __factory[name](*args, **kwargs) 33 | -------------------------------------------------------------------------------- /train/losses/__pycache__/BinBranchLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/BinBranchLoss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/BinBranchLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/BinBranchLoss.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/BinDevianceLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/BinDevianceLoss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/BinDevianceLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/BinDevianceLoss.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/MultiClassCrossEntropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/MultiClassCrossEntropy.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/MultiClassCrossEntropy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/MultiClassCrossEntropy.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/amsoftmax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/amsoftmax.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/angular.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/angular.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/angular.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/angular.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/class_balanced_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/class_balanced_loss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/class_balanced_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/class_balanced_loss.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/margin_ranking_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/margin_ranking_loss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/margin_ranking_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/margin_ranking_loss.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/msloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/msloss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/msloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/msloss.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/prob_diff_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/prob_diff_loss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/prob_diff_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/prob_diff_loss.cpython-37.pyc -------------------------------------------------------------------------------- /train/losses/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/losses/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /train/losses/amsoftmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AdMSoftmaxLoss(nn.Module): 7 | 8 | def __init__(self, in_features, out_features, s=30.0, m=0.4): 9 | ''' 10 | AM Softmax Loss: https://github.com/Leethony/Additive-Margin-Softmax-Loss-Pytorch/blob/master/AdMSLoss.py 11 | ''' 12 | super(AdMSoftmaxLoss, self).__init__() 13 | self.s = s 14 | self.m = m 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | self.fc = nn.Linear(in_features, out_features, bias=False) 18 | 19 | def forward(self, x, labels): 20 | ''' 21 | input shape (N, in_features) 22 | ''' 23 | assert len(x) == len(labels) 24 | assert torch.min(labels) >= 0 25 | assert torch.max(labels) < self.out_features 26 | 27 | for W in self.fc.parameters(): 28 | W = F.normalize(W, dim=1) 29 | 30 | x = F.normalize(x, dim=1) 31 | 32 | wf = self.fc(x) 33 | numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m) 34 | excl = torch.cat([torch.cat((wf[i, :y], wf[i, y + 1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0) 35 | denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1) 36 | L = numerator - torch.log(denominator) 37 | return -torch.mean(L) -------------------------------------------------------------------------------- /train/losses/angular.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | # Constants 7 | N_PAIR = 'n-pair' 8 | ANGULAR = 'angular' 9 | N_PAIR_ANGULAR = 'n-pair-angular' 10 | MAIN_LOSS_CHOICES = (N_PAIR, ANGULAR, N_PAIR_ANGULAR) 11 | 12 | CROSS_ENTROPY = 'cross-entropy' 13 | 14 | 15 | class BlendedLoss(object): 16 | def __init__(self, main_loss_type, cross_entropy_flag): 17 | super(BlendedLoss, self).__init__() 18 | self.main_loss_type = main_loss_type 19 | assert main_loss_type in MAIN_LOSS_CHOICES, "invalid main loss: %s" % main_loss_type 20 | 21 | self.metrics = [] 22 | if self.main_loss_type == N_PAIR: 23 | self.main_loss_fn = NPairLoss() 24 | elif self.main_loss_type == ANGULAR: 25 | self.main_loss_fn = AngularLoss() 26 | elif self.main_loss_type == N_PAIR_ANGULAR: 27 | self.main_loss_fn = NPairAngularLoss() 28 | else: 29 | raise ValueError 30 | 31 | self.cross_entropy_flag = cross_entropy_flag 32 | self.lambda_blending = 0 33 | if cross_entropy_flag: 34 | self.cross_entropy_loss_fn = nn.CrossEntropyLoss() 35 | self.lambda_blending = 0.3 36 | 37 | def calculate_loss(self, target, output_embedding, output_cross_entropy=None): 38 | if target is not None: 39 | target = (target, ) 40 | 41 | loss_dict = {} 42 | blended_loss = 0 43 | if self.cross_entropy_flag: 44 | assert output_cross_entropy is not None, "Outputs for cross entropy loss is needed" 45 | 46 | loss_inputs = self._gen_loss_inputs(target, output_cross_entropy) 47 | cross_entropy_loss = self.cross_entropy_loss_fn(*loss_inputs) 48 | blended_loss += self.lambda_blending * cross_entropy_loss 49 | loss_dict[CROSS_ENTROPY + '-loss'] = [cross_entropy_loss.item()] 50 | 51 | loss_inputs = self._gen_loss_inputs(target, output_embedding) 52 | main_loss_outputs = self.main_loss_fn(*loss_inputs) 53 | main_loss = main_loss_outputs[0] if type(main_loss_outputs) in (tuple, list) else main_loss_outputs 54 | blended_loss += (1 - self.lambda_blending) * main_loss 55 | loss_dict[self.main_loss_type + '-loss'] = [main_loss.item()] 56 | 57 | for metric in self.metrics: 58 | metric(output_embedding, target, main_loss_outputs) 59 | 60 | return blended_loss, loss_dict 61 | 62 | @staticmethod 63 | def _gen_loss_inputs(target, embedding): 64 | if type(embedding) not in (tuple, list): 65 | embedding = (embedding, ) 66 | loss_inputs = embedding 67 | if target is not None: 68 | if type(target) not in (tuple, list): 69 | target = (target, ) 70 | loss_inputs += target 71 | return loss_inputs 72 | 73 | 74 | class NPairLoss(nn.Module): 75 | """ 76 | N-Pair loss 77 | Sohn, Kihyuk. "Improved Deep Metric Learning with Multi-class N-pair Loss Objective," Advances in Neural Information 78 | Processing Systems. 2016. 79 | http://papers.nips.cc/paper/6199-improved-deep-metric-learning-with-multi-class-n-pair-loss-objective 80 | """ 81 | def __init__(self, l2_reg=0.02, **kwargs): 82 | super(NPairLoss, self).__init__() 83 | self.l2_reg = l2_reg 84 | 85 | def forward(self, embeddings, target): 86 | n_pairs, n_negatives = self.get_n_pairs(target) 87 | 88 | if embeddings.is_cuda: 89 | n_pairs = n_pairs.cuda() 90 | n_negatives = n_negatives.cuda() 91 | 92 | anchors = embeddings[n_pairs[:, 0]] # (n, embedding_size) 93 | positives = embeddings[n_pairs[:, 1]] # (n, embedding_size) 94 | negatives = embeddings[n_negatives] # (n, n-1, embedding_size) 95 | 96 | losses = self.n_pair_loss(anchors, positives, negatives) \ 97 | + self.l2_reg * self.l2_loss(anchors, positives) 98 | 99 | return losses 100 | 101 | @staticmethod 102 | def get_n_pairs(labels): 103 | """ 104 | Get index of n-pairs and n-negatives 105 | :param labels: label vector of mini-batch 106 | :return: A tuple of n_pairs (n, 2) 107 | and n_negatives (n, n-1) 108 | """ 109 | labels = labels.cpu().data.numpy() 110 | n_pairs = [] 111 | for label in set(labels): 112 | label_mask = (labels == label) 113 | label_indices = np.where(label_mask)[0] 114 | if len(label_indices) < 2: 115 | continue 116 | anchor, positive = np.random.choice(label_indices, 2, replace=False) 117 | n_pairs.append([anchor, positive]) 118 | 119 | n_pairs = np.array(n_pairs) 120 | n_negatives = [] 121 | for i in range(len(n_pairs)): 122 | negative = np.concatenate([n_pairs[:i, 1], n_pairs[i + 1:, 1]]) 123 | n_negatives.append(negative) 124 | 125 | n_negatives = np.array(n_negatives) 126 | return torch.LongTensor(n_pairs), torch.LongTensor(n_negatives) 127 | 128 | @staticmethod 129 | def n_pair_loss(anchors, positives, negatives): 130 | """ 131 | Calculates N-Pair loss 132 | :param anchors: A torch.Tensor, (n, embedding_size) 133 | :param positives: A torch.Tensor, (n, embedding_size) 134 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 135 | :return: A scalar 136 | """ 137 | anchors = torch.unsqueeze(anchors, dim=1) # (n, 1, embedding_size) 138 | positives = torch.unsqueeze(positives, dim=1) # (n, 1, embedding_size) 139 | 140 | x = torch.matmul(anchors, (negatives - positives).transpose(1, 2)) # (n, 1, n-1) 141 | x = torch.sum(torch.exp(x), 2) # (n, 1) 142 | loss = torch.mean(torch.log(1 + x)) 143 | return loss 144 | 145 | @staticmethod 146 | def l2_loss(anchors, positives): 147 | """ 148 | Calculates L2 norm regularization loss 149 | :param anchors: A torch.Tensor, (n, embedding_size) 150 | :param positives: A torch.Tensor, (n, embedding_size) 151 | :return: A scalar 152 | """ 153 | return torch.sum(anchors**2 + positives**2) / anchors.shape[0] 154 | 155 | 156 | class AngularLoss(NPairLoss): 157 | """ 158 | Angular loss 159 | Wang, Jian. "Deep Metric Learning with Angular Loss," CVPR, 2017 160 | https://arxiv.org/pdf/1708.01682.pdf 161 | """ 162 | def __init__(self, l2_reg=0.02, angle_bound=1., lambda_ang=2, **kwargs): 163 | super(AngularLoss, self).__init__() 164 | self.l2_reg = l2_reg 165 | self.angle_bound = angle_bound 166 | self.lambda_ang = lambda_ang 167 | self.softplus = nn.Softplus() 168 | 169 | def forward(self, embeddings, target): 170 | n_pairs, n_negatives = self.get_n_pairs(target) 171 | if embeddings.is_cuda: 172 | n_pairs = n_pairs.cuda() 173 | n_negatives = n_negatives.cuda() 174 | 175 | anchors = embeddings[n_pairs[:, 0]] # (n, embedding_size) 176 | positives = embeddings[n_pairs[:, 1]] # (n, embedding_size) 177 | negatives = embeddings[n_negatives] # (n, n-1, embedding_size) 178 | 179 | losses = self.angular_loss(anchors, positives, negatives, self.angle_bound) \ 180 | + self.l2_reg * self.l2_loss(anchors, positives) 181 | 182 | return losses, 0, 0, 0 183 | 184 | @staticmethod 185 | def angular_loss(anchors, positives, negatives, angle_bound=1.): 186 | """ 187 | Calculates angular loss 188 | :param anchors: A torch.Tensor, (n, embedding_size) 189 | :param positives: A torch.Tensor, (n, embedding_size) 190 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 191 | :param angle_bound: tan^2 angle 192 | :return: A scalar 193 | """ 194 | print(anchors, positives, negatives);exit(1) 195 | anchors = torch.unsqueeze(anchors, dim=1) # (n, 1, embedding_size) 196 | positives = torch.unsqueeze(positives, dim=1) # (n, 1, embedding_size) 197 | 198 | x = 4. * angle_bound * torch.matmul((anchors + positives), negatives.transpose(1, 2)) \ 199 | - 2. * (1. + angle_bound) * torch.matmul(anchors, positives.transpose(1, 2)) # (n, 1, n-1) 200 | 201 | # Preventing overflow 202 | with torch.no_grad(): 203 | t = torch.max(x, dim=2)[0] 204 | 205 | x = torch.exp(x - t.unsqueeze(dim=1)) 206 | x = torch.log(torch.exp(-t) + torch.sum(x, 2)) 207 | loss = torch.mean(t + x) 208 | 209 | return loss 210 | 211 | 212 | class NPairAngularLoss(AngularLoss): 213 | """ 214 | Angular loss 215 | Wang, Jian. "Deep Metric Learning with Angular Loss," CVPR, 2017 216 | https://arxiv.org/pdf/1708.01682.pdf 217 | """ 218 | def __init__(self, l2_reg=0.02, angle_bound=1., lambda_ang=2, **kwargs): 219 | super(NPairAngularLoss, self).__init__() 220 | self.l2_reg = l2_reg 221 | self.angle_bound = angle_bound 222 | self.lambda_ang = lambda_ang 223 | 224 | def forward(self, embeddings, target): 225 | n_pairs, n_negatives = self.get_n_pairs(target) 226 | 227 | if embeddings.is_cuda: 228 | n_pairs = n_pairs.cuda() 229 | n_negatives = n_negatives.cuda() 230 | 231 | anchors = embeddings[n_pairs[:, 0]] # (n, embedding_size) 232 | positives = embeddings[n_pairs[:, 1]] # (n, embedding_size) 233 | negatives = embeddings[n_negatives] # (n, n-1, embedding_size) 234 | 235 | losses = self.n_pair_angular_loss(anchors, positives, negatives, self.angle_bound) \ 236 | + self.l2_reg * self.l2_loss(anchors, positives) 237 | 238 | return losses, 0, 0, 0 239 | 240 | def n_pair_angular_loss(self, anchors, positives, negatives, angle_bound=1.): 241 | """ 242 | Calculates N-Pair angular loss 243 | :param anchors: A torch.Tensor, (n, embedding_size) 244 | :param positives: A torch.Tensor, (n, embedding_size) 245 | :param negatives: A torch.Tensor, (n, n-1, embedding_size) 246 | :param angle_bound: tan^2 angle 247 | :return: A scalar, n-pair_loss + lambda * angular_loss 248 | """ 249 | n_pair = self.n_pair_loss(anchors, positives, negatives) 250 | angular = self.angular_loss(anchors, positives, negatives, angle_bound) 251 | 252 | return (n_pair + self.lambda_ang * angular) / (1 + self.lambda_ang) -------------------------------------------------------------------------------- /train/losses/class_balanced_loss.py: -------------------------------------------------------------------------------- 1 | """Pytorch implementation of Class-Balanced-Loss 2 | Reference: "Class-Balanced Loss Based on Effective Number of Samples" 3 | Authors: Yin Cui and 4 | Menglin Jia and 5 | Tsung Yi Lin and 6 | Yang Song and 7 | Serge J. Belongie 8 | https://arxiv.org/abs/1901.05555, CVPR'19. 9 | """ 10 | 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | beta, gamma = 0.999, 0.5 17 | 18 | def focal_loss(labels, logits, alpha, no_of_classes): 19 | """Compute the focal loss between `logits` and the ground truth `labels`. 20 | Focal loss = -alpha_t * (1-pt)^gamma * log(pt) 21 | where pt is the probability of being classified to the true class. 22 | pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit). 23 | Args: 24 | labels: A float tensor of size [batch, num_classes]. 25 | logits: A float tensor of size [batch, num_classes]. 26 | alpha: A float tensor of size [batch_size] 27 | specifying per-example weight for balanced cross entropy. 28 | gamma: A float scalar modulating loss from hard and easy examples. 29 | Returns: 30 | focal_loss: A float32 scalar representing normalized total loss. 31 | """ 32 | BCLoss = F.binary_cross_entropy_with_logits(input = logits[:, :no_of_classes], target = labels,reduction = "none") 33 | 34 | if gamma == 0.0: 35 | modulator = 1.0 36 | else: 37 | modulator = torch.exp(-gamma * labels * logits[:, :no_of_classes] - gamma * torch.log(1 + 38 | torch.exp(-1.0 * logits[:, :no_of_classes]))) 39 | 40 | loss = modulator * BCLoss 41 | 42 | weighted_loss = alpha * loss 43 | focal_loss = torch.sum(weighted_loss) 44 | 45 | focal_loss /= torch.sum(labels) 46 | return focal_loss 47 | 48 | 49 | 50 | def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, device): 51 | """Compute the Class Balanced Loss between `logits` and the ground truth `labels`. 52 | Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) 53 | where Loss is one of the standard losses used for Neural Networks. 54 | Args: 55 | labels: A int tensor of size [batch]. 56 | logits: A float tensor of size [batch, no_of_classes]. 57 | samples_per_cls: A python list of size [no_of_classes]. 58 | no_of_classes: total number of classes. int 59 | loss_type: string. One of "sigmoid", "focal", "softmax". 60 | beta: float. Hyperparameter for Class balanced loss. 61 | gamma: float. Hyperparameter for Focal loss. 62 | Returns: 63 | cb_loss: A float tensor representing class balanced loss 64 | """ 65 | if min(samples_per_cls) == 0: 66 | samples_per_cls = [each + 1 for each in samples_per_cls] 67 | effective_num = 1.0 - np.power(beta, samples_per_cls) # classes with 0 samples have effective num = 0 68 | weights = (1.0 - beta) / np.array(effective_num) # divivsion by zero error here because of effective num = 0 69 | weights = weights / np.sum(weights) * no_of_classes 70 | labels_one_hot = F.one_hot(labels, no_of_classes).float() 71 | 72 | weights = torch.tensor(weights).float() 73 | weights = weights.unsqueeze(0) 74 | weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot.detach().cpu() 75 | weights = weights.sum(1) 76 | weights = weights.unsqueeze(1) 77 | weights = weights.repeat(1,no_of_classes).to(device) 78 | 79 | if loss_type == "focal": 80 | cb_loss = focal_loss(labels_one_hot, logits, weights, no_of_classes) 81 | elif loss_type == "sigmoid": 82 | cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) 83 | elif loss_type == "softmax": 84 | pred = logits.softmax(dim = 1) 85 | cb_loss = torch.tensor(F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights), 86 | requires_grad=True).to(device) 87 | # loss = torch.autograd.Variable(cb_loss.data, requires_grad=True) 88 | # cb_loss.requires_grad = True 89 | return cb_loss 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | no_of_classes = 5 95 | logits = torch.rand(10,no_of_classes).float() 96 | labels = torch.randint(0,no_of_classes, size = (10,)) 97 | beta = 0.9999 98 | gamma = 2.0 99 | samples_per_cls = [2,3,1,2,2] 100 | loss_type = "focal" 101 | cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type) 102 | print(cb_loss) -------------------------------------------------------------------------------- /train/losses/margin_ranking_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def compute_margin_ranking_loss(logits, minibatch_labels, num_new_classes, seen_classes, device, outputs_bs): 5 | K, dist, lw_mr = 2, 0.5, 1 6 | # get scores before [-1,1] scaling 7 | assert (outputs_bs.size() == logits.size()) 8 | # compute ground truth scores 9 | high_response_index = torch.zeros(outputs_bs.size()).to(device) 10 | high_response_index = high_response_index.scatter(1, minibatch_labels.view(-1, 1), 1).ge(dist) # indices of actual class = True, rest = False 11 | high_response_scores = outputs_bs.masked_select(high_response_index) # scores of actual labels 12 | # compute top-K scores on none high response classes 13 | none_gt_index = torch.zeros(outputs_bs.size()).to(device) 14 | none_gt_index = none_gt_index.scatter(1, minibatch_labels.view(-1, 1), 1).le(dist) # True for all negative classes, False for actual class 15 | none_gt_scores = outputs_bs.masked_select(none_gt_index).reshape((outputs_bs.size(0), logits.size(1) - 1)) # scores for negative classes 16 | hard_negatives_scores = none_gt_scores.topk(K, dim=1)[0] # top k negative classes 17 | hard_negatives_index = minibatch_labels.lt(seen_classes - num_new_classes) # True for all old classes 18 | hard_negatives_num = torch.nonzero(hard_negatives_index).size(0) # number of old class labels 19 | if hard_negatives_num > 0: 20 | gt_scores = high_response_scores[hard_negatives_index].view(-1, 1).repeat(1, K) # logits score for old class instances repeated k times 21 | hard_scores = hard_negatives_scores[hard_negatives_index] # score for top-k instances the old class is most confused with 22 | assert (gt_scores.size() == hard_scores.size()) 23 | assert (gt_scores.size(0) == hard_negatives_num) 24 | mr_loss = torch.nn.MarginRankingLoss(margin=dist)(gt_scores.view(-1, 1), hard_scores.view(-1, 1), 25 | torch.ones(hard_negatives_num * K).to(device)) * lw_mr 26 | else: 27 | mr_loss = torch.tensor(0.).to(device) 28 | return mr_loss 29 | 30 | def compute_triplet_loss(logits, minibatch_labels, num_new_classes, seen_classes, device, outputs_bs): 31 | K, dist, lw_mr = 2, 0.5, 1 32 | # get scores before [-1,1] scaling 33 | assert (outputs_bs.size() == logits.size()) 34 | # compute ground truth scores 35 | high_response_index = torch.zeros(outputs_bs.size()).to(device) 36 | high_response_index = high_response_index.scatter(1, minibatch_labels.view(-1, 1), 1).ge(dist) # indices of actual labels 37 | print(high_response_index) 38 | high_response_scores = outputs_bs.masked_select(high_response_index) # scores of actual labels 39 | # compute top-K scores on none high response classes 40 | none_gt_index = torch.zeros(outputs_bs.size()).to(device) 41 | none_gt_index = none_gt_index.scatter(1, minibatch_labels.view(-1, 1), 1).le(0.8) # True for all negative classes 42 | none_gt_scores = outputs_bs.masked_select(none_gt_index).reshape( 43 | (outputs_bs.size(0), logits.size(1) - 1)) # scores for negative classes 44 | print(none_gt_index, minibatch_labels, none_gt_scores); exit(1) 45 | hard_negatives_scores = none_gt_scores.topk(K, dim=1)[0] # top k negative classes 46 | hard_negatives_index = minibatch_labels.lt(seen_classes - num_new_classes) # True for all old classes 47 | hard_negatives_num = torch.nonzero(hard_negatives_index).size(0) # number of old class labels -------------------------------------------------------------------------------- /train/losses/msloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class MultiSimilarityLoss(nn.Module): 6 | ''' 7 | Base source code taken from the orig. implementation: 8 | https://github.com/MalongTech/research-ms-loss/ 9 | ''' 10 | def __init__(self, thresh=0.5, _margin=0.1, scale_pos=2.0, scale_neg=40.0, **kwargs): 11 | super(MultiSimilarityLoss, self).__init__() 12 | self.thresh = thresh 13 | self.margin = _margin 14 | self.scale_pos = scale_pos 15 | self.scale_neg = scale_neg 16 | self.epsilon = 1e-5 17 | 18 | def forward(self, feats, labels, device, loss_ = None): 19 | assert feats.size(0) == labels.size(0), \ 20 | "feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" 21 | 22 | batch_size = feats.size(0) 23 | sim_mat = torch.matmul(feats, torch.t(feats)) 24 | loss = loss_ if loss_ is not None else torch.tensor(0.0).to(device) 25 | 26 | for i in range(batch_size): 27 | pos_pair_ = sim_mat[i][labels == labels[i]] 28 | pos_pair_ = pos_pair_[pos_pair_ < 1 - self.epsilon] 29 | neg_pair_ = sim_mat[i][labels != labels[i]] 30 | 31 | if len(neg_pair_) < 1 or len(pos_pair_) < 1: 32 | continue 33 | 34 | neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] 35 | pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] 36 | 37 | if len(neg_pair) < 1 or len(pos_pair) < 1: 38 | continue 39 | 40 | # weighting step 41 | pos_loss = 1.0 / self.scale_pos * torch.log( 42 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh))) 43 | ) 44 | neg_loss = 1.0 / self.scale_neg * torch.log( 45 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh))) 46 | ) 47 | loss += pos_loss + neg_loss 48 | 49 | if loss == 0: 50 | return torch.zeros([], requires_grad=True).to(device) 51 | 52 | return loss / batch_size -------------------------------------------------------------------------------- /train/losses/prob_diff_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def probability_differences(network_output, correct_vector, loss): 5 | """ 6 | Calculates the difference between the two most probable classes 7 | """ 8 | detached_output = network_output.cpu().detach().numpy() 9 | 10 | # if logit probabilities are used we need to get normal probabilities 11 | if np.sum(detached_output) != 1: 12 | detached_output = np.array([np.exp(x) / np.sum(np.exp(detached_output)) 13 | for x in detached_output]) 14 | 15 | ind = np.argpartition(detached_output, -2, axis=1)[:, -2:] 16 | # https://stackoverflow.com/questions/20103779/index-2d-numpy-array-by-a-2d-array-of-indices-without-loops 17 | maxs = detached_output[np.arange(network_output.shape[0])[:, None], ind] 18 | abs_diffs = np.abs(np.diff(maxs)) 19 | # adjust differences such that only correct classification is pushed 20 | abs_diffs = abs_diffs * correct_vector 21 | differences = np.sum(abs_diffs) 22 | 23 | return -(differences) 24 | 25 | 26 | def norm_probability_differences(detached_output, true_labels, loss): 27 | """ 28 | Calculates the difference between the two most probable classes and normalises it with tanh 29 | such that the resulting difference is between 0 and 1. This is multiplied by the current loss 30 | to bound the additonal loss term to numbers between 0 and the current cross entropy loss. 31 | This alleviates the CE loss the bigger the difference is 32 | """ 33 | # if logit probabilities are used we need to get normal probabilities 34 | if np.sum(detached_output) != 1: 35 | detached_output = np.array([np.exp(x) / np.sum(np.exp(detached_output)) 36 | for x in detached_output]) 37 | 38 | predicted_labels = np.argmax(detached_output, 1) 39 | correct_vector = true_labels == predicted_labels 40 | 41 | ind = np.argpartition(detached_output, -2, axis=1)[:, -2:] 42 | # https://stackoverflow.com/questions/20103779/index-2d-numpy-array-by-a-2d-array-of-indices-without-loops 43 | maxs = detached_output[np.arange(detached_output.shape[0])[:, None], ind] 44 | abs_diffs = np.abs(np.diff(maxs)) 45 | # adjust differences such that only correct classification is pushed 46 | abs_diffs = abs_diffs * correct_vector 47 | differences = np.sum(abs_diffs) 48 | 49 | # now we need to scale the difference such that it is between 0 and 1 50 | # to do so we use tanh scaled such that tanh(x) ~ 1 for the max distance 51 | # of all sample in the batch 52 | # since we know that the maximum distance between the first and second highest class 53 | # ~ 100 and tanh ~ 1 starting from 2 we can calculate a scaling factor 54 | batch_size = len(detached_output) 55 | scaling_factor = 2 / (batch_size) 56 | differences = np.tanh(differences * scaling_factor) 57 | 58 | return -(differences * loss) 59 | 60 | 61 | def neg_norm_probability_differences(detached_output, true_labels, loss): 62 | """ 63 | Calculates the difference between the two most probable classes and normalises it with tanh 64 | such that the resulting difference is between 0 and 1. The inverse of this difference is taken. 65 | This is multiplied by the current loss to bound the additonal loss term to numbers between 0 and 66 | the current cross entropy loss. This increases the loss the smaller the differences are. 67 | """ 68 | 69 | # if logit probabilities are used we need to get normal probabilities 70 | if np.sum(detached_output) != 1: 71 | detached_output = np.array([np.exp(x) / np.sum(np.exp(detached_output)) 72 | for x in detached_output]) 73 | 74 | predicted_labels = np.argmax(detached_output, 1) 75 | correct_vector = true_labels == predicted_labels 76 | 77 | ind = np.argpartition(detached_output, -2, axis=1)[:, -2:] 78 | # https://stackoverflow.com/questions/20103779/index-2d-numpy-array-by-a-2d-array-of-indices-without-loops 79 | maxs = detached_output[np.arange(detached_output.shape[0])[:, None], ind] 80 | abs_diffs = np.abs(np.diff(maxs)) 81 | # adjust differences such that only correct classification is pushed 82 | abs_diffs = abs_diffs * correct_vector 83 | differences = np.sum(abs_diffs) 84 | 85 | # now we need to scale the difference such that it is between 0 and 1 86 | # to do so we use tanh scaled such that tanh(x) ~ 1 for the max distance 87 | # of all sample in the batch 88 | # since we know that the maximum distance between the first and second highest class 89 | # ~ 100 and tanh ~ 1 starting from 2 we can calculate a scaling factor 90 | # the result of tanh gets subtracted from 1 to get a factor that is higher if the differences are lower 91 | batch_size = len(detached_output) 92 | scaling_factor = 2 / (batch_size) 93 | differences = 1 - np.tanh(differences * scaling_factor) 94 | 95 | return differences * loss 96 | 97 | 98 | def output_probability_entropy(network_output): 99 | network_output = network_output.cpu().detach().numpy() 100 | if np.sum(network_output[1]) < 0: 101 | network_output = network_output * -1 102 | 103 | entropy = np.sum(np.sum(np.log2(network_output) * network_output, axis=1)) 104 | 105 | return entropy -------------------------------------------------------------------------------- /train/losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: triplet_loss.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['batch_hard_triplet_loss', 'batch_all_triplet_loss'] 4 | 5 | # Cell 6 | import torch 7 | def _pairwise_distances(embeddings, squared=False): 8 | """Compute the 2D matrix of distances between all the embeddings. 9 | Args: 10 | embeddings: tensor of shape (batch_size, embed_dim) 11 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 12 | If false, output is the pairwise euclidean distance matrix. 13 | Returns: 14 | pairwise_distances: tensor of shape (batch_size, batch_size) 15 | """ 16 | dot_product = torch.matmul(embeddings, embeddings.t()) 17 | 18 | # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`. 19 | # This also provides more numerical stability (the diagonal of the result will be exactly 0). 20 | # shape (batch_size,) 21 | square_norm = torch.diag(dot_product) 22 | 23 | # Compute the pairwise distance matrix as we have: 24 | # ||a - b||^2 = ||a||^2 - 2 + ||b||^2 25 | # shape (batch_size, batch_size) 26 | distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1) 27 | 28 | # Because of computation errors, some distances might be negative so we put everything >= 0.0 29 | distances[distances < 0] = 0 30 | 31 | if not squared: 32 | # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal) 33 | # we need to add a small epsilon where distances == 0.0 34 | mask = distances.eq(0).float() 35 | distances = distances + mask * 1e-16 36 | 37 | distances = (1.0 -mask) * torch.sqrt(distances) 38 | 39 | return distances 40 | 41 | def _get_triplet_mask(labels): 42 | """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid. 43 | A triplet (i, j, k) is valid if: 44 | - i, j, k are distinct 45 | - labels[i] == labels[j] and labels[i] != labels[k] 46 | Args: 47 | labels: tf.int32 `Tensor` with shape [batch_size] 48 | """ 49 | # Check that i, j and k are distinct 50 | indices_equal = torch.eye(labels.size(0)).bool() 51 | indices_not_equal = ~indices_equal 52 | i_not_equal_j = indices_not_equal.unsqueeze(2) 53 | i_not_equal_k = indices_not_equal.unsqueeze(1) 54 | j_not_equal_k = indices_not_equal.unsqueeze(0) 55 | 56 | distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k 57 | 58 | 59 | label_equal = labels.unsqueeze(0) == labels.unsqueeze(1) 60 | i_equal_j = label_equal.unsqueeze(2) 61 | i_equal_k = label_equal.unsqueeze(1) 62 | 63 | valid_labels = ~i_equal_k & i_equal_j 64 | 65 | return valid_labels & distinct_indices 66 | 67 | 68 | def _get_anchor_positive_triplet_mask(labels, device): 69 | """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label. 70 | Args: 71 | labels: tf.int32 `Tensor` with shape [batch_size] 72 | Returns: 73 | mask: tf.bool `Tensor` with shape [batch_size, batch_size] 74 | """ 75 | # Check that i and j are distinct 76 | indices_equal = torch.eye(labels.size(0)).bool().to(device) 77 | indices_not_equal = ~indices_equal 78 | 79 | # Check if labels[i] == labels[j] 80 | # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) 81 | labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) 82 | 83 | return labels_equal & indices_not_equal 84 | 85 | 86 | def _get_anchor_negative_triplet_mask(labels): 87 | """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels. 88 | Args: 89 | labels: tf.int32 `Tensor` with shape [batch_size] 90 | Returns: 91 | mask: tf.bool `Tensor` with shape [batch_size, batch_size] 92 | """ 93 | # Check if labels[i] != labels[k] 94 | # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) 95 | 96 | return ~(labels.unsqueeze(0) == labels.unsqueeze(1)) 97 | 98 | 99 | # Cell 100 | def batch_hard_triplet_loss(labels, embeddings, margin, squared=False, device='cpu'): 101 | """Build the triplet loss over a batch of embeddings. 102 | For each anchor, we get the hardest positive and hardest negative to form a triplet. 103 | Args: 104 | labels: labels of the batch, of size (batch_size,) 105 | embeddings: tensor of shape (batch_size, embed_dim) 106 | margin: margin for triplet loss 107 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 108 | If false, output is the pairwise euclidean distance matrix. 109 | Returns: 110 | triplet_loss: scalar tensor containing the triplet loss 111 | """ 112 | # Get the pairwise distance matrix 113 | pairwise_dist = _pairwise_distances(embeddings, squared=squared) 114 | 115 | # For each anchor, get the hardest positive 116 | # First, we need to get a mask for every valid positive (they should have same label) 117 | mask_anchor_positive = _get_anchor_positive_triplet_mask(labels, device).float() 118 | 119 | # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p)) 120 | anchor_positive_dist = mask_anchor_positive * pairwise_dist 121 | 122 | # shape (batch_size, 1) 123 | hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True) 124 | 125 | # For each anchor, get the hardest negative 126 | # First, we need to get a mask for every valid negative (they should have different labels) 127 | mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float() 128 | 129 | # We add the maximum value in each row to the invalid negatives (label(a) == label(n)) 130 | max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True) 131 | anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative) 132 | 133 | # shape (batch_size,) 134 | hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True) 135 | 136 | # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss 137 | tl = hardest_positive_dist - hardest_negative_dist + margin 138 | tl[tl < 0] = 0 139 | triplet_loss = tl.mean() 140 | 141 | return triplet_loss 142 | 143 | # Cell 144 | def batch_all_triplet_loss(labels, embeddings, margin, squared=False): 145 | """Build the triplet loss over a batch of embeddings. 146 | We generate all the valid triplets and average the loss over the positive ones. 147 | Args: 148 | labels: labels of the batch, of size (batch_size,) 149 | embeddings: tensor of shape (batch_size, embed_dim) 150 | margin: margin for triplet loss 151 | squared: Boolean. If true, output is the pairwise squared euclidean distance matrix. 152 | If false, output is the pairwise euclidean distance matrix. 153 | Returns: 154 | triplet_loss: scalar tensor containing the triplet loss 155 | """ 156 | # Get the pairwise distance matrix 157 | pairwise_dist = _pairwise_distances(embeddings, squared=squared) 158 | 159 | anchor_positive_dist = pairwise_dist.unsqueeze(2) 160 | anchor_negative_dist = pairwise_dist.unsqueeze(1) 161 | 162 | # Compute a 3D tensor of size (batch_size, batch_size, batch_size) 163 | # triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k 164 | # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1) 165 | # and the 2nd (batch_size, 1, batch_size) 166 | triplet_loss = anchor_positive_dist - anchor_negative_dist + margin 167 | 168 | 169 | 170 | # Put to zero the invalid triplets 171 | # (where label(a) != label(p) or label(n) == label(a) or a == p) 172 | mask = _get_triplet_mask(labels) 173 | triplet_loss = mask.float() * triplet_loss 174 | 175 | # Remove negative losses (i.e. the easy triplets) 176 | triplet_loss[triplet_loss < 0] = 0 177 | 178 | # Count number of positive triplets (where triplet_loss > 0) 179 | valid_triplets = triplet_loss[triplet_loss > 1e-16] 180 | num_positive_triplets = valid_triplets.size(0) 181 | num_valid_triplets = mask.sum() 182 | 183 | fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16) 184 | 185 | # Get final mean triplet loss over the positive valid triplets 186 | triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16) 187 | print(triplet_loss, fraction_positive_triplets) 188 | 189 | return triplet_loss, fraction_positive_triplets -------------------------------------------------------------------------------- /train/prediction_analyzer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class PredictionAnalysis(): 4 | def __init__(self, y_true, y_pred, dataset, reversed_original_mapping, reversed_virtual_mapping): 5 | self.y_true, self.y_pred = np.array(y_true), np.array(y_pred) 6 | self.dataset = dataset 7 | self.original_mapping, self.virtual_mapping = reversed_original_mapping, reversed_virtual_mapping 8 | 9 | def analyze_misclassified_instances(self, batch_num): 10 | new_classes, old_classes = [[self.original_mapping[item] for item in each] for each in 11 | (self.dataset.classes_by_groups[batch_num], 12 | [item for each in self.dataset.classes_by_groups[:batch_num] for item in each])] 13 | misclassified_new_classes, total_new_classes = self.get_wrongly_predicted_count(new_classes) 14 | if batch_num > 0: 15 | misclassified_old_classes, total_old_classes = self.get_wrongly_predicted_count(old_classes) 16 | old_misclassified_old, old_misclassified_new = self.get_wrongly_predicted_old_classes(old_classes, new_classes) 17 | result = (misclassified_new_classes, total_new_classes, misclassified_old_classes, total_old_classes, old_misclassified_new, old_misclassified_old) 18 | else: 19 | result = (misclassified_new_classes, total_new_classes) 20 | return result 21 | 22 | def get_wrongly_predicted_count(self, true_class_labels): 23 | true_label_indices = np.where(np.in1d(self.y_true, true_class_labels))[0] 24 | filtered_y_true, filtered_y_pred = self.y_true[true_label_indices], self.y_pred[true_label_indices] 25 | array_of_equals = np.equal(filtered_y_true, filtered_y_pred) 26 | return array_of_equals.shape[0] - sum(array_of_equals), filtered_y_true.shape[0] 27 | 28 | def get_wrongly_predicted_old_classes(self, old_class_labels, new_class_labels): 29 | old_labels_indices = np.where(np.in1d(self.y_true, old_class_labels))[0] 30 | filtered_y_true, filtered_y_pred = self.y_true[old_labels_indices], self.y_pred[old_labels_indices] 31 | 32 | old_misclassified_old = [i for i,j in zip(filtered_y_true, filtered_y_pred) if j in old_class_labels and j != i] 33 | old_misclassified_new = [i for i,j in zip(filtered_y_true, filtered_y_pred) if j in new_class_labels] 34 | return len(old_misclassified_old), len(old_misclassified_new) 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /train/record_test_scores.py: -------------------------------------------------------------------------------- 1 | class CurrentInstanceRecorder(): 2 | def __init__(self): 3 | self.previous_class_scores = {} 4 | self.current_class_scores = {} -------------------------------------------------------------------------------- /train/visualisations/__pycache__/exemplar_visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/exemplar_visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/exemplar_visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/exemplar_visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/stability_visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/stability_visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/stability_visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/stability_visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/training_visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/training_visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/training_visualizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/training_visualizer.cpython-37.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/vis_by_person.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/vis_by_person.cpython-36.pyc -------------------------------------------------------------------------------- /train/visualisations/__pycache__/vis_by_person.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/train/visualisations/__pycache__/vis_by_person.cpython-37.pyc -------------------------------------------------------------------------------- /train/visualisations/exemplar_visualizer.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | # from tsnecuda import TSNE 4 | import numpy as np 5 | import seaborn as sns 6 | from matplotlib import pyplot as plt 7 | from matplotlib.font_manager import FontProperties 8 | from sklearn.decomposition import PCA 9 | from sklearn.manifold import TSNE 10 | 11 | fontP = FontProperties() 12 | fontP.set_size('small') 13 | 14 | OUT_DIR = 'vis_outputs/exemp_vis/' 15 | PAMAP_COLOR_DICT = dict({1: 'black', 2: 'red', 3: 'gold', 4: 'deepskyblue', 5: 'grey', 16 | 6: 'olive', 7: 'indigo', 12: 'deeppink', 13: 'orange', 17 | 16: 'lightblue', 17: 'teal', 24: 'brown'}) 18 | DSADS_COLOR_DICT = dict({1: 'black', 2: 'red', 3: 'gold', 4: 'deepskyblue', 5: 'grey', 19 | 6: 'olive', 7: 'indigo', 8: 'deeppink', 9: 'orange', 20 | 10: 'lightblue', 11: 'teal', 12: 'brown', 13: 'lime', 14: 'mediumblue', 21 | 15: 'mediumspringgreen', 16: 'lightsalmon', 17: 'lightsteelblue', 22 | 18: 'orchid', 19: 'sandybrown'}) 23 | 24 | 25 | def scatter_plot_exemps(label_to_features_all, label_to_indices_exemp, virtual_map, original_map, strategy, data_name): 26 | all_values = np.array(list(chain(*label_to_features_all.values()))) 27 | pca_50 = PCA(n_components=50) 28 | all_values = pca_50.fit_transform(all_values) 29 | label_to_indices_adjusted = {label: np.array(indices) + len(label_to_indices_exemp[idx - 1]) if idx > 0 else indices 30 | for idx, (label, indices) in enumerate(label_to_indices_exemp.items())} 31 | color_coding = PAMAP_COLOR_DICT if data_name == 'pamap' or data_name == 'hapt' else DSADS_COLOR_DICT 32 | tsne_feats = TSNE(n_components=2, perplexity=15, learning_rate=100).fit_transform(all_values) 33 | sns_plot = sns.scatterplot(tsne_feats[:, 0], tsne_feats[:, 1], color='grey') 34 | for label, indices in label_to_indices_adjusted.items(): 35 | label = original_map[virtual_map[label]] 36 | sns.scatterplot(tsne_feats[indices, 0], tsne_feats[indices, 1], color=color_coding[label], legend='full') 37 | 38 | fig = sns_plot.get_figure() 39 | box = sns_plot.get_position() 40 | sns_plot.set_position([box.x0, box.y0, box.width * 0.85, box.height]) # resize position 41 | sns_plot.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., 42 | prop=fontP) 43 | fig.savefig(OUT_DIR + f'{data_name}_{strategy}_exemps_{len(label_to_features_all)}.png') 44 | plt.show() 45 | fig.clf() 46 | -------------------------------------------------------------------------------- /train/visualisations/stability_visualizer.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from train.visualisations import training_visualizer 3 | from matplotlib import pyplot as plt 4 | import matplotlib 5 | from train.visualisations import training_visualizer 6 | import matplotlib.ticker as ticker 7 | 8 | plt.tight_layout() 9 | matplotlib.rc('xtick', labelsize=15) 10 | matplotlib.rc('ytick', labelsize=15) 11 | 12 | def draw_lines_with_stddev(class_to_means_dict, class_to_stddev_dict, filename, size_val): 13 | legends = list(class_to_means_dict.keys()) 14 | out_path = training_visualizer.OUT_DIR + 'accuracy_vis/detail_acc/' 15 | fig = plt.figure() 16 | ax1 = fig.add_subplot(111) 17 | print(class_to_means_dict) 18 | batches = [i for i in range(2, len(class_to_means_dict['base']) + 2)] 19 | 20 | for class_id in legends: 21 | print(class_to_means_dict[class_id]) 22 | ax1.plot(batches, class_to_means_dict[class_id]) 23 | ax1.set_xticks(batches) 24 | ax1.set_xticklabels(batches) 25 | ax1.legend(labels=['Base classes', 'Old classes', 'New classes', 'All classes']) 26 | ax1.set_title("Detailed accuracy comparison") 27 | fig.savefig(out_path + f'{filename}_size_{size_val}.png') 28 | plt.show() 29 | plt.close(fig) 30 | 31 | 32 | def draw_multiple_lines(method_to_scores_dict, label, size, filename): 33 | legends = list(method_to_scores_dict.keys()) 34 | out_path = training_visualizer.OUT_DIR + 'accuracy_vis/detail_acc/' 35 | fig = plt.figure() 36 | ax1 = fig.add_subplot(111) 37 | batches = [i for i in range(2, len(list(method_to_scores_dict.values())[0]) + 2)] 38 | 39 | for x_label in batches: 40 | ax1.axvline(x=x_label, linestyle='--', color='grey') 41 | for method in legends: 42 | ax1.plot(batches, method_to_scores_dict[method]) 43 | ax1.set_xticks(batches) 44 | ax1.set_xticklabels(batches) 45 | ax1.legend(labels=legends) 46 | ax1.set_title(label) 47 | fig.savefig(out_path + f'{filename}_holdout_{size}.png') 48 | # plt.show() 49 | plt.close(fig) 50 | 51 | 52 | def draw_multiple_lines_and_plots(base_to_scores_dict, old_to_scores_dict, new_to_scores_dict, all_to_scores_dict, size, 53 | filename, baseline=True): 54 | legends = list(base_to_scores_dict.keys()) 55 | if baseline: 56 | labels = ['WA-MDF', 'WA-ADB', 'BiC', 'LUCIR', 'iCaRL', 'ILOS', 'GEM', 'R-EWC', 'MAS', 'LwF'] 57 | colors = ['firebrick', 'green', 'deepskyblue', 'steelblue', 'chocolate', 'gold', 'orangered', 'lightseagreen', 'darkviolet', 58 | 'indigo', 'darkkhaki', 'dimgray'] 59 | else: 60 | labels = ['CE', 'LwF', 'R-EWC', 'MAS', 'LUCIR-DIS', 'LUCIR-MR', 'LUCIR-DIS+MR', 'ILOS'] 61 | colors = ['royalblue', 'limegreen', 'gold', 'orangered', 'chocolate', 'teal', 'dimgray', 'mediumorchid' ] 62 | out_path = training_visualizer.OUT_DIR + 'accuracy_vis/detail_acc/base_old_new/' 63 | 64 | fig, ax1 = plt.subplots(1, 4, figsize=(13,6.5)) 65 | # plt.tick_params(axis='both', which='major', labelsize=13) 66 | 67 | batches = [i for i in range(2, len(list(base_to_scores_dict.values())[0]) + 2)] 68 | accs = [i for i in range(0, 120, 20)] 69 | for method, color, label in zip(legends, colors, labels): 70 | ax1[0].plot(batches, base_to_scores_dict[method], color=color, label=label) 71 | ax1[1].plot(batches, old_to_scores_dict[method], color=color, label=label) 72 | ax1[2].plot(batches, new_to_scores_dict[method], color=color, label=label) 73 | ax1[3].plot(batches, all_to_scores_dict[method], color=color, label=label) 74 | for ax in ax1: 75 | ax.set_xticks(batches) 76 | ax.set_xticklabels(batches) 77 | ax.set_yticks(accs) 78 | ax.set_yticklabels(accs) 79 | if filename == 'twor': 80 | for label in ax.get_xticklabels()[1::2]: 81 | label.set_visible(False) 82 | for x_label in batches: 83 | for each in ax1: 84 | each.axvline(x=x_label, linestyle='--', color='aliceblue', 85 | zorder=0) # zorder for shifting line to background 86 | box = each.get_position() 87 | each.set_position([box.x0, box.y0, box.width * 0.9, box.height]) 88 | ax1[0].set_title("Base",fontdict={'fontsize': 16, 'fontweight': 'medium'}) 89 | ax1[1].set_title("Old",fontdict={'fontsize': 16, 'fontweight': 'medium'}) 90 | ax1[2].set_title("New",fontdict={'fontsize': 16, 'fontweight': 'medium'}) 91 | ax1[3].set_title("All",fontdict={'fontsize': 16, 'fontweight': 'medium'}) 92 | lgd = plt.legend(bbox_to_anchor=(1, 0.5), loc='center left', fontsize=13) 93 | # lgd.get_title().set_fontsize('15') # legend 'Title' fontsize 94 | # plt.legend().set_visible(False) 95 | 96 | fig.text(0.5, 0.03, 'No. of tasks', ha='center', va='center', fontsize=13) 97 | fig.text(0.06, 0.5, 'Micro-F1 (%)', ha='center', va='center', rotation='vertical', fontsize=13) 98 | 99 | """ 100 | Sort the legends based on scores of last task for all classes. 101 | Source: https://stackoverflow.com/a/27512450/5140684 102 | """ 103 | # final_task_scores = [each[-1] for each in list(all_to_scores_dict.values())] 104 | # order = sorted(range(len(final_task_scores)), key=lambda x: final_task_scores[x], reverse=True) 105 | # handles, labels_ = ax1[3].get_legend_handles_labels() 106 | # ax1[3].legend([handles[idx] for idx in order], [labels_[idx] for idx in order], bbox_to_anchor=(1, 0.5), loc='center left') 107 | """End of sorting legends""" 108 | plt.subplots_adjust(wspace=0.5) 109 | # plt.text(0.5, 1.08, f"{'DSADS' if filename == 'dsads' else 'WS'}", 110 | # horizontalalignment='center', 111 | # fontsize=14, 112 | # transform=ax1[1].transAxes) 113 | fig.savefig(out_path + f'{filename}_{"baseline" if baseline else "regularized"}_detailed_acc_holdout_{size}.pdf', bbox_inches='tight', dpi=700) 114 | plt.show() 115 | plt.close(fig) 116 | 117 | def draw_accs_by_size(size_to_all_accs, filename, key, baseline=False, method='', tp=False): 118 | legends = list(size_to_all_accs.keys()) 119 | if baseline: 120 | labels = ['WA-MDF', 'WA-ADB', 'BiC', 'LUCIR', 'iCaRL', 'ILOS', 'GEM'] 121 | colors = ['firebrick', 'green', 'deepskyblue', 'steelblue', 'chocolate', 'gold', 'orangered', 'lightseagreen', 'darkviolet'] 122 | elif tp: 123 | labels = ['WA-MDF', 'LUCIR', 'ILOS', 'R-EWC'] 124 | colors = ['orangered', 'green', 'gold', 'indigo'] 125 | method_to_label = {'kd_kldiv_wa1': 'WA-MDF', 'cn_lfc_mr': 'LUCIR', 'kd_kldiv_ilos' : 'ILOS', 'ce_online_ewc': 'R-EWC'} 126 | elif len(method) > 0: 127 | labels = ['Random', 'Herding', 'Exemplar', 'Boundary', 'FWSR'] 128 | colors = ['firebrick', 'green', 'deepskyblue', 'gold', 'darkviolet', 129 | 'indigo'] 130 | method_to_label = {'kd_kldiv_wa1': 'WA-MDF', 'cn_lfc_mr': 'LUCIR', 'kd_kldiv_ilos' : 'ILOS', 'ce_online_ewc': 'R-EWC'} 131 | else: 132 | labels = ['CE', 'LwF', 'R-EWC', 'MAS', 'LUCIR-DIS', 'LUCIR-MR', 'LUCIR-DIS+MR', 'ILOS'] 133 | colors = ['royalblue', 'limegreen', 'gold', 'orangered', 'chocolate', 'teal', 'dimgray', 'mediumorchid' ] 134 | 135 | out_path = training_visualizer.OUT_DIR + 'accuracy_vis/detail_acc/holdout_sizes/' 136 | fig, ax1 = plt.subplots() 137 | # plt.tick_params(axis='both', which='major', labelsize=13) 138 | batches = [0, 10, 30, 50, 70] if tp else [0]+ [i for i in list(list(size_to_all_accs.values())[0].keys())] 139 | yticks = [i for i in range(0, 120, 20)] 140 | order = [100 * (i+1) for i in range(len(list(list(size_to_all_accs.values())[0].values())))] 141 | for method_, color, label in zip(legends, colors, labels): 142 | ax1.plot(order, list(size_to_all_accs[method_].values()), color=color, label=label) 143 | 144 | for x_label in order: 145 | ax1.axvline(x=x_label, linestyle='--', color='aliceblue', zorder=0) # zorder for shifting line to background 146 | ax1.set_yticks(yticks) 147 | plt.locator_params(axis='x', nbins=len(order)) 148 | ax1.set_yticklabels(yticks) 149 | ax1.set_xticklabels(batches) 150 | box = ax1.get_position() 151 | # ax1.set_position([box.x0, box.y0, box.width * 0.9, box.height]) 152 | 153 | plt.legend(bbox_to_anchor=(1, 0.5), loc='center left', fontsize=13) 154 | # ax1.legend().set_visible(False) 155 | 156 | ax1.set_xlabel('Train set size (%)' if tp else 'Holdout size per class', fontsize=13) 157 | ax1.set_ylabel(f'{"Macro" if key == "macro" else "Micro"}-F1 (%)', fontsize=13) 158 | if len(method) > 0 or tp: 159 | plt.title(f"{'DSADS' if filename == 'dsads' else 'WS' if filename == 'ws' else 'HA' if filename == 'hatn6' else 'HAPT' if filename == 'hapt' else 'PAMAP2' if filename == 'pamap' else filename.upper()}" 160 | f"{':' + method_to_label[method] if len(method) > 0 else ''}", fontsize=14) 161 | 162 | fig.savefig(out_path + f'{filename}_{key}_{"baseline" if baseline else method if len(method) > 0 else "regularized"}_all_holdouts.pdf', bbox_inches='tight', dpi=700) 163 | # plt.show() 164 | plt.close(fig) 165 | 166 | def draw_scores_by_task(scores_by_task, filename, methods, replay=True, baseline = True): 167 | legends = list(methods) 168 | if baseline: 169 | labels = ['WA-MDF', 'WA-ADB', 'BiC', 'LUCIR', 'iCaRL', 'ILOS', 'GEM', 'R-EWC', 'MAS', 'LwF'] 170 | colors = ['firebrick', 'green', 'deepskyblue', 'steelblue', 'chocolate', 'gold', 'orangered', 'lightseagreen', 'darkviolet', 171 | 'indigo', 'darkkhaki', 'dimgray'] 172 | else: 173 | labels = ['CE', 'LwF', 'R-EWC', 'MAS', 'LUCIR-DIS', 'LUCIR-MR', 'LUCIR-DIS+MR', 'ILOS'] if replay else ['CE', 'LwF', 'R-EWC', 'MAS', 'LUCIR-DIS', 'ILOS'] 174 | colors = ['royalblue', 'limegreen', 'gold', 'orangered', 'chocolate', 'teal', 'dimgray', 'mediumorchid' ] if replay \ 175 | else ['royalblue', 'limegreen', 'gold', 'orangered', 'chocolate', 'mediumorchid' ] 176 | tasks = [i for i in range(1, len(scores_by_task[0])+1)] 177 | out_path = training_visualizer.OUT_DIR + 'accuracy_vis/detail_acc/forgetting/' 178 | fig, ax1 = plt.subplots() 179 | plt.tick_params(axis='both', which='major', labelsize=13) 180 | # 181 | for idx, (scores, color, label) in enumerate(zip(scores_by_task, colors, labels)): 182 | ax1.plot(tasks, scores, color=color, label=label) 183 | 184 | for x_label in tasks: 185 | ax1.axvline(x=x_label, linestyle='--', color='aliceblue', 186 | zorder=0) # zorder for shifting line to background 187 | ax1.set_xticks(tasks) 188 | ax1.set_xticklabels(tasks) 189 | 190 | 191 | box = ax1.get_position() 192 | ax1.set_position([box.x0, box.y0, box.width * 0.9, box.height]) 193 | 194 | plt.legend(bbox_to_anchor=(1, 0.5), loc='center left', fontsize=13) 195 | # ax1.legend().set_visible(False) 196 | ax1.set_xlabel('Incremental task', fontsize=13) 197 | ax1.set_ylabel(f'Forgetting score', fontsize=13) 198 | # plt.title(f"{'With rehearsal' if replay else 'Without rehearsal'}", fontsize=14) 199 | # plt.title(f"{'DSADS' if filename == 'dsads' else 'WS' if filename == 'ws' else 'HA' if filename == 'hatn6' else 'HAPT' if filename == 'hapt' else 'PAMAP2' if filename == 'pamap' else filename.upper()}", fontsize=14) 200 | fig.savefig(out_path + f'{filename}_forgetting{"_baseline" if baseline else "_regularized"}{"_replay" if replay else "_blank"}.pdf', 201 | bbox_inches='tight', dpi=700) 202 | plt.show() 203 | plt.close(fig) 204 | 205 | -------------------------------------------------------------------------------- /train/visualisations/vis_by_person.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import pearsonr, spearmanr, kendalltau 2 | from scipy.spatial.distance import cosine 3 | import matplotlib.pyplot as plt 4 | from matplotlib.font_manager import FontProperties 5 | 6 | ##### from https://stackoverflow.com/a/4700674/5140684 [for matplotlib legend] ####### 7 | fontP = FontProperties() 8 | fontP.set_size('small') 9 | ####################### ######################### ################################ 10 | 11 | PLOT_DIR = 'vis_outputs/per_person/' 12 | 13 | 14 | def drop_column_by_idx(df, index=None): 15 | df = df.drop(df.columns[index], axis=1) 16 | return df 17 | 18 | class VisualizeStatsPerPerson(): 19 | def __init__(self, dataname, df): 20 | self.dataname = dataname 21 | self.df = df 22 | if self.dataname == 'dsads' or self.dataname == 'pamap': 23 | self.person_column = 407 if self.dataname == 'dsads' else 244 24 | self.class_column = 406 if self.dataname == 'dsads' else 243 25 | self.df_by_mean_of_classes = drop_column_by_idx(self.get_mean_by_columns([self.df.columns[self.class_column]]), 26 | self.person_column-1) 27 | self.df_by_mean_of_persons_and_classes = self.get_mean_by_columns([self.df.columns[self.person_column], 28 | self.df.columns[self.class_column]]) 29 | 30 | def get_mean_by_columns(self, list_of_columns): 31 | return self.df.groupby(list_of_columns).mean() 32 | 33 | def plot_variance_by_persons(self): 34 | cosine_similarities, pearsons, spearmans, kendalls = [], [], [], [] 35 | for i, new_df in self.df_by_mean_of_persons_and_classes.groupby(level=0): 36 | for idx, row in new_df.iterrows(): 37 | class_num = idx[1] 38 | x = list(row) 39 | y = list(self.df_by_mean_of_classes.loc[class_num]) 40 | cosine_similarities.append(cosine(x,y)) 41 | pearsons.append(pearsonr(x,y)[0]) 42 | kendalls.append(kendalltau(x,y)[0]) 43 | spearmans.append(spearmanr(x,y)[0]) 44 | #### uncomment if only one similarity to use ###### 45 | # new_df['cosine'] = list_of_similarities 46 | # new_df.reset_index(inplace=True) 47 | # new_df.rename(columns={244:'persons', 243: 'classes'}, inplace=True) 48 | # new_df.boxplot(by='classes', column=['cosine'], grid=False) 49 | # plt.show() 50 | ##################################################### 51 | 52 | self.df_by_mean_of_persons_and_classes['cosine distance from mean'] = cosine_similarities 53 | self.df_by_mean_of_persons_and_classes['pearson'] = pearsons 54 | self.df_by_mean_of_persons_and_classes['spearman'] = spearmans 55 | self.df_by_mean_of_persons_and_classes['kendalltau'] = kendalls 56 | self.df_by_mean_of_persons_and_classes.reset_index(inplace=True) 57 | self.df_by_mean_of_persons_and_classes.rename(columns={self.df_by_mean_of_persons_and_classes.columns[0]:'persons', 58 | self.df_by_mean_of_persons_and_classes.columns[1]:'classes'}, 59 | inplace=True) 60 | myfig = plt.figure() 61 | fig = self.df_by_mean_of_persons_and_classes.boxplot(by='persons', column=['cosine distance from mean', 62 | 'pearson', 'spearman', 'kendalltau'], 63 | grid=False) 64 | plt.title("box plot by correlation") 65 | plt.savefig(PLOT_DIR + f"{self.dataname}_boxplot.png") 66 | plt.close(myfig) 67 | 68 | def visualise_imbalance_by_persons(self, persons_list=None): 69 | ''' 70 | Called for PAMAP/HAPT. 71 | :param df: dataframe read from mat files 72 | :param persons_list: list of person IDs, supplied for hapt dataset 73 | :return: None 74 | ''' 75 | df = self.df.copy() 76 | if self.dataname in ['pamap', 'dsads', 'opp']: 77 | person_column = 244 if self.dataname == 'pamap' else 407 if self.dataname == 'dsads' else 461 78 | activity_column = 243 if self.dataname == 'pamap' else 406 if self.dataname == 'dsads' else 460 79 | df = df.rename(columns={df.columns[person_column]: 'Persons', df.columns[activity_column]: 'Activities'}) 80 | df['Activities'] = df['Activities'].astype(int) 81 | df['Persons'] = df['Persons'].astype(int) 82 | ax = df.groupby([df.columns[person_column], df.columns[activity_column]]).size().unstack().plot(kind='bar', 83 | stacked=True, 84 | colormap='jet', 85 | ) 86 | activity_labels = {1: 'sit', 2:'stand', 3:'lie on back', 4:'lie to right', 5:'ascend stairs', 6:'descend stairs', 87 | 7: 'stand in elevator', 8: 'move in elevator', 9: 'walk in parking lot', 10:'walk flat at 4 km/h', 88 | 11: 'walk inclined at 4 km/h', 12: 'run at 8 km/h', 13: 'stepper exercise', 14: 'cross trainer exercise', 89 | 15: 'cycle horizontally', 16: 'cycle vertically', 17: 'row', 18: 'jump', 19: 'play basketball'} 90 | activity_names = list(activity_labels.values()) 91 | elif self.dataname == 'hapt': 92 | df.insert(0, 'Persons', persons_list) 93 | ax = df.groupby(['Persons', 'AID']).size().unstack().plot(kind='bar',stacked=True, colormap='jet') 94 | else: 95 | ax = df.groupby(['AID']).size().plot(kind='bar',stacked=True, colormap='jet') 96 | # ax.set_xticklabels(['cook', 'eat', 'enter/leave house', 'living room activity', 'use toilet', 'use mirror', 97 | # 'read', 'sleep', 'work'], rotation=50, fontsize=11, horizontalalignment='right') 98 | next_width = 0 99 | # for i, p in enumerate(ax.patches): 100 | # width, height = p.get_width(), p.get_height() 101 | # x, y = p.get_xy() 102 | # if height > 0: 103 | # ax.text(x + width / 2, 104 | # y + height / 2, 105 | # '{:.0f}'.format(height), 106 | # horizontalalignment='center', 107 | # verticalalignment='center') 108 | # L = plt.legend() 109 | # for i in range(0, len(activity_names)): 110 | # L.get_texts()[i].set_text(activity_names[i]) 111 | if self.dataname in ['pamap', 'dsads', 'hapt']: 112 | box = ax.get_position() 113 | ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) 114 | ax.legend(bbox_to_anchor=(1, 0.5), fontsize=13, title='Activity ID', loc='center left', 115 | fancybox=True, shadow=True) 116 | import numpy as np 117 | ax.set_yticks(np.arange(0, 200, 30)) 118 | plt.grid(True, linestyle=':', linewidth=0.3) 119 | plt.tick_params(axis='both', which='major', labelsize=13) 120 | # ax.yaxis.grid(True, which='minor', linestyle='-.', linewidth=0.25) 121 | ax.set_xlabel("Activity ID", fontsize=13) 122 | ax.set_ylabel("Count", fontsize=13) 123 | 124 | ax.figure.savefig(PLOT_DIR + 'imbalance_vis_' + self.dataname + '.pdf', dpi=600,bbox_inches='tight') 125 | plt.show() -------------------------------------------------------------------------------- /utils/__pycache__/data_handler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/utils/__pycache__/data_handler.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_handler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/utils/__pycache__/data_handler.cpython-37.pyc -------------------------------------------------------------------------------- /utils/data_handler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import random 4 | from random import shuffle 5 | from operator import itemgetter 6 | import numpy as np 7 | import pandas as pd 8 | import yaml 9 | from scipy.io import loadmat 10 | from sklearn.model_selection import train_test_split 11 | from collections import Counter, defaultdict 12 | from train.visualisations import vis_by_person 13 | from train.visualisations.training_visualizer import plot_heatmap 14 | from sklearn.preprocessing import LabelEncoder 15 | # np.random.seed(42) 16 | TEST_SIZE = 0.3 17 | TRAIN_VAL_SPLIT = 0.9 18 | NUM_PERMUTATIONS = 3 19 | 20 | label_enc_dict = defaultdict(LabelEncoder) 21 | 22 | def drop_column_by_idx(df, index=None): 23 | df = df.drop(df.columns[index], axis=1) 24 | return df 25 | 26 | def pickle_loader(filepath, encoding='latin1'): 27 | pickle_file = pickle.load(open(filepath, 'rb'), encoding=encoding) 28 | return pickle_file 29 | 30 | class DataHandler: 31 | def __init__(self, dataname, base_classes, per_batch_classes, train_percent, seed_value, vis, corr_vis, keep_val): 32 | self.seed_value = seed_value 33 | self.vis = vis 34 | self.corr_vis = corr_vis 35 | self.seed_randomness() 36 | self.dataname = dataname 37 | self.tp = train_percent # considering separate train and test directories for hapt (0.7+0.3 = 1.0) 38 | self.original_mapping = {} 39 | self.train_data, self.test_data, self.train_labels, self.test_labels = self.get_features_and_labels() 40 | self.nb_cl = base_classes 41 | self.num_classes = per_batch_classes 42 | self.classes_by_groups = [] 43 | self.label_map = {} 44 | self.keep_val = keep_val 45 | self.train_groups, self.val_groups, self.test_groups = self.initialize() 46 | 47 | def seed_randomness(self): 48 | np.random.seed(self.seed_value) 49 | 50 | @staticmethod 51 | def read_config(): 52 | with open('conf/data_paths.yaml', 'r') as stream: 53 | try: 54 | return yaml.safe_load(stream) 55 | except yaml.YAMLError as err: 56 | print(err) 57 | 58 | def get_file_path(self): 59 | config_dict = self.read_config() 60 | if self.dataname == 'milan' or self.dataname == 'twor' or self.dataname == 'aruba': 61 | dir_path = config_dict['large_data'][self.dataname] 62 | elif self.dataname == 'hatn6' or self.dataname == 'ws': 63 | dir_path = config_dict['small_data'][self.dataname] 64 | elif self.dataname in ['pamap', 'dsads', 'hapt']: 65 | dir_path = config_dict['medium_data'][self.dataname] 66 | else: 67 | dir_path = config_dict[self.dataname.split('_')[1] if '_' in self.dataname else self.dataname] 68 | return dir_path 69 | 70 | def get_df_from_mat(self, filename): 71 | dir_path = self.get_file_path() 72 | data_path = dir_path + filename + '.mat' 73 | mat = loadmat(data_path) 74 | data = mat['data_' + filename] 75 | df = pd.DataFrame(data=data) 76 | return df 77 | 78 | def read_data(self): 79 | config_dict = self.read_config() 80 | train_data_path = config_dict[self.dataname]['train'] 81 | test_data_path = config_dict[self.dataname]['test'] 82 | train_df = pd.read_csv(train_data_path) 83 | test_df = pd.read_csv(test_data_path) 84 | return train_df, test_df 85 | 86 | def get_features_and_labels(self): 87 | """ 88 | Function to separate features and labels from the dataframe. 89 | @param df: is the dataframe read from the csv file. 90 | @return: a tuple of two lists: (features, labels) 91 | """ 92 | train_df, test_df = self.generate_dataframe() 93 | train_values, test_values = [df.values.tolist() for df in [train_df, test_df]] 94 | X_train, X_test = [[np.array(item[1:]) for item in each] for each in [train_values, test_values]] 95 | y_train, y_test = [[item[0] for item in each ] for each in [train_values, test_values]] 96 | """ Oversampling of data""" 97 | # if self.dataname == 'hapt': 98 | # counts_dict = Counter(y_train) 99 | # sample_nums = {self.original_mapping[i]: math.ceil(counts_dict[self.original_mapping[i]]*2.8) if i != 8 else 100 | # math.ceil(counts_dict[self.original_mapping[i]]*5) for i in range(7, 13)} 101 | # smote_enn = SMOTEENN(random_state=0, sampling_strategy=sample_nums) 102 | # print(f"Before resampling: {counts_dict}, mapping: {self.original_mapping}") 103 | # X_train, y_train = smote_enn.fit_resample(X_train, y_train) 104 | # X_train = [np.array(item) for item in X_train] 105 | # print(f"After resampling: {Counter(y_train)}, mapping: {self.original_mapping}") 106 | """ End of oversampling""" 107 | return X_train, X_test, y_train, y_test 108 | 109 | 110 | def get_reversed_original_label_maps(self): 111 | return dict(map(reversed, self.original_mapping.items())) 112 | 113 | def replace_class_labels(self, df, train=False): 114 | if train: 115 | sorted_elements = [i for i in range(len(df.AID.unique()))] 116 | self.original_mapping = dict(zip(df.AID.unique(), sorted_elements)) 117 | df.AID = df.AID.map(self.original_mapping) 118 | return df 119 | 120 | @staticmethod 121 | def reorder_columns(df, mnist=False): 122 | cols = list(df.columns.values) 123 | if not mnist: 124 | cols = cols[-1:] + cols[:-1] 125 | df = df[cols] 126 | new_col_names = ['AID' if idx==0 else 'S'+str(each) for idx,each in enumerate(cols)] 127 | df.rename(columns=dict(zip(df.columns, new_col_names)), inplace=True) 128 | df.reset_index(drop=True, inplace=True) 129 | df = df.sort_values(by=['AID'], ascending=True) 130 | return df 131 | 132 | 133 | def get_hapt_data(self, full_path, train=True): 134 | file_mode = 'train' if train else 'test' 135 | features_df = pd.read_csv(full_path + 'X_' + file_mode + '.txt', sep=' ', header=None) 136 | feature_cols = list(features_df.columns.values) 137 | new_col_names = ['S' + str(each) for idx, each in enumerate(feature_cols)] 138 | features_df.rename(columns=dict(zip(features_df.columns, new_col_names)), inplace=True) 139 | labels_df = pd.read_csv(full_path + 'y_' + file_mode + '.txt', sep=' ', header=None) 140 | features_df.insert(0, 'AID', labels_df[labels_df.columns[0]]) 141 | persons_list = None 142 | path = 'train' if train else 'test' 143 | with open(full_path+f'subject_id_{path}.txt', 'r') as FileObj: 144 | persons_list = np.array([int(line[:-1]) for line in FileObj.readlines()]) 145 | all_persons = np.unique(persons_list) 146 | if train: 147 | train_persons = all_persons[:int(len(all_persons) * self.tp)] 148 | index_of_train_persons = np.concatenate([np.where(persons_list == person)[0] for person in 149 | train_persons]).ravel() 150 | features_df = features_df[features_df.index.isin(index_of_train_persons)] 151 | return (features_df, persons_list) 152 | 153 | def generate_dataframe(self): 154 | persons_list = None # used for visualizing imbalance in data for 'hapt' dataset, None for all others 155 | if self.dataname in ['dsads', 'pamap', 'opp']: 156 | df = self.get_df_from_mat(self.dataname + '_loco' if self.dataname == 'opp' else self.dataname) 157 | df_to_visualize = df if self.vis else None 158 | person_column = 407 if self.dataname == 'dsads' else 244 if self.dataname == 'pamap' else 461 159 | print("Count: ", df.groupby(df.columns[243]).count()); 160 | if self.dataname == 'pamap': 161 | df.drop(df.loc[(df[df.columns[person_column]] == 9.0) | 162 | (df[df.columns[person_column]] == 3.0)].index, 163 | inplace=True) 164 | persons = df[df.columns[person_column]].unique() 165 | train_persons = [i + 1 for i in range(math.ceil(max(persons) * (1 - TEST_SIZE)))] 166 | train_persons = train_persons[:math.ceil(self.tp * len(train_persons))] 167 | test_persons = [i for i in persons[-int(TEST_SIZE*len(persons)):]] 168 | # train_persons = np.random.choice(list(persons), math.ceil(max(persons) * (1 - TEST_SIZE)), replace=False).tolist() 169 | print(f"Persons: {persons}, train: {train_persons}, test: {test_persons}") 170 | train_df = drop_column_by_idx(df.loc[df[df.columns[person_column]].isin(train_persons)], 171 | index=person_column) 172 | test_df = drop_column_by_idx(df.loc[df[df.columns[person_column]].isin(test_persons)], 173 | index=person_column) 174 | if self.dataname == 'dsads': 175 | # From data descriptions: Column 406 is the activity sequence indicating the executing of activities (usually not used in experiments). 176 | train_df, test_df = [drop_column_by_idx(df, 405) for df in [train_df, test_df]] 177 | if self.dataname == 'opp': 178 | # From data descriptions: Column 461 is the activity drill. 179 | train_df, test_df = [drop_column_by_idx(df, 460) for df in [train_df, test_df]] 180 | train_df, test_df = [self.reorder_columns(each) for each in [train_df, test_df]] 181 | elif self.dataname == 'hapt': 182 | dir_path = self.get_file_path() 183 | train_path, test_path = dir_path + 'Train/', dir_path + 'Test/' 184 | (train_df, persons_list_train), (test_df, persons_list_test) = [self.get_hapt_data(full_path, train=idx==0) for idx, full_path in 185 | enumerate([train_path, test_path])] 186 | persons_list = np.concatenate((persons_list_train, persons_list_test)) 187 | df_to_visualize = pd.concat((train_df, test_df)) if self.vis else None 188 | 189 | elif self.dataname in ['ws', 'hatn6', 'milan', 'aruba', 'twor']: 190 | # for ha_t_n6 and ws datasets 191 | dir_path = self.get_file_path() 192 | if self.dataname in ['ws', 'hatn6']: 193 | df = pd.read_csv(dir_path) 194 | else: 195 | df = pd.read_excel(dir_path) 196 | df = df.rename(columns={'activity_label': 'AID'}) 197 | # Source: https://stackoverflow.com/a/31939145/5140684 198 | # df['AID'] = df.AID.apply(lambda x: label_enc_dict[x].fit_transform(x)) 199 | # Source: https://pbpython.com/categorical-encoding.html 200 | df['AID'] = df['AID'].astype('category') 201 | df['AID'] = df['AID'].cat.codes 202 | df = self.reorder_columns(df) 203 | df_to_visualize = df.copy() if self.vis else None 204 | train_df, test_df, _, _ = train_test_split(df, df['AID'], test_size=TEST_SIZE, random_state=self.seed_value, 205 | stratify=df['AID']) 206 | train_df_ = train_df.copy() 207 | # sample dataframe based on count of individual labels: a minimum of 1 sample per class should be present 208 | train_df = train_df_.groupby('AID').apply(lambda x: x.sample(max(1, int(len(x) * self.tp)))).reset_index(drop=True) 209 | 210 | elif self.dataname == 'cifar100': 211 | dir_path = self.get_file_path() 212 | train_set, test_set = [pickle_loader(dir_path[key]) for key in ['train', 'test']] 213 | train_df = pd.DataFrame(train_set['data']) 214 | train_df['AID'] = train_set['fine_labels'] 215 | test_df = pd.DataFrame(test_set['data']) 216 | test_df['AID'] = test_set['fine_labels'] 217 | train_df = self.reorder_columns(train_df) 218 | test_df = self.reorder_columns(test_df) 219 | 220 | elif 'mnist' in self.dataname: 221 | dir_path = self.get_file_path() 222 | train_df, test_df = [pd.read_csv(dir_path[key], header=None, index_col=None) for key in ['train', 'test']] 223 | train_df, test_df = [self.reorder_columns(df, mnist=True) for df in [train_df, test_df]] 224 | train_df_ = train_df.copy() 225 | train_df = train_df_.groupby('AID').apply(lambda x: x.sample(max(1, int(len(x) * self.tp)))).reset_index(drop=True) 226 | 227 | train_df, test_df = [self.replace_class_labels(df, train=idx == 0) for idx, df in enumerate([train_df, test_df])] 228 | if self.vis: 229 | self.visualize_by_persons(df_to_visualize, persons_list) 230 | if self.corr_vis: 231 | corr = self.compute_correlation_mat(train_df) 232 | plot_heatmap(corr, out_path=f'corr_vis/by_raw_features/{self.dataname}_tp_{self.tp}.pdf', original_map=self.original_mapping) 233 | return train_df, test_df 234 | 235 | def compute_correlation_mat(self, df): 236 | df_of_means = self.compute_mean_of_df(df) 237 | # df_of_means = df_of_means.rename(columns = self.label_map) 238 | corr = df_of_means.corr() 239 | return corr 240 | 241 | @staticmethod 242 | def compute_mean_of_df(df): 243 | df_of_means = df.groupby('AID', as_index=True).mean() 244 | df_of_means = df_of_means.T.iloc[1:] 245 | return df_of_means 246 | 247 | def visualize_by_persons(self, df, persons_list): 248 | visualizer = vis_by_person.VisualizeStatsPerPerson(self.dataname, df) 249 | # if self.dataname == 'pamap' or self.dataname == 'dsads': 250 | # visualizer.plot_variance_by_persons() 251 | visualizer.visualise_imbalance_by_persons(persons_list) 252 | 253 | def reshape_img_dataframes(self, train=False): 254 | data = self.train_data if train else self.test_data 255 | data_r = np.array([each[:1024].reshape(32,32) for each in data]) 256 | data_g = np.array([each[1024:2048].reshape(32,32) for each in data]) 257 | data_b = np.array([each[2048:].reshape(32,32) for each in data]) 258 | data = np.dstack((data_r, data_g, data_b)) 259 | if train: 260 | self.train_data = data 261 | else: 262 | self.test_data = data 263 | return data 264 | 265 | @staticmethod 266 | def permutate_img_pixels(image, permutation): 267 | image = image[permutation] 268 | return image 269 | 270 | def get_data_by_groups(self, train=True): 271 | if self.dataname == 'cifar100': 272 | _ = self.reshape_img_dataframes(train=train) 273 | if train: 274 | _labels = sorted(set(self.train_labels)) 275 | shuffled_labels = np.random.choice(_labels, len(_labels), replace=False).tolist() 276 | original_labels = [each for each in range(len(shuffled_labels))] 277 | self.label_map = dict(zip(shuffled_labels, original_labels)) 278 | print(self.label_map) 279 | self.classes_by_groups.append(shuffled_labels[:self.nb_cl]) 280 | if 'permuted' in self.dataname: 281 | for idx in range(NUM_PERMUTATIONS - 1): 282 | self.classes_by_groups.append(shuffled_labels) 283 | else: 284 | for idx in range(self.nb_cl, len(shuffled_labels), self.num_classes): 285 | temp = [] 286 | for i in range(self.num_classes): 287 | if (idx + i) < len(shuffled_labels): 288 | temp.append(shuffled_labels[idx + i]) 289 | self.classes_by_groups.append(temp) 290 | self.num_tasks = len(self.classes_by_groups) 291 | grouped_data = [[] for _ in range(self.num_tasks)] 292 | print(f"Classes in each group: {self.classes_by_groups}") 293 | if 'permuted' in self.dataname: 294 | idx = list(range(28 * 28)) 295 | for p_i in range(NUM_PERMUTATIONS): 296 | random.shuffle(idx) 297 | X_train_new, X_test_new = [[self.permutate_img_pixels(item, permutation=idx) for item in each] for each in 298 | [self.train_data, self.test_data]] 299 | data_to_consider = zip(X_train_new, self.train_labels) if train else zip(X_test_new, self.test_labels) 300 | group_ID = p_i 301 | for data, label in data_to_consider: 302 | grouped_data[group_ID].append((data, self.label_map[label])) 303 | else: 304 | data_to_consider = zip(self.train_data, self.train_labels) if train else zip(self.test_data, self.test_labels) 305 | print(self.classes_by_groups) 306 | for data, label in data_to_consider: 307 | group_ID = [idx for idx in range(self.num_tasks) if label in self.classes_by_groups[idx]][0] 308 | # print(f"label: {label}, id: {group_ID}") 309 | grouped_data[group_ID].append((data, self.label_map[label])) 310 | if self.dataname == 'cifar100' and train: 311 | for group in grouped_data: 312 | assert len(group) == 10000, len(group) 313 | return grouped_data 314 | 315 | def initialize(self): 316 | train_groups = self.get_data_by_groups(train=True) 317 | test_groups = self.get_data_by_groups(train=False) 318 | val_groups = [[] for i in range(self.num_tasks)] 319 | if self.keep_val: 320 | for i, train_group in enumerate(train_groups): 321 | _, labels = zip(*train_group) 322 | labels = np.array(labels) 323 | temp_train, temp_val = [], [] 324 | for label in set(labels): 325 | indices = np.where(labels == label)[0] 326 | temp_val.extend([each for idx, each in enumerate(train_group) if idx in 327 | indices[(int)(TRAIN_VAL_SPLIT * len(indices)):]]) 328 | temp_train.extend([each for idx, each in enumerate(train_group) if idx in 329 | indices[:(int)(TRAIN_VAL_SPLIT * len(indices))]]) 330 | train_groups[i] = temp_train 331 | val_groups[i] = temp_val 332 | return train_groups, val_groups, test_groups 333 | 334 | def getNextClasses(self, i): 335 | return self.train_groups[i], self.val_groups[i], self.test_groups[i] 336 | 337 | def getInputDim(self): 338 | return 32*32 if self.dataname == 'cifar100' else self.train_data[0].shape[0] -------------------------------------------------------------------------------- /utils/img/incremental_learning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srvCodes/continual-learning-benchmark/36673a25d9555d6afaf6bd25ebc3ec9641c51599/utils/img/incremental_learning.png -------------------------------------------------------------------------------- /utils/join_pickles.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | def manually_join_pickle(): 4 | path_name = f"../output_reports/hapt/gem_random_1.0_" 5 | # full_dict = pickle.load(open(path_name + '4_logs.pkl', 'rb'))['reports'] 6 | # full_dict[6] = pickle.load(open(path_name + '6_logs.pkl', 'rb'))['reports'][6] 7 | # full_dict[10] = pickle.load(open(path_name + '10_logs.pkl', 'rb'))['reports'][10] 8 | # full_dict[15] = pickle.load(open(path_name + '15_logs.pkl', 'rb'))['reports'][15] 9 | # 10 | # detailed_dict = pickle.load(open(path_name + '8_logs.pkl', 'rb'))['detailed_acc'] 11 | # print(detailed_dict) 12 | sizes = [2,4,6,8,10,15] 13 | dict_so_far = {'reports': {size: {} for size in sizes}, 'errors': {size: {} for size in sizes}, 'detailed_acc': {size: {} for size in sizes}} 14 | for size in sizes: 15 | cur_name = path_name + str(size) + '_logs.pkl' 16 | pickle_dict = pickle.load(open(cur_name, 'rb')) 17 | for key in list(pickle_dict.keys()): 18 | dict_so_far[key].update({size: pickle_dict[key][size]}) 19 | pickle.dump(dict_so_far, open(cur_name, 'wb')) 20 | manually_join_pickle() --------------------------------------------------------------------------------