├── .gitignore
├── LICENSE
├── Model.png
├── README.md
├── scripts
└── cglm_scrape.py
└── src
├── features
├── main.py
├── opts.py
├── torchutils.py
└── trainutils.py
├── knn
├── main.py
├── online_clfs.py
└── opts.py
└── run_blind.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ameya Prabhu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/drimpossible/ACM/ccb70b138942f422ba25e58b26bbe24522140afc/Model.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ACM
2 |
3 | This repository contains the code for the paper:
4 |
5 | **Online Continual Learning Without the Storage Constraint**
6 | [Ameya Prabhu](https://drimpossible.github.io), [Zhipeng Cai](https://zhipengcai.github.io/), [Puneet Dokania](https://puneetkdokania.github.io), [Philip Torr](https://www.robots.ox.ac.uk/~phst/), [Vladlen Koltun](https://vladlen.info/), [Ozan Sener](https://ozansener.net/)
7 | [[Arxiv](https://arxiv.org/abs/2305.09253)]
8 | [[PDF](https://drimpossible.github.io/documents/ACM.pdf)]
9 | [[Bibtex](https://github.com/drimpossible/ACM/#citation)]
10 |
11 |
12 |
13 |
14 |
15 | ## Installation and Dependencies
16 |
17 | Our code was run on a 16GB RTX 3080Ti Laptop GPU with 64GB RAM and PyTorch >=1.13, although better GPU/RAM space will allow for faster experimentation.
18 |
19 | * Install all requirements required to run the code on a Python >=3.9 environment by:
20 | ```
21 | # First, activate a new virtual environment
22 | pip3 install -r requirements.txt
23 | ```
24 |
25 | ### Fast Dataset Setup
26 |
27 | - There is a fast, direct mechanism to download and use our datasets implemented in [this repository](https://github.com/hammoudhasan/CLDatasets).
28 | - Input the directory where the dataset was downloaded into `data_dir` field in `src/opts.py`.
29 | - All codes in this repository were run on this dataset.
30 |
31 | ## Recreating the Datasets
32 |
33 | - `YOUR_DATA_DIR` would contain two subfolders: `cglm` and `cloc`. Following are instructions to setup each dataset:
34 |
35 | ### Continual Google Landmarks V2 (CGLM)
36 |
37 | #### Download Images
38 |
39 | * You can download Continual Google Landmarks V2 dataset by following instructions on their Github repository, run in the `DATA_DIR` directory:
40 | ```
41 | wget -c https://raw.githubusercontent.com/cvdfoundation/google-landmark/master/download-dataset.sh
42 | mkdir train && cd train
43 | bash ../download-dataset.sh train 499
44 | ```
45 |
46 | #### Recreating Metadata
47 |
48 | * Download metadata by running the following commands in the `scripts` directory:
49 | ```
50 | wget -c https://s3.amazonaws.com/google-landmark/metadata/train_attribution.csv
51 | python cglm_scrape.py
52 | ```
53 | * Parse the XML files and organize it as a dictionary.
54 | * Ordering used in the paper is available to download [from here]().
55 | * Now, select only images that are a part of the order file and your dataset should be ready!
56 |
57 | ### Continual YFCC100M (CLOC)
58 |
59 | #### Extremely Fast Image Downloader
60 |
61 | * Download the `cloc.txt` file from [this link](https://www.robots.ox.ac.uk/~ameya/cloc.txt) inside the `YOUR_DATASET_DIR/cloc` directory.
62 | * The `cloc.txt` file contains 36.8M image links, removing missing/broken links from the original download file of CLOC.
63 | * Download the dataset parallely and scalably using img2dataset, finishes in Extracting features from idx '+str(offset)+'..')
71 | model.cuda()
72 | model.eval()
73 |
74 | if ftmodel is not None:
75 | ftmodel.cuda()
76 | ftmodel.eval()
77 |
78 | # We will collect predictions, labels and features in corresponding numpy arrays
79 | with torch.inference_mode():
80 | for (image, label, _, sel_idx, _) in loader:
81 | image = image.cuda(non_blocking=True)
82 | feat = model(image)
83 | if ftmodel is not None:
84 | feat = ftmodel.embed(feat)
85 | predprobs = ftmodel.fc(ftmodel.norm(feat))
86 | pred = torch.argmax(predprobs, dim=1)
87 | predarr[sel_idx%num_per_chunk] = pred.cpu().numpy()
88 |
89 | labelarr[sel_idx%num_per_chunk] = label.cpu().numpy()
90 | featarr[sel_idx%num_per_chunk] = feat.cpu().numpy()
91 |
92 | if ((sel_test_idx.max()+1)//opt.num_per_chunk) > opt.chunk_idx:
93 | np.save(opt.log_dir+'/'+opt.exp_name+f'/features_{chunk_idx}.npy', featarr)
94 | if ftmodel is not None: np.save(opt.log_dir+'/'+opt.exp_name+f'/preds_{chunk_idx}.npy', predarr)
95 | np.save(opt.log_dir+'/'+opt.exp_name+f'/labels_{chunk_idx}.npy', labelarr)
96 | return
97 |
98 |
99 | if __name__ == '__main__':
100 | # Parse arguments and init loggers
101 | torch.multiprocessing.set_sharing_strategy('file_system')
102 | opt = parse_args()
103 |
104 | opt.exp_name = f'{opt.dataset}_{opt.model}_{opt.embed_size}'
105 |
106 | console_logger = get_logger(folder=opt.log_dir+'/'+opt.exp_name+'/')
107 | console_logger.info('==> Params for this experiment:'+str(opt))
108 | seed_everything(opt.seed)
109 |
110 | if opt.model == 'resnet50':
111 | model = models.resnet50(weights="IMAGENET1K_V2")
112 | elif opt.model == 'resnet50_I1B':
113 | model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet50_swsl')
114 | elif opt.model == 'xcit_dino':
115 | model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
116 | if 'xcit' not in opt.model: model.fc = torch.nn.Identity()
117 |
118 | if opt.fc_only:
119 | for param in model.parameters():
120 | param.requires_grad = False
121 |
122 | dim = 512 if opt.model=='xcit_dino' else 2048
123 | ftmodel = EmbedLinearClassifier(dim=dim, embed_size=opt.embed_size, num_classes=opt.num_classes, cosfc=opt.cosine)
124 |
125 | if opt.mode=='AdaptedACM':
126 | assert(ftmodel is not None)
127 | if exists(opt.log_dir+'/'+opt.exp_name+'/ftmodel.pt'):
128 | ftmodel.load_state_dict(torch.load(opt.log_dir+'/'+opt.exp_name+'/ftmodel.pt'))
129 | else:
130 | AdaptedACM(opt=opt, model=model, ftmodel=ftmodel, logger=console_logger)
131 | ACM(opt=opt, model=model, ftmodel=ftmodel, logger=console_logger)
132 | elif opt.mode=='ACM':
133 | ACM(opt=opt, model=model, logger=console_logger)
--------------------------------------------------------------------------------
/src/features/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description='main.py')
5 | # Changing options -- Apart from these arguments, we do not mess with other arguments
6 |
7 | ## Paths
8 | parser.add_argument('--data_dir', type=str, default='TYPE HERE', help='Full path to directory where all datasets are stored')
9 | parser.add_argument('--log_dir', type=str, default='TYPE HERE', help='Full path to the directory where all logs are stored')
10 | parser.add_argument('--order_file_dir', type=str, default='TYPE HERE', help='Full path to the order file')
11 |
12 | ## Dataset
13 | parser.add_argument('--dataset', type=str, default='cglm', help='Name of Dataset', choices=['cglm', 'imagenet', 'cyfcc'])
14 | parser.add_argument('--num_classes', type=int, default=713, choices=[10788, 1000, 713], help='Number of number of classes')
15 | parser.add_argument('--train_batch_size', type=int, default=896, help='Batch size to be used in training')
16 | parser.add_argument('--test_batch_size', type=int, default=4608, help='Batch size to be used in training')
17 | parser.add_argument('--total_iterations', type=int, default=1800, help='Fraction of total dataset used for ptraining')
18 |
19 | ## Model
20 | parser.add_argument('--model', type=str, default='resnet50', choices=['resnet50','resnet50_random','resnet50_I1B','resnet50_dino','xcit_dino','resnet50_V2'], help='Model architecture')
21 | parser.add_argument('--mode', type=str, default='ACM', choices=['AdaptedACM', 'ACM', 'ER'], help='Training type')
22 | parser.add_argument('--embed_size', type=int, default=512, help='Embedding dimensions for retrieval tasks')
23 |
24 | ## Experiment Deets
25 | parser.add_argument('--exp_name', type=str, default='test', help='Full path to the order file')
26 | parser.add_argument('--maxlr', type=float, default=0.2, help='Starting Learning rate')
27 | parser.add_argument('--num_gpus', type=int, default=8, help="Number of GPUs used in training")
28 |
29 | # Default options
30 | parser.add_argument('--seed', type=int, default=0, help='Seed for reproducibility')
31 | parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
32 | parser.add_argument('--clip', type=float, default=2.0, help='Gradient Clipped if val >= clip')
33 | parser.add_argument('--num_workers', type=int, default=8, help='Starting Learning rate')
34 | parser.add_argument('--print_freq', type=int, default=1000, help='Printing utils')
35 | parser.add_argument('--prefrac', type=float, default=0.2, help='Fraction of total dataset used for pretraining (in CGLM)')
36 | parser.add_argument('--chunk_idx', type=int, default=0, help='Parallelize ACM running by sending multiple jobs which compute it on parts')
37 | parser.add_argument('--num_per_chunk', type=int, default=524288, help='Number of Chunks to Subdivide Data to')
38 | parser.add_argument('--num_gdsteps', type=int, default=1, help='Number of gradient descent steps')
39 | parser.add_argument('--extract_feats',action='store_true')
40 |
41 | parser.add_argument('--fc_only',action='store_true')
42 | parser.add_argument('--fc', type=str, default=None, help='Full path to the order file')
43 |
44 | parser.add_argument('--sampler', type=str, default='mixed', help='Full path to the order file')
45 |
46 | parser.add_argument('--delay', type=int, default=0, help='Sets delay in terms of training samples.')
47 | parser.add_argument('--cosine',action='store_true')
48 |
49 |
50 |
51 | opt = parser.parse_args()
52 | return opt
53 |
--------------------------------------------------------------------------------
/src/features/torchutils.py:
--------------------------------------------------------------------------------
1 |
2 | import h5py
3 | import torch, math
4 | from PIL import Image
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.optim.lr_scheduler import _LRScheduler
8 | from torch.utils.data import Dataset, DataLoader
9 | import numpy as np
10 |
11 | class ListFolder(Dataset):
12 | def __init__(self, root, image_paths, labels_path, routing_idx_path=None, mask_values_path=None, isACE=False, offset=0, is_train=False, transform=None):
13 | super(ListFolder, self).__init__()
14 | # Get image list and labels per index
15 | self.root = root
16 | self.offset = offset
17 | self.image_paths = h5py.File(image_paths, "r")["store_list"]
18 | self.labels = h5py.File(labels_path, "r")["store_list"]
19 | self.is_train = is_train
20 | self.isACE = isACE
21 |
22 | if self.is_train:
23 | self.routing_idx = h5py.File(routing_idx_path, "r")["store_list"]
24 | if self.isACE:
25 | self.mask_values = h5py.File(mask_values_path, "r")["store_list"]
26 |
27 | self.transform = transform
28 | assert(len(self.image_paths)==len(self.labels))
29 |
30 | def __getitem__(self, index):
31 | # get the index from the routing index
32 | sample_mask_value = np.array([1.0])
33 | if self.is_train:
34 | sel_idx = self.routing_idx[index+self.offset]
35 | if self.isACE:
36 | sample_mask_value = self.mask_values[index+self.offset]
37 | else:
38 | sel_idx = index+self.offset
39 |
40 | # get the corresponding image and label
41 | img_path = self.root+'/'+self.image_paths[sel_idx].decode("utf-8").strip()
42 | label = self.labels[sel_idx]
43 | sample = pil_loader(img_path)
44 |
45 | if self.transform is not None:
46 | sample = self.transform(sample)
47 |
48 | return sample, label, sample_mask_value, sel_idx, index
49 |
50 | def __len__(self):
51 | assert(len(self.image_paths)==len(self.labels)), 'Length of image path array and labels different'
52 | if self.is_train:
53 | return len(self.routing_idx)-self.offset
54 | else:
55 | return len(self.labels) -self.offset
56 |
57 |
58 | class LinearLR(_LRScheduler):
59 | r"""Set the learning rate of each parameter group with a linear
60 | schedule: :math:`\eta_{t} = \eta_0*(1 - t/T)`, where :math:`\eta_0` is the
61 | initial lr, :math:`t` is the current epoch or iteration (zero-based) and
62 | :math:`T` is the total training epochs or iterations. It is recommended to
63 | use the iteration based calculation if the total number of epochs is small.
64 | When last_epoch=-1, sets initial lr as lr.
65 | It is studied in
66 | `Budgeted Training: Rethinking Deep Neural Network Training Under Resource
67 | Constraints`_.
68 |
69 | Args:
70 | optimizer (Optimizer): Wrapped optimizer.
71 | T (int): Total number of training epochs or iterations.
72 | last_epoch (int): The index of last epoch or iteration. Default: -1.
73 |
74 | .. _Budgeted Training\: Rethinking Deep Neural Network Training Under
75 | Resource Constraints:
76 | https://arxiv.org/abs/1905.04753
77 | """
78 |
79 | def __init__(self, optimizer, T, last_epoch=-1):
80 | self.T = float(T)
81 | super(LinearLR, self).__init__(optimizer, last_epoch)
82 |
83 | def get_lr(self):
84 | rate = 1 - self.last_epoch/self.T
85 | return [rate*base_lr for base_lr in self.base_lrs]
86 |
87 | def _get_closed_form_lr(self):
88 | return self.get_lr()
89 |
90 |
91 | class CosineLinear(nn.Module):
92 | def __init__(self, in_features, out_features, sigma=True):
93 | super(CosineLinear, self).__init__()
94 | self.in_features = in_features
95 | self.out_features = out_features
96 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
97 | self.bias = None
98 | if sigma:
99 | self.sigma = nn.Parameter(torch.Tensor(1))
100 | self.reset_parameters()
101 |
102 | def reset_parameters(self):
103 | stdv = 1. / math.sqrt(self.weight.size(1))
104 | self.weight.data.uniform_(-stdv, stdv)
105 | if self.sigma is not None:
106 | self.sigma.data.fill_(1)
107 |
108 | def forward(self, input):
109 | out = F.linear(F.normalize(input, p=2,dim=1), \
110 | F.normalize(self.weight, p=2, dim=1))
111 | if self.sigma is not None:
112 | out = self.sigma * out
113 |
114 | return out
115 |
116 |
117 | class EmbedLinearClassifier(nn.Module):
118 | """Linear layer to train on top of frozen features"""
119 | def __init__(self, dim, embed_size, num_classes, cosfc=False):
120 | super(EmbedLinearClassifier, self).__init__()
121 | self.embed = None
122 | self.embed = nn.Linear(dim, embed_size)
123 | self.norm = nn.Sequential(nn.BatchNorm1d(embed_size), nn.ReLU(inplace=True))
124 | self.embed.weight.data.normal_(mean=0.0, std=0.01)
125 | self.embed.bias.data.zero_()
126 | if cosfc:
127 | self.fc = CosineLinear(embed_size, num_classes)
128 | else:
129 | self.fc = nn.Linear(embed_size, num_classes)
130 | self.fc.weight.data.normal_(mean=0.0, std=0.01)
131 | self.fc.bias.data.zero_()
132 |
133 | def forward(self, x):
134 | # flatten
135 | x = x.view(x.size(0), -1)
136 | if self.embed is not None:
137 | x = self.embed(x)
138 | x = self.norm(x)
139 | return self.fc(x)
140 |
141 |
142 | def pil_loader(path):
143 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
144 | with open(path, 'rb') as f:
145 | img = Image.open(f)
146 | return img.convert('RGB')
147 |
--------------------------------------------------------------------------------
/src/features/trainutils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import torch, os, logging, random
4 | from os.path import isfile, isdir, join, exists
5 |
6 |
7 | def make_data(fname):
8 | """
9 | """
10 | # read data
11 | fval = open(fname, 'r')
12 | lines_val = fval.readlines()
13 | labels = [None] * len(lines_val)
14 | time = [None] * len(lines_val)
15 | user = [None] * len(lines_val)
16 | store_loc = [None] * len(lines_val)
17 | for i in range(len(lines_val)):
18 | line_splitted = lines_val[i].split(',')
19 | labels[i] = int(line_splitted[0])
20 | time[i] = int(line_splitted[2])
21 | user[i] = line_splitted[3]
22 | store_loc[i] = line_splitted[-1][:-1]
23 | return store_loc, labels, time, user
24 |
25 |
26 | def load_images(image_paths):
27 | """
28 | """
29 | imgs = []
30 | for i in range(len(image_paths)):
31 | img = pil_loader(image_paths[i])
32 | imgs.append(img)
33 | return imgs
34 |
35 |
36 | def convert_imgfolder_to_imglist(data_dir, dataset, split):
37 | """
38 | Takes an input datset directory, and converts
39 | """
40 | classes = [f for f in os.listdir(data_dir+'/'+dataset+'/'+split+'/') if isdir(join(data_dir+'/'+dataset+'/'+split,f))]
41 | img_paths, labels = [], []
42 | count = 0
43 |
44 | for cls in classes:
45 | folder = join(data_dir+'/'+dataset+'/'+split+'/', cls)
46 | imgs = ['/'+split+'/'+cls+'/'+f for f in os.listdir(folder) if (isfile(join(folder,f)))]
47 | lbls = [count for img in imgs]
48 | assert(len(imgs)==len(lbls))
49 | count+=1
50 | img_paths.extend(imgs)
51 | labels.extend(lbls)
52 |
53 | assert(len(img_paths)==len(labels))
54 | return img_paths, labels
55 |
56 |
57 | def load_filelist(filepath, root_dir, check=False):
58 | """
59 | Checks whether images exist, if yes then load it into the filelist. Do check=True for the first run, then turn it off
60 | """
61 | imglist, y = [], []
62 | with open(filepath,'r') as f:
63 | for line in f:
64 | label, image, _ = line.strip().split('\t')
65 | if check:
66 | assert(exists(root_dir+'/'+image))
67 | imglist.append(image.strip())
68 | y.append(int(label))
69 | return imglist, y
70 |
71 |
72 | def get_logger(folder):
73 | """
74 | Initializes the logger for logging opts and intermediate checks in runs.
75 | Note: One logger per experiment where reruns get appended to the log.
76 | """
77 | # global logger
78 | logger = logging.getLogger(__name__)
79 | logger.setLevel(logging.DEBUG)
80 | formatter = logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s")
81 | # file logger
82 | if not os.path.isdir(folder):
83 | os.makedirs(folder)
84 | fh = logging.FileHandler(os.path.join(folder, 'checkpoint.log'), mode='a')
85 | fh.setLevel(logging.INFO)
86 | fh.setFormatter(formatter)
87 | logger.addHandler(fh)
88 | # console logger
89 | ch = logging.StreamHandler()
90 | ch.setLevel(logging.DEBUG)
91 | ch.setFormatter(formatter)
92 | logger.addHandler(ch)
93 | return logger
94 |
95 |
96 | def seed_everything(seed):
97 | """
98 | Fixes the class-to-task assignments and most other sources of randomness, except CUDA training aspects.
99 | """
100 | # Avoid all sorts of randomness for better replication
101 | random.seed(seed)
102 | torch.manual_seed(seed)
103 | torch.cuda.manual_seed_all(seed)
104 | np.random.seed(seed)
105 | os.environ['PYTHONHASHSEED'] = str(seed)
106 | if torch.cuda.is_available():
107 | torch.backends.cudnn.benchmark = True # An exemption for speed :P
108 |
109 |
110 | def save_model(opt, model):
111 | """
112 | Used for saving the pretrained model, not for intermediate breaks in running the code.
113 | """
114 | state = {'opt': opt,
115 | 'state_dict': model.state_dict()}
116 | filename = opt.log_dir+'/'+opt.exp_name+'/model.pt'
117 | torch.save(state, filename)
--------------------------------------------------------------------------------
/src/knn/main.py:
--------------------------------------------------------------------------------
1 | import random, os, online_clfs, time, torch
2 | import numpy as np
3 | from opts import parse_args
4 | from sklearn.preprocessing import LabelEncoder
5 |
6 |
7 | def load_dataset(model, dataset):
8 | #pretrain_X = np.load(os.path.join(opt.feature_path, f"xcit_dino_{dataset}_pretrain_features.npy"))
9 | pretrain_y = np.load(os.path.join(opt.feature_path, f"xcit_dino_{dataset}_pretrain_labels.npy"))
10 | pretrain_X = None
11 | #pretrain_y = None
12 | train_X = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_train_features.npy"))
13 | train_y = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_train_labels.npy"))
14 | test_X = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_test_features.npy"))
15 | test_y = np.load(os.path.join(opt.feature_path, f"{model}_{dataset}_test_labels.npy"))
16 |
17 | # Checks
18 | #assert(pretrain_X.shape[0] == pretrain_y.shape[0])
19 | assert(train_X.shape[0] == train_y.shape[0])
20 | assert(test_X.shape[0] == test_y.shape[0])
21 | #assert(pretrain_X.shape[1] == train_X.shape[1] == test_X.shape[1])
22 |
23 | #print("Total pretrain rows in the dataset:", pretrain_X.shape[0])
24 | print("Total train rows in the dataset:", train_X.shape[0])
25 | print("Total test rows in the dataset:", test_X.shape[0])
26 |
27 | # Normalize labels
28 | le = LabelEncoder()
29 | le.fit(np.concatenate((train_y, pretrain_y)))
30 | #le.fit(train_y)
31 | train_y = le.transform(train_y)
32 | test_y = le.transform(test_y)
33 | #pretrain_y = le.transform(pretrain_y)
34 |
35 | return pretrain_X, pretrain_y, train_X, train_y, test_X, test_y, le.classes_.shape[0]
36 |
37 |
38 | if __name__ == '__main__':
39 | # Parse arguments and init loggers
40 | opt = parse_args()
41 | random.seed(opt.seed)
42 | os.environ['PYTHONHASHSEED'] = str(opt.seed)
43 | np.random.seed(opt.seed)
44 | print('==> Params for this experiment:'+str(opt))
45 |
46 | pretrain_X, pretrain_y, train_X, train_y, test_X, test_y, num_classes = load_dataset(opt.model, dataset=opt.dataset)
47 | opt.feature_dim, opt.num_classes = train_X.shape[1], num_classes
48 |
49 | normalizer = online_clfs.Normalizer(dim=opt.feature_dim)
50 | online_clf = getattr(online_clfs, opt.online_classifier)(opt=opt)
51 |
52 | predarr, labelarr, acc = np.zeros(train_y.shape[0], dtype='u2'), np.zeros(train_y.shape[0], dtype='u2'), np.zeros(train_y.shape[0], dtype='bool')
53 | start_time = time.time()
54 |
55 | for i in range(train_X.shape[0]- opt.delay):
56 | feat_learn = train_X[i]
57 | feat_pred = train_X[i+opt.delay]
58 |
59 | if opt.normalize_input:
60 | feat_learn = normalizer.update_and_transform(feat_learn)
61 | feat_pred = normalizer.transform(feat_pred)
62 |
63 | if i >= 256+opt.delay: # Slightly shifted ahead start point to warmup all classifiers, avoids weird jagged artifacts in plots
64 | if opt.online_classifier == 'ApproxKNearestNeighbours':
65 | if (i+1)%opt.update_k <= (opt.update_size-1):
66 | pred = online_clf.full_predict_step(x=feat_pred, y=train_y[i+opt.delay])
67 | else:
68 | pred = online_clf.predict_step(x=feat_pred)
69 | if i>opt.update_k and (i+1)%opt.update_k == (opt.update_size-1+opt.delay):
70 | online_clf.deploy_num_neighbours()
71 | else:
72 | pred = online_clf.predict_step(x=feat_pred)
73 |
74 | predarr[i+opt.delay] = int(pred)
75 | labelarr[i+opt.delay] = int(train_y[i+opt.delay])
76 | is_correct = (int(pred)==int(train_y[i+opt.delay]))
77 | acc[i+opt.delay] = is_correct*1.0
78 |
79 | if (i+1)%opt.print_freq == 0:
80 | cum_acc = np.array(acc[:i]).mean()
81 | print(f'Step:\t{i}\tCumul Acc:{cum_acc}')
82 |
83 | online_clf.learn_step(x=feat_learn, y=np.array([train_y[i]]))
84 |
85 | total_time = time.time() - start_time
86 | print(f'Total time taken: {total_time:.4f}')
87 | os.makedirs(opt.log_dir, exist_ok=True)
88 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_online_preds_{opt.online_exp_name}.npy"), predarr[256:])
89 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_online_labels_{opt.online_exp_name}.npy"), labelarr[256:])
90 |
91 | print('==> Testing..')
92 | start_time = time.time()
93 |
94 | preds = online_clf.predict_step(x=test_X)
95 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_test_preds_{opt.online_exp_name}.npy"), preds)
96 | np.save(os.path.join(opt.log_dir, f"{opt.model}_{opt.dataset}_{opt.online_classifier}_{opt.lr}_{opt.wd}_test_labels_{opt.online_exp_name}.npy"), test_y)
97 |
--------------------------------------------------------------------------------
/src/knn/online_clfs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import threading, pickle, math, torch, scipy
4 | from sklearn.linear_model import SGDClassifier
5 | from sklearn.metrics.pairwise import cosine_similarity
6 |
7 | class Normalizer():
8 | def __init__(self, size=None, dim=None, mean=None, unnormalized_var=None):
9 | assert(dim is not None)
10 | self.dim = dim
11 | if mean is not None:
12 | assert(size is not None and unnormalized_var is not None)
13 | self.size = size
14 | self.mean = mean
15 | self.var_unnormalized = unnormalized_var
16 | else:
17 | self.size = 0
18 | self.mean = np.zeros(dim)
19 | self.var_unnormalized = np.zeros(dim)
20 |
21 | def update_and_transform(self, x):
22 | if x.ndim == 1:
23 | x = x[np.newaxis, :]
24 |
25 | for idx in range(x.shape[0]):
26 | self.size += 1
27 | new_mean = self.mean + (x[idx] - self.mean)/self.size
28 | self.var_unnormalized = self.var_unnormalized + (x[idx] - self.mean)*(x[idx] - new_mean)
29 | self.mean = new_mean
30 | std = np.sqrt(self.var_unnormalized/(self.size-1))
31 | x[idx] = x[idx] - self.mean/std
32 | return x
33 |
34 | def transform(self, x):
35 | if x.ndim == 1:
36 | x = x[np.newaxis, :]
37 |
38 | if self.size > 2:
39 | std = np.sqrt(self.var_unnormalized/(self.size-1))
40 | return (x - self.mean)/std
41 | else:
42 | return x
43 |
44 |
45 | class OnlineLogisticClassification_VowpalWabbit():
46 | def __init__(self, opt):
47 | from vowpalwabbit import Workspace
48 | from vowpalwabbit.dftovw import DFtoVW, SimpleLabel, Feature
49 | self.model = Workspace(oaa=opt.num_classes, loss_function='logistic', b=30, l=opt.lr)
50 | self.cols = ['label']
51 | for i in range(opt.feature_dim): self.cols.append('f_'+str(i))
52 |
53 | def learn_step(self, x, y):
54 | if x.ndim == 1:
55 | x = x[np.newaxis, :]
56 | df = pd.DataFrame(x, columns=self.cols[1:])
57 | df['label'] = y[0].tolist()
58 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]])
59 | example = feat.convert_df()[0]
60 | self.model.learn(example)
61 |
62 | def predict_step(self, x):
63 | if x.ndim == 1:
64 | x = x[np.newaxis, :]
65 | df = pd.DataFrame(x, columns=self.cols[1:])
66 | df['label'] = [-1]
67 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]])
68 | example = feat.convert_df()[0]
69 | return self.model.predict(example)
70 |
71 |
72 | class OnlineSVM_VowpalWabbit():
73 | def __init__(self, opt):
74 | from vowpalwabbit import Workspace
75 | from vowpalwabbit.dftovw import DFtoVW, SimpleLabel, Feature
76 | self.model = Workspace(oaa=opt.num_classes, loss_function='hinge', b=30, l=opt.lr, l2=opt.wd)
77 | self.cols = ['label']
78 | for i in range(opt.feature_dim): self.cols.append('f_'+str(i))
79 |
80 | def learn_step(self, x, y):
81 | if x.ndim == 1:
82 | x = x[np.newaxis, :]
83 | df = pd.DataFrame(x, columns=self.cols[1:])
84 | df['label'] = y[0].tolist()
85 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]])
86 | example = feat.convert_df()[0]
87 | self.model.learn(example)
88 |
89 | def predict_step(self, x):
90 | if x.ndim == 1:
91 | x = x[np.newaxis, :]
92 | df = pd.DataFrame(x, columns=self.cols[1:])
93 | df['label'] = [-1]
94 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]])
95 | example = feat.convert_df()[0]
96 | return self.model.predict(example)
97 |
98 |
99 | class OnlineSVM_Scikit():
100 | def __init__(self, opt):
101 |
102 | self.clf = SGDClassifier(loss='hinge', penalty='l2', alpha=opt.wd, fit_intercept=True, learning_rate='optimal', warm_start=True)
103 |
104 | def learn_step(self, x, y):
105 | if x.ndim == 1:
106 | x = x[np.newaxis, :]
107 | self.clf.partial_fit(x, y, classes=np.arange(self.num_classes))
108 |
109 | def predict_step(self, x):
110 | if x.ndim == 1:
111 | x = x[np.newaxis, :]
112 | return self.clf.predict(x)
113 |
114 |
115 | class OnlineLogisticClassification_Scikit():
116 | def __init__(self, opt):
117 | self.clf = SGDClassifier(loss='log_loss', penalty='l2', alpha=opt.wd, fit_intercept=True, learning_rate='optimal', warm_start=True)
118 | self.num_classes = opt.num_classes
119 |
120 | def learn_step(self, x, y):
121 | if x.ndim == 1:
122 | x = x[np.newaxis, :]
123 | self.clf.partial_fit(x, y, classes=np.arange(self.num_classes))
124 |
125 | def predict_step(self, x):
126 | if x.ndim == 1:
127 | x = x[np.newaxis, :]
128 | return self.clf.predict(x)
129 |
130 |
131 | class HuberLossClassifier_Scikit():
132 | def __init__(self, opt):
133 | self.clf = SGDClassifier(loss='modified_huber', penalty='l2', alpha=opt.wd, fit_intercept=True, learning_rate='optimal', warm_start=True)
134 | self.num_classes = opt.num_classes
135 |
136 | def learn_step(self, x, y):
137 | if x.ndim == 1:
138 | x = x[np.newaxis, :]
139 | self.clf.partial_fit(x, y, classes=np.arange(self.num_classes))
140 |
141 | def predict_step(self, x):
142 | if x.ndim == 1:
143 | x = x[np.newaxis, :]
144 | return self.clf.predict(x)
145 |
146 |
147 | class ContextualMemoryTree():
148 | def __init__(self, opt):
149 | from vowpalwabbit import Workspace
150 | from vowpalwabbit.dftovw import DFtoVW, SimpleLabel, Feature
151 | num_nodes = opt.num_classes/(np.log(opt.num_classes)/np.log(2)*10)
152 | self.model = Workspace("--memory_tree "+str(num_nodes)+" --max_number_of_labels "+str(opt.num_classes)+' --online --dream_at_update 1 --leaf_example_multiplier 10 --dream_repeats 12 --learn_at_leaf --alpha 0.1 -l '+str(opt.lr)+' -b 30 -c --loss_function squared --sort_features')
153 | self.cols = ['label']
154 | for i in range(opt.feature_dim): self.cols.append('f_'+str(i))
155 |
156 | def learn_step(self, x, y):
157 | if x.ndim == 1:
158 | x = x[np.newaxis, :]
159 | df = pd.DataFrame(x, columns=self.cols[1:])
160 | df['label'] = y[0].tolist()
161 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]])
162 | example = feat.convert_df()[0]
163 | self.model.learn(example)
164 |
165 | def predict_step(self, x):
166 | if x.ndim == 1:
167 | x = x[np.newaxis, :]
168 | df = pd.DataFrame(x, columns=self.cols[1:])
169 | df['label'] = [-1]
170 | feat = DFtoVW(df=df, label=SimpleLabel('label'), features=[Feature(col) for col in self.cols[1:]])
171 | example = feat.convert_df()[0]
172 | return self.model.predict(example)
173 |
174 |
175 | class KNearestNeighbours():
176 | def __init__(self, opt):
177 | self.num_neighbours = opt.num_neighbours
178 | self.train_x, self.train_y = None, None
179 | self.num_neighbours = 1
180 |
181 | # Set distance function
182 | if opt.search_metric == 'cosine':
183 | self.dist = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
184 | elif opt.search_metric == 'l2':
185 | self.dist = torch.nn.PairwiseDistance(p=2)
186 | assert(opt.search_metric in ['cosine', 'l2'])
187 |
188 |
189 | def learn_step(self, x, y):
190 | with torch.no_grad():
191 | if x.ndim == 1:
192 | x = x.unsqueeze(0)
193 |
194 | if self.train_x is not None:
195 | self.train_y = torch.cat((self.train_y, y), dim=0)
196 | self.train_x = torch.cat((self.train_x, x), dim=0)
197 | else:
198 | self.train_y = y
199 | self.train_x = x
200 |
201 |
202 | def predict_step(self, x):
203 | with torch.no_grad():
204 | if x.ndim == 1:
205 | x = x.unsqueeze(0)
206 |
207 | _, idxes = torch.topk(self.dist(x, self.train_x), 1, largest=False)
208 | labels, _ = torch.mode(self.train_y[idxes], dim=1)
209 |
210 | return labels
211 |
212 |
213 | class ApproxKNearestNeighbours():
214 | # https://raw.githubusercontent.com/nmslib/nmslib/master/manual/latex/manual.pdf
215 | def __init__(self, opt):
216 | import hnswlib
217 | self.index = hnswlib.Index(space=opt.search_metric, dim=opt.feature_dim)
218 | self.lock = threading.Lock()
219 | self.cur_idx = 0
220 | self.dset_size = 65536
221 | self.idx2label = np.zeros(self.dset_size, dtype=np.int16)
222 | self.index.init_index(max_elements=self.dset_size, ef_construction=opt.HNSW_ef, M=opt.HNSW_M)
223 | self.num_neighbours = opt.num_neighbours
224 | self.preds = [[], [], [], [], []]
225 | self.labels = []
226 |
227 | def learn_step(self, x, y):
228 | if x.ndim == 1:
229 | x = x[np.newaxis, :]
230 | assert(x.shape[0]==y.shape[0])
231 |
232 | num_added = x.shape[0]
233 | start_idx = self.cur_idx
234 | self.cur_idx += num_added
235 |
236 | if self.cur_idx >= self.dset_size - 2:
237 | with self.lock:
238 | self.dset_size = pow(2, math.ceil(math.log2(self.cur_idx)))
239 | self.dset_size *= 4
240 | self.index.resize_index(self.dset_size)
241 |
242 | new_idx2label = np.zeros(self.dset_size, dtype=np.int16)
243 | new_idx2label[:start_idx] = self.idx2label[:start_idx]
244 | self.idx2label = new_idx2label
245 |
246 | idx = np.arange(start_idx, start_idx + num_added)
247 | self.idx2label[start_idx:start_idx+num_added] = y
248 |
249 | self.index.add_items(data=x, ids=np.asarray(idx))
250 |
251 | def set_ef(self, ef):
252 | self.index.set_ef(ef)
253 |
254 | def load_index(self, path):
255 | self.index.load_index(path)
256 |
257 | with open(path + ".pkl", "rb") as f:
258 | self.cur_idx, self.idx2label = pickle.load(f)
259 |
260 | def save_index(self, path):
261 | self.index.save_index(path)
262 | with open(path + ".pkl", "wb") as f:
263 | pickle.dump((self.cur_idx, self.idx2label), f)
264 |
265 | def set_num_threads(self, num_threads):
266 | self.index.set_num_threads(num_threads)
267 |
268 | def predict_step(self, x):
269 | # Note: y is used only for selecting k for the next step
270 | # Ideally this should be done in learn_step but to avoid computing neighbours twice, we do it here.
271 | if x.ndim == 1:
272 | x = x[np.newaxis, :]
273 |
274 | idxes, _ = self.index.knn_query(data=x, k=self.num_neighbours)
275 |
276 | neighbour_labels = self.idx2label[idxes]
277 |
278 | pred_labels, _ = scipy.stats.mode(neighbour_labels, axis=1)
279 | return pred_labels
280 |
281 | def full_predict_step(self, x, y):
282 | idxes, _ = self.index.knn_query(data=x, k=16)
283 | neighbour_labels = self.idx2label[idxes]
284 |
285 | for j in range(5):
286 | pred_labels, _ = scipy.stats.mode(neighbour_labels[:,:2**j], axis=1)
287 | self.preds[j].append(int(pred_labels))
288 | if 2**j == self.num_neighbours:
289 | out_pred = pred_labels
290 | self.labels.append(y)
291 |
292 | return out_pred
293 |
294 | def deploy_num_neighbours(self):
295 | best_acc, best_k = -1, 0
296 | labels = np.array(self.labels)
297 |
298 | for j in range(5):
299 | preds = np.array(self.preds[j])
300 | acc = ((preds == labels)*1.0).mean()
301 | if acc > best_acc:
302 | best_acc = acc
303 | best_k = 2**j
304 |
305 | self.num_neighbours = best_k
306 | print(self.num_neighbours)
307 | self.preds = [[], [], [], [], []]
308 | self.labels = []
309 |
310 |
311 | class NearestClassMeanCosine():
312 | def __init__(self, opt):
313 | with torch.no_grad():
314 | # Class means is class sums, divided by number of samples
315 | self.class_sums = np.zeros((opt.num_classes, opt.feature_dim))
316 | self.num_samples = np.zeros((opt.num_classes,1))
317 |
318 |
319 | def learn_step(self, x, y):
320 | if x.ndim == 1:
321 | x = x[np.newaxis, :]
322 |
323 | with torch.no_grad():
324 | # Update class mean and number of samples
325 | if self.num_samples.shape[0] <= y.shape[0]:
326 | for index in range(self.num_samples.shape[0]):
327 | if (y==index).sum() == 0:
328 | continue
329 | self.class_sums[index,:] += x[y==index].sum(axis=0)
330 | self.num_samples[index] += (y==index).sum()
331 | else:
332 | for index in range(y.shape[0]):
333 | self.class_sums[y[index],:] += x[index].squeeze()
334 | self.num_samples[y[index]] += 1
335 |
336 |
337 | def predict_step(self, x):
338 | if x.ndim == 1:
339 | x = x[np.newaxis, :]
340 |
341 | class_means = self.class_sums / (self.num_samples+1e-6)
342 | distances = cosine_similarity(x, class_means)
343 | return np.argmax(distances, axis=1)
344 |
345 |
346 | class NearestClassMeanL2():
347 | def __init__(self, opt):
348 | with torch.no_grad():
349 | # Class means is class sums, divided by number of samples
350 | self.class_sums = torch.zeros((opt.num_classes, opt.feature_dim))
351 | self.num_samples = torch.zeros((opt.num_classes,1))
352 |
353 | if opt.gpu:
354 | self.class_sums = self.class_sums.cuda()
355 | self.num_samples = self.num_samples.cuda()
356 | self.dist = self.dist.cuda()
357 |
358 |
359 | def learn_step(self, x, y):
360 | x = torch.from_numpy(x)
361 | if x.ndim == 1:
362 | x = x.unsqueeze(0)
363 |
364 | with torch.no_grad():
365 | # Update class mean and number of samples
366 | if self.num_samples.shape[0] <= y.shape[0]:
367 | for index in range(self.num_samples.shape[0]):
368 | if (y==index).sum() == 0:
369 | continue
370 | self.class_sums[index,:] += x[y==index].sum(dim=0).squeeze()
371 | self.num_samples[index] += (y==index).sum()
372 | else:
373 | for index in range(y.shape[0]):
374 | self.class_sums[y[index],:] += x[index].squeeze()
375 | self.num_samples[y[index]] += 1
376 |
377 |
378 | def predict_step(self, x):
379 | x = torch.from_numpy(x)
380 | if x.ndim == 1:
381 | x = x.unsqueeze(0)
382 |
383 | with torch.no_grad():
384 | class_means = (self.class_sums / (self.num_samples+1e-6)).unsqueeze(0)
385 | x = x.unsqueeze(0)
386 | distances = torch.cdist(x, class_means, p=2.0).squeeze(dim=0)
387 | distances = torch.where(distances!=0, distances, 1e5)
388 | return torch.argmin(distances, dim=1)
389 |
390 |
391 | class StreamingLinearDiscriminantAnalysis():
392 | def __init__(self, opt):
393 | with torch.no_grad():
394 | self.feature_dim = opt.feature_dim
395 | self.num_classes = opt.num_classes
396 | self.shrinkage_param = 1e-4
397 | self.streaming_update_sigma = True
398 |
399 | # setup weights for SLDA
400 | self.muK = torch.zeros((opt.num_classes, opt.feature_dim))
401 | self.cK = torch.zeros(opt.num_classes)
402 | self.Sigma = torch.ones((opt.feature_dim, opt.feature_dim))
403 | self.Lambda = torch.zeros_like(self.Sigma)
404 | self.num_updates = 0
405 | self.prev_num_updates = -1
406 |
407 | if opt.gpu:
408 | self.muK = self.muK.cuda()
409 | self.cK = self.cK.cuda()
410 | self.Sigma = self.Sigma.cuda()
411 | self.Lambda = self.Lambda.cuda()
412 |
413 |
414 | def learn_step(self, x, y):
415 | # make sure things are the right shape
416 | x = torch.from_numpy(x)
417 | if x.ndim == 1:
418 | x = x.unsqueeze(0)
419 |
420 | with torch.no_grad():
421 | # covariance updates
422 | if self.streaming_update_sigma:
423 | x_minus_mu = (x - self.muK[y])
424 | mult = torch.matmul(x_minus_mu.transpose(1, 0), x_minus_mu)
425 | delta = mult * self.num_updates / (self.num_updates + 1)
426 | self.Sigma = (self.num_updates * self.Sigma + delta) / (self.num_updates + 1)
427 |
428 | # update class means
429 | self.muK[y, :] += (x - self.muK[y, :]) / (self.cK[y] + 1).unsqueeze(1)
430 | self.cK[y] += 1
431 | self.num_updates += 1
432 |
433 |
434 | def predict_step(self, x):
435 | x = torch.from_numpy(x)
436 | if x.ndim == 1:
437 | x = x.unsqueeze(0)
438 |
439 | with torch.no_grad():
440 | # initialize parameters for testing
441 | num_samples = x.shape[0]
442 | scores = torch.empty((num_samples, self.num_classes))
443 | mb = num_samples
444 |
445 | # compute/load Lambda matrix
446 | if self.prev_num_updates != self.num_updates:
447 | # there have been updates to the model, compute Lambda
448 | # print('\nFirst predict since model update...computing Lambda matrix...')
449 | Lambda = torch.pinverse(
450 | (1 - self.shrinkage_param) * self.Sigma + self.shrinkage_param * torch.eye(self.feature_dim)).to(self.Lambda.device)
451 | self.Lambda = Lambda
452 | self.prev_num_updates = self.num_updates
453 | else:
454 | Lambda = self.Lambda
455 |
456 | # parameters for predictions
457 | M = self.muK.transpose(1, 0)
458 | W = torch.matmul(Lambda, M)
459 | c = 0.5 * torch.sum(M * W, dim=0)
460 |
461 | # loop in mini-batches over test samples
462 | for i in range(0, num_samples, mb):
463 | start = min(i, num_samples - mb)
464 | end = i + mb
465 | X = x[start:end]
466 | scores[start:end, :] = torch.matmul(X, W) - c
467 |
468 | # return predictions or probabilities
469 | return torch.argmax(scores, dim=1)
470 |
471 |
472 | def fit_base(self, x, y):
473 | print('\nFitting Base...')
474 | x = torch.from_numpy(x)
475 | # update class means
476 | for k in torch.unique(y):
477 | self.muK[k] = x[y == k].mean(0)
478 | self.cK[k] = x[y == k].shape[0]
479 | self.num_updates = x.shape[0]
480 |
481 | print('\nEstimating initial covariance matrix...')
482 | from sklearn.covariance import OAS
483 | cov_estimator = OAS(assume_centered=True)
484 | cov_estimator.fit((x - self.muK[y]).cpu().numpy())
485 | self.Sigma = torch.from_numpy(cov_estimator.covariance_).float().to(self.Sigma.device)
--------------------------------------------------------------------------------
/src/knn/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parse_args():
4 | parser = argparse.ArgumentParser(description='main.py')
5 | parser.add_argument('--feature_path', type=str, default='../../../improved_knn/saved_features/', help='Full path to the directory where all logs are stored')
6 | parser.add_argument('--log_dir', type=str, default='../../logs/', help='Full path to the directory where all logs are stored')
7 | parser.add_argument('--model', type=str, default='xcit_dino', choices=['xcit_dino','r50','rsnet50_i1b'], help='Model used for classification')
8 | parser.add_argument('--online_classifier', type=str, default='ApproxKNearestNeighbours', help='Name of online classifier', choices=['OnlineLogisticClassification_VowpalWabbit',
9 | 'OnlineSVM_VowpalWabbit', 'OnlineSVM_Scikit', 'OnlineLogisticClassification_Scikit',
10 | 'HuberLossClassifier_Scikit', 'ContextualMemoryTree', 'KNearestNeighbours', 'ApproxKNearestNeighbours',
11 | 'NearestClassMeanCosine', 'NearestClassMeanL2', 'StreamingLinearDiscriminantAnalysis'])
12 | parser.add_argument('--dataset', type=str, default='cglm', help='Name of dataset', choices=['clear10', 'clear100', 'cglm', 'cloc'])
13 | parser.add_argument('--search_metric', type=str, default='cosine', choices=['cosine', 'l2'], help='Types of search')
14 | parser.add_argument('--HNSW_ef', type=int, default=200, help='Types of search')
15 | parser.add_argument('--HNSW_M', type=int, default=25, help='Types of search')
16 | parser.add_argument('--lr', type=float, default=2.0, help='Types of search')
17 | parser.add_argument('--wd', type=float, default=1e-4, help='Types of search')
18 | parser.add_argument('--normalize_input', action="store_true", help='Normalize the input to the search')
19 | parser.add_argument('--gpu', action="store_true", help='Peform online learning on GPU')
20 | parser.add_argument('--online_exp_name', type=str, default='test', help='Full path to the order file')
21 | parser.add_argument('--seed', type=int, default=0, help='Seed for reproducibility')
22 | parser.add_argument('--print_freq', type=int, default=5000, help='Printing utils')
23 | parser.add_argument('--num_neighbours', type=int, default=2, help='k for kNN')
24 | parser.add_argument('--update_k', type=int, default=1, help='Update k for kNN after these many samples')
25 | parser.add_argument('--update_size', type=int, default=1, help='Consider these many samples for accuracy calculation for k update')
26 | parser.add_argument('--delay', type=int, default=1, help='Delay for ACM')
27 | opt = parser.parse_args()
28 | return opt
29 |
30 |
--------------------------------------------------------------------------------
/src/run_blind.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import numpy as np
3 | from scipy import stats
4 | from os.path import exists
5 | import h5py, os
6 |
7 | def load_filelist(filepath):
8 | imglist, y = [], []
9 | with open(filepath,'r') as f:
10 | for line in f:
11 | label, _, _ = line.strip().split('\t')
12 | y.append(int(label))
13 | return imglist, y
14 |
15 |
16 | def get_preds(k):
17 | pred = np.ones_like(labels)[1:]
18 |
19 | for i in range(1, len(labels)):
20 | if i <= k:
21 | pred[i-1] = labels[0]
22 | else:
23 | pred[i-1] = stats.mode(labels[i-(k):(i)])[0]
24 | np.save(logdir+'/'+dataset+'_blind_preds_'+str(k)+'.npy', pred)
25 | return pred
26 |
27 |
28 | if __name__ == '__main__':
29 | global logdir, labels
30 | dataset='CLOC' # Choose one from: ['CLEAR10', 'CLEAR100', 'CGLM', 'CLOC']
31 | num_processes = 16
32 | path_to_cldatasets = '/media/bayesiankitten/Hanson/CLDatasets/data/' # Please set directory for labels in the CLDatasets folder
33 | logdir = '/media/bayesiankitten/Alexander/ACM/blind_logs/' # Please set directory for logs
34 | os.makedirs(logdir, exist_ok=True)
35 |
36 | # Load labels file
37 | # Note: When actually picking the k, use pretrain_labels.hdf5 file. This is for analysis plots, shown in the paper (both have same results, but different purpose).
38 | with h5py.File(f'{path_to_cldatasets}/{dataset}/order_files/train_labels.hdf5', 'r') as f:
39 | labels = np.array(f['store_list'])[1:]
40 |
41 | labels = np.array(labels, dtype=np.uint16)
42 | pred = np.ones_like(labels)[1:]
43 |
44 | # Get dataset mode, get performance to check dataset imbalance
45 | if not exists(logdir+'/'+dataset+'_mode.npy'):
46 | modelabel = stats.mode(labels)[0]
47 | pred = pred*modelabel
48 | np.save(logdir+'/'+dataset+'_mode.npy', pred)
49 |
50 | pred = np.load(logdir+'/'+dataset+'_mode.npy')
51 | gt = labels[1:]
52 | acc = np.equal(gt,pred)*1.0
53 | idx = np.arange(acc.shape[0])+1
54 | cumacc = np.cumsum(acc)/idx
55 | print(f'IID Best Classifier: {cumacc.mean()}')
56 |
57 | # Get blind classifier performance
58 | for k in [1, 2, 3, 5, 7, 10, 20, 25, 35, 50, 75, 100, 150, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 15000, 25000, 50000, 75000]:
59 | if not exists(logdir+'/'+dataset+'_blind_preds_'+str(k)+'.npy'):
60 | print(f'Processing mode of past: {k}')
61 | p = multiprocessing.Pool(num_processes)
62 | result = []
63 | result.append(p.apply_async(get_preds, [k]))
64 | for r in result:
65 | r.wait()
66 |
67 | # Show results, see the degree of label correlation
68 | for k in [1, 2, 3, 5, 7, 10, 20, 25, 35, 50, 75, 100, 150, 250, 500, 750, 1000, 2500, 5000, 7500, 10000, 15000, 25000, 50000, 75000]:
69 | pred = np.load(logdir+'/'+dataset+'_blind_preds_'+str(k)+'.npy')
70 | gt = labels[1:]
71 | acc = np.equal(gt,pred)*1.0
72 | idx = np.arange(acc.shape[0])+1
73 | cumacc = np.cumsum(acc)/idx
74 | print(f'Blind Classifier @ {k}: {cumacc.mean()}')
75 |
76 | print('Extracted all blind classifier results!')
77 |
--------------------------------------------------------------------------------