├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.png ├── data ├── __init__.py ├── features │ └── .empty ├── group_workers.json ├── groups.txt ├── imagenet.py ├── structure_released.xml ├── train_labeled.txt └── utils.py ├── main.py └── online_label ├── __init__.py ├── aggregator ├── __init__.py └── aggregator.py ├── annotation_holder.py ├── config ├── config.yaml ├── experiment │ ├── imagenet_100_classes.yaml │ ├── imagenet_animal.yaml │ ├── imagenet_split_0.yaml │ ├── imagenet_split_1.yaml │ ├── imagenet_split_2.yaml │ ├── imagenet_split_3.yaml │ ├── imagenet_split_4.yaml │ └── imagenet_split_5.yaml ├── hyperparams.yaml ├── learner_method │ ├── ds_model.yaml │ ├── efficient_annotation.yaml │ ├── improved_lean.yaml │ └── lean.yaml └── simulation │ ├── amt_structured_noise.yaml │ └── amt_uniform_noise.yaml ├── learner ├── __init__.py ├── nn_learner.py └── utils.py ├── logger ├── __init__.py ├── batch_logger.py └── utils.py ├── online_loop.py ├── optimizer ├── __init__.py ├── em.py └── utils.py ├── sampler ├── __init__.py ├── random_sampler.py ├── sampler.py └── task_assignment_sampler.py └── worker.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/features/byol_r50-e3b0c442.pth_feat1.npy 2 | 3 | .ipynb_checkpoints 4 | outputs/ 5 | 6 | .DS_Store 7 | __pycache__ 8 | *.pyc 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sanja Fidler's Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets 2 | 3 | This is the official implementation of *"Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets"* (CVPR 2021). 4 | For more details, please refer to: 5 | ----------------------- ------------------------------------ 6 | **Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets** 7 | 8 | [Yuan-Hong Liao](https://andrewliao11.github.io), [Amlan Kar](https://amlankar.github.io), [Sanja Fidler](http://www.cs.utoronto.ca/~fidler/) 9 | 10 | University of Toronto 11 | 12 | [[Paper]](https://arxiv.org/abs/2104.12690) [[Video]]() [[Project]](https://fidler-lab.github.io/efficient-annotation-cookbook/) 13 | 14 | **CVPR2021 Oral** 15 | 16 | ![](assets/teaser.png) 17 | 18 | Data is the engine of modern computer vision, which necessitates collecting large-scale datasets. This is expensive, and guaranteeing the quality of the labels is a major challenge. In this paper, we investigate efficient annotation strategies for collecting multi-class classification labels fora large collection of images. While methods that exploit learnt models for labeling exist, a surprisingly prevalent approach is to query humans for a fixed number of labels per datum and aggregate them, which is expensive. Building on prior work on online joint probabilistic modeling of human annotations and machine generated beliefs, we propose modifications and best practices aimed at minimizing human labeling effort. Specifically, we make use ofadvances in self-supervised learning, view annotation as a semi-supervised learning problem, identify and mitigate pitfalls and ablate several key design choices to propose effective guidelines for labeling. Our analysis is done in a more realistic simulation that involves querying human labelers, which uncovers issues with evaluation using existing worker simulation methods. Simulated experiments on a 125k image subset of the ImageNet dataset with 100 classes showthat it can be annotated to 80% top-1 accuracy with 0.35 annotations per image on average, a 2.7x and 6.7x improvement over prior work and manual annotation, respectively. 19 | 20 | ----------------------- ------------------------------------ 21 | 22 | 23 | ## Code usage 24 | 25 | - Downdload the extracted [BYOL](https://papers.nips.cc/paper/2020/file/f3ada80d5c4ee70142b17b8192b2958e-Paper.pdf) features and change root directory accordingly 26 | ``` 27 | wget -P data/features/ http://www.cs.toronto.edu/~andrew/research/cvpr2021-good_practices/data/byol_r50-e3b0c442.pth_feat1.npy 28 | ``` 29 | 30 | Replace `REPO_DIR` ([here](https://github.com/fidler-lab/efficient-annotation-cookbook/blob/master/data/__init__.py)) with the absolute path to the repository. 31 | 32 | 33 | - Run online labeling with simulated workers 34 | - `` can be `imagenet_split_0~5`, `imagenet_animal`, `imagenet_100_classes` 35 | - `` can be `ds_model`, `lean`, `improved_lean`, `efficient_annotation` 36 | - `` can be `amt_structured_noise`, `amt_uniform_noise` 37 | ```python 38 | python main.py experiment= learner_method= simulation 39 | ``` 40 | To change other configurations, go check the [config.yaml](https://github.com/fidler-lab/efficient-annotation-cookbook/blob/master/online_label/config/config.yaml) here. 41 | 42 | ## Code Structure 43 | There are several components in our system: `Sampler`, `AnnotationHolder`, `Learner`, `Optimizer` and `Aggregator`. 44 | 45 | - `Sampler`: We implement `RandomSampler` and `GreedyTaskAssignmentSampler`. For `GreedyTaskAssignmentSampler`, you need to specify an additional flag `max_annotation_per_worker` 46 | 47 | For example, 48 | ```python 49 | python main.py experiment=imagenet_animal learner_method=efficient_annotation simulation=amt_structured_noise sampler.algo=greedy_task_assignment sampler.max_annotation_per_worker=2000 50 | ``` 51 | 52 | - `AnnotationHolder`: It holds all information of each example including worker annotation, ground truth and current risk estimation. For simulated worker, you can call `annotation_holder.collect_annotation` to query annotations. You can also sample the annotation outside and add them by calling `annotation_holder.add_annotation` 53 | 54 | - `Learner`: We implement `DummyLearner` and `LinearNNLearner`. You can use your favorite architecture by overwriting `NNLearner.init_learner` 55 | 56 | - `Optimizer`: We implement `EMOptimizer`. By calling `optimizer.step`, the optimizer perform EM for a fixed number of times unless it's converged. If `DummyLearner` is not used, the optimizer is expected to call `optimizer.fit_machine_learner` to train the machine learner and perform prediction over all data examples. 57 | 58 | - `Aggregator`: We implement `MjAggregator` and `BayesAggregator`. `MjAggregator` performs majority vote to infer the final label. `BayesAggregator` treat the ground truth and worker skill as hidden variables and infer it based on the observation (worker annotation). 59 | 60 | 61 | ## Citation 62 | If you use this code, please cite: 63 | ``` 64 | @misc{liao2021good, 65 | title={Towards Good Practices for Efficiently Annotating Large-Scale Image Classification Datasets}, 66 | author={Yuan-Hong Liao and Amlan Kar and Sanja Fidler}, 67 | year={2021}, 68 | eprint={2104.12690}, 69 | archivePrefix={arXiv}, 70 | primaryClass={cs.CV} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fidler-lab/efficient-annotation-cookbook/8d02da89c8049c549748761e1762a04f40a64da0/assets/teaser.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | REPO_DIR = '/path/to/the/repository' 2 | 3 | imagenet100 = 'n02869837 n01749939 n02488291 n02107142 n13037406 n02091831 n04517823 n04589890 n03062245 n01773797 n01735189 n07831146 n07753275 n03085013 n04485082 n02105505 n01983481 n02788148 n03530642 n04435653 n02086910 n02859443 n13040303 n03594734 n02085620 n02099849 n01558993 n04493381 n02109047 n04111531 n02877765 n04429376 n02009229 n01978455 n02106550 n01820546 n01692333 n07714571 n02974003 n02114855 n03785016 n03764736 n03775546 n02087046 n07836838 n04099969 n04592741 n03891251 n02701002 n03379051 n02259212 n07715103 n03947888 n04026417 n02326432 n03637318 n01980166 n02113799 n02086240 n03903868 n02483362 n04127249 n02089973 n03017168 n02093428 n02804414 n02396427 n04418357 n02172182 n01729322 n02113978 n03787032 n02089867 n02119022 n03777754 n04238763 n02231487 n03032252 n02138441 n02104029 n03837869 n03494278 n04136333 n03794056 n03492542 n02018207 n04067472 n03930630 n03584829 n02123045 n04229816 n02100583 n03642806 n04336792 n03259280 n02116738 n02108089 n03424325 n01855672 n02090622' 4 | imagenet100 = imagenet100.split(' ') 5 | -------------------------------------------------------------------------------- /data/features/.empty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fidler-lab/efficient-annotation-cookbook/8d02da89c8049c549748761e1762a04f40a64da0/data/features/.empty -------------------------------------------------------------------------------- /data/groups.txt: -------------------------------------------------------------------------------- 1 | n02105505 2 | n02113978 3 | n02100583 4 | n02087046 5 | n02086910 6 | n02108089 7 | n02106550 8 | n02089973 9 | n02085620 10 | n02086240 11 | n02099849 12 | n02091831 13 | n02089867 14 | n02090622 15 | n02113799 16 | n02104029 17 | n02093428 18 | n02107142 19 | n02109047 20 | 21 | n02138441 22 | n02326432 23 | n01558993 24 | n02009229 25 | n02123045 26 | n01855672 27 | n02488291 28 | n02396427 29 | n01820546 30 | n01735189 31 | n02116738 32 | n02483362 33 | n02114855 34 | n01729322 35 | n02018207 36 | n01749939 37 | n01692333 38 | n02119022 39 | 40 | n13037406 41 | n07715103 42 | n01980166 43 | n02172182 44 | n07714571 45 | n01983481 46 | n13040303 47 | n02259212 48 | n01978455 49 | n01773797 50 | n07831146 51 | n02231487 52 | n07836838 53 | 54 | n03794056 55 | n04336792 56 | n04127249 57 | n02701002 58 | n04238763 59 | n04429376 60 | n02974003 61 | n02804414 62 | n03891251 63 | n03017168 64 | n04067472 65 | n03642806 66 | n03930630 67 | n04099969 68 | n03785016 69 | n03947888 70 | n03494278 71 | n04592741 72 | n03492542 73 | 74 | n04517823 75 | n03085013 76 | n02877765 77 | n03764736 78 | n03584829 79 | n03787032 80 | n02869837 81 | n04136333 82 | n03777754 83 | n04493381 84 | n04026417 85 | n03062245 86 | n04111531 87 | n03594734 88 | n03259280 89 | n03379051 90 | 91 | n04589890 92 | n07753275 93 | n04485082 94 | n02788148 95 | n03530642 96 | n04435653 97 | n02859443 98 | n03775546 99 | n03637318 100 | n03903868 101 | n04418357 102 | n03032252 103 | n03837869 104 | n04229816 105 | n03424325 106 | 107 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from copy import deepcopy 5 | from collections import Counter 6 | from collections import OrderedDict # For python >= 3.7, dict() is fine 7 | 8 | from . import REPO_DIR, imagenet100 9 | from .utils import ImageNetStruc 10 | 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ImageNetData(object): 16 | 17 | imagenet_struc = ImageNetStruc() 18 | label_path = os.path.join(REPO_DIR, 'data/train_labeled.txt') 19 | feat_dir = os.path.join(REPO_DIR, 'data/features') 20 | 21 | def __init__(self, config): 22 | 23 | self.npr = np.random.RandomState(config.seed) 24 | self.wnids = config.wnids.split(' ') 25 | assert len(self.wnids) == config.n_classes 26 | 27 | self.config = config 28 | self.imagenet_struc.register_nodes(self.wnids) 29 | 30 | self.include_distraction = config.n_data_distraction_per_class > 0 31 | a_data, p_image_path, image_path, log_image_path, d_image_path, features_dict = self.load_data() 32 | 33 | ys = [a_data[p][1] for p in image_path if not self.include_distraction or a_data[p][1] != self.config.n_classes-1] 34 | counter = Counter(ys) 35 | n_max = counter.most_common()[0][1] 36 | n_min = counter.most_common()[-1][1] 37 | rho = n_max / n_min 38 | logger.info(f'Imbalance level: {rho:.2f}') 39 | 40 | self.n_p = len(p_image_path) 41 | self.n = len(image_path) 42 | 43 | self.a_data = a_data 44 | self.p_image_path = p_image_path 45 | self.image_path = image_path 46 | self.log_image_path = log_image_path 47 | self.d_image_path = d_image_path 48 | self.features_dict = features_dict 49 | logger.info(f'Number of data to label: {self.n}, number of prototype data: {self.n_p}') 50 | 51 | def get_y(self, path): 52 | y = [] 53 | for p in path: 54 | y.append(self.a_data[p][1]) 55 | return y 56 | 57 | def get_features_dict(self): 58 | return deepcopy(self.features_dict) 59 | 60 | def __load_a_data(self): 61 | 62 | a_data = OrderedDict() 63 | with open(self.label_path) as f: 64 | content = f.readlines() 65 | path = [i.split(' ')[0] for i in content] 66 | for p in path: 67 | a_data[p] = [p.split('/')[0], None] 68 | 69 | return a_data 70 | 71 | def add_distracting_images(self, a_data, p_image_path, image_path, log_image_path): 72 | 73 | logger.info('Include distracting images') 74 | 75 | other_class = [p for p, (wnid, _) in a_data.items() if wnid not in self.wnids] 76 | valid = [p for p in other_class if p not in p_image_path and p not in image_path] 77 | d_image_path = self.npr.choice(valid, 78 | min(len(valid), self.config.n_classes * self.config.n_data_distraction_per_class), 79 | replace=False).tolist() 80 | image_path = image_path + d_image_path 81 | log_image_path = log_image_path + d_image_path 82 | 83 | valid = [p for p in other_class if p not in p_image_path and \ 84 | p not in image_path and \ 85 | p not in d_image_path] 86 | d_p_image_path = self.npr.choice(valid, 87 | min(len(valid), self.config.n_data_distraction_per_class), 88 | replace=False).tolist() 89 | p_image_path = p_image_path + d_p_image_path 90 | 91 | logger.info(f'Number of distraction: {len(d_image_path)}') 92 | return p_image_path, image_path, log_image_path, d_image_path 93 | 94 | def load_prototype(self, a_data, selected=[]): 95 | 96 | # Random 97 | def __choose(id, n): 98 | valid = [p for p, (wnid, _) in a_data.items() if wnid == id and \ 99 | p not in selected] 100 | return self.npr.choice(valid, min(len(valid), n), replace=False).tolist() 101 | 102 | p_image_path = [] 103 | for id in self.wnids: 104 | p_image_path.extend(__choose(id, self.config.n_prototype_per_class)) 105 | 106 | return p_image_path 107 | 108 | def which_class(self, wnid): 109 | if wnid in self.wnids: 110 | return self.wnids.index(wnid) 111 | elif self.include_distraction: 112 | return self.config.n_classes - 1 113 | else: 114 | raise ValueError 115 | 116 | def load_random_image_path(self, a_data): 117 | p_image_path = self.load_prototype(a_data) 118 | 119 | def __choose(id, n): 120 | valid = [p for p, (wnid, _) in a_data.items() if wnid == id and \ 121 | p not in p_image_path] 122 | return self.npr.choice(valid, min(len(valid), n), replace=False).tolist() 123 | 124 | 125 | image_path = [] 126 | for id in self.wnids: 127 | image_path.extend(__choose(id, self.config.n_data_per_class)) 128 | 129 | log_image_path = image_path.copy() 130 | return p_image_path, image_path, log_image_path 131 | 132 | def load_data(self): 133 | ''' 134 | return 135 | a_data: (dict) key=image_path, value=[wnid, y] 136 | p_image_path: (list) paths of prototype images 137 | image_path: (list) paths of all accessible images (use to compute features) 138 | log_image_path: (list) paths of images of interests 139 | d_image_path: (list) paths of distracting images 140 | features_dict: (dict) {prototype_features: ?, features: ?} 141 | ''' 142 | config = self.config 143 | 144 | # Load all data 145 | a_data = self.__load_a_data() 146 | a_n_data = len(a_data) 147 | a_image_path = list(a_data.keys()) 148 | 149 | 150 | # Sample Prototype/Accessible/Log data 151 | p_image_path, image_path, log_image_path = self.load_random_image_path(a_data) 152 | 153 | 154 | # Include distraction data if needed 155 | if self.include_distraction: 156 | p_image_path, image_path, log_image_path, d_image_path = \ 157 | self.add_distracting_images(a_data, 158 | p_image_path, 159 | image_path, 160 | log_image_path) 161 | 162 | config.n_classes += 1 # Append other_class at the end 163 | else: 164 | d_image_path = [] 165 | 166 | 167 | p = os.path.join(self.feat_dir, config.learner.features) 168 | assert os.path.exists(p), logger.warning(f'Features {p} does not exist') 169 | logger.info(f'Load features from {p}') 170 | 171 | a_features = np.load(p) 172 | p_features = np.array([a_features[a_image_path.index(p)] for p in p_image_path]) 173 | features = np.array([a_features[a_image_path.index(p)] for p in image_path]) 174 | features_dict = { 175 | 'features': features, 176 | 'prototype_features': p_features 177 | } 178 | 179 | 180 | for k, v in a_data.items(): 181 | wnid, _ = v 182 | if self.include_distraction: 183 | y = self.wnids.index(wnid) if wnid in self.wnids else config.n_classes - 1 184 | else: 185 | y = self.wnids.index(wnid) if wnid in self.wnids else None 186 | 187 | a_data[k] = [wnid, y] 188 | 189 | return a_data, p_image_path, image_path, log_image_path, d_image_path, features_dict 190 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xmltodict 3 | from . import REPO_DIR 4 | 5 | 6 | class Node(): 7 | def __init__(self, node, depth=0, parent=None): 8 | self._node = node 9 | self.depth = depth 10 | self.wnid = node['@wnid'] 11 | self.gloss = node['@gloss'] 12 | self.name = node['@words'] 13 | self.parent = parent 14 | self._set_children_node() 15 | 16 | def _set_children_node(self): 17 | self.children = [] 18 | if 'synset' in self._node: 19 | children = self._node['synset'] 20 | if not isinstance(children, list): 21 | children = [children] 22 | for child in children: 23 | child = Node(child, depth=self.depth+1, parent=self) 24 | self.children.append(child) 25 | 26 | self.num_children = len(self.children) 27 | 28 | def is_leaf(self): 29 | return self.num_children == 0 30 | 31 | def __repr__(self): 32 | return self.name 33 | 34 | 35 | class Tree(): 36 | def __init__(self, root): 37 | self.root = root 38 | self.leaf_nodes = self.find_nodes(lambda node: node.is_leaf()) 39 | self.max_depth = max([n.depth for n in self.leaf_nodes]) 40 | 41 | def find_nodes(self, filter): 42 | nodes = [] 43 | to_expand = [self.root] 44 | while len(to_expand) > 0: 45 | node = to_expand.pop() 46 | if filter(node): 47 | nodes.append(node) 48 | to_expand.extend(node.children) 49 | return nodes 50 | 51 | 52 | class ImageNetStruc(): 53 | structure_path = os.path.join(REPO_DIR, 'data/structure_released.xml') 54 | 55 | def __init__(self): 56 | 57 | # >>> ImageNet Structure 58 | with open(self.structure_path) as f: 59 | xml = f.read() 60 | struct = xmltodict.parse(xml) 61 | root = struct['ImageNetStructure']['synset'] 62 | root_node = Node(root) 63 | tree = Tree(root_node) 64 | self.tree = tree 65 | 66 | # Info 67 | wnid_to_name = {} 68 | to_expand = [tree.root] 69 | while len(to_expand) > 0: 70 | node = to_expand.pop(0) 71 | wnid_to_name.update({node.wnid: node.name}) 72 | to_expand.extend(node.children) 73 | 74 | self.wnid_to_name = wnid_to_name 75 | 76 | def register_nodes(self, class_wnid): 77 | nodes_of_interest = [] 78 | for wnid in class_wnid: 79 | nodes = self.tree.find_nodes(lambda node: node.wnid == wnid) 80 | # Each wnid might correspond to multiple nodes in the tree 81 | nodes_of_interest.append(nodes[0]) 82 | self.nodes_of_interest = nodes_of_interest 83 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import hydra 4 | import json 5 | import numpy as np 6 | from shutil import copyfile 7 | from collections import OrderedDict 8 | 9 | from online_label.online_loop import run_online_loop 10 | from online_label.worker import get_worker_class 11 | from online_label.sampler import get_sampler_class 12 | from online_label.aggregator import get_aggregator_class 13 | from online_label.learner import get_learner_class 14 | from online_label.optimizer import get_optimizer_class 15 | from online_label.annotation_holder import AnnotationHolder 16 | from online_label.logger import BatchLogger 17 | from data.imagenet import ImageNetData 18 | 19 | 20 | import logging 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def setup(config): 25 | '''Basic setup 26 | Including loading best hyper-params and saving the experiment configurations. 27 | ''' 28 | 29 | ## Save experiment configurations 30 | logger.info(config.pretty()) 31 | logger.info(f'Current working directory: {os.getcwd()}') 32 | 33 | with open('config.txt', 'w') as f: 34 | f.write(config.pretty()) 35 | 36 | 37 | def init_workers(config, wnids): 38 | 39 | worker_class = get_worker_class(config, wnids) 40 | 41 | workers = OrderedDict() 42 | for i in range(config.worker.n): 43 | w = worker_class(config=config, seed=config.seed+i, known=config.worker.known) 44 | workers.update({w.id: w}) 45 | 46 | mean_reliability = sum([np.diag(w.m).mean() for w in workers.values()]) / config.worker.n 47 | logger.info(f'Average worker reliability: {mean_reliability}') 48 | return workers 49 | 50 | 51 | def save_state(config, workers, annotation_holder, optimizer, learner, step, p=None): 52 | workers_str = json.dumps([w.save_state() for w in workers.values()]) 53 | annotation_holder_str = annotation_holder.save_state() 54 | optimizer_str = optimizer.save_state() 55 | learner_str = learner.save_state() 56 | 57 | state = dict(workers_str=workers_str, 58 | annotation_holder_str=annotation_holder_str, 59 | optimizer_str=optimizer_str, 60 | learner_str=learner_str, 61 | step=step) 62 | 63 | if p is None: 64 | if os.path.exists('latest_state.json'): 65 | copyfile('latest_state.json', 'backup_state.json') 66 | 67 | json.dump(state, open('latest_state.json', 'w')) 68 | else: 69 | json.dump(state, open(p, 'w')) 70 | 71 | 72 | def load_state(workers, annotation_holder, optimizer, learner, sampler, filename): 73 | state = json.load(open(filename)) 74 | workers_str = json.loads(state['workers_str']) 75 | 76 | for w, s in zip(workers.values(), workers_str): 77 | w.load_state(s) 78 | 79 | _workers = OrderedDict() 80 | for w in workers.values(): 81 | _workers.update({w.id: w}) 82 | workers = _workers 83 | 84 | annotation_holder_str = state['annotation_holder_str'] 85 | annotation_holder.load_state(annotation_holder_str, workers) 86 | 87 | optimizer_str = state['optimizer_str'] 88 | optimizer.load_state(optimizer_str, workers) 89 | 90 | learner_str = state['learner_str'] 91 | learner.load_state(learner_str) 92 | 93 | sampler.load_state(annotation_holder, workers) 94 | return state['step'] 95 | 96 | 97 | # support pre-emption 98 | def load_from_latest_state(config, workers, annotation_holder, optimizer, learner, sampler): 99 | 100 | resume = os.path.exists('latest_state.json') 101 | if resume: 102 | try: 103 | start_step = load_state(workers, annotation_holder, optimizer, learner, sampler, 'latest_state.json') 104 | logger.info('Log state from latest_state.json') 105 | except json.decoder.JSONDecodeError: 106 | start_step = load_state(workers, annotation_holder, optimizer, learner, sampler, 'backup_state.json') 107 | logger.info('Log state from backup_state.json') 108 | 109 | logger.info(f'From step {start_step}') 110 | else: 111 | save_state(config, workers, annotation_holder, optimizer, learner, step=0) 112 | start_step = 0 113 | 114 | return start_step 115 | 116 | 117 | @hydra.main(config_path='online_label/config', config_name='config') 118 | def main(config): 119 | 120 | setup(config) 121 | 122 | # >>> Data and Simulated Workers 123 | imagenet_data = ImageNetData(config) 124 | workers = init_workers(config, imagenet_data.wnids) 125 | annotation_holder = AnnotationHolder(config, 126 | workers, 127 | imagenet_data.image_path, 128 | imagenet_data.imagenet_struc) 129 | 130 | # >>> Initialize Components in Online Labeling 131 | aggregator = get_aggregator_class(config)(config, imagenet_data.n) 132 | learner = get_learner_class(config)(config) 133 | optimizer = get_optimizer_class(config)(config, imagenet_data, workers) 134 | sampler = get_sampler_class(config)(config, annotation_holder, workers, optimizer=optimizer) 135 | 136 | 137 | # >>> Load the state 138 | start_step = load_from_latest_state(config, workers, annotation_holder, optimizer, learner, sampler) 139 | 140 | batch_logger = BatchLogger(config, 141 | imagenet_data.get_y(imagenet_data.log_image_path), 142 | imagenet_data.n, 143 | imagenet_data.wnids, 144 | imagenet_data.p_image_path, 145 | imagenet_data.image_path, 146 | imagenet_data.log_image_path, 147 | imagenet_data.d_image_path) 148 | 149 | 150 | # >>> Online Labeling 151 | run_online_loop(config, 152 | imagenet_data, 153 | annotation_holder, 154 | sampler, 155 | aggregator, 156 | learner, 157 | optimizer, 158 | batch_logger, 159 | save_state, 160 | start_step) 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /online_label/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fidler-lab/efficient-annotation-cookbook/8d02da89c8049c549748761e1762a04f40a64da0/online_label/__init__.py -------------------------------------------------------------------------------- /online_label/aggregator/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregator import MjAggregator, BayesAggregator 2 | 3 | 4 | def get_aggregator_class(config): 5 | if config.aggregator.algo.lower() == 'mj': 6 | return MjAggregator 7 | elif config.aggregator.algo.lower() == 'bayes': 8 | return BayesAggregator 9 | else: 10 | raise ValueError 11 | -------------------------------------------------------------------------------- /online_label/aggregator/aggregator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | epsilon = 1e-8 4 | 5 | 6 | class Aggregator(object): 7 | def __init__(self, config, n_data): 8 | self.config = config 9 | self.algo = config.aggregator.algo 10 | self.n_classes = config.n_classes 11 | self.n_data = n_data 12 | 13 | def compute_risk(self, belief): 14 | assert np.max(belief) <= 1. and np.min(belief) >= 0. 15 | y_hat = np.argmax(belief, 1) 16 | confidence = np.take_along_axis(belief, y_hat.reshape(-1, 1), axis=1) 17 | confidence = confidence.reshape(-1) 18 | risk = 1 - confidence 19 | return risk 20 | 21 | def aggregate(self, annotation_holder, **kwargs): 22 | raise NotImplementedError 23 | 24 | def empty_belief(self): 25 | belief = np.zeros((self.n_data, self.n_classes)) 26 | return belief 27 | 28 | 29 | class MjAggregator(Aggregator): 30 | def __init__(self, config, n_data, **kwargs): 31 | Aggregator.__init__(self, config, n_data) 32 | 33 | def aggregate(self, annotation_holder, **kwargs): 34 | belief = self.empty_belief() 35 | 36 | image_path = list(annotation_holder.annotation.keys()) 37 | for i, p in enumerate(image_path): 38 | for anta in annotation_holder.annotation[p]: 39 | _, z, _, _ = anta 40 | belief[i][z] += 1 41 | 42 | votes_per_data = np.sum(belief, 1, keepdims=True) 43 | belief = np.divide(belief, votes_per_data, 44 | out=np.zeros_like(belief), 45 | where=votes_per_data!=0) 46 | return belief 47 | 48 | 49 | class BayesAggregator(Aggregator): 50 | def __init__(self, config, n_data, **kwargs): 51 | Aggregator.__init__(self, config, n_data) 52 | self.uniform_prior = np.ones((1, self.n_classes)) / self.n_classes 53 | 54 | def __normalize(self, log_p): 55 | if len(log_p.shape) == 1: 56 | log_p = log_p.reshape(1, -1) 57 | 58 | b = np.max(log_p, 1, keepdims=True) 59 | log_sum_p = b + np.log(np.exp(log_p - b).sum(1, keepdims=True)) 60 | 61 | p = np.exp(log_p - log_sum_p) 62 | p = np.clip(p, 0., 1.) 63 | return p 64 | 65 | def bayes(self, annotation_holder, belief, prior): 66 | 67 | image_path = list(annotation_holder.annotation.keys()) 68 | for i, p in enumerate(image_path): 69 | 70 | log_prior = np.log(prior[i] + epsilon) if len(prior) == len(image_path) else np.log(prior[0] + epsilon) 71 | 72 | log_likelihood = np.zeros_like(log_prior) 73 | for anta in annotation_holder.annotation[p]: 74 | y, _, _, p_z_given_y = anta 75 | log_likelihood += np.log(p_z_given_y + epsilon) 76 | 77 | log_p = log_prior + log_likelihood 78 | 79 | belief[i] = log_p 80 | 81 | return belief 82 | 83 | def aggregate(self, annotation_holder, prior=None, learner_prob=None, prototype_targets=None, **kwargs): 84 | 85 | belief = self.empty_belief() 86 | if prior is None: 87 | prior = self.uniform_prior 88 | unnormalized_worker_belief = self.bayes(annotation_holder, belief, prior) 89 | worker_belief = self.__normalize(unnormalized_worker_belief) 90 | 91 | if learner_prob is not None: 92 | n_data = annotation_holder.n_data 93 | learner_log_prob = np.log(learner_prob + epsilon) 94 | unnormalized_belief = unnormalized_worker_belief + learner_log_prob[:n_data, :] 95 | return self.__normalize(unnormalized_belief) 96 | else: 97 | return self.__normalize(unnormalized_worker_belief) 98 | -------------------------------------------------------------------------------- /online_label/annotation_holder.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from collections import OrderedDict 4 | from copy import deepcopy 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class NumpyEncoder(json.JSONEncoder): 11 | def default(self, obj): 12 | if isinstance(obj, np.ndarray): 13 | return obj.tolist() 14 | return json.JSONEncoder.default(self, obj) 15 | 16 | 17 | class AnnotationHolder(): 18 | def __init__(self, config, workers, image_path, imagenet_struc): 19 | self.config = config 20 | self.npr = np.random.RandomState(config.seed) 21 | self.imagenet_struc = imagenet_struc 22 | self.workers = workers 23 | 24 | self.n_data = len(image_path) 25 | self.annotation = OrderedDict() 26 | for p in image_path: 27 | self.annotation.update({p: []}) 28 | self.n_annotation = 0 29 | 30 | def load_state(self, checkpoint_annotation_holder, workers): 31 | 32 | self.workers = workers 33 | 34 | checkpoint_annotation_holder = json.loads(checkpoint_annotation_holder) 35 | new_annotation = OrderedDict() 36 | 37 | for k in self.annotation.keys(): 38 | anno_k = [] 39 | for a in checkpoint_annotation_holder[k]: 40 | y, z, j, p_z_given_y = a 41 | anno_k.append((y, z, j, np.array(p_z_given_y))) 42 | 43 | new_annotation.update({k: anno_k}) 44 | 45 | self.annotation = new_annotation 46 | n_annotation = sum([len(v) for v in self.annotation.values()]) 47 | self.n_annotation = n_annotation 48 | 49 | def save_state(self): 50 | return json.dumps(deepcopy(self.annotation), cls=NumpyEncoder) 51 | 52 | def collect_annotation(self, 53 | a_data, 54 | data_path, 55 | worker_id, 56 | belief): 57 | 58 | logger.info(f'Collect {len(data_path)} Annotation') 59 | for p, j in zip(data_path, worker_id): 60 | y = a_data[p][1] 61 | w = self.workers[j] 62 | z, p_z_given_y = w.annotate(y) 63 | self.annotation[p].append((int(y), int(z), j, p_z_given_y)) 64 | 65 | self.n_annotation += len(data_path) 66 | 67 | def add_annotation(self, annotations): 68 | 69 | for anno_i in annotations: 70 | p, y, z, j = anno_i 71 | p_z_given_y = None 72 | 73 | self.annotation[p].append((int(y), int(z), j, p_z_given_y)) 74 | 75 | self.n_annotation += len(annotations) 76 | -------------------------------------------------------------------------------- /online_label/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/job_logging: colorlog 3 | - hydra/hydra_logging: colorlog 4 | - experiment: imagenet_split_1 5 | - simulation: amt_structured_noise 6 | - learner_method: efficient_annotation # lean, improved_lean, efficient_annotation, online_ds 7 | 8 | 9 | hydra: 10 | run: 11 | dir: ./outputs/${now:%Y-%m-%d}-${now:%H-%M-%S} 12 | 13 | exp_name: test 14 | seed: 123 15 | n_jobs: 2 16 | early_stop: false 17 | 18 | aggregator: 19 | algo: bayes # mj bayes 20 | 21 | 22 | optimizer: 23 | max_em_steps: 10 24 | likelihood_epsilon: 0.001 25 | 26 | 27 | learner: 28 | features: byol_r50-e3b0c442.pth_feat1.npy 29 | risk_thres: 0.1 30 | n_hidden_layer: 1 31 | hidden_size: 64 32 | batch_size: 1024 33 | max_epochs: 300 34 | lr_ratio: 0.0001 35 | weight_decay: 0.001 36 | 37 | algo: dummy 38 | calibrate: temperature 39 | semi_supervised: none 40 | early_stop_scope: local 41 | prototype_as_val: true 42 | 43 | mixmatch: 44 | mu: 5 45 | alpha: 0.75 46 | mixmatch_w: 75 47 | 48 | 49 | sampler: 50 | algo: random # random, greedy_task_assignment 51 | max_annotation_per_example: 3 52 | max_annotation_per_worker: 2000 53 | -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_100_classes.yaml: -------------------------------------------------------------------------------- 1 | wnids: n02869837 n01749939 n02488291 n02107142 n13037406 n02091831 n04517823 n04589890 n03062245 n01773797 n01735189 n07831146 n07753275 n03085013 n04485082 n02105505 n01983481 n02788148 n03530642 n04435653 n02086910 n02859443 n13040303 n03594734 n02085620 n02099849 n01558993 n04493381 n02109047 n04111531 n02877765 n04429376 n02009229 n01978455 n02106550 n01820546 n01692333 n07714571 n02974003 n02114855 n03785016 n03764736 n03775546 n02087046 n07836838 n04099969 n04592741 n03891251 n02701002 n03379051 n02259212 n07715103 n03947888 n04026417 n02326432 n03637318 n01980166 n02113799 n02086240 n03903868 n02483362 n04127249 n02089973 n03017168 n02093428 n02804414 n02396427 n04418357 n02172182 n01729322 n02113978 n03787032 n02089867 n02119022 n03777754 n04238763 n02231487 n03032252 n02138441 n02104029 n03837869 n03494278 n04136333 n03794056 n03492542 n02018207 n04067472 n03930630 n03584829 n02123045 n04229816 n02100583 n03642806 n04336792 n03259280 n02116738 n02108089 n03424325 n01855672 n02090622 2 | n_classes: 100 3 | n_data_per_class: 3000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 50 10 | budget: 800000 11 | hit_size: 40 12 | -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_animal.yaml: -------------------------------------------------------------------------------- 1 | # Experiment Setting 2 | experiment: imagenet_animal 3 | wnids: n02105505 n02113978 n02100583 n02087046 n02086910 n02108089 n02106550 n02089973 n02085620 n02086240 n02099849 n02091831 n02089867 n02090622 n02113799 n02104029 n02093428 n02107142 n02109047 n02138441 n02326432 n01558993 n02009229 n02123045 n01855672 n02488291 n02396427 n01820546 n01735189 n02116738 n02483362 n02114855 n01729322 n02018207 n01749939 n01692333 n02119022 4 | n_classes: 37 5 | n_data_per_class: 3000 6 | n_data_distraction_per_class: 0 7 | n_prototype_per_class: 10 8 | risk_thres: 0.1 9 | 10 | online: 11 | n_hits_per_step: 5 12 | budget: 500000 13 | hit_size: 40 14 | 15 | -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_split_0.yaml: -------------------------------------------------------------------------------- 1 | wnids: n02105505 n02113978 n02100583 n02087046 n02086910 n02108089 n02106550 n02089973 n02085620 n02086240 n02099849 n02091831 n02089867 n02090622 n02113799 n02104029 n02093428 n02107142 n02109047 2 | n_classes: 19 3 | n_data_per_class: 2000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 5 10 | budget: 400000 11 | hit_size: 40 -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_split_1.yaml: -------------------------------------------------------------------------------- 1 | wnids: n02138441 n02326432 n01558993 n02009229 n02123045 n01855672 n02488291 n02396427 n01820546 n01735189 n02116738 n02483362 n02114855 n01729322 n02018207 n01749939 n01692333 n02119022 2 | n_classes: 18 3 | n_data_per_class: 2000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 5 10 | budget: 400000 11 | hit_size: 40 -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_split_2.yaml: -------------------------------------------------------------------------------- 1 | wnids: n13037406 n07715103 n01980166 n02172182 n07714571 n01983481 n13040303 n02259212 n01978455 n01773797 n07831146 n02231487 n07836838 2 | n_classes: 13 3 | n_data_per_class: 2000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 5 10 | budget: 400000 11 | hit_size: 40 -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_split_3.yaml: -------------------------------------------------------------------------------- 1 | wnids: n03794056 n04336792 n04127249 n02701002 n04238763 n04429376 n02974003 n02804414 n03891251 n03017168 n04067472 n03642806 n03930630 n04099969 n03785016 n03947888 n03494278 n04592741 n03492542 2 | n_classes: 19 3 | n_data_per_class: 2000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 5 10 | budget: 400000 11 | hit_size: 40 -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_split_4.yaml: -------------------------------------------------------------------------------- 1 | wnids: n04517823 n03085013 n02877765 n03764736 n03584829 n03787032 n02869837 n04136333 n03777754 n04493381 n04026417 n03062245 n04111531 n03594734 n03259280 n03379051 2 | n_classes: 16 3 | n_data_per_class: 2000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 5 10 | budget: 400000 11 | hit_size: 40 -------------------------------------------------------------------------------- /online_label/config/experiment/imagenet_split_5.yaml: -------------------------------------------------------------------------------- 1 | wnids: n04589890 n07753275 n04485082 n02788148 n03530642 n04435653 n02859443 n03775546 n03637318 n03903868 n04418357 n03032252 n03837869 n04229816 n03424325 2 | n_classes: 15 3 | n_data_per_class: 2000 4 | n_data_distraction_per_class: 0 5 | n_prototype_per_class: 10 6 | risk_thres: 0.1 7 | 8 | online: 9 | n_hits_per_step: 5 10 | budget: 400000 11 | hit_size: 40 -------------------------------------------------------------------------------- /online_label/config/hyperparams.yaml: -------------------------------------------------------------------------------- 1 | imagenet_split_0: 2 | none: 3 | lr_ratio: 0.0001 4 | weight_decay: 0.005 5 | pseudolabel: 6 | lr_ratio: 0.0005 7 | weight_decay: 0.005 8 | mixmatch: 9 | lr_ratio: 0.0001 10 | weight_decay: 0.005 11 | mu: 3 12 | mixmatch_w: 50 13 | 14 | 15 | imagenet_split_2: 16 | none: 17 | lr_ratio: 0.001 18 | weight_decay: 0.001 19 | pseudolabel: 20 | lr_ratio: 0.0005 21 | weight_decay: 0.0005 22 | mixmatch: 23 | lr_ratio: 0.001 24 | weight_decay: 0.0005 25 | mu: 5 26 | mixmatch_w: 100 27 | 28 | 29 | imagenet_split_4: 30 | none: 31 | lr_ratio: 0.001 32 | weight_decay: 0.001 33 | pseudolabel: 34 | lr_ratio: 0.0001 35 | weight_decay: 0.0005 36 | mixmatch: 37 | lr_ratio: 0.0005 38 | weight_decay: 0.005 39 | mu: 3 40 | mixmatch_w: 150 41 | 42 | 43 | imagenet_animal: 44 | none: 45 | lr_ratio: 0.0001 46 | weight_decay: 0.005 47 | pseudolabel: 48 | lr_ratio: 0.00005 49 | weight_decay: 0.005 50 | mixmatch: 51 | lr_ratio: 0.001 52 | weight_decay: 0.0005 53 | mu: 5 54 | mixmatch_w: 75 55 | 56 | 57 | imagenet_100_classes: 58 | none: 59 | lr_ratio: 0.0005 60 | weight_decay: 0.0001 61 | pseudolabel: 62 | lr_ratio: 0.0005 63 | weight_decay: 0.0001 64 | mixmatch: 65 | lr_ratio: 0.0005 66 | weight_decay: 0.0001 67 | mu: 10 68 | mixmatch_w: 50 69 | -------------------------------------------------------------------------------- /online_label/config/learner_method/ds_model.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | criterion: soft 3 | 4 | learner: 5 | algo: dummy -------------------------------------------------------------------------------- /online_label/config/learner_method/efficient_annotation.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | criterion: soft 3 | 4 | learner: 5 | algo: mlp 6 | calibrate: temperature 7 | semi_supervised: pseudolabel 8 | early_stop_scope: global 9 | prototype_as_val: true 10 | lr_ratio: 0.0001 11 | weight_decay: 0.001 -------------------------------------------------------------------------------- /online_label/config/learner_method/improved_lean.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | criterion: soft 3 | 4 | learner: 5 | algo: mlp 6 | calibrate: temperature 7 | semi_supervised: none 8 | early_stop_scope: global 9 | prototype_as_val: true 10 | lr_ratio: 0.0001 11 | weight_decay: 0.001 -------------------------------------------------------------------------------- /online_label/config/learner_method/lean.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | criterion: hard 3 | 4 | learner: 5 | algo: mlp 6 | calibrate: cv_3 7 | semi_supervised: none 8 | early_stop_scope: local 9 | prototype_as_val: false 10 | lr_ratio: 0.0001 11 | weight_decay: 0.001 -------------------------------------------------------------------------------- /online_label/config/simulation/amt_structured_noise.yaml: -------------------------------------------------------------------------------- 1 | worker: 2 | type: structured_noise 3 | known: false 4 | n: 50 5 | prior: 6 | type: homegeneous_workers 7 | strength: 10 8 | predefined_params: 9 | pos_count: 2 10 | neg_count: 1 11 | -------------------------------------------------------------------------------- /online_label/config/simulation/amt_uniform_noise.yaml: -------------------------------------------------------------------------------- 1 | worker: 2 | type: uniform_noise 3 | known: false 4 | n: 50 5 | prior: 6 | type: homegeneous_workers 7 | strength: 10 8 | predefined_params: 9 | pos_count: 2 10 | neg_count: 1 11 | -------------------------------------------------------------------------------- /online_label/learner/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class Learner(object): 9 | def __init__(self, config): 10 | self.config = config 11 | self.npr = np.random.RandomState(config.seed) 12 | self.calibrate = config.learner.calibrate 13 | self.semi_supervised = config.learner.semi_supervised 14 | self.risk_thres = config.learner.risk_thres 15 | self.use_cuda = torch.cuda.is_available() 16 | self.early_stop_scope = config.learner.early_stop_scope 17 | self.prototype_as_val = config.learner.prototype_as_val 18 | 19 | def save_state(self): 20 | raise NotImplementedError 21 | 22 | def load_state(self): 23 | raise NotImplementedError 24 | 25 | def fit_and_predict(self, features, prototype_targets, belief, n_annotation, ground_truth): 26 | raise NotImplementedError 27 | 28 | 29 | class DummyLearner(Learner): 30 | def __init__(self, config): 31 | Learner.__init__(self, config) 32 | logger.info('No learner is used') 33 | 34 | def save_state(self): 35 | pass 36 | 37 | def load_state(self, state): 38 | pass 39 | 40 | def fit_and_predict(self, features, prototype_targets, belief, n_annotation, ground_truth): 41 | return None 42 | 43 | 44 | def get_learner_class(config): 45 | from .nn_learner import LinearNNLearner 46 | if config.learner.algo == 'dummy': 47 | return DummyLearner 48 | elif config.learner.algo == 'mlp': 49 | return LinearNNLearner 50 | else: 51 | raise ValueError 52 | -------------------------------------------------------------------------------- /online_label/learner/nn_learner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import math 4 | import itertools 5 | import numpy as np 6 | from collections import Counter 7 | from copy import deepcopy 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import TensorDataset, DataLoader 12 | 13 | from .utils import ModelWithTemperature, LinearModel 14 | from . import Learner 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class NNLearner(Learner): 21 | def __init__(self, config): 22 | Learner.__init__(self, config) 23 | self.batch_size = self.config.learner.batch_size 24 | self.max_epochs = self.config.learner.max_epochs 25 | self.best_model_loss = np.inf 26 | self.best_probs = None 27 | 28 | def save_state(self): 29 | if self.best_probs is not None: 30 | best_model_loss = self.best_model_loss 31 | return json.dumps(dict(best_model_loss=[float(best_model_loss)], best_probs=self.best_probs.tolist())) 32 | 33 | def load_state(self, state): 34 | if state is not None: 35 | state = json.loads(state) 36 | self.best_probs = np.array(state['best_probs']) 37 | self.best_model_loss = state['best_model_loss'][0] 38 | 39 | def init_learner(self, in_channels, out_channels): 40 | raise NotImplementedError 41 | 42 | def fit(self, in_channels, weighted_labeled_loader, val_loader): 43 | '''Fit the model on `weighted_labeled_loader` and validate on `val_loader` 44 | Return the best model based on the loss in val_loader and 45 | its corresponding loss in val_loader 46 | ''' 47 | 48 | t1 = time.time() 49 | model = self.init_learner(in_channels, self.config.n_classes) 50 | best_model = self.init_learner(in_channels, self.config.n_classes) 51 | 52 | if self.use_cuda: 53 | model = model.cuda() 54 | 55 | lr = self.config.learner.lr_ratio * math.sqrt(min(weighted_labeled_loader.batch_size, 56 | len(weighted_labeled_loader.dataset))) 57 | optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=self.config.learner.weight_decay) 58 | 59 | 60 | def __train_step(model, inputs, targets): 61 | 62 | if self.use_cuda: 63 | inputs = inputs.cuda() 64 | targets = targets.cuda() 65 | 66 | optim.zero_grad() 67 | logits = model(inputs) 68 | loss = model.compute_loss(logits, targets) 69 | 70 | loss.backward() 71 | optim.step() 72 | return loss.cpu().item() 73 | 74 | 75 | @torch.no_grad() 76 | def __eval(model, loader): 77 | model.eval() 78 | loss_l = [] 79 | targets_l = [] 80 | for inputs, targets in loader: 81 | if self.use_cuda: 82 | inputs = inputs.cuda() 83 | targets = targets.cuda() 84 | logits = model(inputs) 85 | loss = model.compute_loss(logits, targets, reduction='none') 86 | 87 | loss_l.extend(loss.cpu().numpy()) 88 | targets_l.extend(targets.cpu().numpy()) 89 | return loss_l, targets_l 90 | 91 | 92 | best_loss = np.inf 93 | best_epoch = 0 94 | for epoch in range(1, self.max_epochs+1): 95 | 96 | model.train() 97 | 98 | train_loss_average = [] 99 | for inputs, targets in weighted_labeled_loader: 100 | loss = __train_step(model, inputs, targets) 101 | train_loss_average.append(loss) 102 | 103 | if epoch % 10 == 0: 104 | model.eval() 105 | val_loss_l, _ = __eval(model, val_loader) 106 | val_loss = np.mean(val_loss_l) 107 | logger.debug(f'[Epoch {epoch}] Val Loss: {val_loss}') 108 | 109 | if val_loss < best_loss: 110 | best_model.load_state_dict(deepcopy(model.state_dict())) 111 | best_loss = val_loss 112 | best_epoch = epoch 113 | 114 | logger.debug(f'Use model at epoch {best_epoch}') 115 | logger.debug(f'Fiting nn takes {time.time()-t1} sec in {epoch} epoch') 116 | if self.use_cuda: 117 | best_model = best_model.cuda() 118 | 119 | return best_model, best_loss 120 | 121 | def get_train_val(self, features, prototype_targets, belief, n_annotation, ground_truth): 122 | ''' 123 | features: dict with key "features" and "prototype_features" 124 | prototype_targets: np.ndarray 125 | belief: np.ndarray with shape (#data, #class) 126 | n_annotation: np.ndarray describing the number of annotation for each image 127 | ground_truth: np.ndarray 128 | ''' 129 | 130 | n_prototypes = len(prototype_targets) 131 | n_data = len(features['features']) 132 | 133 | def __get_train_mask(risk, confident_mask): 134 | # Get training data 135 | at_least_one_annotation = n_annotation > 0 136 | logger.debug(f'Number of annotated examples (|W_i| > 0) : {at_least_one_annotation.sum()}') 137 | logger.debug(f'Number of confident examples: {confident_mask.sum()}') 138 | if self.semi_supervised == 'none': 139 | train_mask = np.copy(at_least_one_annotation) 140 | elif self.semi_supervised == 'pseudolabel': 141 | train_mask = np.copy(confident_mask) 142 | else: 143 | raise ValueError 144 | 145 | # Include prototypes 146 | #at_least_one_annotation = np.concatenate([at_least_one_annotation, 147 | # np.zeros(n_prototypes).astype(np.bool)]) 148 | confident_mask = np.concatenate([confident_mask, np.ones(n_prototypes).astype(np.bool)]) 149 | if self.prototype_as_val: 150 | train_mask = np.concatenate([train_mask, np.zeros(n_prototypes).astype(np.bool)]) 151 | else: 152 | train_mask = np.concatenate([train_mask, np.ones(n_prototypes).astype(np.bool)]) 153 | 154 | return train_mask, confident_mask 155 | 156 | prototype_prob = np.zeros([n_prototypes, self.config.n_classes]) 157 | prototype_prob[range(n_prototypes), prototype_targets] = 1. 158 | 159 | # Get confident mask 160 | y_hat = belief.argmax(1) 161 | confidence = belief.max(1) 162 | risk = 1 - confidence 163 | confident_mask = risk < self.risk_thres 164 | 165 | 166 | # All data to use 167 | P = np.concatenate([belief, prototype_prob]) 168 | Y = np.concatenate([y_hat, prototype_targets]) 169 | X = np.concatenate([features['features'], features['prototype_features']]) 170 | ground_truth = np.concatenate([ground_truth, prototype_targets]) 171 | 172 | 173 | # Train 174 | train_mask, confident_mask = __get_train_mask(risk, confident_mask) 175 | 176 | prototype_mask = np.zeros(n_data + n_prototypes).astype(np.bool) 177 | prototype_mask[-n_prototypes:] = True 178 | 179 | # Val 180 | if self.prototype_as_val: 181 | val_mask = prototype_mask.copy() 182 | else: 183 | # Choose a balance set 184 | n_val = prototype_mask.sum() 185 | 186 | idx = self.npr.choice(np.where(train_mask)[0], n_val, replace=False) 187 | val_mask = np.zeros(n_data + n_prototypes).astype(np.bool) 188 | val_mask[idx] = True 189 | train_mask[idx] = False 190 | 191 | 192 | def __ensure_enough_in_train_and_val(train_mask, val_mask): 193 | val_Y = Y[val_mask] 194 | train_Y = Y[train_mask] 195 | 196 | train_counter = Counter(train_Y) 197 | val_counter = Counter(val_Y) 198 | for c in range(self.config.n_classes): 199 | train_c = train_counter[c] 200 | val_c = val_counter[c] 201 | if val_c < math.floor(self.config.n_prototype_per_class / 2): 202 | logger.debug('Not enough class ({c}) in validation set') 203 | n = math.floor(self.config.n_prototype_per_class / 2) - val_c 204 | 205 | # data move from t to v 206 | idx = np.where(Y[train_mask] == c)[0] 207 | idx = self.npr.choice(idx, n, replace=False) 208 | idx = np.where(train_mask)[0][idx] 209 | 210 | val_mask[idx] = True 211 | train_mask[idx] = False 212 | elif train_c < val_c: 213 | # Move half of data from v to t 214 | logger.debug(f'Not enough class ({c}) in train set') 215 | 216 | idx = np.where(Y[val_mask] == c)[0] 217 | idx = self.npr.choice(idx, math.floor(val_c * 0.5), replace=False) 218 | idx = np.where(val_mask)[0][idx] 219 | 220 | train_mask[idx] = True 221 | val_mask[idx] = False 222 | 223 | return train_mask, val_mask 224 | 225 | 226 | train_mask, val_mask = __ensure_enough_in_train_and_val(train_mask, val_mask) 227 | 228 | train_X = X[train_mask] 229 | train_Y = Y[train_mask] 230 | train_gt = ground_truth[train_mask] 231 | 232 | val_X = X[val_mask] 233 | val_Y = Y[val_mask] 234 | val_gt = ground_truth[val_mask] 235 | 236 | logger.debug(f'Train {len(train_X)}/Val {len(val_X)}') 237 | train_acc = (train_Y == train_gt).mean() 238 | val_acc = (val_Y == val_gt).mean() 239 | 240 | logger.debug(f'Accuracy of train dataset ({len(train_gt)}): {train_acc}') 241 | logger.debug(f'Accuracy of validation dataset ({len(val_gt)}): {val_acc}') 242 | 243 | return (X, Y, ground_truth), \ 244 | (train_X, train_Y, train_gt, train_mask), \ 245 | (val_X, val_Y, val_gt, val_mask) 246 | 247 | def create_dataloader_from_data(self, data, batch_size, shuffle=None, weighted=False): 248 | X, Y, _, _ = data 249 | tensor_X = torch.tensor(X) 250 | tensor_Y = torch.tensor(Y) 251 | dataset = TensorDataset(tensor_X, tensor_Y) 252 | 253 | batch_size = int(min(len(Y), batch_size)) 254 | if batch_size == 0: 255 | return None 256 | 257 | if weighted: 258 | if len(Y.shape) == 2: 259 | # Y is a probability distribution 260 | Y = Y.argmax(1) 261 | y_counter = Counter(Y) 262 | if len(Y) == 0: 263 | weights = np.ones_like(Y) 264 | else: 265 | weights = np.array([1. / y_counter[t] for t in Y]) 266 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size) 267 | loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=self.config.n_jobs) 268 | else: 269 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=self.config.n_jobs) 270 | return loader 271 | 272 | def calibrate_model(self, model, val_loader): 273 | scaled_model = ModelWithTemperature(model, self.use_cuda) 274 | if self.use_cuda: 275 | scaled_model = scaled_model.cuda() 276 | scaled_model.set_temperature(val_loader) 277 | if scaled_model.before_ece < scaled_model.after_ece or scaled_model.before_nll < scaled_model.after_nll: 278 | logger.info('ECE or NLL in validation set goes higher than before. Use the uncalibrated model.') 279 | return model 280 | else: 281 | return scaled_model 282 | 283 | def predict(self, model, test_loader): 284 | with torch.no_grad(): 285 | if self.use_cuda: 286 | model = model.cuda() 287 | 288 | model.eval() 289 | probs = [] 290 | for inputs in test_loader: 291 | inputs = inputs[0] 292 | 293 | if self.use_cuda: 294 | inputs = inputs.cuda() 295 | logits = model(inputs) 296 | prob = F.softmax(logits, dim=1) 297 | probs.append(prob.cpu()) 298 | 299 | probs = torch.cat(probs, dim=0) 300 | return probs.cpu().numpy() 301 | 302 | def fit_and_predict(self, features, prototype_targets, belief, n_annotation, ground_truth): 303 | 304 | n_data = features['features'].shape[0] 305 | n_prototypes = features['prototype_features'].shape[0] 306 | in_channels = features['features'].shape[1] 307 | 308 | all_data, train_data, val_data = self.get_train_val(features, prototype_targets, belief, n_annotation, ground_truth) 309 | 310 | if 'cv' in self.calibrate: 311 | assert not self.prototype_as_val 312 | 313 | n_folds = int(self.calibrate.split('_')[1]) 314 | 315 | _, _, _, train_mask = train_data 316 | _, _, _, val_mask = val_data 317 | mask = train_mask | val_mask 318 | 319 | X, Y, gt = all_data 320 | 321 | cv_X = X[mask] 322 | cv_Y = Y[mask] 323 | cv_gt = gt[mask] 324 | 325 | 326 | cv_idx = np.arange(len(cv_X)) 327 | invalid = True 328 | while invalid: 329 | logger.debug(f'Split to {n_folds} folds') 330 | self.npr.shuffle(cv_idx) 331 | cv_splits = np.array_split(cv_idx, n_folds) 332 | # ensure at least one of the split is valid 333 | for cv_s in cv_splits: 334 | cv_val_mask = np.zeros(len(cv_idx)).astype(np.bool) 335 | cv_val_mask[cv_s] = True 336 | cv_train_mask = ~cv_val_mask 337 | if len(np.unique(cv_Y[cv_train_mask])) == self.config.n_classes and \ 338 | len(np.unique(cv_Y[cv_val_mask])) == self.config.n_classes: 339 | invalid = False 340 | 341 | 342 | cv_best_loss_average = [] 343 | cv_probs = [] 344 | for i, cv_s in enumerate(cv_splits): 345 | logger.debug('{} fold'.format(i)) 346 | 347 | cv_val_mask = np.zeros(len(cv_idx)).astype(np.bool) 348 | cv_val_mask[cv_s] = True 349 | cv_train_mask = ~cv_val_mask 350 | 351 | 352 | cv_train = (cv_X[cv_train_mask], 353 | cv_Y[cv_train_mask], 354 | cv_gt[cv_train_mask], 355 | None) 356 | 357 | cv_val = (cv_X[cv_val_mask], 358 | cv_Y[cv_val_mask], 359 | cv_gt[cv_val_mask], 360 | None) 361 | 362 | if len(np.unique(cv_train[1])) == self.config.n_classes and \ 363 | len(np.unique(cv_val[1])) == self.config.n_classes: 364 | 365 | weighted_train_loader = self.create_dataloader_from_data(cv_train, self.batch_size, weighted=True) 366 | val_loader = self.create_dataloader_from_data(cv_val, self.batch_size, shuffle=False) 367 | 368 | cv_best_model, cv_best_loss = self.fit(in_channels, weighted_train_loader, val_loader) 369 | cv_best_loss_average.append(cv_best_loss) 370 | cv_best_model = self.calibrate_model(cv_best_model, val_loader) 371 | 372 | 373 | dataset = TensorDataset(torch.tensor(np.concatenate([features['features'], features['prototype_features']]))) 374 | loader = DataLoader(dataset, 2**13) 375 | probs = self.predict(cv_best_model, loader) 376 | cv_probs.append(probs) 377 | 378 | probs = np.stack(cv_probs, 0).mean(0) 379 | best_model_loss = np.mean(cv_best_loss_average) 380 | 381 | elif self.calibrate == 'temperature': 382 | weighted_train_loader = self.create_dataloader_from_data(train_data, self.batch_size, weighted=True) 383 | val_loader = self.create_dataloader_from_data(val_data, self.batch_size, shuffle=False) 384 | best_model, best_model_loss = self.fit(in_channels, weighted_train_loader, val_loader) 385 | 386 | if self.calibrate == 'temperature': 387 | best_model = self.calibrate_model(best_model, val_loader) 388 | 389 | dataset = TensorDataset(torch.tensor(np.concatenate([features['features'], features['prototype_features']]))) 390 | loader = DataLoader(dataset, 2**13) 391 | probs = self.predict(best_model, loader) 392 | else: 393 | raise ValueError 394 | 395 | 396 | if self.early_stop_scope == 'global': 397 | if best_model_loss < self.best_model_loss: 398 | logger.info(f'Find the model ({best_model_loss}) better than the previous one ({self.best_model_loss})') 399 | self.best_model_loss = best_model_loss 400 | self.best_probs = np.copy(probs) 401 | else: 402 | logger.info(f'Use the previous learned model ({self.best_model_loss})') 403 | probs = np.copy(self.best_probs) 404 | 405 | return probs 406 | 407 | 408 | class LinearNNLearner(NNLearner): 409 | def __init__(self, config): 410 | NNLearner.__init__(self, config) 411 | self.n_hidden_layer = config.learner.n_hidden_layer 412 | self.hidden_size = config.learner.hidden_size 413 | 414 | def init_learner(self, in_channels, out_channels): 415 | model = LinearModel(in_channels, out_channels, self.hidden_size, self.n_hidden_layer) 416 | return model 417 | -------------------------------------------------------------------------------- /online_label/learner/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, optim 3 | from torch.nn import functional as F 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class LinearModel(nn.Module): 10 | def __init__(self, input_dim, output_dim, hidden_dim=32, n_hidden_layer=0): 11 | nn.Module.__init__(self) 12 | 13 | self.input_dim = input_dim 14 | self.output_dim = output_dim 15 | self.hidden_dim = hidden_dim 16 | 17 | modules = [] 18 | in_dim = input_dim 19 | for i in range(n_hidden_layer): 20 | modules.append(nn.Linear(in_dim, hidden_dim)) 21 | modules.append(nn.Tanh()) 22 | in_dim = hidden_dim 23 | modules.append(nn.Linear(in_dim, self.output_dim)) 24 | 25 | self.net = nn.Sequential(*modules) 26 | 27 | def forward(self, x): 28 | out = self.net(x) 29 | return out 30 | 31 | def compute_loss(self, logits, labeled_y, reduction='mean'): 32 | loss = F.cross_entropy(logits, labeled_y, reduction=reduction) 33 | return loss 34 | 35 | def compute_mixmatch_loss(self, logits_l, prob_l, logits_u, prob_u, u_weight): 36 | ce_loss = (-F.softmax(logits_l, 1).log() * prob_l).sum(1).mean() 37 | l2_loss = torch.mean((F.softmax(logits_u, 1) - prob_u)**2) 38 | return ce_loss + u_weight * l2_loss 39 | 40 | 41 | class ModelWithTemperature(nn.Module): 42 | ''' 43 | A thin decorator, which wraps a model with temperature scaling 44 | model (nn.Module): 45 | A classification neural network 46 | NB: Output of the neural network should be the classification logits, 47 | NOT the softmax (or log softmax)! 48 | ''' 49 | def __init__(self, model, use_cuda=False): 50 | super(ModelWithTemperature, self).__init__() 51 | self.use_cuda = use_cuda 52 | self.model = model 53 | self.temperature = nn.Parameter(torch.ones(1) * 1.5) 54 | 55 | def forward(self, input): 56 | logits = self.model(input) 57 | return self.temperature_scale(logits) 58 | 59 | def temperature_scale(self, logits): 60 | ''' 61 | Perform temperature scaling on logits 62 | ''' 63 | # Expand temperature to match the size of logits 64 | temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1)) 65 | return logits / temperature 66 | 67 | 68 | # This function probably should live outside of this class, but whatever 69 | def set_temperature(self, valid_loader): 70 | ''' 71 | Tune the tempearature of the model (using the validation set). 72 | We're going to set it to optimize NLL. 73 | valid_loader (DataLoader): validation set loader 74 | ''' 75 | 76 | nll_criterion = nn.CrossEntropyLoss() 77 | ece_criterion = _ECELoss() 78 | if self.use_cuda: 79 | nll_criterion = nll_criterion.cuda() 80 | ece_criterion = ece_criterion.cuda() 81 | 82 | 83 | # First: collect all the logits and labels for the validation set 84 | logits_list = [] 85 | labels_list = [] 86 | with torch.no_grad(): 87 | for input, label in valid_loader: 88 | if self.use_cuda: 89 | input = input.cuda() 90 | logits = self.model(input) 91 | logits_list.append(logits) 92 | labels_list.append(label) 93 | 94 | logits = torch.cat(logits_list) 95 | labels = torch.cat(labels_list) 96 | if self.use_cuda: 97 | logits = logits.cuda() 98 | labels = labels.cuda() 99 | 100 | # Calculate NLL and ECE before temperature scaling 101 | before_temperature_nll = nll_criterion(logits, labels).item() 102 | before_temperature_ece = ece_criterion(logits, labels).item() 103 | logger.debug('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece)) 104 | 105 | # Next: optimize the temperature w.r.t. NLL 106 | optimizer = optim.LBFGS([self.temperature], lr=0.001, max_iter=50) 107 | 108 | def eval(): 109 | loss = nll_criterion(self.temperature_scale(logits), labels) 110 | loss.backward() 111 | return loss 112 | optimizer.step(eval) 113 | 114 | if self.temperature.item() > 100. or self.temperature.item() < -100.: 115 | logger.debug('Invalid temperature found') 116 | torch.nn.init.constant_(self.temperature, 1) 117 | 118 | # Calculate NLL and ECE after temperature scaling 119 | after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item() 120 | after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item() 121 | 122 | logger.debug(f'Optimal temperature: {self.temperature.item()}') 123 | logger.debug(f'After temperature - NLL: {after_temperature_nll}, ECE:{after_temperature_ece}') 124 | self.after_nll = after_temperature_nll 125 | self.before_nll = before_temperature_nll 126 | self.after_ece = after_temperature_ece 127 | self.before_ece = before_temperature_ece 128 | 129 | return self 130 | 131 | 132 | class _ECELoss(nn.Module): 133 | ''' 134 | Calculates the Expected Calibration Error of a model. 135 | (This isn't necessary for temperature scaling, just a cool metric). 136 | The input to this loss is the logits of a model, NOT the softmax scores. 137 | This divides the confidence outputs into equally-sized interval bins. 138 | In each bin, we compute the confidence gap: 139 | bin_gap = | avg_confidence_in_bin - accuracy_in_bin | 140 | We then return a weighted average of the gaps, based on the number 141 | of samples in each bin 142 | See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht. 143 | 'Obtaining Well Calibrated Probabilities Using Bayesian Binning.' AAAI. 144 | 2015. 145 | ''' 146 | def __init__(self, n_bins=15): 147 | ''' 148 | n_bins (int): number of confidence interval bins 149 | ''' 150 | super(_ECELoss, self).__init__() 151 | bin_boundaries = torch.linspace(0, 1, n_bins + 1) 152 | self.bin_lowers = bin_boundaries[:-1] 153 | self.bin_uppers = bin_boundaries[1:] 154 | 155 | def forward(self, logits, labels): 156 | softmaxes = F.softmax(logits, dim=1) 157 | confidences, predictions = torch.max(softmaxes, 1) 158 | accuracies = predictions.eq(labels) 159 | 160 | ece = torch.zeros(1, device=logits.device) 161 | for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers): 162 | # Calculated |confidence - accuracy| in each bin 163 | in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) 164 | prop_in_bin = in_bin.float().mean() 165 | if prop_in_bin.item() > 0: 166 | accuracy_in_bin = accuracies[in_bin].float().mean() 167 | avg_confidence_in_bin = confidences[in_bin].mean() 168 | ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin 169 | 170 | return ece 171 | -------------------------------------------------------------------------------- /online_label/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_logger import BatchLogger -------------------------------------------------------------------------------- /online_label/logger/batch_logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from .utils import NpEncoder, cal_ece 4 | from collections import defaultdict 5 | 6 | 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class BatchLogger(object): 12 | def __init__(self, config, y, n_data, wnids, p_image_path, image_path, log_image_path, d_image_path): 13 | self.config = config 14 | self.n_data = n_data 15 | self.wnids = wnids 16 | self.ground_truth_y = np.array(y) 17 | self.image_path = image_path 18 | self.log_image_path = log_image_path 19 | 20 | self.include_distraction = config.n_data_distraction_per_class > 0 21 | self.d_image_path = d_image_path 22 | 23 | self.info = defaultdict() 24 | self.info['p_image_path'] = p_image_path 25 | self.info['image_path'] = image_path 26 | self.info['log_image_path'] = log_image_path 27 | self.info['d_image_path'] = d_image_path 28 | self.info['ground_truth_y'] = self.ground_truth_y 29 | self.cur_step = 0. 30 | 31 | def save(self): 32 | json.dump(self.info, open(f'info-{self.cur_step}.json', 'w'), cls=NpEncoder) 33 | 34 | def load(self, filename): 35 | info = json.load(open(filename, 'w')) 36 | self.info = info 37 | 38 | def step(self, 39 | n_data, 40 | annotation, 41 | sampled_image_path, 42 | step, 43 | learner_prob, 44 | y_posterior, 45 | y_posterior_risk): 46 | 47 | 48 | self.cur_step = step 49 | n_annotation = sum([len(v) for v in annotation.values()]) 50 | cost = n_annotation 51 | logger.info(f'Step: {self.cur_step}') 52 | logger.info(f'Number of annotation: {cost}') 53 | logger.info(f'Average annotation: {cost / len(self.log_image_path)}') 54 | 55 | x_axis = { 56 | 'step': step, 57 | 'cost': cost, 58 | 'cost-per-image': cost / self.n_data 59 | } 60 | 61 | self.info['data'] = { 62 | 'step': int(step), 63 | 'cost': int(cost), 64 | 'cost-per-image': cost / self.n_data, 65 | 'sampled_image_path': sampled_image_path, 66 | 'learner_prob': learner_prob, 67 | 'y_posterior': y_posterior.tolist(), 68 | 'y_posterior_risk': y_posterior_risk.tolist() 69 | } 70 | 71 | log_idx = np.array([self.image_path.index(p) for p in self.log_image_path]) 72 | 73 | y_posterior = y_posterior[log_idx, :] 74 | y_posterior_risk = y_posterior_risk[log_idx] 75 | 76 | 77 | self.log_acc(y_posterior, 'aggregation', y_posterior_risk, x_axis) 78 | self.log_risk(y_posterior_risk, 'aggregation', x_axis) 79 | self.log_count(annotation, 'aggregation', x_axis) 80 | self.log_ece(y_posterior, 'aggregation', y_posterior_risk, x_axis) 81 | 82 | if learner_prob is not None: 83 | # learner_prob: [data_learner_prob, prototype_learner_prob] 84 | learner_prob = learner_prob[:self.n_data, :] 85 | learner_prob = learner_prob[log_idx] 86 | self.log_acc(learner_prob, 'learner_prob', y_posterior_risk, x_axis) 87 | self.log_ece(learner_prob, 'learner_prob', y_posterior_risk, x_axis) 88 | 89 | def log_ece(self, belief, prefix, risk, x_axis): 90 | 91 | n_data = self.n_data 92 | ground_truth_y = self.ground_truth_y 93 | config = self.config 94 | 95 | y_hat = belief.argmax(1) 96 | pred_prob = np.take_along_axis(belief, y_hat[:, np.newaxis], axis=1).squeeze() 97 | 98 | def _log_ece(mask, tag, info): 99 | if sum(mask) > 0: 100 | ece = cal_ece(pred_prob[mask], 101 | y_hat[mask], 102 | ground_truth_y[mask]) 103 | logger.info(f'{prefix}\t{tag}/ece: {ece}') 104 | info.update({f'{prefix}/{tag}/ece': ece}) 105 | 106 | info = {} 107 | 108 | all_mask = np.ones(len(risk)).astype(np.bool) 109 | valid_mask = risk < config.risk_thres 110 | invalid_mask = ~valid_mask 111 | 112 | if self.include_distraction: 113 | 114 | # Class of interest 115 | _log_ece(all_mask, 'all', info) 116 | _log_ece(valid_mask, 'valid', info) 117 | _log_ece(invalid_mask, 'invalid', info) 118 | 119 | # Distraction classes 120 | _log_ece(all_mask, 'all', info) 121 | _log_ece(valid_mask, 'valid', info) 122 | _log_ece(invalid_mask, 'invalid', info) 123 | 124 | _log_ece(all_mask, 'all+distraction', info) 125 | _log_ece(valid_mask, 'valid+distraction', info) 126 | _log_ece(invalid_mask, 'invalid+distraction', info) 127 | else: 128 | _log_ece(all_mask, 'all', info) 129 | _log_ece(valid_mask, 'valid', info) 130 | _log_ece(invalid_mask, 'invalid', info) 131 | 132 | 133 | self.info['data'].update(info) 134 | 135 | def log_count(self, annotation, prefix, x_axis): 136 | 137 | count = np.array([len(anno_i) for anno_i in annotation]) 138 | n_data = self.n_data 139 | 140 | info = {f'{prefix}/num_labeled_data': len(np.where(count > 0)[0]) / n_data} 141 | logger.info(f'{prefix}\tnum_labeled_data: {len(np.where(count > 0)[0]) / n_data}') 142 | 143 | self.info['data'].update(info) 144 | 145 | def log_risk(self, risk, prefix, x_axis): 146 | valid_mask = (risk < self.config.risk_thres) 147 | 148 | logger.info(f'{prefix}\tvalid_num: {sum(valid_mask)}') 149 | info = {f'{prefix}/valid_num': sum(valid_mask)} 150 | 151 | self.info['data'].update(info) 152 | 153 | def log_acc(self, belief, prefix, risk, x_axis): 154 | 155 | n_data = self.n_data 156 | ground_truth_y = self.ground_truth_y 157 | config = self.config 158 | 159 | topK = [1, 5, 10] 160 | y_hat = np.argsort(belief, 1)[:, ::-1] 161 | topK_acc = {k: None for k in topK} 162 | for k in topK: 163 | acc = np.any(y_hat[:, :k] == ground_truth_y.reshape(-1, 1), axis=1) 164 | topK_acc[k] = acc 165 | 166 | 167 | def _log_acc(mask, tag, info): 168 | if sum(mask) > 0: 169 | for k in topK: 170 | average_acc = np.mean(topK_acc[k][mask]) 171 | logger.info(f'{prefix}\t{tag}/top{k}: {average_acc}') 172 | info.update({f'{prefix}/{tag}/top{k}': average_acc}) 173 | 174 | 175 | def _log_num(mask, tag, info): 176 | if sum(mask) > 0: 177 | n_correct = len(np.where(y_hat[mask, 0] == \ 178 | ground_truth_y[mask])[0]) 179 | n_incorrect = len(np.where(y_hat[mask, 0] != \ 180 | ground_truth_y[mask])[0]) 181 | logger.info(f'{prefix}\t{tag}/n_correct: {n_correct}') 182 | info.update({f'{prefix}/{tag}/n_correct': n_correct}) 183 | logger.info(f'{prefix}\t{tag}/n_incorrect: {n_incorrect}') 184 | info.update({f'{prefix}/{tag}/n_incorrect': n_incorrect}) 185 | 186 | 187 | info = {} 188 | all_mask = np.ones(len(risk)).astype(np.bool) 189 | valid_mask = risk < config.risk_thres 190 | invalid_mask = ~valid_mask 191 | if self.include_distraction: 192 | distraction_mask = ground_truth_y == config.n_classes-1 193 | 194 | # Class of interest 195 | _log_acc((all_mask & ~distraction_mask), 'all', info) 196 | _log_num((all_mask & ~distraction_mask), 'all', info) 197 | 198 | _log_acc((valid_mask & ~distraction_mask), 'valid', info) 199 | _log_num((valid_mask & ~distraction_mask), 'valid', info) 200 | 201 | _log_acc((invalid_mask & ~distraction_mask), 'invalid', info) 202 | _log_num((invalid_mask & ~distraction_mask), 'invalid', info) 203 | 204 | # Distraction classes 205 | _log_acc((all_mask & distraction_mask), 'all_distraction', info) 206 | _log_num((all_mask & distraction_mask), 'all_distraction', info) 207 | 208 | _log_acc((valid_mask & distraction_mask), 'valid_distraction', info) 209 | _log_num((valid_mask & distraction_mask), 'valid_distraction', info) 210 | 211 | _log_acc((invalid_mask & distraction_mask), 'invalid_distraction', info) 212 | _log_num((invalid_mask & distraction_mask), 'invalid_distraction', info) 213 | 214 | # All 215 | _log_acc(all_mask, 'all+distraction', info) 216 | _log_num(all_mask, 'all+distraction', info) 217 | 218 | _log_acc(valid_mask, 'valid+distraction', info) 219 | _log_num(valid_mask, 'valid+distraction', info) 220 | 221 | _log_acc(invalid_mask, 'invalid+distraction', info) 222 | _log_num(invalid_mask, 'invalid+distraction', info) 223 | 224 | else: 225 | _log_acc(all_mask, 'all', info) 226 | _log_num(all_mask, 'all', info) 227 | 228 | _log_acc(valid_mask, 'valid', info) 229 | _log_num(valid_mask, 'valid', info) 230 | 231 | _log_acc(invalid_mask, 'invalid', info) 232 | _log_num(invalid_mask, 'invalid', info) 233 | 234 | 235 | self.info['data'].update(info) 236 | -------------------------------------------------------------------------------- /online_label/logger/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | 5 | class NpEncoder(json.JSONEncoder): 6 | def default(self, obj): 7 | if isinstance(obj, np.integer): 8 | return int(obj) 9 | elif isinstance(obj, np.floating): 10 | return float(obj) 11 | elif isinstance(obj, np.ndarray): 12 | return obj.tolist() 13 | else: 14 | return super(NpEncoder, self).default(obj) 15 | 16 | 17 | def compute_acc_bin(conf_thresh_lower, conf_thresh_upper, conf, pred, true): 18 | ''' 19 | # Computes accuracy and average confidence for bin 20 | 21 | Args: 22 | conf_thresh_lower (float): Lower Threshold of confidence interval 23 | conf_thresh_upper (float): Upper Threshold of confidence interval 24 | conf (numpy.ndarray): list of confidences 25 | pred (numpy.ndarray): list of predictions 26 | true (numpy.ndarray): list of true labels 27 | 28 | Returns: 29 | (accuracy, avg_conf, len_bin): accuracy of bin, confidence of bin and number of elements in bin. 30 | ''' 31 | filtered_tuples = [x for x in zip(pred, true, conf) if x[2] > conf_thresh_lower and x[2] <= conf_thresh_upper] 32 | if len(filtered_tuples) < 1: 33 | return 0,0,0 34 | else: 35 | correct = len([x for x in filtered_tuples if x[0] == x[1]]) # How many correct labels 36 | len_bin = len(filtered_tuples) # How many elements falls into given bin 37 | avg_conf = sum([x[2] for x in filtered_tuples]) / len_bin # Avg confidence of BIN 38 | accuracy = float(correct)/len_bin # accuracy of BIN 39 | return accuracy, avg_conf, len_bin 40 | 41 | 42 | #https://github.com/markus93/NN_calibration/blob/master/scripts/utility/evaluation.py#L130-L154 43 | def cal_ece(conf, pred, true, bin_size = 0.1): 44 | 45 | ''' 46 | Expected Calibration Error 47 | 48 | Args: 49 | conf (numpy.ndarray): list of confidences 50 | pred (numpy.ndarray): list of predictions 51 | true (numpy.ndarray): list of true labels 52 | bin_size: (float): size of one bin (0,1) # TODO should convert to number of bins? 53 | 54 | Returns: 55 | ece: expected calibration error 56 | ''' 57 | 58 | upper_bounds = np.arange(bin_size, 1+bin_size, bin_size) # Get bounds of bins 59 | 60 | n = len(conf) 61 | ece = 0 # Starting error 62 | 63 | for conf_thresh in upper_bounds: # Go through bounds and find accuracies and confidences 64 | acc, avg_conf, len_bin = compute_acc_bin(conf_thresh-bin_size, conf_thresh, conf, pred, true) 65 | ece += np.abs(acc-avg_conf)*len_bin/n # Add weigthed difference to ECE 66 | 67 | return ece -------------------------------------------------------------------------------- /online_label/online_loop.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class EarlyStopper(): 9 | def __init__(self, hit_size, n_hits_per_step, risk_thres, n_data): 10 | self.risk_thres = risk_thres 11 | self.n_data = n_data 12 | self.cost_per_step = hit_size * n_hits_per_step 13 | self.num_valid_list = [] 14 | 15 | def stop(self, num_valid, n_annotation): 16 | 17 | self.num_valid_list.append(num_valid) 18 | check_prev_n = int(max(5000 // self.cost_per_step, 5)) 19 | 20 | # At least runs `check_prev_n` steps 21 | if len(self.num_valid_list) <= check_prev_n: 22 | return False 23 | 24 | # At least collect n_data/2 annotations 25 | if n_annotation < (self.n_data / 2.): 26 | return False 27 | 28 | max_so_far = max(self.num_valid_list) 29 | 30 | if len(self.num_valid_list) > 3: 31 | num_valid_list = np.array(self.num_valid_list) 32 | if not np.any(max_so_far == num_valid_list[-check_prev_n:]): 33 | # Decrease at least for `check_prev_n` steps 34 | return True 35 | else: 36 | return False 37 | 38 | 39 | def __init(config, annotation_holder, optimizer, aggregator, learner, imagenet_data, batch_logger, start_step): 40 | 41 | logger.info(f'{"*"*20} Step: {start_step} {"*"*20}') 42 | 43 | # >>> Jointly optimize the worker skills and true labels 44 | info = optimizer.step(annotation_holder, aggregator, learner, imagenet_data) 45 | 46 | # >>> Log information 47 | if start_step == 0: 48 | batch_logger.step(annotation_holder.n_data, 49 | annotation_holder.annotation, 50 | sampled_image_path=[], 51 | step=start_step, 52 | **info) 53 | 54 | return info 55 | 56 | 57 | def run_online_loop(config, 58 | imagenet_data, 59 | annotation_holder, 60 | sampler, 61 | aggregator, 62 | learner, 63 | optimizer, 64 | batch_logger, 65 | save_state_fn, 66 | start_step): 67 | 68 | 69 | step = start_step 70 | early_stopper = EarlyStopper(config.online.hit_size, 71 | config.online.n_hits_per_step, 72 | config.risk_thres, 73 | annotation_holder.n_data) 74 | info = __init(config, 75 | annotation_holder, 76 | optimizer, 77 | aggregator, 78 | learner, 79 | imagenet_data, 80 | batch_logger, 81 | start_step) 82 | 83 | for step in range(start_step+1, config.online.budget // config.online.hit_size+1): 84 | 85 | num_valid = sum(info['y_posterior_risk'] < config.risk_thres) 86 | if early_stopper.stop(num_valid, annotation_holder.n_annotation): 87 | save_state_fn(config, annotation_holder.workers, annotation_holder, optimizer, learner, step=step, p='early_stop_state.json') 88 | logger.info('Early stop since the number of valid examples decreases for certain steps') 89 | if config.early_stop: 90 | break 91 | 92 | if sampler.stop(info['y_posterior_risk']): 93 | logger.info('Examples either satisfy the risk criterion or reach maximum number of annotation') 94 | break 95 | 96 | logger.info(f'{"*"*20} Step: {step} {"*"*20}') 97 | 98 | logger.debug('Construct HITs') 99 | data_idx, worker_id = sampler.sample(config.online.hit_size, 100 | config.online.n_hits_per_step, 101 | info['y_posterior_risk'], 102 | y_posterior=info['y_posterior'], 103 | feature_dict=imagenet_data.get_features_dict()) 104 | 105 | 106 | data_path = [imagenet_data.image_path[i] for i in data_idx] 107 | 108 | logger.debug('Workers Annotating') 109 | annotation_holder.collect_annotation(imagenet_data.a_data, data_path, worker_id, info['y_posterior']) 110 | 111 | # >>> Jointly optimize the worker skills and true labels 112 | info = optimizer.step(annotation_holder, aggregator, learner, imagenet_data) 113 | save_state_fn(config, annotation_holder.workers, annotation_holder, optimizer, learner, step=step) 114 | 115 | # >>> Log information 116 | batch_logger.step(annotation_holder.n_data, 117 | annotation_holder.annotation, 118 | sampled_image_path=data_path, 119 | step=step, 120 | **info) 121 | -------------------------------------------------------------------------------- /online_label/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | class Optimizer(object): 2 | def __init__(self, config): 3 | self.config = config 4 | self.converged = None 5 | self.max_step_reached = None 6 | 7 | def reset(self): 8 | self.converged = False 9 | self.max_step_reached = False 10 | 11 | def step(self, annotation_holder, aggregator, learner, imagenet_data): 12 | raise NotImplementedError 13 | 14 | 15 | def get_optimizer_class(config): 16 | from .em import EMOptimizer 17 | return EMOptimizer 18 | -------------------------------------------------------------------------------- /online_label/optimizer/em.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from collections import OrderedDict, defaultdict 4 | from . import Optimizer 5 | from .utils import DirichletDist 6 | 7 | import logging 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class EMOptimizer(Optimizer): 12 | 13 | def __init__(self, config, imagenet_data, workers): 14 | Optimizer.__init__(self, config) 15 | self.prev_learner_prob = None 16 | self.imagenet_data = imagenet_data 17 | self.__init_workers_estimation(workers) 18 | 19 | def __init_workers_estimation(self, workers): 20 | self.workers_estimated_m = {p: DirichletDist(self.config, 21 | global_cm=w.global_cm, 22 | worker_cm=w.m) for p, w in workers.items()} 23 | 24 | def reset_learner(self): 25 | logger.info('Reset learners') 26 | self.prev_learner_prob = None 27 | 28 | def save_state(self): 29 | state = {} 30 | if self.prev_learner_prob is not None : 31 | state.update({'prev_learner_prob': self.prev_learner_prob.tolist()}) 32 | for k, v in self.workers_estimated_m.items(): 33 | state.update({f'#workers-{k}': v.save_state()}) 34 | return json.dumps(state) 35 | 36 | def load_state(self, state, workers): 37 | if state is not None: 38 | state = json.loads(state) 39 | self.prev_learner_prob = np.array(state['prev_learner_prob']) 40 | self.__init_workers_estimation(workers) 41 | 42 | def step(self, annotation_holder, aggregator, learner, imagenet_data): 43 | 44 | self.reset() 45 | 46 | if self.config.worker.known or annotation_holder.n_annotation == 0: 47 | y_posterior = self.infer_y_posterior(aggregator, annotation_holder) 48 | self.converged = True 49 | else: 50 | self._init_w_as_prior() 51 | self._update_worker_likelihood(annotation_holder) 52 | 53 | prev_y_posterior = None 54 | prev_z_likelihood = None 55 | 56 | # Expectation Maximization 57 | for i in range(self.config.optimizer.max_em_steps): 58 | 59 | # E step 60 | y_posterior = self.infer_y_posterior(aggregator, 61 | annotation_holder, 62 | self.prev_learner_prob) 63 | 64 | # M step 65 | self.m_step(annotation_holder, y_posterior) 66 | 67 | z_likelihood = self.calculate_z_likelihood(annotation_holder, y_posterior) 68 | logger.debug(f'[{i}] Z Likelihood = {z_likelihood.mean()}') 69 | if prev_z_likelihood is not None: 70 | if z_likelihood.mean() - prev_z_likelihood.mean() < 0: 71 | logger.warning('[Warning] Likelihood is decreasing after E step') 72 | 73 | 74 | if self.config.optimizer.criterion == 'hard': 75 | if i > 0 and prev_y_posterior is not None and np.all(prev_y_posterior.argmax(1) == y_posterior.argmax(1)): 76 | self.converged = True 77 | logger.info(f'[Hard constraint] Stop at step: {i}') 78 | break 79 | elif self.config.optimizer.criterion == 'soft': 80 | diff_z_likelihood = abs(prev_z_likelihood.mean() - z_likelihood.mean()) 81 | if i > 0 and diff_z_likelihood < self.config.optimizer.likelihood_epsilon: 82 | self.converged = True 83 | logger.info(f'[Soft constraint] Stop at step: {i}') 84 | break 85 | else: 86 | raise ValueError 87 | 88 | prev_z_likelihood = z_likelihood 89 | prev_y_posterior = y_posterior 90 | 91 | if not self.converged: 92 | logger.info(f'Use the EM steps at step: {self.config.optimizer.max_em_steps - 1}') 93 | 94 | 95 | # Learning 96 | y_posterior = self.infer_y_posterior(aggregator, 97 | annotation_holder, 98 | self.prev_learner_prob) 99 | learner_prob = self.fit_machine_learner(annotation_holder, imagenet_data, learner, y_posterior) 100 | 101 | # Use new learner to infer y posterior again 102 | y_posterior = self.infer_y_posterior(aggregator, annotation_holder, learner_prob) 103 | 104 | 105 | self.prev_learner_prob = learner_prob 106 | return { 107 | 'y_posterior_risk': aggregator.compute_risk(y_posterior), 108 | 'learner_prob': learner_prob, 109 | 'y_posterior': y_posterior 110 | } 111 | 112 | def infer_y_posterior(self, aggregator, annotation_holder, learner_prob=None): 113 | y_posterior = aggregator.aggregate(annotation_holder, 114 | learner_prob=learner_prob) 115 | return y_posterior 116 | 117 | def fit_machine_learner(self, annotation_holder, imagenet_data, learner, y_posterior): 118 | 119 | logger.debug('Fitting Learner') 120 | feature_dict = imagenet_data.get_features_dict() 121 | n_annotation = np.array([len(v) for v in annotation_holder.annotation.values()]) 122 | 123 | learner_prob = learner.fit_and_predict(feature_dict, 124 | np.array(imagenet_data.get_y(imagenet_data.p_image_path)), 125 | y_posterior, n_annotation, 126 | np.array(imagenet_data.get_y(imagenet_data.image_path))) 127 | 128 | 129 | logger.debug('Finish fitting') 130 | return learner_prob 131 | 132 | def calculate_z_likelihood(self, annotation_holder, belief): 133 | y_pred = belief.argmax(1) 134 | likelihood = [] 135 | 136 | image_path = list(annotation_holder.annotation.keys()) 137 | for k, y in zip(image_path, y_pred): 138 | for a_i in annotation_holder.annotation[k]: 139 | _, z, j, _ = a_i 140 | likelihood.append(self.workers_estimated_m[j].p_z_given_y(z)[y]) 141 | return np.array(likelihood) 142 | 143 | def m_step(self, annotation_holder, belief): 144 | 145 | image_path = list(annotation_holder.annotation.keys()) 146 | for id, _ in self.workers_estimated_m.items(): 147 | # Collect annotation from a worker 148 | w_eval = [] 149 | 150 | for i, k in enumerate(image_path): 151 | for a_k in annotation_holder.annotation[k]: 152 | _, z, j, _ = a_k 153 | if str(j) == str(id): 154 | w_eval.append((belief[i], z)) 155 | 156 | # Update the worker posterior 157 | before = self.workers_estimated_m[id].posterior_alpha.sum() 158 | self.workers_estimated_m[id].update(w_eval) 159 | after = self.workers_estimated_m[id].posterior_alpha.sum() 160 | 161 | self._update_worker_likelihood(annotation_holder) 162 | 163 | def _update_worker_likelihood(self, annotation_holder): 164 | 165 | # Update `p_z_given_y` in the annotation holder 166 | annotation = annotation_holder.annotation 167 | n_data = annotation_holder.n_data 168 | new_annotation = OrderedDict() 169 | 170 | for k, v in annotation_holder.annotation.items(): 171 | anno_k = [] 172 | for a_k in v: 173 | y, z, j, _ = a_k 174 | p_z_given_y = self.workers_estimated_m[j].p_z_given_y(z) 175 | anno_k.append((y, z, j, np.array(p_z_given_y))) 176 | 177 | new_annotation.update({k: anno_k}) 178 | 179 | annotation_holder.annotation = new_annotation 180 | return annotation_holder 181 | 182 | def _init_w_as_prior(self): 183 | for estimated_m in self.workers_estimated_m.values(): 184 | estimated_m.init_posterior() 185 | -------------------------------------------------------------------------------- /online_label/optimizer/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | class DirichletDist(): 9 | def __init__(self, config, **kwargs): 10 | 11 | n_classes = config.n_classes 12 | prior_alpha = np.zeros((n_classes, n_classes)) 13 | if config.worker.prior.type == 'predefined_homegeneous' or len(kwargs) == 0: 14 | logger.debug('Use Predefined worker prior') 15 | neg_density = config.worker.prior.predefined_params.neg_count / n_classes 16 | prior_alpha += config.worker.prior.predefined_params.neg_count / n_classes 17 | prior_alpha += np.eye(n_classes) * (config.worker.prior.predefined_params.pos_count - neg_density) 18 | else: 19 | logger.debug('Use worker prior: {}'.format(config.worker.prior.type)) 20 | if config.worker.prior.type == 'homegeneous_workers': 21 | # Diagonal terms are initialized by the correctness of all the workers 22 | # Off-diagonal terms are initialized by the incorrecness of all the workers 23 | global_cm = kwargs['global_cm'] 24 | off_diagonal_density = (global_cm.sum(1).mean() - global_cm.diagonal().mean()) / (n_classes-1) 25 | prior_alpha += off_diagonal_density 26 | np.fill_diagonal(prior_alpha, np.ones(n_classes) * global_cm.diagonal().mean()) 27 | elif config.worker.prior.type == 'homegeneous_workers_diagonal_perfect': 28 | # Diagonal terms are initialized by the per-class correctness of all the workers 29 | # Off-diagonal terms are initialized by the per-class incorrecness of all the workers 30 | global_cm = kwargs['global_cm'] 31 | off_diagonal_density = global_cm.sum(1) - global_cm.diagonal() 32 | off_diagonal_density = off_diagonal_density.reshape(-1, 1) / (n_classes - 1) 33 | prior_alpha += off_diagonal_density 34 | np.fill_diagonal(prior_alpha, global_cm.diagonal()) 35 | elif config.worker.prior.type == 'homegeneous_workers_perfect': 36 | # Diagonal terms are initialized by the per-class correctness of all the workers 37 | # Off-diagonal terms are initialized by the per-class-per-prediction incorrecness of all the workers 38 | global_cm = kwargs['global_cm'] 39 | prior_alpha = global_cm 40 | elif config.worker.prior.type == 'heterogeneous_workers': 41 | # Diagonal terms are initialized by the correctness of the workers 42 | # Off-diagonal terms are initialized by the incorrecness of the workers 43 | worker_cm = kwargs['worker_cm'] 44 | off_diagonal_density = (worker_cm.sum(1).mean() - worker_cm.diagonal().mean()) / (n_classes-1) 45 | prior_alpha += off_diagonal_density 46 | np.fill_diagonal(prior_alpha, np.ones(n_classes) * worker_cm.diagonal().mean()) 47 | elif config.worker.prior.type == 'heterogeneous_workers_diagonal_perfect': 48 | # Diagonal terms are initialized by the per-class correctness of the workers 49 | # Off-diagonal terms are initialized by the per-class incorrecness of the workers 50 | worker_cm = kwargs['worker_cm'] 51 | off_diagonal_density = worker_cm.sum(1) - worker_cm.diagonal() 52 | off_diagonal_density = off_diagonal_density.reshape(-1, 1) / (n_classes - 1) 53 | prior_alpha += off_diagonal_density 54 | np.fill_diagonal(prior_alpha, worker_cm.diagonal()) 55 | elif config.worker.prior.type == 'heterogeneous_workers_perfect': 56 | # Diagonal terms are initialized by the per-class correctness of all the workers 57 | # Off-diagonal terms are initialized by the per-class-per-prediction incorrecness of all the workers 58 | worker_cm = kwargs['worker_cm'] 59 | prior_alpha = worker_cm 60 | else: 61 | raise ValueError 62 | 63 | self.prior_alpha = prior_alpha * config.worker.prior.strength 64 | self.n_classes = n_classes 65 | self.init_posterior() 66 | 67 | def init_posterior(self): 68 | self.posterior_alpha = np.zeros((self.n_classes, self.n_classes)) 69 | 70 | def save_state(self): 71 | return json.dumps({'posterior_alpha': self.posterior_alpha.tolist(), 72 | 'prior_alpha': self.prior_alpha.tolist()}) 73 | 74 | def load_state(self, state): 75 | state = json.loads(state) 76 | self.posterior_alpha = state['posterior_alpha'] 77 | 78 | def batch_confidence(self, belief): 79 | # belief: (n_samples, n_classes) 80 | if len(belief.shape) == 1: 81 | belief = belief.reshape(1, -1) 82 | 83 | cm = self.prior_alpha + self.posterior_alpha 84 | cm = cm / cm.sum(1, keepdims=True) 85 | c = (belief * cm.diagonal().reshape(1, -1)).sum(1).mean() 86 | return c 87 | 88 | def p_z_given_y(self, z): 89 | cm = self.prior_alpha + self.posterior_alpha 90 | cm = cm / cm.sum(1, keepdims=True) 91 | return cm[:, z] 92 | 93 | def update(self, eval_list): 94 | posterior_alpha = np.zeros((self.n_classes, self.n_classes)) 95 | for i in eval_list: 96 | prob, z = i 97 | pred = np.argmax(prob + np.random.rand(self.n_classes)*1e-8) # Add noise to break tie 98 | posterior_alpha[pred, z] += 1. 99 | 100 | self.posterior_alpha = posterior_alpha 101 | -------------------------------------------------------------------------------- /online_label/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import Sampler 2 | from .random_sampler import RandomSampler 3 | from .task_assignment_sampler import GreedyTaskAssignmentSampler 4 | 5 | 6 | def get_sampler_class(config): 7 | if config.sampler.algo.lower() == 'random': 8 | return RandomSampler 9 | elif config.sampler.algo.lower() == 'greedy_task_assignment': 10 | return GreedyTaskAssignmentSampler 11 | else: 12 | raise ValueError 13 | -------------------------------------------------------------------------------- /online_label/sampler/random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .sampler import Sampler 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class RandomSampler(Sampler): 10 | 11 | def __init__(self, config, annotation_holder, workers, **kwargs): 12 | Sampler.__init__(self, config, annotation_holder, workers) 13 | 14 | def stop(self, risk, **kwargs): 15 | confident = risk < self.risk_thres 16 | exceed_max = self.n_annotation >= self.max_annotation_per_example 17 | logger.debug(f'Number of unconfident examples: {sum(confident)}') 18 | logger.debug(f'Number of examples exceeds budget: {sum(exceed_max)}') 19 | return np.all(np.logical_or(confident, exceed_max)) 20 | 21 | def sample(self, hit_size, n_hit, risk, **kwargs): 22 | 23 | assert self.n_workers >= n_hit, 'You are launching more HITs than the size of worker pool' 24 | 25 | sampled_workers = self.npr.choice([w.id for w in self.workers.values()], n_hit, replace=n_hit>self.n_workers) 26 | data_idx = [] 27 | worker_id = [] 28 | for w in sampled_workers: 29 | # Check every time to avoid over sampling some examples 30 | unconfident_mask = self.lower_risk_thres_if_necessary(risk, hit_size) 31 | unconfident = np.where(unconfident_mask)[0] 32 | 33 | _data_idx = self.npr.choice(unconfident, hit_size, replace=False) 34 | _worker_id = np.repeat([w], hit_size) 35 | self.n_annotation[_data_idx] += 1 36 | data_idx.extend(_data_idx) 37 | worker_id.extend(_worker_id) 38 | 39 | return data_idx, worker_id 40 | -------------------------------------------------------------------------------- /online_label/sampler/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from collections import defaultdict 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class Sampler(object): 10 | def __init__(self, config, annotation_holder, workers): 11 | self.npr = np.random.RandomState(config.seed) 12 | self.config = config 13 | self.annotation_holder = annotation_holder 14 | self.workers = workers 15 | self.n_workers = len(workers) 16 | self.risk_thres = config.risk_thres 17 | self.max_annotation_per_example = config.sampler.max_annotation_per_example 18 | self.max_annotation_per_worker = config.sampler.max_annotation_per_worker 19 | self.__init_annotation_stats() 20 | 21 | def __init_annotation_stats(self): 22 | self.n_annotation = np.array([len(v) for _, v in self.annotation_holder.annotation.items()]) 23 | 24 | worker_n_annotation_counter = defaultdict(lambda: 0) 25 | for w in self.workers.values(): 26 | worker_n_annotation_counter[w.id] = 0 27 | for _, vs in self.annotation_holder.annotation.items(): 28 | for v in vs: 29 | worker_n_annotation_counter[v[2]] += 1 30 | self.worker_n_annotation_counter = worker_n_annotation_counter 31 | 32 | def load_state(self, annotation_holder, workers): 33 | self.annotation_holder = annotation_holder 34 | self.workers = workers 35 | self.__init_annotation_stats() 36 | 37 | def lower_risk_thres_if_necessary(self, risk, hit_size): 38 | unconfident_mask = risk > self.risk_thres 39 | exceed_max = self.n_annotation >= self.max_annotation_per_example 40 | valid_sample = np.logical_and(~exceed_max, unconfident_mask) 41 | 42 | # Avoid sample the same image in the same HIT 43 | new_risk_thres = self.risk_thres 44 | while valid_sample.sum() < hit_size: 45 | new_risk_thres -= 0.01 46 | unconfident_mask = risk > new_risk_thres 47 | valid_sample = np.logical_and(~exceed_max, unconfident_mask) 48 | 49 | if new_risk_thres != self.risk_thres: 50 | logger.debug(f'Lower the risk threshold to {new_risk_thres}') 51 | 52 | return valid_sample 53 | 54 | def stop(self, risk, **kwargs): 55 | raise NotImplementedError 56 | 57 | def sample(self, hit_size, n_hit, risk, **kwargs): 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /online_label/sampler/task_assignment_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | 4 | from .sampler import Sampler 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class GreedyTaskAssignmentSampler(Sampler): 11 | def __init__(self, config, annotation_holder, workers, optimizer): 12 | Sampler.__init__(self, config, annotation_holder, workers) 13 | self.optimizer = optimizer 14 | 15 | def stop(self, risk, **kwargs): 16 | confident = risk < self.risk_thres 17 | exceed_max = self.n_annotation >= self.max_annotation_per_example 18 | 19 | valid_worker_id = self._get_valid_workers() 20 | if len(valid_worker_id) == 0: 21 | logger.info(f'All worker annotate reach maximum number of annotation {self.max_annotation_per_worker}') 22 | 23 | return np.all(np.logical_or(confident, exceed_max)) or len(valid_worker_id) == 0 24 | 25 | def _get_valid_workers(self): 26 | valid_worker_id = [k for k, v in self.worker_n_annotation_counter.items() if v + self.config.online.hit_size <= self.max_annotation_per_worker] 27 | return valid_worker_id 28 | 29 | def sample(self, hit_size, n_hit, risk, y_posterior, **kwargs): 30 | 31 | assert self.n_workers >= n_hit, 'You are launching more HITs than the size of worker pool' 32 | 33 | valid_worker_id = self._get_valid_workers() 34 | if len(valid_worker_id) < self.n_workers: 35 | logger.debug(f'{self.n_workers - len(valid_worker_id)} already have maximum annotation {self.max_annotation_per_worker}') 36 | 37 | if len(valid_worker_id) < n_hit: 38 | logger.info(f'Reduce the `n_hit_per_step` to {len(valid_worker_id)} since we only have {len(valid_worker_id)} available now') 39 | n_hit = len(valid_worker_id) 40 | 41 | 42 | data_idx = [] 43 | worker_id = [] 44 | 45 | for i in range(n_hit): 46 | unconfident_mask = self.lower_risk_thres_if_necessary(risk, hit_size) 47 | unconfident = np.where(unconfident_mask)[0] 48 | _data_idx = self.npr.choice(unconfident, hit_size, replace=False) 49 | _y_posterior = y_posterior[_data_idx] 50 | 51 | #valid_worker_confidence = defaultdict() 52 | _valid_worker_ids = [] 53 | _valid_worker_confidence = [] 54 | for k in valid_worker_id: 55 | confidence = self.optimizer.workers_estimated_m[k].batch_confidence(_y_posterior) 56 | _valid_worker_ids.append(k) 57 | _valid_worker_confidence.append(confidence) 58 | 59 | 60 | _valid_worker_confidence = np.array(_valid_worker_confidence) + \ 61 | self.npr.rand(len(_valid_worker_confidence)) * 1e-8 62 | idx = np.argmax(_valid_worker_confidence) 63 | _worker_id = _valid_worker_ids[idx] 64 | valid_worker_id.remove(_worker_id) 65 | 66 | data_idx.extend(_data_idx) 67 | worker_id.extend(np.repeat([_worker_id], hit_size)) 68 | 69 | self.n_annotation[_data_idx] += 1 70 | self.worker_n_annotation_counter[_worker_id] += hit_size 71 | 72 | 73 | return data_idx, worker_id 74 | -------------------------------------------------------------------------------- /online_label/worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import uuid 4 | import numpy as np 5 | 6 | from data import REPO_DIR, imagenet100 7 | 8 | import logging 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class Worker(object): 13 | 14 | def __init__(self, config, known, seed, **kwargs): 15 | 16 | self.id = str(uuid.uuid4()) 17 | self.config = config 18 | self.seed = seed 19 | self.npr = np.random.RandomState(seed) 20 | self.n_classes = config.n_classes 21 | self.known = known 22 | 23 | self.m = self.sample_confusion_matrix() # (actual classes, predict classes) 24 | 25 | def save_state(self): 26 | return json.dumps(dict(id=self.id, m=self.m.tolist())) 27 | 28 | def load_state(self, state): 29 | state = json.loads(state) 30 | self.id = state['id'] 31 | self.m = np.array(state['m']) 32 | 33 | def annotate(self, true_y, qmask=None): 34 | m = self.m 35 | valid_options = range(self.n_classes) 36 | 37 | prob = m[true_y] 38 | prob = np.clip(prob, 0., 1.) 39 | 40 | z = self.npr.choice(range(self.n_classes), p=prob) 41 | if self.known: 42 | p_z_given_y = m[:, z] 43 | else: 44 | p_z_given_y = None 45 | 46 | return z, p_z_given_y 47 | 48 | def sample_confusion_matrix(self): 49 | raise NotImplementedError 50 | 51 | 52 | class UniformWorker(Worker): 53 | '''w/o class correlation 54 | ''' 55 | def sample_confusion_matrix(self): 56 | config = self.config 57 | 58 | m = np.zeros((self.n_classes, self.n_classes)) 59 | m += np.eye(self.n_classes) * self.npr.normal(config.worker.reliability.mean, 60 | config.worker.reliability.std, 61 | size=self.n_classes) 62 | m = np.clip(m, 0, 1) 63 | for i in range(self.n_classes): 64 | noise = self.npr.rand(self.n_classes) 65 | noise /= noise.sum() 66 | noise += noise[i] / (self.n_classes - 1) 67 | noise *= (1 - m[i, i]) 68 | m[i] += -1*(np.eye(self.n_classes)-1)[i] * noise 69 | 70 | reliability = np.diag(m).mean() 71 | logger.debug(f'Reliability: {reliability}') 72 | return m 73 | 74 | 75 | class PerfectWorker(Worker): 76 | def sample_confusion_matrix(self): 77 | m = np.eye(self.n_classes) 78 | m = np.clip(m, 0, 1) 79 | return m 80 | 81 | 82 | class RealWorker(Worker): 83 | wnids = None 84 | worker_cm_info = json.load(open(os.path.join(REPO_DIR, 'data/group_workers.json'), 'r')) 85 | groups_path = os.path.join(REPO_DIR, 'data/groups.txt') 86 | global_cm = [] 87 | for _k, _v in worker_cm_info['group_workers'].items(): 88 | _v = np.array(_v) 89 | global_cm.append(_v.sum(0)) 90 | global_cm = sum(global_cm) 91 | 92 | def __init__(self, config, known, seed, **kwargs): 93 | 94 | keep_indices = np.array([imagenet100.index(i.lower()) for i in self.wnids]) 95 | self.keep_indices = keep_indices 96 | self.global_cm = self.global_cm[keep_indices, :][:, keep_indices] 97 | Worker.__init__(self, config, known, seed, **kwargs) 98 | 99 | def sample_confusion_matrix(self): 100 | m = self._sample_confusion_matrix() 101 | 102 | with open(self.groups_path) as fp: 103 | groups = fp.read() 104 | groups = groups.split('\n\n') 105 | groups.pop(-1) 106 | groups = [np.array(i.split('\n')) for i in groups] 107 | 108 | 109 | def __which_group(i): 110 | for g_idx, g in enumerate(groups): 111 | if i in g: 112 | return g_idx 113 | 114 | 115 | # Add uniform noise in off-diagonal terms 116 | noise_level = 0.03 # According to the amt stats 117 | for i, i_wnid in enumerate(self.wnids): 118 | i_group = __which_group(i_wnid) 119 | same_group_mask = np.zeros(self.config.n_classes).astype(np.bool) 120 | same_group_mask[i] = True 121 | for j, j_wnid in enumerate(self.wnids): 122 | if i != j and i_group == __which_group(j_wnid): 123 | same_group_mask[j] = True 124 | 125 | 126 | if (~same_group_mask).sum() > 0: 127 | density_to_spread = m[i, same_group_mask].sum() 128 | m[i, same_group_mask] = m[i, same_group_mask] * (1 - noise_level) 129 | m[i, ~same_group_mask] += density_to_spread * (noise_level) / max(sum(~same_group_mask), 1e-8) 130 | 131 | return m 132 | 133 | 134 | class StructuredNoiseWorker(RealWorker): 135 | def _sample_confusion_matrix(self): 136 | 137 | cm = [] 138 | for _, v in self.worker_cm_info['group_workers'].items(): 139 | v = np.array(v) 140 | global_v = v.sum(0) 141 | idx = self.npr.choice(range(len(v)), 1) 142 | cm.append(global_v + v[idx][0] * 10) 143 | 144 | cm = sum(cm) 145 | if self.config.n_data_distraction_per_class > 0: 146 | m = np.zeros((self.config.n_classes, self.config.n_classes)) 147 | m[:self.config.n_classes-1, :self.config.n_classes-1] = cm[self.keep_indices, :][:, self.keep_indices] 148 | 149 | drop_indices_mask = np.ones(cm.shape[0]).astype(np.bool) 150 | drop_indices_mask[self.keep_indices] = False 151 | 152 | m[-1, :self.config.n_classes-1] = cm[drop_indices_mask, :][:, self.keep_indices].sum(0) # Last row 153 | m[:self.config.n_classes-1, -1] = cm[self.keep_indices, :][:, drop_indices_mask].sum(1) # Last column 154 | m[-1, -1] = cm[drop_indices_mask, :][:, drop_indices_mask].sum() 155 | else: 156 | m = cm[self.keep_indices, :][:, self.keep_indices] 157 | 158 | assert len(np.where(m.sum(1)==0)[0]) == 0 159 | 160 | m = m / (m.sum(1, keepdims=True) + 1e-8) 161 | 162 | 163 | reliability = np.diag(m).mean() 164 | logger.debug(f'Reliability: {reliability:.2f}') 165 | return m 166 | 167 | 168 | class UniformNoiseWorker(RealWorker): 169 | def _sample_confusion_matrix(self): 170 | 171 | cm = [] 172 | for _, v in self.worker_cm_info['group_workers'].items(): 173 | v = np.array(v) 174 | global_v = v.sum(0) 175 | idx = self.npr.choice(range(len(v)), 10) 176 | cm.append(global_v + v[idx][0] * 10) 177 | cm = sum(cm) 178 | 179 | imagenet100_name = self.worker_cm_info['imagenet100_name'] 180 | 181 | m = cm[self.keep_indices, :][:, self.keep_indices] 182 | 183 | if self.config.n_data_distraction_per_class > 0: 184 | m = np.zeros((self.config.n_classes, self.config.n_classes)) 185 | m[:self.config.n_classes-1, :self.config.n_classes-1] = cm[self.keep_indices, :][:, self.keep_indices] 186 | idx_to_drop_mask = np.ones(cm.shape[0]).astype(np.bool) 187 | idx_to_drop_mask[self.keep_indices] = False 188 | 189 | 190 | m[-1, :len(self.keep_indices)] = cm[idx_to_drop_mask, :][:, self.keep_indices].sum(0) 191 | m[:len(self.keep_indices), -1] = cm[self.keep_indices, :][:, idx_to_drop_mask].sum(1) 192 | m[-1, -1] = cm[idx_to_drop_mask, :][:, idx_to_drop_mask].sum() 193 | else: 194 | m = cm[self.keep_indices, :][:, self.keep_indices] 195 | assert len(np.where(m.sum(1)==0)[0]) == 0 196 | 197 | m = m / (m.sum(1, keepdims=True) + 1e-8) 198 | 199 | n_classes = len(self.wnids) 200 | _m = np.zeros((n_classes, n_classes)) 201 | _m += ((1 - m.diagonal()) / (n_classes - 1)).reshape(-1, 1) 202 | np.fill_diagonal(_m, m.diagonal()) 203 | m = _m 204 | 205 | reliability = np.diag(m).mean() 206 | logger.debug(f'Reliability: {reliability:.2f}') 207 | return m 208 | 209 | 210 | def get_worker_class(config, wnids): 211 | 212 | if config.worker.type == 'perfect': 213 | worker_class = PerfectWorker 214 | elif config.worker.type == 'uniform': 215 | worker_class = UniformWorker 216 | elif config.worker.type == 'uniform_noise': 217 | worker_class = UniformNoiseWorker 218 | worker_class.wnids = wnids 219 | elif config.worker.type == 'structured_noise': 220 | worker_class = StructuredNoiseWorker 221 | worker_class.wnids = wnids 222 | else: 223 | raise ValueError 224 | 225 | return worker_class 226 | --------------------------------------------------------------------------------