├── .gitignore ├── LICENSE ├── Model.png ├── README.md ├── scripts └── cglm_scrape.py └── src ├── features ├── main.py ├── opts.py ├── torchutils.py └── trainutils.py ├── knn ├── main.py ├── online_clfs.py └── opts.py └── run_blind.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ameya Prabhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drimpossible/ACM/ccb70b138942f422ba25e58b26bbe24522140afc/Model.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ACM 2 | 3 | This repository contains the code for the paper: 4 | 5 | **Online Continual Learning Without the Storage Constraint** 6 | [Ameya Prabhu](https://drimpossible.github.io), [Zhipeng Cai](https://zhipengcai.github.io/), [Puneet Dokania](https://puneetkdokania.github.io), [Philip Torr](https://www.robots.ox.ac.uk/~phst/), [Vladlen Koltun](https://vladlen.info/), [Ozan Sener](https://ozansener.net/) 7 | [[Arxiv](https://arxiv.org/abs/2305.09253)] 8 | [[PDF](https://drimpossible.github.io/documents/ACM.pdf)] 9 | [[Bibtex](https://github.com/drimpossible/ACM/#citation)] 10 | 11 |

12 | Figure which describes our ACM model 13 |

14 | 15 | ## Installation and Dependencies 16 | 17 | Our code was run on a 16GB RTX 3080Ti Laptop GPU with 64GB RAM and PyTorch >=1.13, although better GPU/RAM space will allow for faster experimentation. 18 | 19 | * Install all requirements required to run the code on a Python >=3.9 environment by: 20 | ``` 21 | # First, activate a new virtual environment 22 | pip3 install -r requirements.txt 23 | ``` 24 | 25 | ### Fast Dataset Setup 26 | 27 | - There is a fast, direct mechanism to download and use our datasets implemented in [this repository](https://github.com/hammoudhasan/CLDatasets). 28 | - Input the directory where the dataset was downloaded into `data_dir` field in `src/opts.py`. 29 | - All codes in this repository were run on this dataset. 30 | 31 | ## Recreating the Datasets 32 | 33 | - `YOUR_DATA_DIR` would contain two subfolders: `cglm` and `cloc`. Following are instructions to setup each dataset: 34 | 35 | ### Continual Google Landmarks V2 (CGLM) 36 | 37 | #### Download Images 38 | 39 | * You can download Continual Google Landmarks V2 dataset by following instructions on their Github repository, run in the `DATA_DIR` directory: 40 | ``` 41 | wget -c https://raw.githubusercontent.com/cvdfoundation/google-landmark/master/download-dataset.sh 42 | mkdir train && cd train 43 | bash ../download-dataset.sh train 499 44 | ``` 45 | 46 | #### Recreating Metadata 47 | 48 | * Download metadata by running the following commands in the `scripts` directory: 49 | ``` 50 | wget -c https://s3.amazonaws.com/google-landmark/metadata/train_attribution.csv 51 | python cglm_scrape.py 52 | ``` 53 | * Parse the XML files and organize it as a dictionary. 54 | * Ordering used in the paper is available to download [from here](). 55 | * Now, select only images that are a part of the order file and your dataset should be ready! 56 | 57 | ### Continual YFCC100M (CLOC) 58 | 59 | #### Extremely Fast Image Downloader 60 | 61 | * Download the `cloc.txt` file from [this link](https://www.robots.ox.ac.uk/~ameya/cloc.txt) inside the `YOUR_DATASET_DIR/cloc` directory. 62 | * The `cloc.txt` file contains 36.8M image links, removing missing/broken links from the original download file of CLOC. 63 | * Download the dataset parallely and scalably using img2dataset, finishes in Extracting features from idx '+str(offset)+'..') 71 | model.cuda() 72 | model.eval() 73 | 74 | if ftmodel is not None: 75 | ftmodel.cuda() 76 | ftmodel.eval() 77 | 78 | # We will collect predictions, labels and features in corresponding numpy arrays 79 | with torch.inference_mode(): 80 | for (image, label, _, sel_idx, _) in loader: 81 | image = image.cuda(non_blocking=True) 82 | feat = model(image) 83 | if ftmodel is not None: 84 | feat = ftmodel.embed(feat) 85 | predprobs = ftmodel.fc(ftmodel.norm(feat)) 86 | pred = torch.argmax(predprobs, dim=1) 87 | predarr[sel_idx%num_per_chunk] = pred.cpu().numpy() 88 | 89 | labelarr[sel_idx%num_per_chunk] = label.cpu().numpy() 90 | featarr[sel_idx%num_per_chunk] = feat.cpu().numpy() 91 | 92 | if ((sel_test_idx.max()+1)//opt.num_per_chunk) > opt.chunk_idx: 93 | np.save(opt.log_dir+'/'+opt.exp_name+f'/features_{chunk_idx}.npy', featarr) 94 | if ftmodel is not None: np.save(opt.log_dir+'/'+opt.exp_name+f'/preds_{chunk_idx}.npy', predarr) 95 | np.save(opt.log_dir+'/'+opt.exp_name+f'/labels_{chunk_idx}.npy', labelarr) 96 | return 97 | 98 | 99 | if __name__ == '__main__': 100 | # Parse arguments and init loggers 101 | torch.multiprocessing.set_sharing_strategy('file_system') 102 | opt = parse_args() 103 | 104 | opt.exp_name = f'{opt.dataset}_{opt.model}_{opt.embed_size}' 105 | 106 | console_logger = get_logger(folder=opt.log_dir+'/'+opt.exp_name+'/') 107 | console_logger.info('==> Params for this experiment:'+str(opt)) 108 | seed_everything(opt.seed) 109 | 110 | if opt.model == 'resnet50': 111 | model = models.resnet50(weights="IMAGENET1K_V2") 112 | elif opt.model == 'resnet50_I1B': 113 | model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet50_swsl') 114 | elif opt.model == 'xcit_dino': 115 | model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8') 116 | if 'xcit' not in opt.model: model.fc = torch.nn.Identity() 117 | 118 | if opt.fc_only: 119 | for param in model.parameters(): 120 | param.requires_grad = False 121 | 122 | dim = 512 if opt.model=='xcit_dino' else 2048 123 | ftmodel = EmbedLinearClassifier(dim=dim, embed_size=opt.embed_size, num_classes=opt.num_classes, cosfc=opt.cosine) 124 | 125 | if opt.mode=='AdaptedACM': 126 | assert(ftmodel is not None) 127 | if exists(opt.log_dir+'/'+opt.exp_name+'/ftmodel.pt'): 128 | ftmodel.load_state_dict(torch.load(opt.log_dir+'/'+opt.exp_name+'/ftmodel.pt')) 129 | else: 130 | AdaptedACM(opt=opt, model=model, ftmodel=ftmodel, logger=console_logger) 131 | ACM(opt=opt, model=model, ftmodel=ftmodel, logger=console_logger) 132 | elif opt.mode=='ACM': 133 | ACM(opt=opt, model=model, logger=console_logger) -------------------------------------------------------------------------------- /src/features/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='main.py') 5 | # Changing options -- Apart from these arguments, we do not mess with other arguments 6 | 7 | ## Paths 8 | parser.add_argument('--data_dir', type=str, default='TYPE HERE', help='Full path to directory where all datasets are stored') 9 | parser.add_argument('--log_dir', type=str, default='TYPE HERE', help='Full path to the directory where all logs are stored') 10 | parser.add_argument('--order_file_dir', type=str, default='TYPE HERE', help='Full path to the order file') 11 | 12 | ## Dataset 13 | parser.add_argument('--dataset', type=str, default='cglm', help='Name of Dataset', choices=['cglm', 'imagenet', 'cyfcc']) 14 | parser.add_argument('--num_classes', type=int, default=713, choices=[10788, 1000, 713], help='Number of number of classes') 15 | parser.add_argument('--train_batch_size', type=int, default=896, help='Batch size to be used in training') 16 | parser.add_argument('--test_batch_size', type=int, default=4608, help='Batch size to be used in training') 17 | parser.add_argument('--total_iterations', type=int, default=1800, help='Fraction of total dataset used for ptraining') 18 | 19 | ## Model 20 | parser.add_argument('--model', type=str, default='resnet50', choices=['resnet50','resnet50_random','resnet50_I1B','resnet50_dino','xcit_dino','resnet50_V2'], help='Model architecture') 21 | parser.add_argument('--mode', type=str, default='ACM', choices=['AdaptedACM', 'ACM', 'ER'], help='Training type') 22 | parser.add_argument('--embed_size', type=int, default=512, help='Embedding dimensions for retrieval tasks') 23 | 24 | ## Experiment Deets 25 | parser.add_argument('--exp_name', type=str, default='test', help='Full path to the order file') 26 | parser.add_argument('--maxlr', type=float, default=0.2, help='Starting Learning rate') 27 | parser.add_argument('--num_gpus', type=int, default=8, help="Number of GPUs used in training") 28 | 29 | # Default options 30 | parser.add_argument('--seed', type=int, default=0, help='Seed for reproducibility') 31 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') 32 | parser.add_argument('--clip', type=float, default=2.0, help='Gradient Clipped if val >= clip') 33 | parser.add_argument('--num_workers', type=int, default=8, help='Starting Learning rate') 34 | parser.add_argument('--print_freq', type=int, default=1000, help='Printing utils') 35 | parser.add_argument('--prefrac', type=float, default=0.2, help='Fraction of total dataset used for pretraining (in CGLM)') 36 | parser.add_argument('--chunk_idx', type=int, default=0, help='Parallelize ACM running by sending multiple jobs which compute it on parts') 37 | parser.add_argument('--num_per_chunk', type=int, default=524288, help='Number of Chunks to Subdivide Data to') 38 | parser.add_argument('--num_gdsteps', type=int, default=1, help='Number of gradient descent steps') 39 | parser.add_argument('--extract_feats',action='store_true') 40 | 41 | parser.add_argument('--fc_only',action='store_true') 42 | parser.add_argument('--fc', type=str, default=None, help='Full path to the order file') 43 | 44 | parser.add_argument('--sampler', type=str, default='mixed', help='Full path to the order file') 45 | 46 | parser.add_argument('--delay', type=int, default=0, help='Sets delay in terms of training samples.') 47 | parser.add_argument('--cosine',action='store_true') 48 | 49 | 50 | 51 | opt = parser.parse_args() 52 | return opt 53 | -------------------------------------------------------------------------------- /src/features/torchutils.py: -------------------------------------------------------------------------------- 1 | 2 | import h5py 3 | import torch, math 4 | from PIL import Image 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | from torch.utils.data import Dataset, DataLoader 9 | import numpy as np 10 | 11 | class ListFolder(Dataset): 12 | def __init__(self, root, image_paths, labels_path, routing_idx_path=None, mask_values_path=None, isACE=False, offset=0, is_train=False, transform=None): 13 | super(ListFolder, self).__init__() 14 | # Get image list and labels per index 15 | self.root = root 16 | self.offset = offset 17 | self.image_paths = h5py.File(image_paths, "r")["store_list"] 18 | self.labels = h5py.File(labels_path, "r")["store_list"] 19 | self.is_train = is_train 20 | self.isACE = isACE 21 | 22 | if self.is_train: 23 | self.routing_idx = h5py.File(routing_idx_path, "r")["store_list"] 24 | if self.isACE: 25 | self.mask_values = h5py.File(mask_values_path, "r")["store_list"] 26 | 27 | self.transform = transform 28 | assert(len(self.image_paths)==len(self.labels)) 29 | 30 | def __getitem__(self, index): 31 | # get the index from the routing index 32 | sample_mask_value = np.array([1.0]) 33 | if self.is_train: 34 | sel_idx = self.routing_idx[index+self.offset] 35 | if self.isACE: 36 | sample_mask_value = self.mask_values[index+self.offset] 37 | else: 38 | sel_idx = index+self.offset 39 | 40 | # get the corresponding image and label 41 | img_path = self.root+'/'+self.image_paths[sel_idx].decode("utf-8").strip() 42 | label = self.labels[sel_idx] 43 | sample = pil_loader(img_path) 44 | 45 | if self.transform is not None: 46 | sample = self.transform(sample) 47 | 48 | return sample, label, sample_mask_value, sel_idx, index 49 | 50 | def __len__(self): 51 | assert(len(self.image_paths)==len(self.labels)), 'Length of image path array and labels different' 52 | if self.is_train: 53 | return len(self.routing_idx)-self.offset 54 | else: 55 | return len(self.labels) -self.offset 56 | 57 | 58 | class LinearLR(_LRScheduler): 59 | r"""Set the learning rate of each parameter group with a linear 60 | schedule: :math:`\eta_{t} = \eta_0*(1 - t/T)`, where :math:`\eta_0` is the 61 | initial lr, :math:`t` is the current epoch or iteration (zero-based) and 62 | :math:`T` is the total training epochs or iterations. It is recommended to 63 | use the iteration based calculation if the total number of epochs is small. 64 | When last_epoch=-1, sets initial lr as lr. 65 | It is studied in 66 | `Budgeted Training: Rethinking Deep Neural Network Training Under Resource 67 | Constraints`_. 68 | 69 | Args: 70 | optimizer (Optimizer): Wrapped optimizer. 71 | T (int): Total number of training epochs or iterations. 72 | last_epoch (int): The index of last epoch or iteration. Default: -1. 73 | 74 | .. _Budgeted Training\: Rethinking Deep Neural Network Training Under 75 | Resource Constraints: 76 | https://arxiv.org/abs/1905.04753 77 | """ 78 | 79 | def __init__(self, optimizer, T, last_epoch=-1): 80 | self.T = float(T) 81 | super(LinearLR, self).__init__(optimizer, last_epoch) 82 | 83 | def get_lr(self): 84 | rate = 1 - self.last_epoch/self.T 85 | return [rate*base_lr for base_lr in self.base_lrs] 86 | 87 | def _get_closed_form_lr(self): 88 | return self.get_lr() 89 | 90 | 91 | class CosineLinear(nn.Module): 92 | def __init__(self, in_features, out_features, sigma=True): 93 | super(CosineLinear, self).__init__() 94 | self.in_features = in_features 95 | self.out_features = out_features 96 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 97 | self.bias = None 98 | if sigma: 99 | self.sigma = nn.Parameter(torch.Tensor(1)) 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | stdv = 1. / math.sqrt(self.weight.size(1)) 104 | self.weight.data.uniform_(-stdv, stdv) 105 | if self.sigma is not None: 106 | self.sigma.data.fill_(1) 107 | 108 | def forward(self, input): 109 | out = F.linear(F.normalize(input, p=2,dim=1), \ 110 | F.normalize(self.weight, p=2, dim=1)) 111 | if self.sigma is not None: 112 | out = self.sigma * out 113 | 114 | return out 115 | 116 | 117 | class EmbedLinearClassifier(nn.Module): 118 | """Linear layer to train on top of frozen features""" 119 | def __init__(self, dim, embed_size, num_classes, cosfc=False): 120 | super(EmbedLinearClassifier, self).__init__() 121 | self.embed = None 122 | self.embed = nn.Linear(dim, embed_size) 123 | self.norm = nn.Sequential(nn.BatchNorm1d(embed_size), nn.ReLU(inplace=True)) 124 | self.embed.weight.data.normal_(mean=0.0, std=0.01) 125 | self.embed.bias.data.zero_() 126 | if cosfc: 127 | self.fc = CosineLinear(embed_size, num_classes) 128 | else: 129 | self.fc = nn.Linear(embed_size, num_classes) 130 | self.fc.weight.data.normal_(mean=0.0, std=0.01) 131 | self.fc.bias.data.zero_() 132 | 133 | def forward(self, x): 134 | # flatten 135 | x = x.view(x.size(0), -1) 136 | if self.embed is not None: 137 | x = self.embed(x) 138 | x = self.norm(x) 139 | return self.fc(x) 140 | 141 | 142 | def pil_loader(path): 143 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 144 | with open(path, 'rb') as f: 145 | img = Image.open(f) 146 | return img.convert('RGB') 147 | -------------------------------------------------------------------------------- /src/features/trainutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch, os, logging, random 4 | from os.path import isfile, isdir, join, exists 5 | 6 | 7 | def make_data(fname): 8 | """ 9 | """ 10 | # read data 11 | fval = open(fname, 'r') 12 | lines_val = fval.readlines() 13 | labels = [None] * len(lines_val) 14 | time = [None] * len(lines_val) 15 | user = [None] * len(lines_val) 16 | store_loc = [None] * len(lines_val) 17 | for i in range(len(lines_val)): 18 | line_splitted = lines_val[i].split(',') 19 | labels[i] = int(line_splitted[0]) 20 | time[i] = int(line_splitted[2]) 21 | user[i] = line_splitted[3] 22 | store_loc[i] = line_splitted[-1][:-1] 23 | return store_loc, labels, time, user 24 | 25 | 26 | def load_images(image_paths): 27 | """ 28 | """ 29 | imgs = [] 30 | for i in range(len(image_paths)): 31 | img = pil_loader(image_paths[i]) 32 | imgs.append(img) 33 | return imgs 34 | 35 | 36 | def convert_imgfolder_to_imglist(data_dir, dataset, split): 37 | """ 38 | Takes an input datset directory, and converts 39 | """ 40 | classes = [f for f in os.listdir(data_dir+'/'+dataset+'/'+split+'/') if isdir(join(data_dir+'/'+dataset+'/'+split,f))] 41 | img_paths, labels = [], [] 42 | count = 0 43 | 44 | for cls in classes: 45 | folder = join(data_dir+'/'+dataset+'/'+split+'/', cls) 46 | imgs = ['/'+split+'/'+cls+'/'+f for f in os.listdir(folder) if (isfile(join(folder,f)))] 47 | lbls = [count for img in imgs] 48 | assert(len(imgs)==len(lbls)) 49 | count+=1 50 | img_paths.extend(imgs) 51 | labels.extend(lbls) 52 | 53 | assert(len(img_paths)==len(labels)) 54 | return img_paths, labels 55 | 56 | 57 | def load_filelist(filepath, root_dir, check=False): 58 | """ 59 | Checks whether images exist, if yes then load it into the filelist. Do check=True for the first run, then turn it off 60 | """ 61 | imglist, y = [], [] 62 | with open(filepath,'r') as f: 63 | for line in f: 64 | label, image, _ = line.strip().split('\t') 65 | if check: 66 | assert(exists(root_dir+'/'+image)) 67 | imglist.append(image.strip()) 68 | y.append(int(label)) 69 | return imglist, y 70 | 71 | 72 | def get_logger(folder): 73 | """ 74 | Initializes the logger for logging opts and intermediate checks in runs. 75 | Note: One logger per experiment where reruns get appended to the log. 76 | """ 77 | # global logger 78 | logger = logging.getLogger(__name__) 79 | logger.setLevel(logging.DEBUG) 80 | formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s") 81 | # file logger 82 | if not os.path.isdir(folder): 83 | os.makedirs(folder) 84 | fh = logging.FileHandler(os.path.join(folder, 'checkpoint.log'), mode='a') 85 | fh.setLevel(logging.INFO) 86 | fh.setFormatter(formatter) 87 | logger.addHandler(fh) 88 | # console logger 89 | ch = logging.StreamHandler() 90 | ch.setLevel(logging.DEBUG) 91 | ch.setFormatter(formatter) 92 | logger.addHandler(ch) 93 | return logger 94 | 95 | 96 | def seed_everything(seed): 97 | """ 98 | Fixes the class-to-task assignments and most other sources of randomness, except CUDA training aspects. 99 | """ 100 | # Avoid all sorts of randomness for better replication 101 | random.seed(seed) 102 | torch.manual_seed(seed) 103 | torch.cuda.manual_seed_all(seed) 104 | np.random.seed(seed) 105 | os.environ['PYTHONHASHSEED'] = str(seed) 106 | if torch.cuda.is_available(): 107 | torch.backends.cudnn.benchmark = True # An exemption for speed :P 108 | 109 | 110 | def save_model(opt, model): 111 | """ 112 | Used for saving the pretrained model, not for intermediate breaks in running the code. 113 | """ 114 | state = {'opt': opt, 115 | 'state_dict': model.state_dict()} 116 | filename = opt.log_dir+'/'+opt.exp_name+'/model.pt' 117 | torch.save(state, filename) -------------------------------------------------------------------------------- /src/knn/main.py: -------------------------------------------------------------------------------- 1 | import random, os, online_clfs, time, torch 2 | import numpy as np 3 | from opts import parse_args 4 | from sklearn.preprocessing import LabelEncoder 5 | 6 | 7 | def load_dataset(model, dataset): 8 | #pretrain_X = np.load(os.path.join(opt.feature_path, f"xcit_dino_{dataset}_pretrain_features.npy")) 9 | pretrain_y = np.load(os.path.join(opt.feature_path, f"xcit_dino_{dataset}_pretrain_labels.npy")) 10 | pretrain_X = None 11 | #pretrain_y = None 12 | train_X = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_train_features.npy")) 13 | train_y = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_train_labels.npy")) 14 | test_X = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_test_features.npy")) 15 | test_y = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_test_labels.npy")) 16 | 17 | # Checks 18 | #assert(pretrain_X.shape[0] == pretrain_y.shape[0]) 19 | assert(train_X.shape[0] == train_y.shape[0]) 20 | assert(test_X.shape[0] == test_y.shape[0]) 21 | #assert(pretrain_X.shape[1] == train_X.shape[1] == test_X.shape[1]) 22 | 23 | #print("Total pretrain rows in the dataset:", pretrain_X.shape[0]) 24 | print("Total train rows in the dataset:", train_X.shape[0]) 25 | print("Total test rows in the dataset:", test_X.shape[0]) 26 | 27 | # Normalize labels 28 | le = LabelEncoder() 29 | le.fit(np.concatenate((train_y, pretrain_y))) 30 | #le.fit(train_y) 31 | train_y = le.transform(train_y) 32 | test_y = le.transform(test_y) 33 | #pretrain_y = le.transform(pretrain_y) 34 | 35 | return pretrain_X, pretrain_y, train_X, train_y, test_X, test_y, le.classes_.shape[0] 36 | 37 | 38 | if __name__ == '__main__': 39 | # Parse arguments and init loggers 40 | opt = parse_args() 41 | random.seed(opt.seed) 42 | os.environ['PYTHONHASHSEED'] = str(opt.seed) 43 | np.random.seed(opt.seed) 44 | print('==> Params for this experiment:'+str(opt)) 45 | 46 | pretrain_X, pretrain_y, train_X, train_y, test_X, test_y, num_classes = load_dataset(opt.model, dataset=opt.dataset) 47 | opt.feature_dim, opt.num_classes = train_X.shape[1], num_classes 48 | 49 | normalizer = online_clfs.Normalizer(dim=opt.feature_dim) 50 | online_clf = getattr(online_clfs, opt.online_classifier)(opt=opt) 51 | 52 | predarr, labelarr, acc = np.zeros(train_y.shape[0], dtype='u2'), np.zeros(train_y.shape[0], dtype='u2'), np.zeros(train_y.shape[0], dtype='bool') 53 | start_time = time.time() 54 | 55 | for i in range(train_X.shape[0]- opt.delay): 56 | feat_learn = train_X[i] 57 | feat_pred = train_X[i+opt.delay] 58 | 59 | if opt.normalize_input: 60 | feat_learn = normalizer.update_and_transform(feat_learn) 61 | feat_pred = normalizer.transform(feat_pred) 62 | 63 | if i >= 256+opt.delay: # Slightly shifted ahead start point to warmup all classifiers, avoids weird jagged artifacts in plots 64 | if opt.online_classifier == 'ApproxKNearestNeighbours': 65 | if (i+1)%opt.update_k <= (opt.update_size-1): 66 | pred = online_clf.full_predict_step(x=feat_pred, y=train_y[i+opt.delay]) 67 | else: 68 | pred = online_clf.predict_step(x=feat_pred) 69 | if i>opt.update_k and (i+1)%opt.update_k == (opt.update_size-1+opt.delay): 70 | online_clf.deploy_num_neighbours() 71 | else: 72 | pred = online_clf.predict_step(x=feat_pred) 73 | 74 | predarr[i+opt.delay] = int(pred) 75 | labelarr[i+opt.delay] = int(train_y[i+opt.delay]) 76 | is_correct = (int(pred)==int(train_y[i+opt.delay])) 77 | acc[i+opt.delay] = is_correct*1.0 78 | 79 | if (i+1)%opt.print_freq == 0: 80 | cum_acc = np.array(acc[:i]).mean() 81 | print(f'Step:\t{i}\tCumul Acc:{cum_acc}') 82 | 83 | online_clf.learn_step(x=feat_learn, y=np.array([train_y[i]])) 84 | 85 | total_time = time.time() - start_time 86 | print(f'Total time taken: {total_time:.4f}') 87 | os.makedirs(opt.log_dir, exist_ok=True) 88 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_online_preds_{opt.online_exp_name}.npy"), predarr[256:]) 89 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_online_labels_{opt.online_exp_name}.npy"), labelarr[256:]) 90 | 91 | print('==> Testing..') 92 | start_time = time.time() 93 | 94 | preds = online_clf.predict_step(x=test_X) 95 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_test_preds_{opt.online_exp_name}.npy"), preds) 96 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_test_labels_{opt.online_exp_name}.npy"), test_y) 97 | -------------------------------------------------------------------------------- /src/knn/online_clfs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import threading, pickle, math, torch, scipy 4 | from sklearn.linear_model import SGDClassifier 5 | from sklearn.metrics.pairwise import cosine_similarity 6 | 7 | class Normalizer(): 8 | def __init__(self, size=None, dim=None, mean=None, unnormalized_var=None): 9 | assert(dim is not None) 10 | self.dim = dim 11 | if mean is not None: 12 | assert(size is not None and unnormalized_var is not None) 13 | self.size = size 14 | self.mean = mean 15 | self.var_unnormalized = unnormalized_var 16 | else: 17 | self.size = 0 18 | self.mean = np.zeros(dim) 19 | self.var_unnormalized = np.zeros(dim) 20 | 21 | def update_and_transform(self, x): 22 | if x.ndim == 1: 23 | x = x[np.newaxis, :] 24 | 25 | for idx in range(x.shape[0]): 26 | self.size += 1 27 | new_mean = self.mean + (x[idx] - self.mean)/self.size 28 | self.var_unnormalized = self.var_unnormalized + (x[idx] - self.mean)*(x[idx] - new_mean) 29 | self.mean = new_mean 30 | std = np.sqrt(self.var_unnormalized/(self.size-1)) 31 | x[idx] = x[idx] - self.mean/std 32 | return x 33 | 34 | def transform(self, x): 35 | if x.ndim == 1: 36 | x = x[np.newaxis, :] 37 | 38 | if self.size > 2: 39 | std = np.sqrt(self.var_unnormalized/(self.size-1)) 40 | return (x - self.mean)/std 41 | else: 42 | return x 43 | 44 | 45 | class OnlineLogisticClassification_VowpalWabbit(): 46 | def __init__(self, opt): 47 | from vowpalwabbit import Workspace 48 | from vowpalwabbit.dftovw import DFtoVW, SimpleLabel, Feature 49 | self.model = Workspace(oaa=opt.num_classes, loss_function='logistic', b=30, l=opt.lr) 50 | self.cols = ['label'] 51 | for i in range(opt.feature_dim): self.cols.append('f_'+str(i)) 52 | 53 | def learn_step(self, x, y): 54 | if x.ndim == 1: 55 | x = x[np.newaxis, :] 56 | df = pd.DataFrame(x, columns=self.cols[1:]) 57 | df['label'] = y[0].tolist() 58 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]]) 59 | example = feat.convert_df()[0] 60 | self.model.learn(example) 61 | 62 | def predict_step(self, x): 63 | if x.ndim == 1: 64 | x = x[np.newaxis, :] 65 | df = pd.DataFrame(x, columns=self.cols[1:]) 66 | df['label'] = [-1] 67 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]]) 68 | example = feat.convert_df()[0] 69 | return self.model.predict(example) 70 | 71 | 72 | class OnlineSVM_VowpalWabbit(): 73 | def __init__(self, opt): 74 | from vowpalwabbit import Workspace 75 | from vowpalwabbit.dftovw import DFtoVW, SimpleLabel, Feature 76 | self.model = Workspace(oaa=opt.num_classes, loss_function='hinge', b=30, l=opt.lr, l2=opt.wd) 77 | self.cols = ['label'] 78 | for i in range(opt.feature_dim): self.cols.append('f_'+str(i)) 79 | 80 | def learn_step(self, x, y): 81 | if x.ndim == 1: 82 | x = x[np.newaxis, :] 83 | df = pd.DataFrame(x, columns=self.cols[1:]) 84 | df['label'] = y[0].tolist() 85 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]]) 86 | example = feat.convert_df()[0] 87 | self.model.learn(example) 88 | 89 | def predict_step(self, x): 90 | if x.ndim == 1: 91 | x = x[np.newaxis, :] 92 | df = pd.DataFrame(x, columns=self.cols[1:]) 93 | df['label'] = [-1] 94 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]]) 95 | example = feat.convert_df()[0] 96 | return self.model.predict(example) 97 | 98 | 99 | class OnlineSVM_Scikit(): 100 | def __init__(self, opt): 101 | 102 | self.clf = SGDClassifier(loss='hinge', penalty='l2', alpha=opt.wd, fit_intercept=True, learning_rate='optimal', warm_start=True) 103 | 104 | def learn_step(self, x, y): 105 | if x.ndim == 1: 106 | x = x[np.newaxis, :] 107 | self.clf.partial_fit(x, y, classes=np.arange(self.num_classes)) 108 | 109 | def predict_step(self, x): 110 | if x.ndim == 1: 111 | x = x[np.newaxis, :] 112 | return self.clf.predict(x) 113 | 114 | 115 | class OnlineLogisticClassification_Scikit(): 116 | def __init__(self, opt): 117 | self.clf = SGDClassifier(loss='log_loss', penalty='l2', alpha=opt.wd, fit_intercept=True, learning_rate='optimal', warm_start=True) 118 | self.num_classes = opt.num_classes 119 | 120 | def learn_step(self, x, y): 121 | if x.ndim == 1: 122 | x = x[np.newaxis, :] 123 | self.clf.partial_fit(x, y, classes=np.arange(self.num_classes)) 124 | 125 | def predict_step(self, x): 126 | if x.ndim == 1: 127 | x = x[np.newaxis, :] 128 | return self.clf.predict(x) 129 | 130 | 131 | class HuberLossClassifier_Scikit(): 132 | def __init__(self, opt): 133 | self.clf = SGDClassifier(loss='modified_huber', penalty='l2', alpha=opt.wd, fit_intercept=True, learning_rate='optimal', warm_start=True) 134 | self.num_classes = opt.num_classes 135 | 136 | def learn_step(self, x, y): 137 | if x.ndim == 1: 138 | x = x[np.newaxis, :] 139 | self.clf.partial_fit(x, y, classes=np.arange(self.num_classes)) 140 | 141 | def predict_step(self, x): 142 | if x.ndim == 1: 143 | x = x[np.newaxis, :] 144 | return self.clf.predict(x) 145 | 146 | 147 | class ContextualMemoryTree(): 148 | def __init__(self, opt): 149 | from vowpalwabbit import Workspace 150 | from vowpalwabbit.dftovw import DFtoVW, SimpleLabel, Feature 151 | num_nodes = opt.num_classes/(np.log(opt.num_classes)/np.log(2)*10) 152 | self.model = Workspace("--memory_tree "+str(num_nodes)+" --max_number_of_labels "+str(opt.num_classes)+' --online --dream_at_update 1 --leaf_example_multiplier 10 --dream_repeats 12 --learn_at_leaf --alpha 0.1 -l '+str(opt.lr)+' -b 30 -c --loss_function squared --sort_features') 153 | self.cols = ['label'] 154 | for i in range(opt.feature_dim): self.cols.append('f_'+str(i)) 155 | 156 | def learn_step(self, x, y): 157 | if x.ndim == 1: 158 | x = x[np.newaxis, :] 159 | df = pd.DataFrame(x, columns=self.cols[1:]) 160 | df['label'] = y[0].tolist() 161 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]]) 162 | example = feat.convert_df()[0] 163 | self.model.learn(example) 164 | 165 | def predict_step(self, x): 166 | if x.ndim == 1: 167 | x = x[np.newaxis, :] 168 | df = pd.DataFrame(x, columns=self.cols[1:]) 169 | df['label'] = [-1] 170 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]]) 171 | example = feat.convert_df()[0] 172 | return self.model.predict(example) 173 | 174 | 175 | class KNearestNeighbours(): 176 | def __init__(self, opt): 177 | self.num_neighbours = opt.num_neighbours 178 | self.train_x, self.train_y = None, None 179 | self.num_neighbours = 1 180 | 181 | # Set distance function 182 | if opt.search_metric == 'cosine': 183 | self.dist = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 184 | elif opt.search_metric == 'l2': 185 | self.dist = torch.nn.PairwiseDistance(p=2) 186 | assert(opt.search_metric in ['cosine', 'l2']) 187 | 188 | 189 | def learn_step(self, x, y): 190 | with torch.no_grad(): 191 | if x.ndim == 1: 192 | x = x.unsqueeze(0) 193 | 194 | if self.train_x is not None: 195 | self.train_y = torch.cat((self.train_y, y), dim=0) 196 | self.train_x = torch.cat((self.train_x, x), dim=0) 197 | else: 198 | self.train_y = y 199 | self.train_x = x 200 | 201 | 202 | def predict_step(self, x): 203 | with torch.no_grad(): 204 | if x.ndim == 1: 205 | x = x.unsqueeze(0) 206 | 207 | _, idxes = torch.topk(self.dist(x, self.train_x), 1, largest=False) 208 | labels, _ = torch.mode(self.train_y[idxes], dim=1) 209 | 210 | return labels 211 | 212 | 213 | class ApproxKNearestNeighbours(): 214 | # https://raw.githubusercontent.com/nmslib/nmslib/master/manual/latex/manual.pdf 215 | def __init__(self, opt): 216 | import hnswlib 217 | self.index = hnswlib.Index(space=opt.search_metric, dim=opt.feature_dim) 218 | self.lock = threading.Lock() 219 | self.cur_idx = 0 220 | self.dset_size = 65536 221 | self.idx2label = np.zeros(self.dset_size, dtype=np.int16) 222 | self.index.init_index(max_elements=self.dset_size, ef_construction=opt.HNSW_ef, M=opt.HNSW_M) 223 | self.num_neighbours = opt.num_neighbours 224 | self.preds = [[], [], [], [], []] 225 | self.labels = [] 226 | 227 | def learn_step(self, x, y): 228 | if x.ndim == 1: 229 | x = x[np.newaxis, :] 230 | assert(x.shape[0]==y.shape[0]) 231 | 232 | num_added = x.shape[0] 233 | start_idx = self.cur_idx 234 | self.cur_idx += num_added 235 | 236 | if self.cur_idx >= self.dset_size - 2: 237 | with self.lock: 238 | self.dset_size = pow(2, math.ceil(math.log2(self.cur_idx))) 239 | self.dset_size *= 4 240 | self.index.resize_index(self.dset_size) 241 | 242 | new_idx2label = np.zeros(self.dset_size, dtype=np.int16) 243 | new_idx2label[:start_idx] = self.idx2label[:start_idx] 244 | self.idx2label = new_idx2label 245 | 246 | idx = np.arange(start_idx, start_idx + num_added) 247 | self.idx2label[start_idx:start_idx+num_added] = y 248 | 249 | self.index.add_items(data=x, ids=np.asarray(idx)) 250 | 251 | def set_ef(self, ef): 252 | self.index.set_ef(ef) 253 | 254 | def load_index(self, path): 255 | self.index.load_index(path) 256 | 257 | with open(path + ".pkl", "rb") as f: 258 | self.cur_idx, self.idx2label = pickle.load(f) 259 | 260 | def save_index(self, path): 261 | self.index.save_index(path) 262 | with open(path + ".pkl", "wb") as f: 263 | pickle.dump((self.cur_idx, self.idx2label), f) 264 | 265 | def set_num_threads(self, num_threads): 266 | self.index.set_num_threads(num_threads) 267 | 268 | def predict_step(self, x): 269 | # Note: y is used only for selecting k for the next step 270 | # Ideally this should be done in learn_step but to avoid computing neighbours twice, we do it here. 271 | if x.ndim == 1: 272 | x = x[np.newaxis, :] 273 | 274 | idxes, _ = self.index.knn_query(data=x, k=self.num_neighbours) 275 | 276 | neighbour_labels = self.idx2label[idxes] 277 | 278 | pred_labels, _ = scipy.stats.mode(neighbour_labels, axis=1) 279 | return pred_labels 280 | 281 | def full_predict_step(self, x, y): 282 | idxes, _ = self.index.knn_query(data=x, k=16) 283 | neighbour_labels = self.idx2label[idxes] 284 | 285 | for j in range(5): 286 | pred_labels, _ = scipy.stats.mode(neighbour_labels[:,:2**j], axis=1) 287 | self.preds[j].append(int(pred_labels)) 288 | if 2**j == self.num_neighbours: 289 | out_pred = pred_labels 290 | self.labels.append(y) 291 | 292 | return out_pred 293 | 294 | def deploy_num_neighbours(self): 295 | best_acc, best_k = -1, 0 296 | labels = np.array(self.labels) 297 | 298 | for j in range(5): 299 | preds = np.array(self.preds[j]) 300 | acc = ((preds == labels)*1.0).mean() 301 | if acc > best_acc: 302 | best_acc = acc 303 | best_k = 2**j 304 | 305 | self.num_neighbours = best_k 306 | print(self.num_neighbours) 307 | self.preds = [[], [], [], [], []] 308 | self.labels = [] 309 | 310 | 311 | class NearestClassMeanCosine(): 312 | def __init__(self, opt): 313 | with torch.no_grad(): 314 | # Class means is class sums, divided by number of samples 315 | self.class_sums = np.zeros((opt.num_classes, opt.feature_dim)) 316 | self.num_samples = np.zeros((opt.num_classes,1)) 317 | 318 | 319 | def learn_step(self, x, y): 320 | if x.ndim == 1: 321 | x = x[np.newaxis, :] 322 | 323 | with torch.no_grad(): 324 | # Update class mean and number of samples 325 | if self.num_samples.shape[0] <= y.shape[0]: 326 | for index in range(self.num_samples.shape[0]): 327 | if (y==index).sum() == 0: 328 | continue 329 | self.class_sums[index,:] += x[y==index].sum(axis=0) 330 | self.num_samples[index] += (y==index).sum() 331 | else: 332 | for index in range(y.shape[0]): 333 | self.class_sums[y[index],:] += x[index].squeeze() 334 | self.num_samples[y[index]] += 1 335 | 336 | 337 | def predict_step(self, x): 338 | if x.ndim == 1: 339 | x = x[np.newaxis, :] 340 | 341 | class_means = self.class_sums / (self.num_samples+1e-6) 342 | distances = cosine_similarity(x, class_means) 343 | return np.argmax(distances, axis=1) 344 | 345 | 346 | class NearestClassMeanL2(): 347 | def __init__(self, opt): 348 | with torch.no_grad(): 349 | # Class means is class sums, divided by number of samples 350 | self.class_sums = torch.zeros((opt.num_classes, opt.feature_dim)) 351 | self.num_samples = torch.zeros((opt.num_classes,1)) 352 | 353 | if opt.gpu: 354 | self.class_sums = self.class_sums.cuda() 355 | self.num_samples = self.num_samples.cuda() 356 | self.dist = self.dist.cuda() 357 | 358 | 359 | def learn_step(self, x, y): 360 | x = torch.from_numpy(x) 361 | if x.ndim == 1: 362 | x = x.unsqueeze(0) 363 | 364 | with torch.no_grad(): 365 | # Update class mean and number of samples 366 | if self.num_samples.shape[0] <= y.shape[0]: 367 | for index in range(self.num_samples.shape[0]): 368 | if (y==index).sum() == 0: 369 | continue 370 | self.class_sums[index,:] += x[y==index].sum(dim=0).squeeze() 371 | self.num_samples[index] += (y==index).sum() 372 | else: 373 | for index in range(y.shape[0]): 374 | self.class_sums[y[index],:] += x[index].squeeze() 375 | self.num_samples[y[index]] += 1 376 | 377 | 378 | def predict_step(self, x): 379 | x = torch.from_numpy(x) 380 | if x.ndim == 1: 381 | x = x.unsqueeze(0) 382 | 383 | with torch.no_grad(): 384 | class_means = (self.class_sums / (self.num_samples+1e-6)).unsqueeze(0) 385 | x = x.unsqueeze(0) 386 | distances = torch.cdist(x, class_means, p=2.0).squeeze(dim=0) 387 | distances = torch.where(distances!=0, distances, 1e5) 388 | return torch.argmin(distances, dim=1) 389 | 390 | 391 | class StreamingLinearDiscriminantAnalysis(): 392 | def __init__(self, opt): 393 | with torch.no_grad(): 394 | self.feature_dim = opt.feature_dim 395 | self.num_classes = opt.num_classes 396 | self.shrinkage_param = 1e-4 397 | self.streaming_update_sigma = True 398 | 399 | # setup weights for SLDA 400 | self.muK = torch.zeros((opt.num_classes, opt.feature_dim)) 401 | self.cK = torch.zeros(opt.num_classes) 402 | self.Sigma = torch.ones((opt.feature_dim, opt.feature_dim)) 403 | self.Lambda = torch.zeros_like(self.Sigma) 404 | self.num_updates = 0 405 | self.prev_num_updates = -1 406 | 407 | if opt.gpu: 408 | self.muK = self.muK.cuda() 409 | self.cK = self.cK.cuda() 410 | self.Sigma = self.Sigma.cuda() 411 | self.Lambda = self.Lambda.cuda() 412 | 413 | 414 | def learn_step(self, x, y): 415 | # make sure things are the right shape 416 | x = torch.from_numpy(x) 417 | if x.ndim == 1: 418 | x = x.unsqueeze(0) 419 | 420 | with torch.no_grad(): 421 | # covariance updates 422 | if self.streaming_update_sigma: 423 | x_minus_mu = (x - self.muK[y]) 424 | mult = torch.matmul(x_minus_mu.transpose(1, 0), x_minus_mu) 425 | delta = mult * self.num_updates / (self.num_updates + 1) 426 | self.Sigma = (self.num_updates * self.Sigma + delta) / (self.num_updates + 1) 427 | 428 | # update class means 429 | self.muK[y, :] += (x - self.muK[y, :]) / (self.cK[y] + 1).unsqueeze(1) 430 | self.cK[y] += 1 431 | self.num_updates += 1 432 | 433 | 434 | def predict_step(self, x): 435 | x = torch.from_numpy(x) 436 | if x.ndim == 1: 437 | x = x.unsqueeze(0) 438 | 439 | with torch.no_grad(): 440 | # initialize parameters for testing 441 | num_samples = x.shape[0] 442 | scores = torch.empty((num_samples, self.num_classes)) 443 | mb = num_samples 444 | 445 | # compute/load Lambda matrix 446 | if self.prev_num_updates != self.num_updates: 447 | # there have been updates to the model, compute Lambda 448 | # print('\nFirst predict since model update...computing Lambda matrix...') 449 | Lambda = torch.pinverse( 450 | (1 - self.shrinkage_param) * self.Sigma + self.shrinkage_param * torch.eye(self.feature_dim)).to(self.Lambda.device) 451 | self.Lambda = Lambda 452 | self.prev_num_updates = self.num_updates 453 | else: 454 | Lambda = self.Lambda 455 | 456 | # parameters for predictions 457 | M = self.muK.transpose(1, 0) 458 | W = torch.matmul(Lambda, M) 459 | c = 0.5 * torch.sum(M * W, dim=0) 460 | 461 | # loop in mini-batches over test samples 462 | for i in range(0, num_samples, mb): 463 | start = min(i, num_samples - mb) 464 | end = i + mb 465 | X = x[start:end] 466 | scores[start:end, :] = torch.matmul(X, W) - c 467 | 468 | # return predictions or probabilities 469 | return torch.argmax(scores, dim=1) 470 | 471 | 472 | def fit_base(self, x, y): 473 | print('\nFitting Base...') 474 | x = torch.from_numpy(x) 475 | # update class means 476 | for k in torch.unique(y): 477 | self.muK[k] = x[y == k].mean(0) 478 | self.cK[k] = x[y == k].shape[0] 479 | self.num_updates = x.shape[0] 480 | 481 | print('\nEstimating initial covariance matrix...') 482 | from sklearn.covariance import OAS 483 | cov_estimator = OAS(assume_centered=True) 484 | cov_estimator.fit((x - self.muK[y]).cpu().numpy()) 485 | self.Sigma = torch.from_numpy(cov_estimator.covariance_).float().to(self.Sigma.device) -------------------------------------------------------------------------------- /src/knn/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='main.py') 5 | parser.add_argument('--feature_path', type=str, default='../../../improved_knn/saved_features/', help='Full path to the directory where all logs are stored') 6 | parser.add_argument('--log_dir', type=str, default='../../logs/', help='Full path to the directory where all logs are stored') 7 | parser.add_argument('--model', type=str, default='xcit_dino', choices=['xcit_dino','r50','rsnet50_i1b'], help='Model used for classification') 8 | parser.add_argument('--online_classifier', type=str, default='ApproxKNearestNeighbours', help='Name of online classifier', choices=['OnlineLogisticClassification_VowpalWabbit', 9 | 'OnlineSVM_VowpalWabbit', 'OnlineSVM_Scikit', 'OnlineLogisticClassification_Scikit', 10 | 'HuberLossClassifier_Scikit', 'ContextualMemoryTree', 'KNearestNeighbours', 'ApproxKNearestNeighbours', 11 | 'NearestClassMeanCosine', 'NearestClassMeanL2', 'StreamingLinearDiscriminantAnalysis']) 12 | parser.add_argument('--dataset', type=str, default='cglm', help='Name of dataset', choices=['clear10', 'clear100', 'cglm', 'cloc']) 13 | parser.add_argument('--search_metric', type=str, default='cosine', choices=['cosine', 'l2'], help='Types of search') 14 | parser.add_argument('--HNSW_ef', type=int, default=200, help='Types of search') 15 | parser.add_argument('--HNSW_M', type=int, default=25, help='Types of search') 16 | parser.add_argument('--lr', type=float, default=2.0, help='Types of search') 17 | parser.add_argument('--wd', type=float, default=1e-4, help='Types of search') 18 | parser.add_argument('--normalize_input', action="store_true", help='Normalize the input to the search') 19 | parser.add_argument('--gpu', action="store_true", help='Peform online learning on GPU') 20 | parser.add_argument('--online_exp_name', type=str, default='test', help='Full path to the order file') 21 | parser.add_argument('--seed', type=int, default=0, help='Seed for reproducibility') 22 | parser.add_argument('--print_freq', type=int, default=5000, help='Printing utils') 23 | parser.add_argument('--num_neighbours', type=int, default=2, help='k for kNN') 24 | parser.add_argument('--update_k', type=int, default=1, help='Update k for kNN after these many samples') 25 | parser.add_argument('--update_size', type=int, default=1, help='Consider these many samples for accuracy calculation for k update') 26 | parser.add_argument('--delay', type=int, default=1, help='Delay for ACM') 27 | opt = parser.parse_args() 28 | return opt 29 | 30 | -------------------------------------------------------------------------------- /src/run_blind.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import numpy as np 3 | from scipy import stats 4 | from os.path import exists 5 | import h5py, os 6 | 7 | def load_filelist(filepath): 8 | imglist, y = [], [] 9 | with open(filepath,'r') as f: 10 | for line in f: 11 | label, _, _ = line.strip().split('\t') 12 | y.append(int(label)) 13 | return imglist, y 14 | 15 | 16 | def get_preds(k): 17 | pred = np.ones_like(labels)[1:] 18 | 19 | for i in range(1, len(labels)): 20 | if i <= k: 21 | pred[i-1] = labels[0] 22 | else: 23 | pred[i-1] = stats.mode(labels[i-(k):(i)])[0] 24 | np.save(logdir+'/'+dataset+'_blind_preds_'+str(k)+'.npy', pred) 25 | return pred 26 | 27 | 28 | if __name__ == '__main__': 29 | global logdir, labels 30 | dataset='CLOC' # Choose one from: ['CLEAR10', 'CLEAR100', 'CGLM', 'CLOC'] 31 | num_processes = 16 32 | path_to_cldatasets = '/media/bayesiankitten/Hanson/CLDatasets/data/' # Please set directory for labels in the CLDatasets folder 33 | logdir = '/media/bayesiankitten/Alexander/ACM/blind_logs/' # Please set directory for logs 34 | os.makedirs(logdir, exist_ok=True) 35 | 36 | # Load labels file 37 | # Note: When actually picking the k, use pretrain_labels.hdf5 file. This is for analysis plots, shown in the paper (both have same results, but different purpose). 38 | with h5py.File(f'{path_to_cldatasets}/{dataset}/order_files/train_labels.hdf5', 'r') as f: 39 | labels = np.array(f['store_list'])[1:] 40 | 41 | labels = np.array(labels, dtype=np.uint16) 42 | pred = np.ones_like(labels)[1:] 43 | 44 | # Get dataset mode, get performance to check dataset imbalance 45 | if not exists(logdir+'/'+dataset+'_mode.npy'): 46 | modelabel = stats.mode(labels)[0] 47 | pred = pred*modelabel 48 | np.save(logdir+'/'+dataset+'_mode.npy', pred) 49 | 50 | pred = np.load(logdir+'/'+dataset+'_mode.npy') 51 | gt = labels[1:] 52 | acc = np.equal(gt,pred)*1.0 53 | idx = np.arange(acc.shape[0])+1 54 | cumacc = np.cumsum(acc)/idx 55 | print(f'IID Best Classifier: {cumacc.mean()}') 56 | 57 | # Get blind classifier performance 58 | for k in [1, 2, 3, 5, 7, 10, 20, 25, 35, 50, 75, 100, 150, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 15000, 25000, 50000, 75000]: 59 | if not exists(logdir+'/'+dataset+'_blind_preds_'+str(k)+'.npy'): 60 | print(f'Processing mode of past: {k}') 61 | p = multiprocessing.Pool(num_processes) 62 | result = [] 63 | result.append(p.apply_async(get_preds, [k])) 64 | for r in result: 65 | r.wait() 66 | 67 | # Show results, see the degree of label correlation 68 | for k in [1, 2, 3, 5, 7, 10, 20, 25, 35, 50, 75, 100, 150, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 15000, 25000, 50000, 75000]: 69 | pred = np.load(logdir+'/'+dataset+'_blind_preds_'+str(k)+'.npy') 70 | gt = labels[1:] 71 | acc = np.equal(gt,pred)*1.0 72 | idx = np.arange(acc.shape[0])+1 73 | cumacc = np.cumsum(acc)/idx 74 | print(f'Blind Classifier @ {k}: {cumacc.mean()}') 75 | 76 | print('Extracted all blind classifier results!') 77 | --------------------------------------------------------------------------------