├── plench ├── __init__.py ├── core │ ├── __init__.py │ ├── networks.py │ ├── model_selection.py │ ├── hparams_registry.py │ └── algorithms.py ├── data │ ├── __init__.py │ ├── cutout.py │ ├── randaugment.py │ └── datasets.py ├── lib │ ├── __init__.py │ ├── reporting.py │ ├── command_launchers.py │ ├── densenet.py │ ├── resnet.py │ ├── query.py │ └── misc.py ├── collect_results.py ├── sweep.py └── train.py └── readme.md /plench/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /plench/core/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /plench/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /plench/lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /plench/lib/reporting.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import json 4 | import os 5 | 6 | import tqdm 7 | 8 | from .query import Q 9 | 10 | def load_records(path): 11 | records = [] 12 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))), 13 | ncols=80, 14 | leave=False): 15 | results_path = os.path.join(path, subdir, "results.jsonl") 16 | try: 17 | with open(results_path, "r") as f: 18 | for line in f: 19 | records.append(json.loads(line[:-1])) 20 | except IOError: 21 | pass 22 | 23 | return Q(records) 24 | 25 | def get_grouped_records(records): 26 | """Group records by (trial_seed, dataset, algorithm). """ 27 | result = collections.defaultdict(lambda: []) 28 | for r in records: 29 | group = (r["args"]["trial_seed"], 30 | r["args"]["dataset"], 31 | r["args"]["algorithm"]) 32 | result[group].append(r) 33 | return Q([{"trial_seed": t, "dataset": d, "algorithm": a, 34 | "records": Q(r)} for (t,d,a),r in result.items()]) 35 | -------------------------------------------------------------------------------- /plench/data/cutout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class Cutout(object): 5 | """Randomly mask out one or more patches from an image. 6 | Args: 7 | n_holes (int): Number of patches to cut out of each image. 8 | length (int): The length (in pixels) of each square patch. 9 | """ 10 | def __init__(self, n_holes, length): 11 | self.n_holes = n_holes 12 | self.length = length 13 | 14 | def __call__(self, img): 15 | """ 16 | Args: 17 | img (Tensor): Tensor image of size (C, H, W). 18 | Returns: 19 | Tensor: Image with n_holes of dimension length x length cut out of it. 20 | """ 21 | h = img.size(1) 22 | w = img.size(2) 23 | 24 | mask = np.ones((h, w), np.float32) 25 | 26 | for n in range(self.n_holes): 27 | y = np.random.randint(h) 28 | x = np.random.randint(w) 29 | 30 | y1 = np.clip(y - self.length // 2, 0, h) 31 | y2 = np.clip(y + self.length // 2, 0, h) 32 | x1 = np.clip(x - self.length // 2, 0, w) 33 | x2 = np.clip(x + self.length // 2, 0, w) 34 | 35 | mask[y1: y2, x1: x2] = 0. 36 | 37 | mask = torch.from_numpy(mask) 38 | mask = mask.expand_as(img) 39 | img = img * mask 40 | 41 | return img 42 | -------------------------------------------------------------------------------- /plench/core/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from plench.lib import resnet 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, input_dim, hidden_dim): 8 | super(MLP, self).__init__() 9 | self.fc1 = nn.Linear(input_dim, hidden_dim) 10 | self.relu1 = nn.ReLU() 11 | self.n_outputs = hidden_dim 12 | #self.fc2 = nn.Linear(hidden_dim, output_dim) 13 | 14 | def forward(self, x): 15 | out = x.view(-1, self.num_flat_features(x)) 16 | #print(out.dtype) 17 | out = self.fc1(out) 18 | out = self.relu1(out) 19 | #out = self.fc2(out) 20 | return out 21 | 22 | def num_flat_features(self, x): 23 | size = x.size()[1:] 24 | num_features = 1 25 | for s in size: 26 | num_features *= s 27 | return num_features 28 | 29 | def calc_dim(input_shape): 30 | input_shape = input_shape[1:] 31 | num_features = 1 32 | for s in input_shape: 33 | num_features *= s 34 | return num_features 35 | 36 | def Featurizer(input_shape, hparams): 37 | """Auto-select an appropriate featurizer for the given input shape.""" 38 | dim = calc_dim(input_shape) 39 | if hparams["model"] == "MLP": 40 | return MLP(dim, 500) 41 | elif hparams["model"] == "LeNet": 42 | return LeNet() 43 | elif hparams["model"] == "ResNet": 44 | return resnet.resnet(depth=32) 45 | else: 46 | raise NotImplementedError 47 | 48 | 49 | def Classifier(in_features, out_features): 50 | return torch.nn.Linear(in_features, out_features) 51 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # PLENCH 2 | This repository is the official implementation of the paper "Realistic Evaluation of Deep Partial-Label Learning Algorithms" and technical details of this approach can be found in the paper. 3 | 4 | ## Requirements 5 | - Python 3.6.13 6 | - numpy 1.19.2 7 | - Pytorch 1.7.1 8 | - torchvision 0.8.2 9 | - pandas 1.1.5 10 | - scipy 1.5.4 11 | - tqdm 12 | - Pillow 13 | 14 | ## Dataset 15 | Most tabular datasets can be found at https://palm.seu.edu.cn/zhangml/Resources.htm#data. PLCIFAR10 can be found at https://github.com/wwangwitsel/PLCIFAR10. 16 | 17 | ### Run an Algorithm 18 | ``` 19 | python -m plench.train --data_dir= --algorithm PRODEN --dataset PLCIFAR10_Aggregate --output_dir= --steps 60000 --skip_model_save --checkpoint_freq 1000 20 | ``` 21 | 22 | ## Run Algorithms in Batches 23 | ``` 24 | python -m plench.sweep launch --data_dir= --command_launcher multi_gpu --n_hparams_from 0 --n_hparams 20 --n_trials_from 0 --n_trials 3 --datasets PLCIFAR10_Aggregate PLCIFAR10_Vaguest --algorithms PRODEN CAVL --output_dir= --skip_confirmation --skip_model_save --steps 60000 25 | ``` 26 | 27 | ## Collect Experimental Results 28 | ``` 29 | python -m plench.collect_results --input_dir= 30 | ``` 31 | 32 | ## Acknowledgement 33 | The code was based on the codebase of the following paper: 34 | 35 | - Ishaan Gulrajani and David Lopez-Paz. In search of lost domain generalization. In Proceedings of the 9th International Conference on Learning Representations, 2021. 36 | 37 | 38 | ## Citation 39 | ``` 40 | @inproceedings{wang2025realistic, 41 | author = {Wang, Wei and Wu, Dong-Dong and Wang, Jindong and Niu, Gang and Zhang, Min-Ling and Sugiyama, Masashi}, 42 | title = {Realistic evaluation of deep partial-label learning algorithms}, 43 | booktitle = {Proceedings of the 13th International Conference on Learning Representations}, 44 | year = {2025} 45 | } 46 | ``` 47 | 48 | 49 | -------------------------------------------------------------------------------- /plench/lib/command_launchers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A command launcher launches a list of commands on a cluster; implement your own 3 | launcher to add support for your cluster. We've provided an example launcher 4 | which runs all commands serially on the local machine. 5 | """ 6 | 7 | import subprocess 8 | import time 9 | import torch 10 | import os 11 | 12 | def local_launcher(commands): 13 | """Launch commands serially on the local machine.""" 14 | for cmd in commands: 15 | subprocess.call(cmd, shell=True) 16 | 17 | def dummy_launcher(commands): 18 | """ 19 | Doesn't run anything; instead, prints each command. 20 | Useful for testing. 21 | """ 22 | for cmd in commands: 23 | print(f'Dummy launcher: {cmd}') 24 | 25 | def multi_gpu_launcher(commands): 26 | """ 27 | Launch commands on the local machine, using all GPUs in parallel. 28 | """ 29 | print('WARNING: using experimental multi_gpu_launcher.') 30 | try: 31 | # Get list of GPUs from env, split by ',' and remove empty string '' 32 | # To handle the case when there is one extra comma: `CUDA_VISIBLE_DEVICES=0,1,2,3, python3 ...` 33 | available_gpus = [x for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',') if x != ''] 34 | except Exception: 35 | # If the env variable is not set, we use all GPUs 36 | available_gpus = [str(x) for x in range(torch.cuda.device_count())] 37 | n_gpus = len(available_gpus) 38 | procs_by_gpu = [None]*n_gpus 39 | 40 | while len(commands) > 0: 41 | for idx, gpu_idx in enumerate(available_gpus): 42 | proc = procs_by_gpu[idx] 43 | if (proc is None) or (proc.poll() is not None): 44 | # Nothing is running on this GPU; launch a command. 45 | cmd = commands.pop(0) 46 | new_proc = subprocess.Popen( 47 | f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', shell=True) 48 | procs_by_gpu[idx] = new_proc 49 | break 50 | time.sleep(1) 51 | 52 | # Wait for the last few tasks to finish before returning 53 | for p in procs_by_gpu: 54 | if p is not None: 55 | p.wait() 56 | 57 | REGISTRY = { 58 | 'local': local_launcher, 59 | 'dummy': dummy_launcher, 60 | 'multi_gpu': multi_gpu_launcher 61 | } 62 | 63 | -------------------------------------------------------------------------------- /plench/core/model_selection.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | 4 | def filter_step0_records(records): 5 | """Filter step 0""" 6 | return records.filter(lambda r: r['step'] != 0) 7 | 8 | class SelectionMethod: 9 | """Abstract class whose subclasses implement strategies for model 10 | selection across hparams and timesteps.""" 11 | 12 | def __init__(self): 13 | raise TypeError 14 | 15 | @classmethod 16 | def run_acc(self, run_records): 17 | """ 18 | Given records from a run, return a {val_acc, test_acc} dict representing 19 | the best val-acc and corresponding test-acc for that run. 20 | """ 21 | raise NotImplementedError 22 | 23 | @classmethod 24 | def hparams_accs(self, records): 25 | """ 26 | Given all records from a single (dataset, algorithm) pair, 27 | return a sorted list of (run_acc, records) tuples. 28 | """ 29 | return (records.group('args.hparams_seed') 30 | .map(lambda _, run_records: 31 | ( 32 | self.run_acc(run_records), 33 | run_records 34 | ) 35 | ).filter(lambda x: x[0] is not None) 36 | .sorted(key=lambda x: x[0]['val_acc'])[::-1] 37 | ) 38 | 39 | @classmethod 40 | def sweep_acc(self, records): 41 | """ 42 | Given all records from a single (dataset, algorithm) pair, 43 | return the mean test acc of the k runs with the top val accs. 44 | """ 45 | _hparams_accs = self.hparams_accs(records) 46 | if len(_hparams_accs): 47 | return _hparams_accs[0][0]['test_acc'] 48 | else: 49 | return None 50 | 51 | class OracleSelectionMethod(SelectionMethod): 52 | """Like Selection method which picks argmax(val_accuracy) across all hparams 53 | and checkpoints, but instead of taking the argmax over all 54 | checkpoints, we pick the last checkpoint, i.e. no early stopping.""" 55 | name = "Oracle Accuracy" 56 | 57 | @classmethod 58 | def run_acc(self, run_records): 59 | run_records = filter_step0_records(run_records) 60 | if not len(run_records): 61 | return None 62 | chosen_record = run_records.sorted(lambda r: r['step'])[-1] 63 | return { 64 | 'val_acc': chosen_record['val_accuracy'], 65 | 'test_acc': chosen_record['test_acc'] 66 | } 67 | 68 | class OracleSelectionWithEarlyStoppingMethod(SelectionMethod): 69 | """Picks argmax(val_accuracy), with early stopping""" 70 | name = "Oracle Accuracy w/ ES" 71 | 72 | @classmethod 73 | def _step_acc(self, record): 74 | """Given a single record, return a {val_acc, test_acc} dict.""" 75 | 76 | return { 77 | 'val_acc': record['val_accuracy'], 78 | 'test_acc': record['test_acc'] 79 | } 80 | 81 | @classmethod 82 | def run_acc(self, run_records): 83 | test_records = filter_step0_records(run_records) 84 | if not len(test_records): 85 | return None 86 | return test_records.map(self._step_acc).argmax('val_acc') 87 | 88 | class CoveringRateSelectionMethod(SelectionMethod): 89 | """Picks argmax(val_covering_rate)""" 90 | name = "Covering Rate" 91 | 92 | @classmethod 93 | def _step_acc(self, record): 94 | """Given a single record, return a {val_acc, test_acc} dict.""" 95 | 96 | return { 97 | 'val_acc': record['val_covering_rate'], 98 | 'test_acc': record['test_acc'] 99 | } 100 | 101 | @classmethod 102 | def run_acc(self, run_records): 103 | test_records = filter_step0_records(run_records) 104 | if not len(test_records): 105 | return None 106 | return test_records.map(self._step_acc).argmax('val_acc') 107 | 108 | class ApproximatedAccuracySelectionMethod(SelectionMethod): 109 | """Picks argmax(val_approximated_acc)""" 110 | name = "Approximated Accuracy" 111 | 112 | @classmethod 113 | def _step_acc(self, record): 114 | """Given a single record, return a {val_acc, test_acc} dict.""" 115 | 116 | return { 117 | 'val_acc': record['val_approximated_acc'], 118 | 'test_acc': record['test_acc'] 119 | } 120 | 121 | @classmethod 122 | def run_acc(self, run_records): 123 | test_records = filter_step0_records(run_records) 124 | if not len(test_records): 125 | return None 126 | return test_records.map(self._step_acc).argmax('val_acc') 127 | -------------------------------------------------------------------------------- /plench/data/randaugment.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | 9 | def AutoContrast(img, _): 10 | return PIL.ImageOps.autocontrast(img) 11 | 12 | 13 | def Brightness(img, v): 14 | assert v >= 0.0 15 | return PIL.ImageEnhance.Brightness(img).enhance(v) 16 | 17 | 18 | def Color(img, v): 19 | assert v >= 0.0 20 | return PIL.ImageEnhance.Color(img).enhance(v) 21 | 22 | 23 | def Contrast(img, v): 24 | assert v >= 0.0 25 | return PIL.ImageEnhance.Contrast(img).enhance(v) 26 | 27 | 28 | def Equalize(img, _): 29 | return PIL.ImageOps.equalize(img) 30 | 31 | 32 | def Invert(img, _): 33 | return PIL.ImageOps.invert(img) 34 | 35 | 36 | def Identity(img, v): 37 | return img 38 | 39 | 40 | def Posterize(img, v): # [4, 8] 41 | v = int(v) 42 | v = max(1, v) 43 | return PIL.ImageOps.posterize(img, v) 44 | 45 | 46 | def Rotate(img, v): # [-30, 30] 47 | #assert -30 <= v <= 30 48 | #if random.random() > 0.5: 49 | # v = -v 50 | return img.rotate(v) 51 | 52 | 53 | 54 | def Sharpness(img, v): # [0.1,1.9] 55 | assert v >= 0.0 56 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 57 | 58 | 59 | def ShearX(img, v): # [-0.3, 0.3] 60 | #assert -0.3 <= v <= 0.3 61 | #if random.random() > 0.5: 62 | # v = -v 63 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 64 | 65 | 66 | def ShearY(img, v): # [-0.3, 0.3] 67 | #assert -0.3 <= v <= 0.3 68 | #if random.random() > 0.5: 69 | # v = -v 70 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 71 | 72 | 73 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 74 | #assert -0.3 <= v <= 0.3 75 | #if random.random() > 0.5: 76 | # v = -v 77 | v = v * img.size[0] 78 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 79 | 80 | 81 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 82 | #assert v >= 0.0 83 | #if random.random() > 0.5: 84 | # v = -v 85 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 86 | 87 | 88 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 89 | #assert -0.3 <= v <= 0.3 90 | #if random.random() > 0.5: 91 | # v = -v 92 | v = v * img.size[1] 93 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 94 | 95 | 96 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 97 | #assert 0 <= v 98 | #if random.random() > 0.5: 99 | # v = -v 100 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 101 | 102 | 103 | def Solarize(img, v): # [0, 256] 104 | assert 0 <= v <= 256 105 | return PIL.ImageOps.solarize(img, v) 106 | 107 | 108 | def Cutout(img, v): #[0, 60] => percentage: [0, 0.2] => change to [0, 0.5] 109 | assert 0.0 <= v <= 0.5 110 | if v <= 0.: 111 | return img 112 | 113 | v = v * img.size[0] 114 | return CutoutAbs(img, v) 115 | 116 | 117 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 118 | # assert 0 <= v <= 20 119 | if v < 0: 120 | return img 121 | w, h = img.size 122 | x0 = np.random.uniform(w) 123 | y0 = np.random.uniform(h) 124 | 125 | x0 = int(max(0, x0 - v / 2.)) 126 | y0 = int(max(0, y0 - v / 2.)) 127 | x1 = min(w, x0 + v) 128 | y1 = min(h, y0 + v) 129 | 130 | xy = (x0, y0, x1, y1) 131 | color = (125, 123, 114) 132 | # color = (0, 0, 0) 133 | img = img.copy() 134 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 135 | return img 136 | 137 | 138 | def augment_list(): 139 | l = [ 140 | (AutoContrast, 0, 1), 141 | (Brightness, 0.05, 0.95), 142 | (Color, 0.05, 0.95), 143 | (Contrast, 0.05, 0.95), 144 | (Equalize, 0, 1), 145 | (Identity, 0, 1), 146 | (Posterize, 4, 8), 147 | (Rotate, -30, 30), 148 | (Sharpness, 0.05, 0.95), 149 | (ShearX, -0.3, 0.3), 150 | (ShearY, -0.3, 0.3), 151 | (Solarize, 0, 256), 152 | (TranslateX, -0.3, 0.3), 153 | (TranslateY, -0.3, 0.3) 154 | ] 155 | return l 156 | 157 | 158 | class RandomAugment: 159 | def __init__(self, n, m=None): 160 | self.n = n 161 | self.m = m # [0, 30] in fixmatch, deprecated. 162 | self.augment_list = augment_list() 163 | 164 | 165 | def __call__(self, img): 166 | ops = random.choices(self.augment_list, k=self.n) 167 | for op, min_val, max_val in ops: 168 | val = min_val + float(max_val - min_val)*random.random() 169 | img = op(img, val) 170 | cutout_val = random.random() * 0.5 171 | img = Cutout(img, cutout_val) #for fixmatch 172 | return img 173 | 174 | 175 | if __name__ == '__main__': 176 | randaug = RandomAugment(3,5) 177 | print(randaug) 178 | for item in randaug.augment_list: 179 | print(item) 180 | -------------------------------------------------------------------------------- /plench/lib/densenet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/bearpaw/pytorch-classification/blob/master/models/cifar/densenet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | __all__ = ['densenet'] 10 | 11 | 12 | from torch.autograd import Variable 13 | 14 | class Bottleneck(nn.Module): 15 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 16 | super(Bottleneck, self).__init__() 17 | planes = expansion * growthRate 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 22 | padding=1, bias=False) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.dropRate = dropRate 25 | 26 | def forward(self, x): 27 | out = self.bn1(x) 28 | out = self.relu(out) 29 | out = self.conv1(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | out = self.conv2(out) 33 | if self.dropRate > 0: 34 | out = F.dropout(out, p=self.dropRate, training=self.training) 35 | 36 | out = torch.cat((x, out), 1) 37 | 38 | return out 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 43 | super(BasicBlock, self).__init__() 44 | planes = expansion * growthRate 45 | self.bn1 = nn.BatchNorm2d(inplanes) 46 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 47 | padding=1, bias=False) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.dropRate = dropRate 50 | 51 | def forward(self, x): 52 | out = self.bn1(x) 53 | out = self.relu(out) 54 | out = self.conv1(out) 55 | if self.dropRate > 0: 56 | out = F.dropout(out, p=self.dropRate, training=self.training) 57 | 58 | out = torch.cat((x, out), 1) 59 | 60 | return out 61 | 62 | 63 | class Transition(nn.Module): 64 | def __init__(self, inplanes, outplanes): 65 | super(Transition, self).__init__() 66 | self.bn1 = nn.BatchNorm2d(inplanes) 67 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 68 | bias=False) 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | def forward(self, x): 72 | out = self.bn1(x) 73 | out = self.relu(out) 74 | out = self.conv1(out) 75 | out = F.avg_pool2d(out, 2) 76 | return out 77 | 78 | 79 | class DenseNet(nn.Module): 80 | 81 | def __init__(self, depth=22, block=Bottleneck, 82 | dropRate=0, growthRate=12, compressionRate=2): 83 | super(DenseNet, self).__init__() 84 | 85 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 86 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 87 | 88 | self.growthRate = growthRate 89 | self.dropRate = dropRate 90 | 91 | # self.inplanes is a global variable used across multiple 92 | # helper functions 93 | self.inplanes = growthRate * 2 94 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 95 | bias=False) 96 | self.dense1 = self._make_denseblock(block, n) 97 | self.trans1 = self._make_transition(compressionRate) 98 | self.dense2 = self._make_denseblock(block, n) 99 | self.trans2 = self._make_transition(compressionRate) 100 | self.dense3 = self._make_denseblock(block, n) 101 | self.bn = nn.BatchNorm2d(self.inplanes) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.avgpool = nn.AvgPool2d(8) 104 | self.n_outputs = self.inplanes 105 | 106 | # Weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_denseblock(self, block, blocks): 116 | layers = [] 117 | for i in range(blocks): 118 | # Currently we fix the expansion ratio as the default value 119 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 120 | self.inplanes += self.growthRate 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def _make_transition(self, compressionRate): 125 | inplanes = self.inplanes 126 | outplanes = int(math.floor(self.inplanes // compressionRate)) 127 | self.inplanes = outplanes 128 | return Transition(inplanes, outplanes) 129 | 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | 134 | x = self.trans1(self.dense1(x)) 135 | x = self.trans2(self.dense2(x)) 136 | x = self.dense3(x) 137 | x = self.bn(x) 138 | x = self.relu(x) 139 | 140 | x = self.avgpool(x) 141 | x = x.view(x.size(0), -1) 142 | 143 | return x 144 | 145 | 146 | def densenet(**kwargs): 147 | """ 148 | Constructs a DenseNet model. 149 | """ 150 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /plench/lib/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | ''' 9 | import torch.nn as nn 10 | import math 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | __all__ = ['resnet'] 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | "3x3 convolution with padding" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = nn.BatchNorm2d(planes) 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.bn3 = nn.BatchNorm2d(planes * 4) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, depth): 97 | super(ResNet, self).__init__() 98 | # Model type specifies number of layers for CIFAR-10 model 99 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 100 | n = (depth - 2) // 6 101 | 102 | block = Bottleneck if depth >=44 else BasicBlock 103 | 104 | self.inplanes = 16 105 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(16) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.layer1 = self._make_layer(block, 16, n) 110 | self.layer2 = self._make_layer(block, 32, n, stride=2) 111 | self.layer3 = self._make_layer(block, 64, n, stride=2) 112 | self.avgpool = nn.AvgPool2d(8) 113 | self.n_outputs = 64 * block.expansion 114 | #self.fc = nn.Linear(64 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 119 | m.weight.data.normal_(0, math.sqrt(2. / n)) 120 | elif isinstance(m, nn.BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d(self.inplanes, planes * block.expansion, 129 | kernel_size=1, stride=stride, bias=False), 130 | nn.BatchNorm2d(planes * block.expansion), 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) # 32x32 145 | 146 | x = self.layer1(x) # 32x32 147 | x = self.layer2(x) # 16x16 148 | x = self.layer3(x) # 8x8 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | #x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet(**kwargs): 158 | """ 159 | Constructs a ResNet model. 160 | """ 161 | return ResNet(**kwargs) -------------------------------------------------------------------------------- /plench/collect_results.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import argparse 3 | import functools 4 | import glob 5 | import pickle 6 | import itertools 7 | import json 8 | import os 9 | import random 10 | import sys 11 | 12 | import numpy as np 13 | import tqdm 14 | 15 | from .data import datasets 16 | from .core import algorithms, model_selection 17 | from .lib import misc, reporting 18 | from .lib.query import Q 19 | import warnings 20 | 21 | def format_mean(data, latex): 22 | """Given a list of datapoints, return a string describing their mean and 23 | standard error""" 24 | if len(data) == 0: 25 | return None, None, "X" 26 | mean = 100 * np.mean(list(data)) 27 | err = 100 * np.std(list(data) / np.sqrt(len(data))) 28 | if latex: 29 | return mean, err, "{:.2f}$\\pm${:.2f}".format(mean, err) 30 | else: 31 | return mean, err, "{:.2f} +/- {:.2f}".format(mean, err) 32 | 33 | def print_table(table, header_text, row_labels, col_labels, colwidth=10, 34 | latex=True): 35 | """Pretty-print a 2D array of data, optionally with row/col labels""" 36 | print("") 37 | 38 | if latex: 39 | num_cols = len(table[0]) 40 | print("\\begin{center}") 41 | print("\\adjustbox{max width=\\textwidth}{%") 42 | print("\\begin{tabular}{l" + "c" * num_cols + "}") 43 | print("\\toprule") 44 | else: 45 | print("--------", header_text) 46 | 47 | for row, label in zip(table, row_labels): 48 | row.insert(0, label) 49 | 50 | if latex: 51 | col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}" 52 | for col_label in col_labels] 53 | table.insert(0, col_labels) 54 | 55 | for r, row in enumerate(table): 56 | misc.print_row(row, colwidth=colwidth, latex=latex) 57 | if latex and r == 0: 58 | print("\\midrule") 59 | if latex: 60 | print("\\bottomrule") 61 | print("\\end{tabular}}") 62 | print("\\end{center}") 63 | 64 | def print_results_tables(records, dataset, latex): 65 | """Given all records, print a results table for each dataset.""" 66 | SELECTION_METHODS = [ 67 | model_selection.CoveringRateSelectionMethod, 68 | model_selection.ApproximatedAccuracySelectionMethod, 69 | model_selection.OracleSelectionMethod, 70 | model_selection.OracleSelectionWithEarlyStoppingMethod 71 | ] 72 | 73 | # read algorithm names and sort (predefined order) 74 | alg_names = Q(records).select("args.algorithm").unique() 75 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 76 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 77 | 78 | 79 | # Print an "averages" table 80 | if latex: 81 | print() 82 | print("\\subsubsection{Averages}") 83 | 84 | selection_method_names = [] 85 | for j, selection_method in enumerate(SELECTION_METHODS): 86 | selection_method_names.append(selection_method.name) 87 | table = [[None for _ in [*selection_method_names]] for _ in alg_names] 88 | for i, algorithm in enumerate(alg_names): 89 | for j, selection_method in enumerate(SELECTION_METHODS): 90 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 91 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 92 | ).filter(lambda g: g["sweep_acc"] is not None) 93 | trial_averages = (grouped_records 94 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 95 | .group("trial_seed") 96 | .map(lambda trial_seed, group: 97 | group.select("sweep_acc").mean() 98 | ) 99 | ) 100 | mean, err, table[i][j] = format_mean(trial_averages, latex) 101 | 102 | col_labels = ["Algorithm", *selection_method_names] 103 | header_text = f"Dataset: {dataset}" 104 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 105 | latex=latex) 106 | 107 | if __name__ == "__main__": 108 | np.set_printoptions(suppress=True) 109 | 110 | parser = argparse.ArgumentParser( 111 | description="Partial-label learning testbed") 112 | parser.add_argument("--input_dir", type=str, required=True) 113 | parser.add_argument("--latex", action="store_true") 114 | args = parser.parse_args() 115 | 116 | results_file = "results.tex" if args.latex else "results.txt" 117 | 118 | sys.stdout = misc.Tee(os.path.join(args.input_dir, results_file), "w") 119 | 120 | records = reporting.load_records(args.input_dir) 121 | 122 | if args.latex: 123 | print("\\documentclass{article}") 124 | print("\\usepackage{booktabs}") 125 | print("\\usepackage{adjustbox}") 126 | print("\\begin{document}") 127 | print("\\section{Full PLENCH results}") 128 | print("% Total records:", len(records)) 129 | else: 130 | print("Total records:", len(records)) 131 | 132 | # read dataset names and sort (lexicographic order) 133 | dataset_names = Q(records).select("args.dataset").unique().sorted() 134 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 135 | 136 | for dataset in dataset_names: 137 | if args.latex: 138 | print() 139 | print("\\subsection{{Dataset: {}}}".format( 140 | dataset)) 141 | print_results_tables(records, dataset, args.latex) 142 | 143 | if args.latex: 144 | print("\\end{document}") 145 | -------------------------------------------------------------------------------- /plench/lib/query.py: -------------------------------------------------------------------------------- 1 | """Small query library.""" 2 | 3 | import collections 4 | import inspect 5 | import json 6 | import types 7 | import unittest 8 | import warnings 9 | import math 10 | 11 | import numpy as np 12 | 13 | 14 | def make_selector_fn(selector): 15 | """ 16 | If selector is a function, return selector. 17 | Otherwise, return a function corresponding to the selector string. Examples 18 | of valid selector strings and the corresponding functions: 19 | x lambda obj: obj['x'] 20 | x.y lambda obj: obj['x']['y'] 21 | x,y lambda obj: (obj['x'], obj['y']) 22 | """ 23 | if isinstance(selector, str): 24 | if ',' in selector: 25 | parts = selector.split(',') 26 | part_selectors = [make_selector_fn(part) for part in parts] 27 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 28 | elif '.' in selector: 29 | parts = selector.split('.') 30 | part_selectors = [make_selector_fn(part) for part in parts] 31 | def f(obj): 32 | for sel in part_selectors: 33 | obj = sel(obj) 34 | return obj 35 | return f 36 | else: 37 | key = selector.strip() 38 | return lambda obj: obj[key] 39 | elif isinstance(selector, types.FunctionType): 40 | return selector 41 | else: 42 | raise TypeError 43 | 44 | def hashable(obj): 45 | try: 46 | hash(obj) 47 | return obj 48 | except TypeError: 49 | return json.dumps({'_':obj}, sort_keys=True) 50 | 51 | class Q(object): 52 | def __init__(self, list_): 53 | super(Q, self).__init__() 54 | self._list = list_ 55 | 56 | def __len__(self): 57 | return len(self._list) 58 | 59 | def __getitem__(self, key): 60 | return self._list[key] 61 | 62 | def __eq__(self, other): 63 | if isinstance(other, self.__class__): 64 | return self._list == other._list 65 | else: 66 | return self._list == other 67 | 68 | def __str__(self): 69 | return str(self._list) 70 | 71 | def __repr__(self): 72 | return repr(self._list) 73 | 74 | def _append(self, item): 75 | """Unsafe, be careful you know what you're doing.""" 76 | self._list.append(item) 77 | 78 | def group(self, selector): 79 | """ 80 | Group elements by selector and return a list of (group, group_records) 81 | tuples. 82 | """ 83 | selector = make_selector_fn(selector) 84 | groups = {} 85 | for x in self._list: 86 | group = selector(x) 87 | group_key = hashable(group) 88 | if group_key not in groups: 89 | groups[group_key] = (group, Q([])) 90 | groups[group_key][1]._append(x) 91 | results = [groups[key] for key in sorted(groups.keys())] 92 | return Q(results) 93 | 94 | def group_map(self, selector, fn): 95 | """ 96 | Group elements by selector, apply fn to each group, and return a list 97 | of the results. 98 | """ 99 | return self.group(selector).map(fn) 100 | 101 | def map(self, fn): 102 | """ 103 | map self onto fn. If fn takes multiple args, tuple-unpacking 104 | is applied. 105 | """ 106 | if len(inspect.signature(fn).parameters) > 1: 107 | return Q([fn(*x) for x in self._list]) 108 | else: 109 | return Q([fn(x) for x in self._list]) 110 | 111 | def select(self, selector): 112 | selector = make_selector_fn(selector) 113 | return Q([selector(x) for x in self._list]) 114 | 115 | def min(self): 116 | return min(self._list) 117 | 118 | def max(self): 119 | return max(self._list) 120 | 121 | def sum(self): 122 | return sum(self._list) 123 | 124 | def len(self): 125 | return len(self._list) 126 | 127 | def mean(self): 128 | with warnings.catch_warnings(): 129 | warnings.simplefilter("ignore") 130 | return float(np.mean(self._list)) 131 | 132 | def std(self): 133 | with warnings.catch_warnings(): 134 | warnings.simplefilter("ignore") 135 | return float(np.std(self._list)) 136 | 137 | def mean_std(self): 138 | return (self.mean(), self.std()) 139 | 140 | def argmax(self, selector): 141 | selector = make_selector_fn(selector) 142 | return max(self._list, key=selector) 143 | 144 | def filter(self, fn): 145 | return Q([x for x in self._list if fn(x)]) 146 | 147 | def filter_equals(self, selector, value): 148 | """like [x for x in y if x.selector == value]""" 149 | selector = make_selector_fn(selector) 150 | return self.filter(lambda r: selector(r) == value) 151 | 152 | def filter_not_none(self): 153 | return self.filter(lambda r: r is not None) 154 | 155 | def filter_not_nan(self): 156 | return self.filter(lambda r: not np.isnan(r)) 157 | 158 | def flatten(self): 159 | return Q([y for x in self._list for y in x]) 160 | 161 | def unique(self): 162 | result = [] 163 | result_set = set() 164 | for x in self._list: 165 | hashable_x = hashable(x) 166 | if hashable_x not in result_set: 167 | result_set.add(hashable_x) 168 | result.append(x) 169 | return Q(result) 170 | 171 | def sorted(self, key=None): 172 | if key is None: 173 | key = lambda x: x 174 | def key2(x): 175 | x = key(x) 176 | if isinstance(x, (np.floating, float)) and np.isnan(x): 177 | return float('-inf') 178 | else: 179 | return x 180 | return Q(sorted(self._list, key=key2)) 181 | -------------------------------------------------------------------------------- /plench/lib/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import hashlib 3 | import sys 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import json 10 | 11 | class NpEncoder(json.JSONEncoder): 12 | def default(self, obj): 13 | if isinstance(obj, np.integer): 14 | return int(obj) 15 | if isinstance(obj, np.floating): 16 | return float(obj) 17 | if isinstance(obj, np.ndarray): 18 | return obj.tolist() 19 | return super(NpEncoder, self).default(obj) 20 | 21 | def pdb(): 22 | sys.stdout = sys.__stdout__ 23 | import pdb 24 | print("Launching PDB, enter 'n' to step to parent function.") 25 | pdb.set_trace() 26 | 27 | def seed_hash(*args): 28 | """ 29 | Derive an integer hash from all args, for use as a random seed. 30 | """ 31 | args_str = str(args) 32 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) 33 | 34 | def print_separator(): 35 | print("="*80) 36 | 37 | def print_row(row, colwidth=10, latex=False): 38 | if latex: 39 | sep = " & " 40 | end_ = "\\\\" 41 | else: 42 | sep = " " 43 | end_ = "" 44 | 45 | def format_val(x): 46 | if np.issubdtype(type(x), np.floating): 47 | x = "{:.10f}".format(x) 48 | return str(x).ljust(colwidth)[:colwidth] 49 | print(sep.join([format_val(x) for x in row]), end_) 50 | 51 | class _SplitDataset(torch.utils.data.Dataset): 52 | """Used by split_dataset""" 53 | def __init__(self, underlying_dataset, keys): 54 | super(_SplitDataset, self).__init__() 55 | self.underlying_dataset = underlying_dataset 56 | self.keys = keys 57 | self.data = self.underlying_dataset.data[self.keys] 58 | self.partial_targets = self.underlying_dataset.partial_targets[self.keys] 59 | def __getitem__(self, key): 60 | original_image, weak_image, strong_image, distill_image, partial_targets, ord_labels = self.underlying_dataset[self.keys[key]] 61 | return original_image, weak_image, strong_image, distill_image, partial_targets, ord_labels, key 62 | def __len__(self): 63 | return len(self.keys) 64 | 65 | def split_dataset(dataset, n, seed=0): 66 | assert(n <= len(dataset)) 67 | keys = list(range(len(dataset))) 68 | np.random.RandomState(seed).shuffle(keys) 69 | keys_1 = keys[:n] 70 | keys_2 = keys[n:] 71 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 72 | 73 | 74 | def accuracy(network, loader, device): 75 | correct = 0 76 | total = 0 77 | network.eval() 78 | with torch.no_grad(): 79 | for x, y in loader: 80 | x = x.to(device) 81 | y = y.to(device) 82 | p = network.predict(x) 83 | batch_weights = torch.ones(len(x)) 84 | batch_weights = batch_weights.to(device) 85 | if p.size(1) == 1: 86 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() 87 | else: 88 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() 89 | total += batch_weights.sum().item() 90 | network.train() 91 | 92 | return correct / total 93 | 94 | class Tee: 95 | def __init__(self, fname, mode="a"): 96 | self.stdout = sys.stdout 97 | self.file = open(fname, mode) 98 | 99 | def write(self, message): 100 | self.stdout.write(message) 101 | self.file.write(message) 102 | self.flush() 103 | 104 | def flush(self): 105 | self.stdout.flush() 106 | self.file.flush() 107 | 108 | def val_accuracy(network, loader, device): 109 | correct = 0 110 | total = 0 111 | network.eval() 112 | with torch.no_grad(): 113 | for x, x_weak, x_strong, x_distill, partial_y, y, _ in loader: 114 | x = x.to(device) 115 | partial_y = partial_y.to(device) 116 | y = y.to(device) 117 | p = network.predict(x) 118 | if p.size(1) == 1: 119 | correct += p.gt(0).eq(y).float().sum().item() 120 | else: 121 | correct += p.argmax(1).eq(y).float().sum().item() 122 | total += len(x) 123 | network.train() 124 | 125 | return correct / total 126 | 127 | 128 | def val_covering_rate(network, loader, device): 129 | correct = 0 130 | total = 0 131 | network.eval() 132 | with torch.no_grad(): 133 | for x, x_weak, x_strong, x_distill, partial_y, _, _ in loader: 134 | x = x.to(device) 135 | partial_y = partial_y.to(device) 136 | p = network.predict(x) 137 | predicted_label = p.argmax(1) 138 | covering_per_example = partial_y[torch.arange(len(x)), predicted_label] 139 | correct += covering_per_example.sum().item() 140 | total += len(x) 141 | network.train() 142 | 143 | return correct / total 144 | 145 | def val_approximated_accuracy(network, loader, device): 146 | correct = 0 147 | total = 0 148 | network.eval() 149 | with torch.no_grad(): 150 | for x, x_weak, x_strong, x_distill, partial_y, _, _ in loader: 151 | x = x.to(device) 152 | partial_y = partial_y.to(device) 153 | batch_outputs = network.predict(x) 154 | temp_un_conf = F.softmax(batch_outputs, dim=1) 155 | label_confidence = temp_un_conf * partial_y 156 | base_value = label_confidence.sum(dim=1).unsqueeze(1).repeat(1, label_confidence.shape[1]) + 1e-12 157 | label_confidence = label_confidence / base_value 158 | predicted_label = batch_outputs.argmax(1) 159 | risk_mat = torch.ones_like(partial_y).float() 160 | risk_mat[torch.arange(len(x)), predicted_label] = 0 161 | correct += len(x) - (risk_mat * label_confidence).sum().item() 162 | total += len(x) 163 | network.train() 164 | return correct / total -------------------------------------------------------------------------------- /plench/sweep.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run sweeps 3 | """ 4 | 5 | import argparse 6 | import copy 7 | import getpass 8 | import hashlib 9 | import json 10 | import os 11 | import random 12 | import shutil 13 | import time 14 | import uuid 15 | 16 | import numpy as np 17 | import torch 18 | 19 | from .core import hparams_registry, algorithms 20 | from .lib import misc, command_launchers 21 | 22 | import tqdm 23 | import shlex 24 | 25 | class Job: 26 | NOT_LAUNCHED = 'Not launched' 27 | INCOMPLETE = 'Incomplete' 28 | DONE = 'Done' 29 | 30 | def __init__(self, train_args, sweep_output_dir): 31 | args_str = json.dumps(train_args, sort_keys=True) 32 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 33 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 34 | 35 | self.train_args = copy.deepcopy(train_args) 36 | self.train_args['output_dir'] = self.output_dir 37 | command = ['python -m', 'plench.train'] 38 | for k, v in sorted(self.train_args.items()): 39 | if k == 'skip_model_save': 40 | command.append(f'--{k}') 41 | else: 42 | if isinstance(v, list): 43 | v = ' '.join([str(v_) for v_ in v]) 44 | elif isinstance(v, str): 45 | v = shlex.quote(v) 46 | command.append(f'--{k} {v}') 47 | 48 | self.command_str = ' '.join(command) 49 | 50 | if os.path.exists(os.path.join(self.output_dir, 'done')): 51 | self.state = Job.DONE 52 | elif os.path.exists(self.output_dir): 53 | self.state = Job.INCOMPLETE 54 | else: 55 | self.state = Job.NOT_LAUNCHED 56 | 57 | def __str__(self): 58 | job_info = (self.train_args['dataset'], 59 | self.train_args['algorithm'], 60 | self.train_args['hparams_seed'], 61 | self.train_args['trial_seed']) 62 | return '{}: {} {}'.format( 63 | self.state, 64 | self.output_dir, 65 | job_info) 66 | 67 | @staticmethod 68 | def launch(jobs, launcher_fn): 69 | print('Launching...') 70 | jobs = jobs.copy() 71 | np.random.shuffle(jobs) 72 | print('Making job directories:') 73 | for job in tqdm.tqdm(jobs, leave=False): 74 | os.makedirs(job.output_dir, exist_ok=True) 75 | commands = [job.command_str for job in jobs] 76 | launcher_fn(commands) 77 | print(f'Launched {len(jobs)} jobs!') 78 | 79 | @staticmethod 80 | def delete(jobs): 81 | print('Deleting...') 82 | for job in jobs: 83 | shutil.rmtree(job.output_dir) 84 | print(f'Deleted {len(jobs)} jobs!') 85 | 86 | def make_args_list(n_trials_from, n_trials, dataset_names, algorithms, n_hparams_from, n_hparams, steps, 87 | data_dir, holdout_fraction, hparams, pl_type, fps_rate,skip_model_save): 88 | args_list = [] 89 | for trial_seed in range(n_trials_from, n_trials): 90 | for dataset in dataset_names: 91 | for algorithm in algorithms: 92 | for hparams_seed in range(n_hparams_from, n_hparams): 93 | train_args = {} 94 | train_args['dataset'] = dataset 95 | train_args['algorithm'] = algorithm 96 | train_args['holdout_fraction'] = holdout_fraction 97 | train_args['hparams_seed'] = hparams_seed 98 | train_args['data_dir'] = data_dir 99 | train_args['trial_seed'] = trial_seed 100 | train_args['seed'] = misc.seed_hash(dataset, 101 | algorithm, hparams_seed, trial_seed) 102 | if pl_type is not None: 103 | train_args['pl_type'] = pl_type 104 | if fps_rate is not None: 105 | train_args['fps_rate'] = fps_rate 106 | if steps is not None: 107 | train_args['steps'] = steps 108 | if hparams is not None: 109 | train_args['hparams'] = hparams 110 | if skip_model_save: 111 | train_args['skip_model_save'] = "" 112 | args_list.append(train_args) 113 | return args_list 114 | 115 | def ask_for_confirmation(): 116 | response = input('Are you sure? (y/n) ') 117 | if not response.lower().strip()[:1] == "y": 118 | print('Nevermind!') 119 | exit(0) 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser(description='Run a sweep') 123 | parser.add_argument('command', choices=['launch', 'delete_incomplete']) 124 | parser.add_argument('--datasets', nargs='+', type=str, required=True) 125 | parser.add_argument('--algorithms', nargs='+', type=str, required=True) 126 | parser.add_argument('--n_hparams_from', type=int, default=0) 127 | parser.add_argument('--n_hparams', type=int, default=20) 128 | parser.add_argument('--output_dir', type=str, required=True) 129 | parser.add_argument('--data_dir', type=str, required=True) 130 | parser.add_argument('--seed', type=int, default=0) 131 | parser.add_argument('--n_trials_from', type=int, default=0) 132 | parser.add_argument('--n_trials', type=int, default=3) 133 | parser.add_argument('--command_launcher', type=str, required=True) 134 | parser.add_argument('--steps', type=int, default=None) 135 | parser.add_argument('--hparams', type=str, default=None) 136 | parser.add_argument('--holdout_fraction', type=float, default=0.1) 137 | parser.add_argument('--skip_confirmation', action='store_true') 138 | parser.add_argument('--skip_model_save', action='store_true') 139 | parser.add_argument('--pl_type', help='partial label generation type', default='real', type=str, choices=['uss','fps', 'real'], required=False) 140 | parser.add_argument('--fps_rate', help='partial label generation flip rate of fps', default=0, type=float, required=False) 141 | 142 | args = parser.parse_args() 143 | 144 | args_list = make_args_list( 145 | n_trials_from=args.n_trials_from, 146 | n_trials=args.n_trials, 147 | dataset_names=args.datasets, 148 | algorithms=args.algorithms, 149 | n_hparams_from=args.n_hparams_from, 150 | n_hparams=args.n_hparams, 151 | steps=args.steps, 152 | data_dir=args.data_dir, 153 | holdout_fraction=args.holdout_fraction, 154 | hparams=args.hparams, 155 | pl_type=args.pl_type, 156 | fps_rate=args.fps_rate, 157 | skip_model_save=args.skip_model_save 158 | ) 159 | 160 | jobs = [Job(train_args, args.output_dir) for train_args in args_list] 161 | 162 | for job in jobs: 163 | print(job) 164 | 165 | # if delete incomplete 166 | if len([j for j in jobs if j.state == Job.INCOMPLETE]) > 0: 167 | for j_delete in [j for j in jobs if j.state == Job.INCOMPLETE]: 168 | print(j_delete) 169 | 170 | print("{} jobs: {} done, {} incomplete, {} not launched.".format( 171 | len(jobs), 172 | len([j for j in jobs if j.state == Job.DONE]), 173 | len([j for j in jobs if j.state == Job.INCOMPLETE]), 174 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) 175 | ) 176 | 177 | if args.command == 'launch': 178 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] 179 | print(f'About to launch {len(to_launch)} jobs.') 180 | if not args.skip_confirmation: 181 | ask_for_confirmation() 182 | launcher_fn = command_launchers.REGISTRY[args.command_launcher] 183 | Job.launch(to_launch, launcher_fn) 184 | 185 | elif args.command == 'delete_incomplete': 186 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 187 | print(f'About to delete {len(to_delete)} jobs.') 188 | if not args.skip_confirmation: 189 | ask_for_confirmation() 190 | Job.delete(to_delete) -------------------------------------------------------------------------------- /plench/core/hparams_registry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from plench.lib import misc 3 | 4 | MLP_DATASET = [ 5 | "Lost", 6 | "MSRCv2", 7 | "Mirflickr", 8 | "Birdsong", 9 | "Malagasy", 10 | "SoccerPlayer", 11 | "Italian", 12 | "YahooNews", 13 | "English" 14 | ] 15 | 16 | 17 | RESNET_DATASET = [ 18 | "PLCIFAR10", 19 | "PLCIFAR10_Aggregate", 20 | "PLCIFAR10_Vaguest" 21 | ] 22 | 23 | 24 | def _define_hparam(hparams, hparam_name, default_val, random_val_fn): 25 | hparams[hparam_name] = (hparams, hparam_name, default_val, random_val_fn) 26 | 27 | 28 | def _hparams(algorithm, dataset, random_seed): 29 | """ 30 | Global registry of hyperparams. Each entry is a (default, random) tuple. 31 | New algorithms / networks / etc. should add entries here. 32 | """ 33 | 34 | hparams = {} 35 | 36 | def _hparam(name, default_val, random_val_fn): 37 | """Define a hyperparameter. random_val_fn takes a RandomState and 38 | returns a random hyperparameter value.""" 39 | assert(name not in hparams) 40 | random_state = np.random.RandomState( 41 | misc.seed_hash(random_seed, name) 42 | ) 43 | hparams[name] = (default_val, random_val_fn(random_state)) 44 | 45 | # Unconditional hparam definitions. 46 | 47 | if dataset in MLP_DATASET: 48 | _hparam('model', 'MLP', lambda r: 'MLP') 49 | elif dataset in RESNET_DATASET: 50 | _hparam('model', 'ResNet', lambda r: 'ResNet') 51 | 52 | _hparam('lr', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 53 | 54 | _hparam('weight_decay', 1e-5, lambda r: 10**r.uniform(-6, -3)) 55 | if dataset in RESNET_DATASET: 56 | _hparam('batch_size', 256, lambda r: 2**int(r.uniform(6, 9))) 57 | elif dataset in MLP_DATASET: 58 | _hparam('batch_size', 128, lambda r: 2**int(r.uniform(5, 8))) 59 | 60 | # algorithm-specific hyperparameters 61 | if algorithm == 'LWS': 62 | _hparam('lw_weight', 2, lambda r: r.choice([1, 2])) 63 | elif algorithm == 'POP': 64 | _hparam('rollWindow', 5, lambda r: r.choice([3, 4, 5, 6, 7])) 65 | _hparam('warm_up', 20, lambda r: r.choice([10, 15, 20])) 66 | _hparam('theta', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 67 | _hparam('inc', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 68 | _hparam('pop_interval', 1000, lambda r: r.choice([500, 1000, 1500, 2000])) 69 | elif algorithm == 'IDGP': 70 | _hparam('warm_up_epoch', 10, lambda r: r.choice([5, 10, 15, 20])) 71 | elif algorithm == 'ABS_GCE': 72 | _hparam('q', 0.7, lambda r: 0.7) 73 | elif algorithm == 'DIRK': 74 | _hparam('momentum',0.99, lambda r: r.choice([0.5,0.9,0.99])) 75 | _hparam('moco_queue',8192,lambda r:r.choice([4096,8192])) 76 | _hparam('feat_dim', 128, lambda r: r.choice([64, 128, 256])) 77 | _hparam('prot_start',178,lambda r:r.choice([178])) 78 | _hparam('weight',1.0,lambda r: r.choice([0.1,0.3,0.5,0.7,0.9,1.0])) 79 | _hparam('dist_temperature',0.4,lambda r: r.choice([0.2,0.4,0.6])) 80 | _hparam('feat_temperature', 0.07, lambda r: r.choice([0.03, 0.05, 0.07, 0.09])) 81 | elif algorithm == 'ABLE': 82 | _hparam('feat_dim',128,lambda r: r.choice([64,128,256])) 83 | _hparam('loss_weight',1.0,lambda r: r.choice([0.5,1.0,2.0])) 84 | _hparam('temperature',0.07,lambda r:r.choice([0.03,0.05,0.07,0.09])) 85 | elif algorithm == 'PiCO': 86 | _hparam('prot_start',178,lambda r:r.choice([178])) 87 | _hparam('momentum', 0.9, lambda r: r.choice([0.8,0.9])) 88 | _hparam('feat_dim',128,lambda r: r.choice([64,128,256])) 89 | _hparam('moco_queue',8192,lambda r:r.choice([4096,8192])) 90 | _hparam('moco_m',0.999,lambda r:r.choice([0.9,0.999])) 91 | _hparam('proto_m',0.99,lambda r:r.choice([0.99,0.9])) 92 | _hparam('loss_weight',0.5,lambda r:r.choice([0.5,1.0])) 93 | _hparam('conf_ema_range','0.95,0.8',lambda r:r.choice(['0.95,0.8'])) 94 | elif algorithm == 'VALEN': 95 | _hparam('warm_up', 1000, lambda r: r.choice([1000,2000, 3000, 4000, 5000])) 96 | _hparam('knn',3, lambda r:r.choice([3,4])) 97 | _hparam('alpha',1.0,lambda r:r.choice([1.0])) 98 | _hparam('beta',1.0,lambda r:r.choice([1.0])) 99 | _hparam('lambda',1.0,lambda r:r.choice([1.0])) 100 | _hparam('gamma',1.0,lambda r:r.choice([1.0])) 101 | _hparam('theta',1.0,lambda r:r.choice([1.0])) 102 | _hparam('sigma',1.0,lambda r:r.choice([1.0])) 103 | _hparam('correct',1.0,lambda r:r.choice([1.0])) 104 | elif algorithm == 'NN': 105 | _hparam('beta', 0., lambda r: 0.) 106 | 107 | elif algorithm == 'FREDIS': 108 | _hparam('theta', 1e-6, lambda r: 1e-6) 109 | _hparam('inc', 1e-6, lambda r: r.choice([1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1])) 110 | _hparam('delta', 1.0, lambda r: 1.0) 111 | _hparam('dec', 1e-6, lambda r: r.choice([1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1])) 112 | _hparam('times', int(2), lambda r: r.choice([int(1), int(2), int(3), int(4), int(5)])) 113 | _hparam('change_size', int(500), lambda r: r.choice([int(50), int(100), int(500), int(300), int(800), int(1000)])) 114 | _hparam('update_interval', int(20), lambda r: r.choice([int(10), int(20)])) 115 | _hparam('lam', 10, lambda r: r.choice([1, 10])) 116 | _hparam('alpha', 1.0, lambda r: r.choice([1e-3, 1e-2, 1e-1, 1])) 117 | elif algorithm == 'PiCO_plus': 118 | _hparam('prot_start', 250, lambda r: r.choice([1,250])) 119 | _hparam('momentum', 0.9, lambda r: r.choice([0.8, 0.9])) 120 | _hparam('feat_dim', 128, lambda r: r.choice([64, 128, 256])) 121 | _hparam('moco_queue', 8192, lambda r: r.choice([4096, 8192])) 122 | _hparam('moco_m', 0.999, lambda r: r.choice([0.9, 0.999])) 123 | _hparam('proto_m', 0.99, lambda r: r.choice([0.99, 0.9])) 124 | _hparam('loss_weight', 0.5, lambda r: r.choice([0.5, 1.0])) 125 | _hparam('conf_ema_range', '0.95,0.8', lambda r: r.choice(['0.95,0.8'])) 126 | if dataset == 'PLCIFAR10_Aggregate': 127 | _hparam('pure_ratio', 1 - 0.00132, lambda r: 1 - 0.00132) 128 | elif dataset == 'PLCIFAR10_Vaguest': 129 | _hparam('pure_ratio', 1 - 0.17556, lambda r: 1 - 0.17556) 130 | else: 131 | _hparam('pure_ratio', 1, lambda r: r.choice([0.6,0.8,0.95,0.99,1])) 132 | _hparam('knn_start',5000,lambda r:r.choice([2,5000,100])) 133 | _hparam('chosen_neighbors',16,lambda r:r.choice([8,16])) 134 | _hparam('temperature_guess',0.07,lambda r:r.choice([0.1,0.07])) 135 | _hparam('ur_weight',0.1,lambda r:r.choice([0.1,0.5])) 136 | _hparam('cls_weight',2,lambda r:r.choice([2,3,5])) 137 | elif algorithm == 'ALIM': 138 | _hparam('prot_start',178,lambda r:r.choice([178])) 139 | _hparam('momentum', 0.9, lambda r: r.choice([0.8,0.9])) 140 | _hparam('feat_dim',128,lambda r: r.choice([64,128,256])) 141 | _hparam('moco_queue',8192,lambda r:r.choice([4096,8192])) 142 | _hparam('moco_m',0.999,lambda r:r.choice([0.9,0.999])) 143 | _hparam('proto_m',0.99,lambda r:r.choice([0.99,0.9])) 144 | _hparam('loss_weight',0.5,lambda r:r.choice([0.5,1.0])) 145 | _hparam('conf_ema_range','0.95,0.8',lambda r:r.choice(['0.95,0.8'])) 146 | 147 | _hparam('start_epoch', 40, lambda r: r.choice([20, 40, 80, 100, 140])) 148 | _hparam('loss_weight_mixup', 1.0, lambda r: 1.0) 149 | _hparam('mixup_alpha', 4, lambda r: 4) 150 | if dataset == 'PLCIFAR10_Aggregate': 151 | _hparam('noise_rate', 0.00132, lambda r: 0.00132) 152 | elif dataset == 'PLCIFAR10_Vaguest': 153 | _hparam('noise_rate', 0.17556, lambda r: 0.17556) 154 | else: 155 | _hparam('noise_rate', 0.2, lambda r: r.choice([0.1, 0.2, 0.3, 0.4, 0.5, 0.7])) 156 | 157 | return hparams 158 | 159 | 160 | def default_hparams(algorithm, dataset): 161 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()} 162 | 163 | 164 | def random_hparams(algorithm, dataset, seed): 165 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()} 166 | -------------------------------------------------------------------------------- /plench/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import os 5 | import random 6 | import sys 7 | import time 8 | import numpy as np 9 | import PIL 10 | import torch 11 | import torchvision 12 | import torch.utils.data 13 | 14 | from .data import datasets 15 | from .core import hparams_registry, algorithms 16 | from .lib import misc 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description='Partial-label learning') 20 | parser.add_argument('--data_dir', type=str, default="") 21 | parser.add_argument('--dataset', type=str, default="PLCIFAR10_Aggregate") 22 | parser.add_argument('--algorithm', type=str, default="PRODEN") 23 | parser.add_argument('--hparams', type=str, 24 | help='JSON-serialized hparams dict') 25 | parser.add_argument('--hparams_seed', type=int, default=0, 26 | help='Seed for random hparams (0 means "default hparams")') 27 | parser.add_argument('--trial_seed', type=int, default=0, 28 | help='Trial number (used for seeding split_dataset and ' 29 | 'random_hparams).') 30 | parser.add_argument('--seed', type=int, default=0, 31 | help='Seed for everything else') 32 | parser.add_argument('--steps', type=int, default=60000, 33 | help='Number of steps. Default is dataset-dependent.') 34 | parser.add_argument('--checkpoint_freq', type=int, default=1000, 35 | help='Checkpoint every N steps. Default is dataset-dependent.') 36 | parser.add_argument('--output_dir', type=str, default="train_output") 37 | parser.add_argument('--holdout_fraction', type=float, default=0.1) 38 | parser.add_argument('--tabular_test_fraction', type=float, default=0.1, required=False) 39 | parser.add_argument('--skip_model_save', action='store_true') 40 | parser.add_argument('--save_model_every_checkpoint', action='store_true') 41 | parser.add_argument('--pl_type', help='partial label generation type', default='real', type=str, choices=['uss','fps', 'real'], required=False) 42 | parser.add_argument('--fps_rate', help='partial label generation flip rate of fps', default=0, type=float, required=False) 43 | args = parser.parse_args() 44 | start_step = 0 45 | algorithm_dict = None 46 | 47 | os.makedirs(args.output_dir, exist_ok=True) 48 | sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) 49 | sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) 50 | 51 | print("Environment:") 52 | print("\tPython: {}".format(sys.version.split(" ")[0])) 53 | print("\tPyTorch: {}".format(torch.__version__)) 54 | print("\tTorchvision: {}".format(torchvision.__version__)) 55 | print("\tCUDA: {}".format(torch.version.cuda)) 56 | print("\tCUDNN: {}".format(torch.backends.cudnn.version())) 57 | print("\tNumPy: {}".format(np.__version__)) 58 | print("\tPIL: {}".format(PIL.__version__)) 59 | 60 | print('Args:') 61 | for k, v in sorted(vars(args).items()): 62 | print('\t{}: {}'.format(k, v)) 63 | 64 | if args.hparams_seed == 0: 65 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 66 | else: 67 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 68 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 69 | if args.hparams: 70 | hparams.update(json.loads(args.hparams)) 71 | 72 | hparams['max_steps']=args.steps 73 | hparams['output_dir']=args.output_dir 74 | print('HParams:') 75 | for k, v in sorted(hparams.items()): 76 | print('\t{}: {}'.format(k, v)) 77 | 78 | random.seed(args.seed) 79 | np.random.seed(args.seed) 80 | torch.manual_seed(args.seed) 81 | torch.backends.cudnn.deterministic = True 82 | torch.backends.cudnn.benchmark = False 83 | 84 | os.environ['PYTHONHASHSEED'] = str(args.seed) 85 | torch.cuda.manual_seed(args.seed) 86 | torch.cuda.manual_seed_all(args.seed) 87 | 88 | 89 | 90 | if torch.cuda.is_available(): 91 | device = "cuda" 92 | else: 93 | device = "cpu" 94 | 95 | if args.dataset in datasets.IMAGE_DATASETS and args.dataset in vars(datasets): 96 | dataset = vars(datasets)[args.dataset](args.data_dir, args) 97 | test_dataset = datasets.test_dataset_gen(args.data_dir, args) 98 | elif args.dataset in datasets.TABULAR_DATASETS: 99 | dataset, test_dataset = datasets.tabular_train_test_dataset_gen(root=args.data_dir, seed=misc.seed_hash(args.trial_seed), args=args) 100 | else: 101 | raise NotImplementedError 102 | 103 | val_dataset, train_dataset = misc.split_dataset(dataset, int(len(dataset)*args.holdout_fraction), misc.seed_hash(args.trial_seed)) 104 | 105 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=hparams['batch_size'], shuffle=True,num_workers=0, drop_last=True) 106 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False, num_workers=0, drop_last=False) 107 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=64, shuffle=True,num_workers=0, drop_last=False) 108 | 109 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 110 | algorithm = algorithm_class(train_dataset.data.shape, train_dataset.partial_targets, hparams) 111 | 112 | if algorithm_dict is not None: 113 | algorithm.load_state_dict(algorithm_dict) 114 | 115 | algorithm.to(device) 116 | train_minibatches_iterator = iter(train_loader) 117 | checkpoint_vals = collections.defaultdict(lambda: []) 118 | steps_per_epoch = len(train_dataset)/hparams['batch_size'] 119 | n_steps = args.steps 120 | checkpoint_freq = args.checkpoint_freq 121 | 122 | def save_checkpoint(filename): 123 | if args.skip_model_save: 124 | return 125 | save_dict = { 126 | "args": vars(args), 127 | "model_input_shape": train_dataset.data.shape, 128 | "model_num_classes": dataset.num_classes, 129 | "model_hparams": hparams, 130 | "model_dict": algorithm.state_dict() 131 | } 132 | torch.save(save_dict, os.path.join(args.output_dir, filename)) 133 | 134 | last_results_keys = None 135 | for step in range(start_step, n_steps): 136 | algorithm.train() 137 | step_start_time = time.time() 138 | if args.algorithm == 'ALIM': 139 | try: 140 | minibatches_device1 = [item.to(device) for item in next(train_minibatches_iterator)] 141 | except: 142 | train_minibatches_iterator = iter(train_loader) 143 | minibatches_device1 = [item.to(device) for item in next(train_minibatches_iterator)] 144 | try: 145 | minibatches_device2 = [item.to(device) for item in next(train_minibatches_iterator)] 146 | except: 147 | train_minibatches_iterator = iter(train_loader) 148 | minibatches_device2 = [item.to(device) for item in next(train_minibatches_iterator)] 149 | step_vals = algorithm.update(minibatches_device1, minibatches_device2) 150 | else: 151 | try: 152 | minibatches_device = [item.to(device) for item in next(train_minibatches_iterator)] 153 | except: 154 | train_minibatches_iterator = iter(train_loader) 155 | minibatches_device = [item.to(device) for item in next(train_minibatches_iterator)] 156 | step_vals = algorithm.update(minibatches_device) 157 | checkpoint_vals['step_time'].append(time.time() - step_start_time) 158 | 159 | for key, val in step_vals.items(): 160 | checkpoint_vals[key].append(val) 161 | 162 | if (step % checkpoint_freq == 0) or (step == n_steps - 1): 163 | results = { 164 | 'step': step, 165 | 'epoch': step / steps_per_epoch, 166 | } 167 | 168 | for key, val in checkpoint_vals.items(): 169 | results[key] = np.mean(val) 170 | 171 | val_acc = misc.val_accuracy(algorithm, val_loader, device) 172 | results['val_accuracy'] = val_acc 173 | val_covering_rate = misc.val_covering_rate(algorithm, val_loader, device) 174 | results['val_covering_rate'] = val_covering_rate 175 | val_approximated_acc = misc.val_approximated_accuracy(algorithm, val_loader, device) 176 | results['val_approximated_acc'] = val_approximated_acc 177 | acc = misc.accuracy(algorithm, test_loader, device) 178 | results['test_acc'] = acc 179 | results['mem_gb'] = torch.cuda.max_memory_allocated() / (1024.*1024.*1024.) 180 | 181 | results_keys = sorted(results.keys()) 182 | if results_keys != last_results_keys: 183 | misc.print_row(results_keys, colwidth=12) 184 | last_results_keys = results_keys 185 | misc.print_row([results[key] for key in results_keys], colwidth=12) 186 | 187 | results.update({ 188 | 'hparams': hparams, 189 | 'args': vars(args) 190 | }) 191 | 192 | epochs_path = os.path.join(args.output_dir, 'results.jsonl') 193 | with open(epochs_path, 'a') as f: 194 | f.write(json.dumps(results, cls=misc.NpEncoder, sort_keys=True) + "\n") 195 | 196 | algorithm_dict = algorithm.state_dict() 197 | checkpoint_vals = collections.defaultdict(lambda: []) 198 | 199 | if args.save_model_every_checkpoint: 200 | save_checkpoint(f'model_step{step}.pkl') 201 | 202 | save_checkpoint('model.pkl') 203 | # delete VALEN adj file 204 | if args.algorithm == 'VALEN': 205 | adj_file_path = os.path.join(args.output_dir, 'adj_matrix.npy') 206 | if os.path.exists(adj_file_path): 207 | os.remove(adj_file_path) 208 | 209 | with open(os.path.join(args.output_dir, 'done'), 'w') as f: 210 | f.write('done') -------------------------------------------------------------------------------- /plench/data/datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | from PIL import Image, ImageFile 5 | from torchvision import transforms 6 | import torchvision.datasets as dsets 7 | from torch.utils.data import TensorDataset, Subset, ConcatDataset, Dataset 8 | import pickle 9 | import numpy as np 10 | from scipy.special import comb 11 | from .randaugment import RandomAugment 12 | from .cutout import Cutout 13 | from scipy.io import loadmat 14 | 15 | 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | 18 | DATASETS = [ 19 | "PLCIFAR10_Aggregate", 20 | "PLCIFAR10_Vaguest", 21 | "Lost", 22 | "MSRCv2", 23 | "Mirflickr", 24 | "Birdsong", 25 | "Malagasy", 26 | "SoccerPlayer", 27 | "Italian", 28 | "YahooNews", 29 | "English" 30 | ] 31 | 32 | IMAGE_DATASETS = [ 33 | "PLCIFAR10_Aggregate", 34 | "PLCIFAR10_Vaguest" 35 | ] 36 | 37 | TABULAR_DATASETS = [ 38 | "Lost", 39 | "MSRCv2", 40 | "Mirflickr", 41 | "Birdsong", 42 | "Malagasy", 43 | "SoccerPlayer", 44 | "Italian", 45 | "YahooNews", 46 | "English" 47 | ] 48 | 49 | class gen_index_test_tabular_dataset(Dataset): 50 | def __init__(self, images, true_labels): 51 | self.data = images 52 | self.data = torch.from_numpy(self.data).to(torch.float32) 53 | self.ord_labels = true_labels 54 | 55 | def __len__(self): 56 | return len(self.ord_labels) 57 | 58 | def __getitem__(self, index): 59 | each_image = self.data[index] 60 | each_true_label = self.ord_labels[index] 61 | 62 | return each_image, each_true_label 63 | 64 | class gen_index_train_tabular_dataset(Dataset): 65 | def __init__(self, images, given_label_matrix, true_labels): 66 | self.data = images 67 | self.data = torch.from_numpy(self.data).to(torch.float32) 68 | self.partial_targets = given_label_matrix 69 | self.ord_labels = true_labels 70 | self.input_dim = images.shape[1] 71 | self.num_classes = given_label_matrix.shape[1] 72 | 73 | def __len__(self): 74 | return len(self.ord_labels) 75 | 76 | def __getitem__(self, index): 77 | each_image = self.data[index] 78 | each_label = self.partial_targets[index] 79 | each_true_label = self.ord_labels[index] 80 | 81 | return each_image, each_image, each_image, each_image, each_label, each_true_label 82 | 83 | def tabular_train_test_dataset_gen(root, seed, args=None): 84 | dataset_path = os.path.join(root, (args.dataset+".mat")) 85 | total_data = loadmat(dataset_path) 86 | data, ord_labels_mat, partial_targets = total_data['data'], total_data['target'], total_data['partial_target'] 87 | ord_labels_mat = ord_labels_mat.transpose() 88 | partial_targets = partial_targets.transpose() 89 | if type(ord_labels_mat) != np.ndarray: 90 | ord_labels_mat = ord_labels_mat.toarray() 91 | partial_targets = partial_targets.toarray() 92 | if data.shape[0] != ord_labels_mat.shape[0] or data.shape[0] != partial_targets.shape[0]: 93 | raise RuntimeError('The shape of data and labels does not match!') 94 | if ord_labels_mat.sum() != len(data): 95 | raise RuntimeError('Data may have more than one label!') 96 | _, ord_labels = np.where(ord_labels_mat == 1) 97 | data = (data - data.mean(axis=0, keepdims=True))/(data.std(axis=0, keepdims=True)+1e-6) 98 | data = data.astype(float) 99 | total_size = data.shape[0] 100 | train_size = int(total_size * (1 - args.tabular_test_fraction)) 101 | keys = list(range(total_size)) 102 | np.random.RandomState(seed).shuffle(keys) 103 | train_idx = keys[:train_size] 104 | test_idx = keys[train_size:] 105 | train_set = gen_index_train_tabular_dataset(data[train_idx], partial_targets[train_idx], ord_labels[train_idx]) 106 | test_set = gen_index_test_tabular_dataset(data[test_idx], ord_labels[test_idx]) 107 | return train_set, test_set 108 | 109 | 110 | 111 | 112 | 113 | class Tabular_Dataset(Dataset): 114 | def __init__(self, root, args=None): 115 | self.dataset_path = root 116 | self.total_data = loadmat(mat_path) 117 | 118 | self.data, self.ord_labels_mat, self.partial_targets = self.total_data['data'], self.total_data['target'], self.total_data['partial_target'] 119 | self.ord_labels_mat = self.ord_labels_mat.transpose() 120 | self.partial_targets = self.partial_targets.transpose() 121 | if self.data.shape[0] != self.ord_labels_mat.shape[0] or self.data.shape[0] != self.partial_targets.shape[0]: 122 | raise RuntimeError('The shape of data and labels does not match!') 123 | if self.ord_labels_mat.sum() != len(self.data): 124 | raise RuntimeError('Data may have more than one label!') 125 | _, self.ord_labels = np.where(self.ord_labels_mat == 1) 126 | self.data = (self.data - self.data.mean(axis=0, keepdims=True))/self.data.std(axis=0, keepdims=True) 127 | self.input_dim = self.data.shape[1] 128 | self.num_classes = self.ord_labels_mat.shape[1] 129 | 130 | def __len__(self): 131 | return len(self.data) 132 | 133 | def __getitem__(self, index): 134 | image = self.data[index] 135 | original_image = image 136 | weak_image = image 137 | strong_image = image 138 | return original_image, weak_image, strong_image, self.partial_targets[index], self.ord_labels[index] 139 | 140 | 141 | class PLCIFAR10_Aggregate(Dataset): 142 | def __init__(self, root, args=None): 143 | 144 | 145 | dataset_path = os.path.join(root, 'plcifar10', f"plcifar10.pkl") 146 | partial_label_all = pickle.load(open(dataset_path, "rb")) 147 | 148 | self.transform = transforms.Compose( 149 | [transforms.ToTensor(), 150 | transforms.RandomHorizontalFlip(), 151 | transforms.RandomCrop(32,4), 152 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))]) 153 | self.strong_transform = transforms.Compose([ 154 | transforms.ToPILImage(), 155 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 156 | transforms.RandomHorizontalFlip(), 157 | RandomAugment(3, None), 158 | transforms.ToTensor(), 159 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605)), 160 | ]) 161 | self.distill_transform = transforms.Compose([ 162 | transforms.ToPILImage(), 163 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 164 | transforms.RandomHorizontalFlip(), 165 | transforms.ToTensor(), 166 | Cutout(n_holes=1, length=16), 167 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605)) 168 | ]) 169 | self.test_transform = transforms.Compose( 170 | [transforms.ToTensor(), 171 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))]) 172 | self.input_dim = 32 * 32 * 3 173 | self.num_classes = 10 174 | 175 | original_dataset = dsets.CIFAR10(root=root, train=True, download=True) 176 | self.data = original_dataset.data 177 | 178 | self.partial_targets = np.zeros((len(self.data), self.num_classes)) 179 | 180 | 181 | for key, value in partial_label_all.items(): 182 | for candidate_label_set in value: 183 | for label in candidate_label_set: 184 | self.partial_targets[key, label] = 1 185 | 186 | 187 | self.ord_labels = original_dataset.targets 188 | self.ord_labels = np.array(self.ord_labels) 189 | 190 | def __len__(self): 191 | return len(self.data) 192 | 193 | def __getitem__(self, index): 194 | image = self.data[index] 195 | original_image = self.test_transform(image) 196 | weak_image = self.transform(image) 197 | strong_image=self.strong_transform(image) 198 | distill_image=self.distill_transform(image) 199 | return original_image, weak_image, strong_image, distill_image, self.partial_targets[index], self.ord_labels[index] 200 | 201 | 202 | class PLCIFAR10_Vaguest(Dataset): 203 | def __init__(self, root, args=None): 204 | dataset_path = os.path.join(root, 'plcifar10', f"plcifar10.pkl") 205 | partial_label_all = pickle.load(open(dataset_path, "rb")) 206 | self.transform = transforms.Compose( 207 | [transforms.ToTensor(), 208 | transforms.RandomHorizontalFlip(), 209 | transforms.RandomCrop(32,4), 210 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))]) 211 | self.strong_transform = transforms.Compose([ 212 | transforms.ToPILImage(), 213 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 214 | transforms.RandomHorizontalFlip(), 215 | RandomAugment(3, None), 216 | transforms.ToTensor(), 217 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605)), 218 | ]) 219 | self.distill_transform = transforms.Compose([ 220 | transforms.ToPILImage(), 221 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 222 | transforms.RandomHorizontalFlip(), 223 | transforms.ToTensor(), 224 | Cutout(n_holes=1, length=16), 225 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605)) 226 | ]) 227 | self.test_transform = transforms.Compose( 228 | [transforms.ToTensor(), 229 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))]) 230 | self.input_dim = 32 * 32 * 3 231 | self.num_classes = 10 232 | 233 | original_dataset = dsets.CIFAR10(root=root, train=True, download=True) 234 | self.data = original_dataset.data 235 | 236 | self.partial_targets = np.zeros((len(self.data), self.num_classes)) 237 | 238 | 239 | for key, value in partial_label_all.items(): 240 | vaguest_candidate_label_set = [] 241 | largest_num = 0 242 | for candidate_label_set in value: 243 | if len(candidate_label_set) > largest_num: 244 | vaguest_candidate_label_set = candidate_label_set 245 | largest_num = len(candidate_label_set) 246 | for label in vaguest_candidate_label_set: 247 | self.partial_targets[key, label] = 1 248 | 249 | self.ord_labels = original_dataset.targets 250 | self.ord_labels = np.array(self.ord_labels) 251 | 252 | def __len__(self): 253 | return len(self.data) 254 | 255 | def __getitem__(self, index): 256 | image = self.data[index] 257 | original_image = self.test_transform(image) 258 | weak_image = self.transform(image) 259 | strong_image=self.strong_transform(image) 260 | distill_image = self.distill_transform(image) 261 | return original_image, weak_image, strong_image, distill_image, self.partial_targets[index], self.ord_labels[index] 262 | 263 | 264 | def test_dataset_gen(root, args=None): 265 | if args.dataset == "PLCIFAR10_Aggregate" or args.dataset == "PLCIFAR10_Vaguest": 266 | test_transform = transforms.Compose( 267 | [transforms.ToTensor(), 268 | transforms.Normalize((0.4922, 0.4832, 0.4486), (0.2456, 0.2419, 0.2605))]) 269 | test_dataset = dsets.CIFAR10(root=root, train=False, transform=test_transform) 270 | return test_dataset 271 | -------------------------------------------------------------------------------- /plench/core/algorithms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import copy 6 | import numpy as np 7 | import random 8 | import scipy.sparse as sp 9 | import os 10 | from sklearn.metrics import euclidean_distances 11 | 12 | from . import networks 13 | 14 | ALGORITHMS = [ 15 | 'PRODEN', 16 | 'VALEN', 17 | 'CAVL', 18 | 'POP', 19 | 'ABS_MAE', 20 | 'ABS_GCE', 21 | 'CC', 22 | 'EXP', 23 | 'MCL_GCE', 24 | 'MCL_MSE', 25 | 'LWS', 26 | 'IDGP', 27 | 'PC', 28 | 'Forward', 29 | 'NN', 30 | 'GA', 31 | 'SCL_EXP', 32 | 'SCL_NL', 33 | 'L_W', 34 | 'OP_W', 35 | 'PiCO', 36 | 'ABLE', 37 | 'CRDPLL', 38 | 'DIRK', 39 | 'FREDIS', 40 | 'ALIM', 41 | 'PiCO_plus', 42 | ] 43 | 44 | def get_algorithm_class(algorithm_name): 45 | """Return the algorithm class with the given name.""" 46 | if algorithm_name not in globals(): 47 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 48 | return globals()[algorithm_name] 49 | 50 | class Algorithm(torch.nn.Module): 51 | """ 52 | A subclass of Algorithm implements a partial-label learning algorithm. 53 | Subclasses should implement the following: 54 | - update() 55 | - predict() 56 | """ 57 | def __init__(self, input_shape, train_givenY, hparams): 58 | super(Algorithm, self).__init__() 59 | self.hparams = hparams 60 | self.num_data = input_shape[0] 61 | self.num_classes = train_givenY.shape[1] 62 | 63 | def update(self, minibatches, unlabeled=None): 64 | """ 65 | Perform one update step 66 | """ 67 | raise NotImplementedError 68 | 69 | def predict(self, x): 70 | raise NotImplementedError 71 | 72 | class PRODEN(Algorithm): 73 | """ 74 | PRODEN 75 | Reference: Progressive identification of true labels for partial-label learning, ICML 2020. 76 | """ 77 | 78 | def __init__(self, input_shape, train_givenY, hparams): 79 | super(PRODEN, self).__init__(input_shape, train_givenY, hparams) 80 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 81 | self.classifier = networks.Classifier( 82 | self.featurizer.n_outputs, 83 | self.num_classes) 84 | 85 | self.network = nn.Sequential(self.featurizer, self.classifier) 86 | self.optimizer = torch.optim.Adam( 87 | self.network.parameters(), 88 | lr=self.hparams["lr"], 89 | weight_decay=self.hparams['weight_decay'] 90 | ) 91 | train_givenY = torch.from_numpy(train_givenY) 92 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 93 | label_confidence = train_givenY.float()/tempY 94 | self.label_confidence = label_confidence 95 | 96 | def update(self, minibatches): 97 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 98 | loss = self.rc_loss(self.predict(x), index) 99 | self.optimizer.zero_grad() 100 | loss.backward() 101 | self.optimizer.step() 102 | self.confidence_update(x, partial_y, index) 103 | return {'loss': loss.item()} 104 | 105 | def rc_loss(self, outputs, index): 106 | device = "cuda" if index.is_cuda else "cpu" 107 | self.label_confidence = self.label_confidence.to(device) 108 | logsm_outputs = F.log_softmax(outputs, dim=1) 109 | final_outputs = logsm_outputs * self.label_confidence[index, :] 110 | average_loss = - ((final_outputs).sum(dim=1)).mean() 111 | return average_loss 112 | 113 | def predict(self, x): 114 | return self.network(x) 115 | 116 | def confidence_update(self, batchX, batchY, batch_index): 117 | with torch.no_grad(): 118 | batch_outputs = self.predict(batchX) 119 | temp_un_conf = F.softmax(batch_outputs, dim=1) 120 | self.label_confidence[batch_index, :] = temp_un_conf * batchY # un_confidence stores the weight of each example 121 | base_value = self.label_confidence.sum(dim=1).unsqueeze(1).repeat(1, self.label_confidence.shape[1]) 122 | self.label_confidence = self.label_confidence / base_value 123 | 124 | class CC(Algorithm): 125 | """ 126 | CC 127 | Reference: Provably consistent partial-label learning, NeurIPS 2020. 128 | """ 129 | 130 | def __init__(self, input_shape, train_givenY, hparams): 131 | super(CC, self).__init__(input_shape, train_givenY, hparams) 132 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 133 | self.classifier = networks.Classifier( 134 | self.featurizer.n_outputs, 135 | self.num_classes) 136 | 137 | self.network = nn.Sequential(self.featurizer, self.classifier) 138 | self.optimizer = torch.optim.Adam( 139 | self.network.parameters(), 140 | lr=self.hparams["lr"], 141 | weight_decay=self.hparams['weight_decay'] 142 | ) 143 | 144 | def update(self, minibatches): 145 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 146 | loss = self.cc_loss(self.predict(x), partial_y) 147 | self.optimizer.zero_grad() 148 | loss.backward() 149 | self.optimizer.step() 150 | return {'loss': loss.item()} 151 | 152 | def cc_loss(self, outputs, partialY): 153 | sm_outputs = F.softmax(outputs, dim=1) 154 | final_outputs = sm_outputs * partialY 155 | average_loss = - torch.log(final_outputs.sum(dim=1)).mean() 156 | return average_loss 157 | 158 | def predict(self, x): 159 | return self.network(x) 160 | 161 | class EXP(Algorithm): 162 | """ 163 | EXP 164 | Reference: Learning with multiple complementary labels, ICML 2020. 165 | """ 166 | 167 | def __init__(self, input_shape, train_givenY, hparams): 168 | super(EXP, self).__init__(input_shape, train_givenY, hparams) 169 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 170 | self.classifier = networks.Classifier( 171 | self.featurizer.n_outputs, 172 | self.num_classes) 173 | 174 | self.network = nn.Sequential(self.featurizer, self.classifier) 175 | self.optimizer = torch.optim.Adam( 176 | self.network.parameters(), 177 | lr=self.hparams["lr"], 178 | weight_decay=self.hparams['weight_decay'] 179 | ) 180 | 181 | def update(self, minibatches): 182 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 183 | loss = self.exp_loss(self.predict(x), partial_y) 184 | self.optimizer.zero_grad() 185 | loss.backward() 186 | self.optimizer.step() 187 | return {'loss': loss.item()} 188 | 189 | def exp_loss(self, outputs, partialY): 190 | can_num = partialY.sum(dim=1).float() # n 191 | soft_max = nn.Softmax(dim=1) 192 | sm_outputs = soft_max(outputs) 193 | final_outputs = sm_outputs * partialY 194 | average_loss = ((self.num_classes-1)/(self.num_classes-can_num) * torch.exp(-final_outputs.sum(dim=1))).mean() 195 | return average_loss 196 | 197 | def predict(self, x): 198 | return self.network(x) 199 | 200 | class URE_LMCL(Algorithm): 201 | """ 202 | URE for LMCL 203 | Reference: Learning with multiple complementary labels, ICML 2020. 204 | """ 205 | def __init__(self, input_shape, train_givenY, hparams): 206 | super(URE_LMCL, self).__init__(input_shape, train_givenY, hparams) 207 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 208 | self.classifier = networks.Classifier( 209 | self.featurizer.n_outputs, 210 | self.num_classes) 211 | 212 | self.network = nn.Sequential(self.featurizer, self.classifier) 213 | self.optimizer = torch.optim.Adam( 214 | self.network.parameters(), 215 | lr=self.hparams["lr"], 216 | weight_decay=self.hparams['weight_decay'] 217 | ) 218 | 219 | def update(self, minibatches): 220 | _, x, strong_x, distill_x, partial_y, _, _ = minibatches 221 | loss = self.unbiased_estimator(self.predict(x), partial_y) 222 | self.optimizer.zero_grad() 223 | loss.backward() 224 | self.optimizer.step() 225 | return {'loss': loss.item()} 226 | 227 | def unbiased_estimator(self, outputs, partialY): 228 | device = "cuda" if outputs.is_cuda else "cpu" 229 | comp_num = self.num_classes - partialY.sum(dim=1) 230 | temp_loss = torch.zeros_like(outputs).to(device) 231 | for i in range(self.num_classes): 232 | tempY = torch.zeros_like(outputs).to(device) 233 | tempY[:, i] = 1.0 234 | temp_loss[:, i] = self.loss_fn(outputs, tempY) 235 | 236 | candidate_loss = (temp_loss * partialY).sum(dim=1) 237 | noncandidate_loss = (temp_loss * (1-partialY)).sum(dim=1) 238 | total_loss = candidate_loss - (self.num_classes-comp_num-1.0)/(comp_num * noncandidate_loss+1e-20) 239 | average_loss = total_loss.mean() 240 | return average_loss 241 | 242 | def loss_fn(self, outputs, Y): 243 | 244 | raise NotImplementedError 245 | 246 | def predict(self, x): 247 | return self.network(x) 248 | 249 | class MCL_GCE(URE_LMCL): 250 | """ 251 | MCL_GCE 252 | Reference: Learning with multiple complementary labels, ICML 2020. 253 | """ 254 | 255 | def __init__(self, input_shape, train_givenY, hparams): 256 | super(MCL_GCE, self).__init__(input_shape, train_givenY, hparams) 257 | 258 | def loss_fn(self, outputs, Y): 259 | q = 0.7 260 | sm_outputs = F.softmax(outputs, dim=1) 261 | pow_outputs = torch.pow(sm_outputs, q) 262 | sample_loss = (1-(pow_outputs*Y).sum(dim=1))/q # n 263 | return sample_loss 264 | 265 | class MCL_MSE(URE_LMCL): 266 | """ 267 | MCL_MSE 268 | Reference: Learning with multiple complementary labels, ICML 2020. 269 | """ 270 | 271 | def __init__(self, input_shape, train_givenY, hparams): 272 | super(MCL_MSE, self).__init__(input_shape, train_givenY, hparams) 273 | 274 | def loss_fn(self, outputs, Y): 275 | sm_outputs = F.softmax(outputs, dim=1) 276 | loss_fn_local = nn.MSELoss(reduction='none') 277 | loss_matrix = loss_fn_local(sm_outputs, Y.float()) 278 | sample_loss = loss_matrix.sum(dim=-1) 279 | return sample_loss 280 | 281 | class LWS(Algorithm): 282 | """ 283 | LWS 284 | Reference: Leveraged weighted loss for partial label learning, ICML 2021. 285 | """ 286 | 287 | def __init__(self, input_shape, train_givenY, hparams): 288 | super(LWS, self).__init__(input_shape, train_givenY, hparams) 289 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 290 | self.classifier = networks.Classifier( 291 | self.featurizer.n_outputs, 292 | self.num_classes) 293 | 294 | self.network = nn.Sequential(self.featurizer, self.classifier) 295 | self.optimizer = torch.optim.Adam( 296 | self.network.parameters(), 297 | lr=self.hparams["lr"], 298 | weight_decay=self.hparams['weight_decay'] 299 | ) 300 | train_givenY = torch.from_numpy(train_givenY) 301 | label_confidence = torch.ones(train_givenY.shape[0], train_givenY.shape[1]) / train_givenY.shape[1] 302 | self.label_confidence = label_confidence 303 | 304 | def update(self, minibatches): 305 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 306 | device = "cuda" if index.is_cuda else "cpu" 307 | self.label_confidence = self.label_confidence.to(device) 308 | loss = self.lws_loss(self.predict(x), partial_y, index) 309 | self.optimizer.zero_grad() 310 | loss.backward() 311 | self.optimizer.step() 312 | self.confidence_update(x, partial_y, index) 313 | return {'loss': loss.item()} 314 | 315 | def lws_loss(self, outputs, partialY, index): 316 | device = "cuda" if outputs.is_cuda else "cpu" 317 | onezero = torch.zeros(outputs.shape[0], outputs.shape[1]) 318 | onezero[partialY > 0] = 1 319 | counter_onezero = 1 - onezero 320 | onezero = onezero.to(device) 321 | counter_onezero = counter_onezero.to(device) 322 | sig_loss1 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1]) 323 | sig_loss1 = sig_loss1.to(device) 324 | sig_loss1[outputs < 0] = 1 / (1 + torch.exp(outputs[outputs < 0])) 325 | sig_loss1[outputs > 0] = torch.exp(-outputs[outputs > 0]) / ( 326 | 1 + torch.exp(-outputs[outputs > 0])) 327 | l1 = self.label_confidence[index, :] * onezero * sig_loss1 328 | average_loss1 = torch.sum(l1) / l1.size(0) 329 | sig_loss2 = 0.5 * torch.ones(outputs.shape[0], outputs.shape[1]) 330 | sig_loss2 = sig_loss2.to(device) 331 | sig_loss2[outputs > 0] = 1 / (1 + torch.exp(-outputs[outputs > 0])) 332 | sig_loss2[outputs < 0] = torch.exp( 333 | outputs[outputs < 0]) / (1 + torch.exp(outputs[outputs < 0])) 334 | l2 = self.label_confidence[index, :] * counter_onezero * sig_loss2 335 | average_loss2 = torch.sum(l2) / l2.size(0) 336 | average_loss = average_loss1 + self.hparams["lw_weight"] * average_loss2 337 | return average_loss 338 | 339 | def predict(self, x): 340 | return self.network(x) 341 | 342 | def confidence_update(self, batchX, batchY, batch_index): 343 | with torch.no_grad(): 344 | device = "cuda" if batch_index.is_cuda else "cpu" 345 | batch_outputs = self.predict(batchX) 346 | sm_outputs = F.softmax(batch_outputs, dim=1) 347 | onezero = torch.zeros(sm_outputs.shape[0], sm_outputs.shape[1]) 348 | onezero[batchY > 0] = 1 349 | counter_onezero = 1 - onezero 350 | onezero = onezero.to(device) 351 | counter_onezero = counter_onezero.to(device) 352 | new_weight1 = sm_outputs * onezero 353 | new_weight1 = new_weight1 / (new_weight1 + 1e-8).sum(dim=1).repeat( 354 | self.label_confidence.shape[1], 1).transpose(0, 1) 355 | new_weight2 = sm_outputs * counter_onezero 356 | new_weight2 = new_weight2 / (new_weight2 + 1e-8).sum(dim=1).repeat( 357 | self.label_confidence.shape[1], 1).transpose(0, 1) 358 | new_weight = new_weight1 + new_weight2 359 | self.label_confidence[batch_index, :] = new_weight 360 | 361 | class CAVL(Algorithm): 362 | """ 363 | CAVL 364 | Reference: Exploiting Class Activation Value for Partial-Label Learning, ICLR 2022. 365 | """ 366 | 367 | def __init__(self, input_shape, train_givenY, hparams): 368 | super(CAVL, self).__init__(input_shape, train_givenY, hparams) 369 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 370 | self.classifier = networks.Classifier( 371 | self.featurizer.n_outputs, 372 | self.num_classes) 373 | 374 | self.network = nn.Sequential(self.featurizer, self.classifier) 375 | self.optimizer = torch.optim.Adam( 376 | self.network.parameters(), 377 | lr=self.hparams["lr"], 378 | weight_decay=self.hparams['weight_decay'] 379 | ) 380 | train_givenY = torch.from_numpy(train_givenY) 381 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 382 | label_confidence = train_givenY.float()/tempY 383 | self.label_confidence = label_confidence 384 | self.label_confidence = self.label_confidence.double() 385 | 386 | def update(self, minibatches): 387 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 388 | loss = self.rc_loss(self.predict(x), index) 389 | self.optimizer.zero_grad() 390 | loss.backward() 391 | self.optimizer.step() 392 | self.confidence_update(x, partial_y, index) 393 | return {'loss': loss.item()} 394 | 395 | def rc_loss(self, outputs, index): 396 | device = "cuda" if index.is_cuda else "cpu" 397 | self.label_confidence = self.label_confidence.to(device) 398 | logsm_outputs = F.log_softmax(outputs, dim=1) 399 | final_outputs = logsm_outputs * self.label_confidence[index, :] 400 | average_loss = - ((final_outputs).sum(dim=1)).mean() 401 | return average_loss 402 | 403 | def predict(self, x): 404 | return self.network(x) 405 | 406 | def confidence_update(self, batchX, batchY, batch_index): 407 | with torch.no_grad(): 408 | batch_outputs = self.predict(batchX) 409 | cav = (batch_outputs*torch.abs(1-batch_outputs))*batchY 410 | cav_pred = torch.max(cav,dim=1)[1] 411 | gt_label = F.one_hot(cav_pred,batchY.shape[1]) 412 | self.label_confidence[batch_index,:] = gt_label.double() 413 | 414 | class POP(Algorithm): 415 | """ 416 | POP 417 | Reference: Progressive purification for instance-dependent partial label learning, ICML 2023. 418 | """ 419 | 420 | def __init__(self, input_shape, train_givenY, hparams): 421 | super(POP, self).__init__(input_shape, train_givenY, hparams) 422 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 423 | self.classifier = networks.Classifier( 424 | self.featurizer.n_outputs, 425 | self.num_classes) 426 | 427 | self.network = nn.Sequential(self.featurizer, self.classifier) 428 | self.optimizer = torch.optim.Adam( 429 | self.network.parameters(), 430 | lr=self.hparams["lr"], 431 | weight_decay=self.hparams['weight_decay'] 432 | ) 433 | self.train_givenY = torch.from_numpy(train_givenY) 434 | tempY = self.train_givenY.sum(dim=1).unsqueeze(1).repeat(1, self.train_givenY.shape[1]) 435 | label_confidence = self.train_givenY.float()/tempY 436 | self.label_confidence = label_confidence 437 | self.f_record = torch.zeros([self.hparams['rollWindow'], label_confidence.shape[0], label_confidence.shape[1]]) 438 | self.curr_iter = 0 439 | self.theta = self.hparams['theta'] 440 | self.steps_per_epoch = train_givenY.shape[0] // self.hparams['batch_size'] 441 | 442 | 443 | def update(self, minibatches): 444 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 445 | device = "cuda" if index.is_cuda else "cpu" 446 | loss = self.rc_loss(self.predict(x), index) 447 | self.optimizer.zero_grad() 448 | loss.backward() 449 | self.optimizer.step() 450 | self.confidence_update(x, partial_y, index) 451 | self.f_record = self.f_record.to(device) 452 | if self.curr_iter % self.steps_per_epoch == 0: 453 | epoch_num = self.curr_iter / self.steps_per_epoch 454 | self.f_record[int(epoch_num % self.hparams['rollWindow']), :] = self.label_confidence 455 | if self.curr_iter >= (self.hparams['warm_up'] * self.steps_per_epoch): 456 | temp_prob_matrix = self.f_record.mean(0) 457 | # label correction 458 | temp_prob_matrix = temp_prob_matrix / temp_prob_matrix.sum(dim=1).repeat(temp_prob_matrix.size(1),1).transpose(0, 1) 459 | correction_label_matrix = self.train_givenY 460 | correction_label_matrix = correction_label_matrix.to(device) 461 | pre_correction_label_matrix = correction_label_matrix.clone() 462 | correction_label_matrix[temp_prob_matrix / torch.max(temp_prob_matrix, dim=1, keepdim=True)[0] < self.theta] = 0 463 | tmp_label_matrix = temp_prob_matrix * correction_label_matrix 464 | self.label_confidence = tmp_label_matrix / tmp_label_matrix.sum(dim=1).repeat(tmp_label_matrix.size(1), 1).transpose(0, 1) 465 | if self.theta < 0.4: 466 | if torch.sum( 467 | torch.not_equal(pre_correction_label_matrix, correction_label_matrix)) < 0.0001 * pre_correction_label_matrix.shape[0] * self.num_classes: 468 | self.theta *= (self.hparams['inc'] + 1) 469 | self.curr_iter = self.curr_iter + 1 470 | 471 | return {'loss': loss.item()} 472 | 473 | def rc_loss(self, outputs, index): 474 | device = "cuda" if index.is_cuda else "cpu" 475 | self.label_confidence = self.label_confidence.to(device) 476 | logsm_outputs = F.log_softmax(outputs, dim=1) 477 | final_outputs = logsm_outputs * self.label_confidence[index, :] 478 | average_loss = - ((final_outputs).sum(dim=1)).mean() 479 | return average_loss 480 | 481 | def predict(self, x): 482 | return self.network(x) 483 | 484 | def confidence_update(self, batchX, batchY, batch_index): 485 | with torch.no_grad(): 486 | batch_outputs = self.predict(batchX) 487 | temp_un_conf = F.softmax(batch_outputs, dim=1) 488 | self.label_confidence[batch_index, :] = temp_un_conf * batchY # un_confidence stores the weight of each example 489 | base_value = self.label_confidence.sum(dim=1).unsqueeze(1).repeat(1, self.label_confidence.shape[1]) 490 | self.label_confidence = self.label_confidence / base_value 491 | 492 | class IDGP(Algorithm): 493 | """ 494 | IDGP 495 | Reference: Decompositional Generation Process for Instance-Dependent Partial Label Learning, ICLR 2023. 496 | """ 497 | 498 | def __init__(self, input_shape, train_givenY, hparams): 499 | super(IDGP, self).__init__(input_shape, train_givenY, hparams) 500 | self.featurizer_f = networks.Featurizer(input_shape, self.hparams) 501 | self.classifier_f = networks.Classifier( 502 | self.featurizer_f.n_outputs, 503 | self.num_classes) 504 | self.f = nn.Sequential(self.featurizer_f, self.classifier_f) 505 | self.f_opt = torch.optim.Adam( 506 | self.f.parameters(), 507 | lr=self.hparams["lr"], 508 | weight_decay=self.hparams['weight_decay'] 509 | ) 510 | 511 | self.featurizer_g = networks.Featurizer(input_shape, self.hparams) 512 | self.classifier_g = networks.Classifier( 513 | self.featurizer_g.n_outputs, 514 | self.num_classes) 515 | self.g = nn.Sequential(self.featurizer_g, self.classifier_g) 516 | self.g_opt = torch.optim.Adam( 517 | self.g.parameters(), 518 | lr=self.hparams["lr"], 519 | weight_decay=self.hparams['weight_decay'] 520 | ) 521 | 522 | train_givenY = torch.from_numpy(train_givenY) 523 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 524 | label_confidence = train_givenY.float()/tempY 525 | self.d_array = label_confidence 526 | self.b_array = train_givenY 527 | self.d_array = self.d_array.double() 528 | self.b_array = self.b_array.double() 529 | self.curr_iter = 0 530 | self.warm_up_epoch = hparams['warm_up_epoch'] 531 | self.ramp_iter_num = int(hparams['max_steps'] * 0.2) 532 | self.steps_per_epoch = train_givenY.shape[0] / self.hparams['batch_size'] 533 | 534 | 535 | def weighted_crossentropy_f(self, f_outputs, weight, eps=1e-12): 536 | l = weight * torch.log(f_outputs+eps) 537 | loss = (-torch.sum(l)) / l.size(0) 538 | 539 | return loss 540 | 541 | def weighted_crossentropy_f_with_g(self, f_outputs, g_outputs, targets, eps=1e-12): 542 | weight = g_outputs.clone().detach() * targets 543 | weight[weight == 0] = 1.0 544 | logits1 = (1 - weight) / (weight+eps) 545 | logits2 = weight.prod(dim=1, keepdim=True) 546 | weight = logits1 * logits2 547 | weight = weight * targets 548 | weight = weight / (weight.sum(dim=1, keepdim=True)+eps) 549 | weight = weight.clone().detach() 550 | 551 | l = weight * torch.log(f_outputs+eps) 552 | loss = (-torch.sum(l)) / l.size(0) 553 | 554 | return loss 555 | 556 | def weighted_crossentropy_g_with_f(self, g_outputs, f_outputs, targets, eps=1e-12): 557 | 558 | weight = f_outputs.clone().detach() * targets 559 | weight = weight / (weight.sum(dim=1, keepdim=True) + eps) 560 | l = weight * ( torch.log((1 - g_outputs) / (g_outputs + eps)+eps)) 561 | l = weight * (torch.log(1.0000001 - g_outputs)) 562 | loss = ( - torch.sum(l)) / ( l.size(0)) + \ 563 | ( - torch.sum(targets * torch.log(g_outputs+eps) + (1 - targets) * torch.log(1.0000001 - g_outputs))) / (l.size(0)) 564 | 565 | return loss 566 | 567 | def weighted_crossentropy_g(self, g_outputs, weight, eps=1e-12): 568 | l = weight * torch.log(g_outputs+eps) + (1 - weight) * torch.log(1.0000001 - g_outputs) 569 | loss = ( - torch.sum(l)) / (l.size(0)) 570 | 571 | return loss 572 | 573 | def update_d(self, f_outputs, targets, eps=1e-12): 574 | new_d = f_outputs.clone().detach() * targets.clone().detach() 575 | new_d = new_d / (new_d.sum(dim=1, keepdim=True) + eps) 576 | new_d = new_d.double() 577 | return new_d 578 | 579 | def update_b(self, g_outputs, targets): 580 | new_b = g_outputs.clone().detach() * targets.clone().detach() 581 | new_b = new_b.double() 582 | return new_b 583 | 584 | def noisy_output(self, outputs, d_array, targets): 585 | _, true_labels = torch.max(d_array * targets, dim=1) 586 | device = "cuda" if outputs.is_cuda else "cpu" 587 | pseudo_matrix = F.one_hot(true_labels, outputs.shape[1]).float().to(device).detach() 588 | return pseudo_matrix * (1 - outputs) + (1 - pseudo_matrix) * outputs 589 | 590 | def update(self, minibatches): 591 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 592 | device = "cuda" if index.is_cuda else "cpu" 593 | consistency_criterion_f = nn.KLDivLoss(reduction='batchmean').to(device) 594 | consistency_criterion_g = nn.KLDivLoss(reduction='batchmean').to(device) 595 | self.d_array = self.d_array.to(device) 596 | self.b_array = self.b_array.to(device) 597 | L_F = None 598 | if self.curr_iter <= self.warm_up_epoch * self.steps_per_epoch: 599 | # warm up of f 600 | f_logits_o = self.f(x) 601 | #f_logits_o_max = torch.max(f_logits_o, dim=1) 602 | #f_logits_o = f_logits_o - f_logits_o_max.view(-1, 1).expand_as(f_logits_o) 603 | f_outputs_o = F.softmax(f_logits_o / 1., dim=1) 604 | L_f_o = self.weighted_crossentropy_f(f_outputs_o, self.d_array[index,:]) 605 | L_F = L_f_o 606 | self.f_opt.zero_grad() 607 | L_F.backward() 608 | self.f_opt.step() 609 | # warm up of g 610 | g_logits_o = self.g(x) 611 | g_outputs_o = torch.sigmoid(g_logits_o / 1) 612 | L_g_o = self.weighted_crossentropy_g(g_outputs_o, self.b_array[index,:]) 613 | L_g = L_g_o 614 | self.g_opt.zero_grad() 615 | L_g.backward() 616 | self.g_opt.step() 617 | else: 618 | f_logits_o = self.f(x) 619 | g_logits_o = self.g(x) 620 | 621 | f_outputs_o = F.softmax(f_logits_o / 1., dim=1) 622 | g_outputs_o = torch.sigmoid(g_logits_o / 1.) 623 | 624 | L_f = self.weighted_crossentropy_f(f_outputs_o, self.d_array[index,:]) 625 | L_f_g = self.weighted_crossentropy_f_with_g(f_outputs_o, self.noisy_output(g_outputs_o, self.d_array[index, :], partial_y), partial_y) 626 | 627 | L_g = self.weighted_crossentropy_g(g_outputs_o, self.b_array[index,:]) 628 | 629 | L_g_f = self.weighted_crossentropy_g_with_f(g_outputs_o, f_outputs_o, partial_y) 630 | 631 | f_outputs_log_o = torch.log_softmax(f_logits_o, dim=-1) 632 | f_consist_loss0 = consistency_criterion_f(f_outputs_log_o, self.d_array[index,:].float()) 633 | f_consist_loss = f_consist_loss0 634 | g_outputs_log_o = nn.LogSigmoid()(g_logits_o) 635 | g_consist_loss0 = consistency_criterion_g(g_outputs_log_o, self.b_array[index,:].float()) 636 | g_consist_loss = g_consist_loss0 637 | lam = min(self.curr_iter / self.ramp_iter_num, 1) 638 | 639 | L_F = L_f + L_f_g + lam * f_consist_loss 640 | L_G = L_g + L_g_f + lam * g_consist_loss 641 | self.f_opt.zero_grad() 642 | L_F.backward() 643 | self.f_opt.step() 644 | self.g_opt.zero_grad() 645 | L_G.backward() 646 | self.g_opt.step() 647 | self.d_array[index,:] = self.update_d(f_outputs_o, partial_y) 648 | self.b_array[index,:] = self.update_b(g_outputs_o, partial_y) 649 | self.curr_iter += 1 650 | 651 | return {'loss': L_F.item()} 652 | 653 | def predict(self, x): 654 | return self.f(x) 655 | 656 | class ABS_MAE(Algorithm): 657 | """ 658 | ABS_MAE 659 | Reference: On the Robustness of Average Losses for Partial-Label Learning, TPAMI 2024. 660 | """ 661 | 662 | def __init__(self, input_shape, train_givenY, hparams): 663 | super(ABS_MAE, self).__init__(input_shape, train_givenY, hparams) 664 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 665 | self.classifier = networks.Classifier( 666 | self.featurizer.n_outputs, 667 | self.num_classes) 668 | 669 | self.network = nn.Sequential(self.featurizer, self.classifier) 670 | self.optimizer = torch.optim.Adam( 671 | self.network.parameters(), 672 | lr=self.hparams["lr"], 673 | weight_decay=self.hparams['weight_decay'] 674 | ) 675 | train_givenY = torch.from_numpy(train_givenY) 676 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 677 | label_confidence = train_givenY.float()/tempY 678 | self.label_confidence = label_confidence 679 | 680 | def update(self, minibatches): 681 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 682 | device = "cuda" if partial_y.is_cuda else "cpu" 683 | loss = self.mae_loss(self.predict(x), index, device) 684 | self.optimizer.zero_grad() 685 | loss.backward() 686 | self.optimizer.step() 687 | return {'loss': loss.item()} 688 | 689 | def mae_loss(self, outputs, index, device): 690 | sm_outputs = F.softmax(outputs, dim=1) 691 | sm_outputs = sm_outputs.unsqueeze(1) 692 | sm_outputs = sm_outputs.expand([-1,self.num_classes,-1]) 693 | label_one_hot = torch.eye(self.num_classes).to(device) 694 | loss = torch.abs(sm_outputs - label_one_hot).sum(dim=-1) 695 | self.label_confidence = self.label_confidence.to(device) 696 | loss = loss * self.label_confidence[index, :] 697 | avg_loss = loss.sum(dim=1).mean() 698 | return avg_loss 699 | 700 | def predict(self, x): 701 | return self.network(x) 702 | 703 | class ABS_GCE(Algorithm): 704 | """ 705 | ABS_GCE 706 | Reference: On the Robustness of Average Losses for Partial-Label Learning, TPAMI 2024. 707 | """ 708 | 709 | def __init__(self, input_shape, train_givenY, hparams): 710 | super(ABS_GCE, self).__init__(input_shape, train_givenY, hparams) 711 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 712 | self.classifier = networks.Classifier( 713 | self.featurizer.n_outputs, 714 | self.num_classes) 715 | 716 | self.network = nn.Sequential(self.featurizer, self.classifier) 717 | self.optimizer = torch.optim.Adam( 718 | self.network.parameters(), 719 | lr=self.hparams["lr"], 720 | weight_decay=self.hparams['weight_decay'] 721 | ) 722 | train_givenY = torch.from_numpy(train_givenY) 723 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 724 | label_confidence = train_givenY.float()/tempY 725 | self.label_confidence = label_confidence 726 | self.q = hparams['q'] 727 | 728 | def update(self, minibatches): 729 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 730 | device = "cuda" if partial_y.is_cuda else "cpu" 731 | loss = self.gce_loss(self.predict(x), index, device, q=self.q) 732 | self.optimizer.zero_grad() 733 | loss.backward() 734 | self.optimizer.step() 735 | return {'loss': loss.item()} 736 | 737 | def gce_loss(self, outputs, index, device, q): 738 | sm_outputs = F.softmax(outputs, dim=1) 739 | sm_outputs = torch.pow(sm_outputs, q) 740 | loss = (1. - sm_outputs) / q 741 | self.label_confidence = self.label_confidence.to(device) 742 | loss = loss * self.label_confidence[index, :] 743 | avg_loss = loss.sum(dim=1).mean() 744 | return avg_loss 745 | 746 | def predict(self, x): 747 | return self.network(x) 748 | 749 | class DIRK(Algorithm): 750 | """ 751 | DIRK 752 | Reference: Distilling Reliable Knowledge for Instance-dependent Partial Label Learning, AAAI 2024 753 | """ 754 | 755 | class tea_model(nn.Module): 756 | def __init__(self, num_classes,input_shape,hparams, base_encoder): 757 | super().__init__() 758 | self.encoder = base_encoder(num_classes,input_shape,hparams) 759 | self.register_buffer("queue_feat", torch.randn(hparams['moco_queue'], hparams['feat_dim'])) 760 | self.register_buffer("queue_dist", torch.randn(hparams['moco_queue'], num_classes)) 761 | self.register_buffer("queue_partY", torch.randn(hparams['moco_queue'], num_classes)) 762 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 763 | self.queue_feat = F.normalize(self.queue_feat, dim=0) 764 | self.moco_queue = hparams['moco_queue'] 765 | 766 | @torch.no_grad() 767 | def _dequeue_and_enqueue(self, keys_feat, keys_dist, keys_partY): 768 | batch_size = keys_feat.shape[0] 769 | ptr = int(self.queue_ptr) 770 | assert self.moco_queue % batch_size == 0 771 | self.queue_feat[ptr:ptr + batch_size] = keys_feat 772 | self.queue_dist[ptr:ptr + batch_size] = keys_dist 773 | self.queue_partY[ptr:ptr + batch_size] = keys_partY 774 | ptr = (ptr + batch_size) % self.moco_queue 775 | self.queue_ptr[0] = ptr 776 | 777 | def forward(self, img_w=None, img_s=None, img_distill=None, partY=None): 778 | with torch.no_grad(): 779 | _, feat_k = self.encoder(img_w) 780 | output_k, _ = self.encoder(img_distill) 781 | output_k = torch.softmax(output_k, dim=1) 782 | output_k = self.get_correct_conf(output_k, partY) 783 | 784 | features = torch.cat((feat_k, self.queue_feat.clone().detach()), dim=0) 785 | partYs = torch.cat((partY, self.queue_partY.clone().detach()), dim=0) 786 | dists = torch.cat((output_k, self.queue_dist.clone().detach()), dim=0) 787 | self._dequeue_and_enqueue(feat_k, output_k, partY) 788 | return features, partYs, dists, output_k 789 | 790 | def get_correct_conf(self, un_conf, partY): 791 | part_confidence = un_conf * partY 792 | part_confidence = part_confidence / part_confidence.sum(dim=1).unsqueeze(1).repeat(1,part_confidence.shape[1]) 793 | comp_confidence = un_conf * (1 - partY) 794 | comp_confidence = comp_confidence / (comp_confidence.sum(dim=1).unsqueeze(1).repeat(1, comp_confidence.shape[1]) + 1e-20) 795 | comp_max = comp_confidence.max(dim=1)[0].unsqueeze(1).repeat(1, partY.shape[1]) 796 | part_min = ((1 - partY) + part_confidence).min(dim=1)[0].unsqueeze(1).repeat(1, partY.shape[1]) 797 | fenmu = (un_conf * partY).sum(dim=1) 798 | M = 1.0 / fenmu 799 | M = M.unsqueeze(1).repeat(1, partY.shape[1]) 800 | a = (M * comp_max) / (M * comp_max + part_min) 801 | a[a == 0] = 1 802 | rec_confidence = part_confidence * a + comp_confidence * (1 - a) 803 | return rec_confidence 804 | 805 | class stu_model(nn.Module): 806 | def __init__(self, num_classes,input_shape,hparams, base_encoder): 807 | super().__init__() 808 | self.encoder = base_encoder(num_classes,input_shape,hparams) 809 | def forward(self, img_s, img_distill, is_eval=False): 810 | output_s, _ = self.encoder(img_distill) 811 | if is_eval: 812 | return output_s 813 | _, feat_s = self.encoder(img_s) 814 | return output_s, feat_s 815 | 816 | class DIRKNet(nn.Module): 817 | def __init__(self, num_classes, input_shape, hparams): 818 | super().__init__() 819 | self.featurizer = networks.Featurizer(input_shape, hparams) 820 | self.classifier = networks.Classifier( 821 | self.featurizer.n_outputs, 822 | num_classes) 823 | self.head = nn.Sequential( 824 | nn.Linear(self.featurizer.n_outputs, self.featurizer.n_outputs), 825 | nn.ReLU(inplace=True), 826 | nn.Linear(self.featurizer.n_outputs, hparams['feat_dim'])) 827 | def forward(self, x): 828 | feat = self.featurizer(x) 829 | feat_c = self.head(feat) 830 | logits = self.classifier(feat) 831 | return logits, F.normalize(feat_c, dim=1) 832 | 833 | def dirk_loss(self,output, confidence, Y=None): 834 | logsm_outputs = F.log_softmax(output, dim=1) 835 | final_outputs = logsm_outputs * confidence 836 | average_loss = - ((final_outputs).sum(dim=1)).mean() 837 | return average_loss 838 | 839 | 840 | class WeightedConLoss(nn.Module): 841 | def __init__(self, temperature=0.07, base_temperature=0.07, dist_temperature=0.07): 842 | super().__init__() 843 | self.temperature = temperature 844 | self.base_temperature = base_temperature 845 | self.dist_temperature = dist_temperature 846 | def forward(self, features, dist, mask=None, batch_size=-1): 847 | if mask is not None: 848 | mask = mask.float() 849 | anchor_dot_contrast = torch.div(torch.matmul(features[:batch_size], features.T),self.temperature) 850 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 851 | logits = anchor_dot_contrast - logits_max.detach() 852 | logits_mask = torch.scatter(torch.ones_like(anchor_dot_contrast),1,torch.arange(batch_size).view(-1, 1).cuda(),0) 853 | mask = logits_mask * mask 854 | dist_temperature = self.dist_temperature 855 | dist_norm = dist / torch.norm(dist, dim=-1, keepdim=True) 856 | anchor_dot_simi = torch.div(torch.matmul(dist_norm[:batch_size], dist_norm.T), dist_temperature) 857 | logits_simi_max, _ = torch.max(anchor_dot_simi, dim=1, keepdim=True) 858 | logits_simi = anchor_dot_simi - logits_simi_max.detach() 859 | exp_simi = torch.exp(logits_simi) * mask 860 | weight = exp_simi / exp_simi.sum(dim=1).unsqueeze(1).repeat(1, anchor_dot_simi.shape[1]) 861 | exp_logits = torch.exp(logits) * logits_mask 862 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 863 | weighted_log_prob_pos = weight * log_prob 864 | loss = -(self.temperature / self.base_temperature) * weighted_log_prob_pos 865 | loss = loss.sum(dim=1).mean() 866 | else: 867 | q = features[:batch_size] 868 | k = features[batch_size:batch_size * 2] 869 | queue = features[batch_size * 2:] 870 | k, queue = k.detach(), queue.detach() 871 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 872 | l_neg = torch.einsum('nc,kc->nk', [q, queue]) 873 | logits = torch.cat([l_pos, l_neg], dim=1) 874 | logits /= self.temperature 875 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 876 | loss = F.cross_entropy(logits, labels) 877 | return loss 878 | 879 | def __init__(self,input_shape,train_givenY,hparams): 880 | super(DIRK, self).__init__(input_shape,train_givenY,hparams) 881 | self.stu = self.stu_model(self.num_classes,input_shape,hparams,self.DIRKNet) 882 | self.tea=self.tea_model(self.num_classes,input_shape,hparams,self.DIRKNet) 883 | self.optimizer = torch.optim.Adam( 884 | self.stu.parameters(), 885 | lr=self.hparams["lr"], 886 | weight_decay=self.hparams['weight_decay'] 887 | ) 888 | self.train_givenY = torch.from_numpy(train_givenY) 889 | self.loss_cont_fn = self.WeightedConLoss(temperature=self.hparams['feat_temperature'], dist_temperature=self.hparams['dist_temperature']) 890 | self.curr_iter = 0 891 | 892 | 893 | def update(self,minibatches): 894 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 895 | 896 | features, partYs, dists, rec_conf_t = self.tea(x, strong_x, distill_x, partial_y) 897 | output_s, feat_s = self.stu(strong_x, distill_x) 898 | features_cont = torch.cat((feat_s, features), dim=0) 899 | partY_cont = torch.cat((partial_y, partYs), dim=0) 900 | dist_cont = torch.cat((rec_conf_t, dists), dim=0) 901 | batch_size = output_s.shape[0] 902 | mask_partial = torch.matmul(partY_cont[:batch_size], partY_cont.T) 903 | mask_partial[mask_partial != 0] = 1 904 | _, pseudo_target = torch.max(dist_cont, dim=1) 905 | pseudo_target = pseudo_target.contiguous().view(-1, 1) 906 | mask_pseudo_target = torch.eq(pseudo_target[:batch_size], pseudo_target.T).float() 907 | start_upd_prot = self.curr_iter >= self.hparams['prot_start'] 908 | if start_upd_prot: 909 | mask = mask_partial * mask_pseudo_target 910 | else: 911 | mask = None 912 | 913 | if self.hparams['weight'] != 0: 914 | loss_cont = self.loss_cont_fn(features=features_cont, dist=dist_cont, mask=mask, batch_size=partial_y.shape[0]) 915 | else: 916 | loss_cont = torch.tensor(0.0).cuda() 917 | loss_dirk = self.dirk_loss(output_s, rec_conf_t) 918 | loss = loss_dirk + self.hparams['weight'] * loss_cont 919 | 920 | self.optimizer.zero_grad() 921 | loss.backward() 922 | self.optimizer.step() 923 | self.curr_iter += 1 924 | self.model_update(self.tea,self.stu,self.hparams['momentum']) 925 | return {'loss': loss.item()} 926 | 927 | def predict(self, x): 928 | return self.stu(None,x,is_eval=True) 929 | 930 | def tea_predict(self,x): 931 | return self.tea(x) 932 | 933 | def model_update(self,model_tea, model_stu, momentum=0.99): 934 | for param_tea, param_stu in zip(model_tea.parameters(), model_stu.parameters()): 935 | param_tea.data = param_tea.data * momentum + param_stu.data * (1 - momentum) 936 | 937 | 938 | 939 | class CRDPLL(Algorithm): 940 | """ 941 | CRDPLL 942 | Reference: Revisiting Consistency Regularization for Deep Partial Label Learning, ICML 2022. 943 | """ 944 | 945 | def __init__(self, input_shape, train_givenY, hparams): 946 | super(CRDPLL, self).__init__(input_shape, train_givenY, hparams) 947 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 948 | self.classifier = networks.Classifier( 949 | self.featurizer.n_outputs, 950 | self.num_classes) 951 | 952 | self.network = nn.Sequential(self.featurizer, self.classifier) 953 | self.optimizer = torch.optim.Adam( 954 | self.network.parameters(), 955 | lr=self.hparams["lr"], 956 | weight_decay=self.hparams['weight_decay'] 957 | ) 958 | 959 | train_givenY = torch.from_numpy(train_givenY) 960 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 961 | label_confidence = train_givenY.float() / tempY 962 | self.label_confidence = label_confidence 963 | 964 | self.consistency_criterion = nn.KLDivLoss(reduction='batchmean') 965 | self.train_givenY=train_givenY 966 | self.lam = 1 967 | self.curr_iter = 0 968 | self.max_steps = self.hparams['max_steps'] 969 | 970 | def update(self, minibatches): 971 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 972 | loss = self.cr_loss(self.predict(x), self.predict(strong_x), index) 973 | self.optimizer.zero_grad() 974 | loss.backward() 975 | self.optimizer.step() 976 | self.curr_iter = self.curr_iter + 1 977 | self.confidence_update(x,strong_x, partial_y, index) 978 | return {'loss': loss.item()} 979 | 980 | def cr_loss(self, outputs, strong_outputs, index): 981 | device = "cuda" if index.is_cuda else "cpu" 982 | self.label_confidence = self.label_confidence.to(device) 983 | self.consistency_criterion=self.consistency_criterion.to(device) 984 | self.train_givenY=self.train_givenY.to(device) 985 | consist_loss0 = self.consistency_criterion(F.log_softmax(outputs, dim=1), self.label_confidence[index, :].float()) 986 | consist_loss1 = self.consistency_criterion(F.log_softmax(strong_outputs, dim=1), self.label_confidence[index, :].float()) 987 | super_loss = -torch.mean( 988 | torch.sum(torch.log(1.0000001 - F.softmax(outputs, dim=1)) * (1 - self.train_givenY[index, :]), dim=1)) 989 | lam = min((self.curr_iter / (self.max_steps*0.5)) * self.lam, self.lam) 990 | average_loss = lam * (consist_loss0 + consist_loss1) + super_loss 991 | return average_loss 992 | 993 | def predict(self, x): 994 | return self.network(x) 995 | 996 | def confidence_update(self,batchX,strong_batchX,batchY,batch_index): 997 | with torch.no_grad(): 998 | batch_outputs = self.predict(batchX) 999 | strong_batch_outputs=self.predict(strong_batchX) 1000 | temp_un_conf=F.softmax(batch_outputs,dim=1) 1001 | strong_temp_un_conf=F.softmax(strong_batch_outputs,dim=1) 1002 | self.label_confidence[batch_index,:]=torch.pow(temp_un_conf,1/(1+1))*torch.pow(strong_temp_un_conf,1/(1+1))*batchY 1003 | base_value=self.label_confidence[batch_index,:].sum(dim=1).unsqueeze(1).repeat(1,self.label_confidence[batch_index,:].shape[1]) 1004 | self.label_confidence[batch_index,:]=self.label_confidence[batch_index,:]/base_value 1005 | 1006 | 1007 | 1008 | 1009 | class ABLE(Algorithm): 1010 | """ 1011 | ABLE 1012 | Reference: Ambiguity-Induced Contrastive Learning for Instance-Dependent Partial Label Learning, IJCAI 2022 1013 | """ 1014 | 1015 | class ABLE_model(nn.Module): 1016 | def __init__(self, num_classes,input_shape,hparams, base_encoder): 1017 | super().__init__() 1018 | self.encoder = base_encoder(num_classes,input_shape,hparams) 1019 | def forward(self, hparams=None, img_w=None, images=None, partial_Y=None, is_eval=False): 1020 | if is_eval: 1021 | output_raw, q = self.encoder(img_w) 1022 | return output_raw 1023 | outputs, features = self.encoder(images) 1024 | batch_size = hparams['batch_size'] 1025 | f1, f2 = torch.split(features, [batch_size, batch_size], dim=0) 1026 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) 1027 | return outputs, features 1028 | 1029 | class ABLENet(nn.Module): 1030 | def __init__(self,num_classes,input_shape,hparams): 1031 | super().__init__() 1032 | self.featurizer = networks.Featurizer(input_shape, hparams) 1033 | self.classifier = networks.Classifier( 1034 | self.featurizer.n_outputs, 1035 | num_classes) 1036 | self.head = nn.Sequential( 1037 | nn.Linear(self.featurizer.n_outputs, self.featurizer.n_outputs), 1038 | nn.ReLU(inplace=True), 1039 | nn.Linear(self.featurizer.n_outputs, hparams['feat_dim'])) 1040 | def forward(self,x): 1041 | feat = self.featurizer(x) 1042 | feat_c = self.head(feat) 1043 | logits = self.classifier(feat) 1044 | return logits, F.normalize(feat_c, dim=1) 1045 | 1046 | class ClsLoss(nn.Module): 1047 | def __init__(self, predicted_score): 1048 | super().__init__() 1049 | self.predicted_score = predicted_score 1050 | self.init_predicted_score = predicted_score.detach() 1051 | def forward(self, outputs, index): 1052 | device = "cuda" if outputs.is_cuda else "cpu" 1053 | self.predicted_score=self.predicted_score.to(device) 1054 | logsm_outputs = F.log_softmax(outputs, dim=1) 1055 | final_outputs = self.predicted_score[index, :] * logsm_outputs 1056 | cls_loss = - ((final_outputs).sum(dim=1)).mean() 1057 | return cls_loss 1058 | def update_target(self, batch_index, updated_confidence): 1059 | with torch.no_grad(): 1060 | self.predicted_score[batch_index, :] = updated_confidence.detach() 1061 | return None 1062 | 1063 | class ConLoss(nn.Module): 1064 | def __init__(self, predicted_score, base_temperature=0.07): 1065 | super().__init__() 1066 | self.predicted_score = predicted_score 1067 | self.init_predicted_score = predicted_score.detach() 1068 | self.base_temperature = base_temperature 1069 | def forward(self, hparams, outputs, features, Y, index): 1070 | batch_size = hparams['batch_size'] 1071 | device = "cuda" if outputs.is_cuda else "cpu" 1072 | self.predicted_score=self.predicted_score.to(device) 1073 | contrast_count = features.shape[1] 1074 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 1075 | anchor_feature = contrast_feature 1076 | anchor_count = contrast_count 1077 | anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), hparams['temperature']) 1078 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 1079 | logits = anchor_dot_contrast - logits_max.detach() 1080 | Y = Y.float() 1081 | output_sm = F.softmax(outputs[0: batch_size, :], dim=1).float() 1082 | output_sm_d = output_sm.detach() 1083 | _, target_predict = (output_sm_d * Y).max(1) 1084 | predict_labels = target_predict.repeat(batch_size, 1).to(device) 1085 | mask_logits = torch.zeros_like(predict_labels).float().to(device) 1086 | pos_set = (Y == 1.0).nonzero().to(device) 1087 | ones_flag = torch.ones(batch_size).float().to(device) 1088 | zeros_flag = torch.zeros(batch_size).float().to(device) 1089 | for pos_set_i in range(pos_set.shape[0]): 1090 | sample_idx = pos_set[pos_set_i][0] 1091 | class_idx = pos_set[pos_set_i][1] 1092 | mask_logits_tmp = torch.where(predict_labels[sample_idx] == class_idx, ones_flag, zeros_flag).float() 1093 | if mask_logits_tmp.sum() > 0: 1094 | mask_logits_tmp = mask_logits_tmp / mask_logits_tmp.sum() 1095 | mask_logits[sample_idx] = mask_logits[sample_idx] + mask_logits_tmp * \ 1096 | self.predicted_score[sample_idx][class_idx] 1097 | mask_logits = mask_logits.repeat(anchor_count, contrast_count) 1098 | logits_mask = torch.scatter(torch.ones_like(mask_logits),1,torch.arange(batch_size * anchor_count).view(-1, 1).to(device),0).float() 1099 | mask_logits = mask_logits * logits_mask 1100 | exp_logits = logits_mask * torch.exp(logits) 1101 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 1102 | mean_log_prob_pos = (mask_logits * log_prob).sum(1) 1103 | loss_con_m = - (hparams['temperature'] / self.base_temperature) * mean_log_prob_pos 1104 | loss_con = loss_con_m.view(anchor_count, batch_size).mean() 1105 | revisedY_raw = Y.clone() 1106 | revisedY_raw = revisedY_raw * output_sm_d 1107 | revisedY_raw = revisedY_raw / revisedY_raw.sum(dim=1).repeat(Y.shape[1], 1).transpose(0, 1) 1108 | new_target = revisedY_raw.detach() 1109 | return loss_con, new_target 1110 | def update_target(self, batch_index, updated_confidence): 1111 | with torch.no_grad(): 1112 | self.predicted_score[batch_index, :] = updated_confidence.detach() 1113 | return None 1114 | 1115 | def __init__(self,input_shape,train_givenY,hparams): 1116 | super(ABLE, self).__init__(input_shape,train_givenY,hparams) 1117 | self.network=self.ABLE_model(self.num_classes,input_shape,hparams=hparams,base_encoder=self.ABLENet) 1118 | self.optimizer = torch.optim.Adam( 1119 | self.network.parameters(), 1120 | lr=self.hparams["lr"], 1121 | weight_decay=self.hparams['weight_decay'] 1122 | ) 1123 | train_givenY = torch.from_numpy(train_givenY) 1124 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 1125 | label_confidence = train_givenY.float() / tempY 1126 | self.label_confidence = label_confidence 1127 | self.loss_cls = self.ClsLoss(predicted_score=label_confidence.float()) 1128 | self.loss_con = self.ConLoss(predicted_score=label_confidence.float()) 1129 | self.train_givenY = train_givenY 1130 | 1131 | def update(self,minibatches): 1132 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1133 | X_tot = torch.cat([x, strong_x], dim=0) 1134 | batch_size = self.hparams['batch_size'] 1135 | 1136 | cls_out, features = self.network(hparams=self.hparams, images=X_tot, partial_Y=partial_y, is_eval=False) 1137 | cls_out_w = cls_out[0: batch_size, :] 1138 | 1139 | cls_loss = self.loss_cls(cls_out_w, index) 1140 | con_loss, new_target = self.loss_con(self.hparams, cls_out, features, partial_y, index) 1141 | loss = cls_loss + self.hparams['loss_weight'] * con_loss 1142 | self.optimizer.zero_grad() 1143 | loss.backward() 1144 | self.optimizer.step() 1145 | self.loss_cls.update_target(batch_index=index, updated_confidence=new_target) 1146 | self.loss_con.update_target(batch_index=index, updated_confidence=new_target) 1147 | return {'loss': loss.item()} 1148 | 1149 | def predict(self,images,): 1150 | return self.network(img_w=images,is_eval=True) 1151 | 1152 | 1153 | class PiCO(Algorithm): 1154 | """ 1155 | PiCO 1156 | Reference: PiCO: Contrastive Label Disambiguation for Partial Label Learning, ICLR 2022. 1157 | """ 1158 | 1159 | class PiCO_model(nn.Module): 1160 | def __init__(self, num_classes,input_shape,hparams, base_encoder): 1161 | super().__init__() 1162 | self.encoder_q = base_encoder(num_classes, input_shape, hparams) 1163 | self.encoder_k = base_encoder(num_classes, input_shape, hparams) 1164 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 1165 | param_k.data.copy_(param_q.data) 1166 | param_k.requires_grad = False 1167 | self.register_buffer("queue", torch.randn(hparams['moco_queue'], hparams['feat_dim'])) 1168 | self.register_buffer("queue_pseudo", torch.randn(hparams['moco_queue'])) 1169 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 1170 | self.register_buffer("prototypes", torch.zeros(num_classes, hparams['feat_dim'])) 1171 | self.queue = F.normalize(self.queue, dim=0) 1172 | 1173 | @torch.no_grad() 1174 | def _momentum_update_key_encoder(self, hparams): 1175 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 1176 | param_k.data = param_k.data * hparams['moco_m'] + param_q.data * (1. - hparams['moco_m']) 1177 | 1178 | @torch.no_grad() 1179 | def _dequeue_and_enqueue(self, keys, labels, hparams): 1180 | batch_size = keys.shape[0] 1181 | ptr = int(self.queue_ptr) 1182 | assert hparams['moco_queue'] % batch_size == 0 1183 | self.queue[ptr:ptr + batch_size, :] = keys 1184 | self.queue_pseudo[ptr:ptr + batch_size] = labels 1185 | ptr = (ptr + batch_size) % hparams['moco_queue'] 1186 | self.queue_ptr[0] = ptr 1187 | 1188 | 1189 | def forward(self, img_q, im_k=None, partial_Y=None, hparams=None, is_eval=False): 1190 | output, q = self.encoder_q(img_q) 1191 | if is_eval: 1192 | return output 1193 | 1194 | predicted_scores = torch.softmax(output, dim=1) * partial_Y 1195 | max_scores, pseudo_labels_b = torch.max(predicted_scores, dim=1) 1196 | prototypes = self.prototypes.clone().detach() 1197 | logits_prot = torch.mm(q, prototypes.t()) 1198 | score_prot = torch.softmax(logits_prot, dim=1) 1199 | with torch.no_grad(): 1200 | for feat, label in zip(q, pseudo_labels_b): 1201 | self.prototypes[label] = self.prototypes[label] * hparams['proto_m'] + (1 - hparams['proto_m']) * feat 1202 | self.prototypes = F.normalize(self.prototypes, p=2, dim=1).detach() 1203 | with torch.no_grad(): 1204 | self._momentum_update_key_encoder(hparams) 1205 | _, k = self.encoder_k(im_k) 1206 | features = torch.cat((q, k, self.queue.clone().detach()), dim=0) 1207 | pseudo_labels = torch.cat((pseudo_labels_b, pseudo_labels_b, self.queue_pseudo.clone().detach()), dim=0) 1208 | self._dequeue_and_enqueue(k, pseudo_labels_b, hparams) 1209 | return output, features, pseudo_labels, score_prot 1210 | 1211 | 1212 | class PiCONet(nn.Module): 1213 | def __init__(self,num_classes,input_shape,hparams): 1214 | super().__init__() 1215 | self.featurizer = networks.Featurizer(input_shape, hparams) 1216 | self.classifier = networks.Classifier( 1217 | self.featurizer.n_outputs, 1218 | num_classes) 1219 | self.head = nn.Sequential( 1220 | nn.Linear(self.featurizer.n_outputs, self.featurizer.n_outputs), 1221 | nn.ReLU(inplace=True), 1222 | nn.Linear(self.featurizer.n_outputs, hparams['feat_dim'])) 1223 | self.register_buffer("prototypes", torch.zeros(num_classes, hparams['feat_dim'])) 1224 | 1225 | def forward(self, x): 1226 | feat = self.featurizer(x) 1227 | feat_c = self.head(feat) 1228 | logits = self.classifier(feat) 1229 | return logits, F.normalize(feat_c, dim=1) 1230 | 1231 | class partial_loss(nn.Module): 1232 | def __init__(self, confidence, hparams, conf_ema_m=0.99): 1233 | super().__init__() 1234 | self.confidence = confidence 1235 | self.init_conf = confidence.detach() 1236 | self.conf_ema_m = conf_ema_m 1237 | self.conf_ema_range = [float(item) for item in hparams['conf_ema_range'].split(',')] 1238 | def set_conf_ema_m(self, epoch, total_epochs): 1239 | start = self.conf_ema_range[0] 1240 | end = self.conf_ema_range[1] 1241 | self.conf_ema_m = 1. * epoch /total_epochs * (end - start) + start 1242 | def forward(self, outputs, index): 1243 | device = "cuda" if outputs.is_cuda else "cpu" 1244 | self.confidence=self.confidence.to(device) 1245 | logsm_outputs = F.log_softmax(outputs, dim=1) 1246 | final_outputs = logsm_outputs * self.confidence[index, :] 1247 | average_loss = - ((final_outputs).sum(dim=1)).mean() 1248 | return average_loss 1249 | def confidence_update(self, temp_un_conf, batch_index, batchY): 1250 | with torch.no_grad(): 1251 | _, prot_pred = (temp_un_conf * batchY).max(dim=1) 1252 | pseudo_label = F.one_hot(prot_pred, batchY.shape[1]).float().cuda().detach() 1253 | self.confidence[batch_index, :] = self.conf_ema_m * self.confidence[batch_index, :] \ 1254 | + (1 - self.conf_ema_m) * pseudo_label 1255 | return None 1256 | 1257 | class SupConLoss(nn.Module): 1258 | def __init__(self, temperature=0.07, base_temperature=0.07): 1259 | super().__init__() 1260 | self.temperature = temperature 1261 | self.base_temperature = base_temperature 1262 | def forward(self, features, mask=None, batch_size=-1): 1263 | device = (torch.device('cuda') if features.is_cuda else torch.device('cpu')) 1264 | if mask is not None: 1265 | mask = mask.float().detach().to(device) 1266 | anchor_dot_contrast = torch.div( 1267 | torch.matmul(features[:batch_size], features.T), 1268 | self.temperature) 1269 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 1270 | logits = anchor_dot_contrast - logits_max.detach() 1271 | logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).to(device), 0) 1272 | mask = mask * logits_mask 1273 | exp_logits = torch.exp(logits) * logits_mask 1274 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 1275 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 1276 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 1277 | loss = loss.mean() 1278 | else: 1279 | q = features[:batch_size] 1280 | k = features[batch_size:batch_size * 2] 1281 | queue = features[batch_size * 2:] 1282 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 1283 | l_neg = torch.einsum('nc,kc->nk', [q, queue]) 1284 | logits = torch.cat([l_pos, l_neg], dim=1) 1285 | logits /= self.temperature 1286 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 1287 | loss = F.cross_entropy(logits, labels) 1288 | return loss 1289 | 1290 | def __init__(self,input_shape,train_givenY,hparams): 1291 | super(PiCO, self).__init__(input_shape,train_givenY,hparams) 1292 | self.network=self.PiCO_model(self.num_classes,input_shape,hparams=hparams,base_encoder=self.PiCONet) 1293 | self.optimizer = torch.optim.Adam( 1294 | self.network.parameters(), 1295 | lr=self.hparams["lr"], 1296 | weight_decay=self.hparams['weight_decay'] 1297 | ) 1298 | 1299 | train_givenY = torch.from_numpy(train_givenY) 1300 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 1301 | label_confidence = train_givenY.float() / tempY 1302 | self.label_confidence = label_confidence 1303 | self.loss_fn = self.partial_loss(label_confidence.float(),self.hparams) 1304 | self.loss_cont_fn = self.SupConLoss() 1305 | self.train_givenY = train_givenY 1306 | self.curr_iter = 0 1307 | self.max_steps = self.hparams['max_steps'] 1308 | 1309 | def update(self,minibatches): 1310 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1311 | cls_out, features_cont, pseudo_target_cont, score_prot = self.network(x, strong_x, partial_y, self.hparams) 1312 | batch_size = cls_out.shape[0] 1313 | pseudo_target_cont = pseudo_target_cont.contiguous().view(-1, 1) 1314 | 1315 | start_upd_prot = self.curr_iter >= self.hparams['prot_start'] 1316 | if start_upd_prot: 1317 | self.loss_fn.confidence_update(temp_un_conf=score_prot, batch_index=index, batchY=partial_y) 1318 | if start_upd_prot: 1319 | mask = torch.eq(pseudo_target_cont[:batch_size], pseudo_target_cont.T).float().cuda() 1320 | else: 1321 | mask = None 1322 | loss_cont = self.loss_cont_fn(features=features_cont, mask=mask, batch_size=batch_size) 1323 | loss_cls = self.loss_fn(cls_out, index) 1324 | loss = loss_cls + self.hparams['loss_weight'] * loss_cont 1325 | self.optimizer.zero_grad() 1326 | loss.backward() 1327 | self.optimizer.step() 1328 | self.loss_fn.set_conf_ema_m(self.curr_iter, self.max_steps) 1329 | self.curr_iter = self.curr_iter + 1 1330 | return {'loss': loss.item()} 1331 | 1332 | def predict(self,images,): 1333 | return self.network(img_q=images,is_eval=True) 1334 | 1335 | 1336 | class VALEN(Algorithm): 1337 | """ 1338 | VALEN 1339 | Reference: Instance-Dependent Partial Label Learning, NeurIPS 2021. 1340 | """ 1341 | 1342 | class VAE_Bernulli_Decoder(nn.Module): 1343 | def __init__(self, n_in, n_hidden, n_out, keep_prob=1.0) -> None: 1344 | super().__init__() 1345 | self.layer1 = nn.Linear(n_in, n_hidden) 1346 | self.layer2 = nn.Linear(n_hidden, n_out) 1347 | self._init_weight() 1348 | 1349 | def _init_weight(self): 1350 | for m in self.modules(): 1351 | if isinstance(m, nn.Linear): 1352 | nn.init.xavier_normal_(m.weight.data) 1353 | m.bias.data.fill_(0.01) 1354 | 1355 | def forward(self, inputs): 1356 | h0 = self.layer1(inputs) 1357 | h0 = F.relu(h0) 1358 | x_hat = self.layer2(h0) 1359 | return x_hat 1360 | 1361 | def num_flat_features(self, input_shape): 1362 | size = input_shape[1:] 1363 | num_features = 1 1364 | for s in size: 1365 | num_features *= s 1366 | return num_features 1367 | 1368 | def __init__(self, input_shape, train_givenY, hparams): 1369 | super(VALEN, self).__init__(input_shape,train_givenY,hparams) 1370 | self.featurizer_net = networks.Featurizer(input_shape, self.hparams) 1371 | self.classifier_net = networks.Classifier( 1372 | self.featurizer_net.n_outputs, 1373 | self.num_classes) 1374 | self.net = nn.Sequential(self.featurizer_net, self.classifier_net) 1375 | self.enc=copy.deepcopy(self.net) 1376 | self.num_features = self.num_flat_features(input_shape) 1377 | self.dec=self.VAE_Bernulli_Decoder(self.num_classes, self.num_features, self.num_features) 1378 | self.optimizer = torch.optim.Adam( 1379 | list(self.net.parameters()) + list(self.enc.parameters()) + list(self.dec.parameters()), 1380 | lr=self.hparams["lr"], 1381 | weight_decay=self.hparams['weight_decay'] 1382 | ) 1383 | 1384 | self.warmup_opt=torch.optim.SGD(list(self.net.parameters()), lr=self.hparams["lr"], weight_decay=self.hparams['weight_decay'], momentum=0.9) 1385 | train_givenY = torch.from_numpy(train_givenY) 1386 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 1387 | partial_weight = train_givenY.float() / tempY 1388 | self.o_array = partial_weight 1389 | self.feature_extracted= torch.zeros((train_givenY.shape[0], self.featurizer_net.n_outputs)) 1390 | 1391 | self.curr_iter = 0 1392 | self.steps_per_epoch = train_givenY.shape[0] / self.hparams['batch_size'] 1393 | self.mat_save_path = hparams['output_dir'] 1394 | 1395 | def partial_loss(self,output1, target, true, eps=1e-12): 1396 | output = F.softmax(output1, dim=1) 1397 | l = target * torch.log(output + eps) 1398 | loss = (-torch.sum(l)) / l.size(0) 1399 | revisedY = target.clone() 1400 | revisedY[revisedY > 0] = 1 1401 | revisedY = revisedY * (output.clone().detach() + eps) 1402 | revisedY = revisedY / revisedY.sum(dim=1).repeat(revisedY.size(1), 1).transpose(0, 1) 1403 | new_target = revisedY 1404 | return loss, new_target 1405 | 1406 | def alpha_loss(self, alpha, prior_alpha): 1407 | KLD = torch.mvlgamma(alpha.sum(1), p=1) - torch.mvlgamma(alpha, p=1).sum(1) - torch.mvlgamma(prior_alpha.sum(1), 1408 | p=1) + torch.mvlgamma( 1409 | prior_alpha, p=1).sum(1) + ((alpha - prior_alpha) * ( 1410 | torch.digamma(alpha) - torch.digamma(alpha.sum(dim=1, keepdim=True).expand_as(alpha)))).sum(1) 1411 | return KLD.mean() 1412 | 1413 | def update(self,minibatches): 1414 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1415 | device = "cuda" if index.is_cuda else "cpu" 1416 | if self.curr_iter 0] = 1.0 1447 | s_alpha = s_alpha * revised_alpha 1448 | s_alpha_sum = s_alpha.clone().detach().sum(dim=1, keepdim=True) 1449 | s_alpha = s_alpha / s_alpha_sum + 1e-2 1450 | L_d, new_d = self.partial_loss(alpha, self.o_array[index, :], None) 1451 | alpha = torch.exp(alpha / 4) 1452 | alpha = F.hardtanh(alpha, min_val=1e-2, max_val=30) 1453 | L_alpha = self.alpha_loss(alpha, self.prior_alpha) 1454 | dirichlet_sample_machine = torch.distributions.dirichlet.Dirichlet(s_alpha) 1455 | d = dirichlet_sample_machine.rsample() 1456 | x_hat = self.dec(d.float()) 1457 | x_hat = x_hat.view(x.shape) 1458 | A_hat = F.softmax(self.dot_product_decode(d.float()), dim=1) 1459 | L_recx = 0.01 * F.mse_loss(x_hat, x) 1460 | #L_recy = 0.01 * F.binary_cross_entropy_with_logits(d, partial_y) 1461 | L_recy = 0.01 * F.binary_cross_entropy_with_logits(d, partial_y.float()) 1462 | L_recA = F.mse_loss(A_hat, self.A.to(device)[index, :][:, index].to(device)) 1463 | L_rec = L_recx + L_recy + L_recA 1464 | L_o, new_o = self.partial_loss(outputs, self.d_array[index, :], None) 1465 | L = self.hparams['alpha'] * L_rec + self.hparams['beta'] * L_alpha + self.hparams['gamma'] * L_d + self.hparams['theta'] * L_o 1466 | self.optimizer.zero_grad() 1467 | L.backward() 1468 | self.optimizer.step() 1469 | new_d = self.revised_target(d, new_d) 1470 | new_d = self.hparams['correct'] * new_d + (1 - self.hparams['correct']) * self.o_array[index, :] 1471 | self.d_array[index, :] = new_d.clone().detach() 1472 | self.o_array[index, :] = new_o.clone().detach() 1473 | self.curr_iter = self.curr_iter + 1 1474 | return {'loss': L.item()} 1475 | 1476 | def predict(self, x): 1477 | return self.net(x) 1478 | 1479 | def revised_target(self,output, target): 1480 | revisedY = target.clone() 1481 | revisedY[revisedY > 0] = 1 1482 | revisedY = revisedY * (output.clone().detach()) 1483 | revisedY = revisedY / revisedY.sum(dim=1).repeat(revisedY.size(1), 1).transpose(0, 1) 1484 | new_target = revisedY 1485 | 1486 | return new_target 1487 | def dot_product_decode(self,Z): 1488 | A_pred = torch.sigmoid(torch.matmul(Z, Z.t())) 1489 | return A_pred 1490 | 1491 | def adj_normalize(self, mx): 1492 | rowsum = np.array(mx.sum(1)) 1493 | r_inv = np.power(rowsum, -1).flatten() 1494 | r_inv[np.isinf(r_inv)] = 0 1495 | r_mat_inv = sp.diags(r_inv) 1496 | mx = r_mat_inv.dot(mx) 1497 | return mx 1498 | 1499 | def sparse_mx_to_torch_sparse_tensor(self,sparse_mx): 1500 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 1501 | indices = torch.from_numpy( 1502 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64) 1503 | ) 1504 | values = torch.from_numpy(sparse_mx.data) 1505 | shape = torch.Size(sparse_mx.shape) 1506 | return torch.sparse.FloatTensor(indices, values, shape) 1507 | 1508 | def gen_adj_matrix2(self, X, k=10, path=""): 1509 | if os.path.exists(path): 1510 | print("Found adj matrix file and Load.") 1511 | adj_m = np.load(path) 1512 | print("Adj matrix Finished.") 1513 | else: 1514 | print("Not Found adj matrix file and Compute.") 1515 | dm = euclidean_distances(X, X) 1516 | adj_m = np.zeros_like(dm) 1517 | row = np.arange(0, X.shape[0]) 1518 | dm[row, row] = np.inf 1519 | for _ in range(0, k): 1520 | col = np.argmin(dm, axis=1) 1521 | dm[row, col] = np.inf 1522 | adj_m[row, col] = 1.0 1523 | np.save(path, adj_m) 1524 | print("Adj matrix Finished.") 1525 | adj_m = sp.coo_matrix(adj_m) 1526 | adj_m = self.adj_normalize(adj_m + sp.eye(adj_m.shape[0])) 1527 | adj_m = self.sparse_mx_to_torch_sparse_tensor(adj_m) 1528 | return adj_m 1529 | 1530 | ## Complementary-label learning algorithms 1531 | 1532 | class PC(Algorithm): 1533 | """ 1534 | PC 1535 | Reference: Learning from Complementary Labels, NIPS 2017. 1536 | """ 1537 | 1538 | def __init__(self, input_shape, train_givenY, hparams): 1539 | super(PC, self).__init__(input_shape, train_givenY, hparams) 1540 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1541 | self.classifier = networks.Classifier( 1542 | self.featurizer.n_outputs, 1543 | self.num_classes) 1544 | 1545 | self.network = nn.Sequential(self.featurizer, self.classifier) 1546 | self.optimizer = torch.optim.Adam( 1547 | self.network.parameters(), 1548 | lr=self.hparams["lr"], 1549 | weight_decay=self.hparams['weight_decay'] 1550 | ) 1551 | 1552 | def update(self, minibatches): 1553 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1554 | total_idxes, comp_labels = torch.where(partial_y == 0) 1555 | K = partial_y.shape[1] 1556 | outputs = self.predict(x)[total_idxes] 1557 | loss = self.pc_loss(outputs, K, comp_labels) 1558 | self.optimizer.zero_grad() 1559 | loss.backward() 1560 | self.optimizer.step() 1561 | return {'loss': loss.item()} 1562 | 1563 | def pc_loss(self, f, K, labels): 1564 | sigmoid = nn.Sigmoid() 1565 | fbar = f.gather(1, labels.long().view(-1, 1)).repeat(1, K) 1566 | loss_matrix = sigmoid( -1. * (f - fbar)) # multiply -1 for "complementary" 1567 | M1, M2 = K*(K-1)/2, K-1 1568 | pc_loss = torch.sum(loss_matrix)*(K-1)/len(labels) - M1 + M2 1569 | return pc_loss 1570 | 1571 | def predict(self, x): 1572 | return self.network(x) 1573 | 1574 | class Forward(Algorithm): 1575 | """ 1576 | Forward 1577 | Reference: Learning with Biased Complementary Labels, ECCV 2018. 1578 | """ 1579 | 1580 | def __init__(self, input_shape, train_givenY, hparams): 1581 | super(Forward, self).__init__(input_shape, train_givenY, hparams) 1582 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1583 | self.classifier = networks.Classifier( 1584 | self.featurizer.n_outputs, 1585 | self.num_classes) 1586 | 1587 | self.network = nn.Sequential(self.featurizer, self.classifier) 1588 | self.optimizer = torch.optim.Adam( 1589 | self.network.parameters(), 1590 | lr=self.hparams["lr"], 1591 | weight_decay=self.hparams['weight_decay'] 1592 | ) 1593 | 1594 | def update(self, minibatches): 1595 | _, x, strong_x, distill_x, partial_y, _, _ = minibatches 1596 | device = "cuda" if partial_y.is_cuda else "cpu" 1597 | K = partial_y.shape[1] 1598 | total_idxes, comp_labels = torch.where(partial_y == 0) 1599 | outputs = self.predict(x)[total_idxes] 1600 | loss = self.forward_loss(f=outputs, K=K, labels=comp_labels, device=device) 1601 | self.optimizer.zero_grad() 1602 | loss.backward() 1603 | self.optimizer.step() 1604 | return {'loss': loss.item()} 1605 | 1606 | def forward_loss(self, f, K, labels, device): 1607 | Q = torch.ones(K,K) * 1/(K-1) 1608 | Q = Q.to(device) 1609 | for k in range(K): 1610 | Q[k,k] = 0 1611 | q = torch.mm(F.softmax(f, 1), Q) 1612 | return F.nll_loss(q.log(), labels.long()) 1613 | 1614 | def predict(self, x): 1615 | return self.network(x) 1616 | 1617 | class NN(Algorithm): 1618 | """ 1619 | NN 1620 | Reference: Complementary-Label Learning for Arbitrary Losses and Models, ICML 2019. 1621 | """ 1622 | 1623 | def __init__(self, input_shape, train_givenY, hparams): 1624 | super(NN, self).__init__(input_shape, train_givenY, hparams) 1625 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1626 | self.classifier = networks.Classifier( 1627 | self.featurizer.n_outputs, 1628 | self.num_classes) 1629 | 1630 | self.network = nn.Sequential(self.featurizer, self.classifier) 1631 | self.optimizer = torch.optim.Adam( 1632 | self.network.parameters(), 1633 | lr=self.hparams["lr"], 1634 | weight_decay=self.hparams['weight_decay'] 1635 | ) 1636 | 1637 | self.ccp = self.class_prior(train_givenY) 1638 | 1639 | def update(self, minibatches): 1640 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1641 | device = "cuda" if index.is_cuda else "cpu" 1642 | total_idxes, comp_labels = torch.where(partial_y == 0) 1643 | K = partial_y.shape[1] 1644 | outputs = self.predict(x)[total_idxes] 1645 | loss = self.non_negative_loss(f=outputs, K=K, labels=comp_labels, ccp=self.ccp, beta=self.hparams['beta'], device=device) 1646 | self.optimizer.zero_grad() 1647 | loss.backward() 1648 | self.optimizer.step() 1649 | return {'loss': loss.item()} 1650 | 1651 | def non_negative_loss(self, f, K, labels, ccp, beta, device): 1652 | ccp = torch.from_numpy(ccp).float().to(device) 1653 | neglog = -F.log_softmax(f, dim=1) 1654 | loss_vector = torch.zeros(K, requires_grad=True).to(device) 1655 | temp_loss_vector = torch.zeros(K).to(device) 1656 | for k in range(K): 1657 | idx = labels == k 1658 | if torch.sum(idx).item() > 0: 1659 | idxs = idx.bool().view(-1,1).repeat(1,K) 1660 | neglog_k = torch.masked_select(neglog, idxs).view(-1,K) 1661 | temp_loss_vector[k] = -(K-1) * ccp[k] * torch.mean(neglog_k, dim=0)[k] # average of k-th class loss for k-th comp class samples 1662 | loss_vector = loss_vector + torch.mul(ccp, torch.mean(neglog_k, dim=0)) # only k-th in the summation of the second term inside max 1663 | loss_vector = loss_vector + temp_loss_vector 1664 | count = np.bincount(labels.data.cpu()).astype('float') 1665 | while len(count) < K: 1666 | count = np.append(count, 0) # when largest label is below K, bincount will not take care of them 1667 | loss_vector_with_zeros = torch.cat((loss_vector.view(-1,1), torch.zeros(K, requires_grad=True).view(-1,1).to(device)-beta), 1) 1668 | max_loss_vector, _ = torch.max(loss_vector_with_zeros, dim=1) 1669 | final_loss = torch.sum(max_loss_vector) 1670 | return final_loss 1671 | 1672 | def class_prior(self, train_givenY): 1673 | _, complementary_labels = np.where(train_givenY == 0) 1674 | return np.bincount(complementary_labels) / len(complementary_labels) 1675 | 1676 | def predict(self, x): 1677 | return self.network(x) 1678 | 1679 | class GA(Algorithm): 1680 | """ 1681 | GA 1682 | Reference: Complementary-Label Learning for Arbitrary Losses and Models, ICML 2019. 1683 | """ 1684 | 1685 | def __init__(self, input_shape, train_givenY, hparams): 1686 | super(GA, self).__init__(input_shape, train_givenY, hparams) 1687 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1688 | self.classifier = networks.Classifier( 1689 | self.featurizer.n_outputs, 1690 | self.num_classes) 1691 | 1692 | self.network = nn.Sequential(self.featurizer, self.classifier) 1693 | self.optimizer = torch.optim.Adam( 1694 | self.network.parameters(), 1695 | lr=self.hparams["lr"], 1696 | weight_decay=self.hparams['weight_decay'] 1697 | ) 1698 | 1699 | self.ccp = self.class_prior(train_givenY) 1700 | 1701 | def update(self, minibatches): 1702 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1703 | device = "cuda" if index.is_cuda else "cpu" 1704 | total_idxes, comp_labels = torch.where(partial_y == 0) 1705 | K = partial_y.shape[1] 1706 | outputs = self.predict(x)[total_idxes] 1707 | loss, loss_vector = self.assump_free_loss(f=outputs, K=K, labels=comp_labels, ccp=self.ccp, device=device) 1708 | self.optimizer.zero_grad() 1709 | if torch.min(loss_vector).item() < 0: 1710 | loss_vector_with_zeros = torch.cat((loss_vector.view(-1,1), torch.zeros(K, requires_grad=True).view(-1,1).to(device)), 1) 1711 | min_loss_vector, _ = torch.min(loss_vector_with_zeros, dim=1) 1712 | loss = torch.sum(min_loss_vector) 1713 | loss.backward() 1714 | for group in self.optimizer.param_groups: 1715 | for p in group['params']: 1716 | p.grad = -1*p.grad 1717 | else: 1718 | loss.backward() 1719 | self.optimizer.step() 1720 | return {'loss': loss.item()} 1721 | 1722 | def non_negative_loss(self, f, K, labels, ccp, beta, device): 1723 | ccp = torch.from_numpy(ccp).float().to(device) 1724 | neglog = -F.log_softmax(f, dim=1) 1725 | loss_vector = torch.zeros(K, requires_grad=True).to(device) 1726 | temp_loss_vector = torch.zeros(K).to(device) 1727 | for k in range(K): 1728 | idx = labels == k 1729 | if torch.sum(idx).item() > 0: 1730 | idxs = idx.bool().view(-1,1).repeat(1,K) 1731 | neglog_k = torch.masked_select(neglog, idxs).view(-1,K) 1732 | temp_loss_vector[k] = -(K-1) * ccp[k] * torch.mean(neglog_k, dim=0)[k] # average of k-th class loss for k-th comp class samples 1733 | loss_vector = loss_vector + torch.mul(ccp, torch.mean(neglog_k, dim=0)) # only k-th in the summation of the second term inside max 1734 | loss_vector = loss_vector + temp_loss_vector 1735 | count = np.bincount(labels.data.cpu()).astype('float') 1736 | while len(count) < K: 1737 | count = np.append(count, 0) # when largest label is below K, bincount will not take care of them 1738 | loss_vector_with_zeros = torch.cat((loss_vector.view(-1,1), torch.zeros(K, requires_grad=True).view(-1,1).to(device)-beta), 1) 1739 | max_loss_vector, _ = torch.max(loss_vector_with_zeros, dim=1) 1740 | final_loss = torch.sum(max_loss_vector) 1741 | return final_loss, torch.mul(torch.from_numpy(count).float().to(device), loss_vector) 1742 | 1743 | def assump_free_loss(self, f, K, labels, ccp, device): 1744 | """Assumption free loss (based on Thm 1) is equivalent to non_negative_loss if the max operator's threshold is negative inf.""" 1745 | return self.non_negative_loss(f=f, K=K, labels=labels, ccp=ccp, beta=np.inf, device=device) 1746 | 1747 | def class_prior(self, train_givenY): 1748 | _, complementary_labels = np.where(train_givenY == 0) 1749 | return np.bincount(complementary_labels) / len(complementary_labels) 1750 | 1751 | def predict(self, x): 1752 | return self.network(x) 1753 | 1754 | class SCL_EXP(Algorithm): 1755 | """ 1756 | SCL-EXP 1757 | Reference: Unbiased Risk Estimators Can Mislead: A Case Study of Learning with Complementary Labels, ICML 2020. 1758 | """ 1759 | 1760 | def __init__(self, input_shape, train_givenY, hparams): 1761 | super(SCL_EXP, self).__init__(input_shape, train_givenY, hparams) 1762 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1763 | self.classifier = networks.Classifier( 1764 | self.featurizer.n_outputs, 1765 | self.num_classes) 1766 | 1767 | self.network = nn.Sequential(self.featurizer, self.classifier) 1768 | self.optimizer = torch.optim.Adam( 1769 | self.network.parameters(), 1770 | lr=self.hparams["lr"], 1771 | weight_decay=self.hparams['weight_decay'] 1772 | ) 1773 | 1774 | def update(self, minibatches): 1775 | _, x, strong_x, distill_x, partial_y, _, _ = minibatches 1776 | total_idxes, comp_labels = torch.where(partial_y == 0) 1777 | outputs = self.predict(x)[total_idxes] 1778 | loss = self.SCL_EXP_loss(f=outputs, labels=comp_labels) 1779 | self.optimizer.zero_grad() 1780 | loss.backward() 1781 | self.optimizer.step() 1782 | return {'loss': loss.item()} 1783 | 1784 | def SCL_EXP_loss(self, f, labels): 1785 | sm_outputs = F.softmax(f, dim=1) 1786 | loss = -F.nll_loss(sm_outputs.exp(), labels.long()) 1787 | return loss 1788 | 1789 | def predict(self, x): 1790 | return self.network(x) 1791 | 1792 | class SCL_NL(Algorithm): 1793 | """ 1794 | SCL-NL 1795 | Reference: Unbiased Risk Estimators Can Mislead: A Case Study of Learning with Complementary Labels, ICML 2020. 1796 | """ 1797 | 1798 | def __init__(self, input_shape, train_givenY, hparams): 1799 | super(SCL_NL, self).__init__(input_shape, train_givenY, hparams) 1800 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1801 | self.classifier = networks.Classifier( 1802 | self.featurizer.n_outputs, 1803 | self.num_classes) 1804 | 1805 | self.network = nn.Sequential(self.featurizer, self.classifier) 1806 | self.optimizer = torch.optim.Adam( 1807 | self.network.parameters(), 1808 | lr=self.hparams["lr"], 1809 | weight_decay=self.hparams['weight_decay'] 1810 | ) 1811 | 1812 | def update(self, minibatches): 1813 | _, x, strong_x, distill_x, partial_y, _, _ = minibatches 1814 | total_idxes, comp_labels = torch.where(partial_y == 0) 1815 | outputs = self.predict(x)[total_idxes] 1816 | loss = self.SCL_NL_loss(f=outputs, labels=comp_labels) 1817 | self.optimizer.zero_grad() 1818 | loss.backward() 1819 | self.optimizer.step() 1820 | return {'loss': loss.item()} 1821 | 1822 | def SCL_NL_loss(self, f, labels): 1823 | p = (1 - F.softmax(f, dim=1) + 1e-6).log() 1824 | loss = F.nll_loss(p, labels.long()) 1825 | return loss 1826 | 1827 | def predict(self, x): 1828 | return self.network(x) 1829 | 1830 | class L_W(Algorithm): 1831 | """ 1832 | L-W 1833 | Reference: Discriminative Complementary-Label Learning with Weighted Loss, ICML 2021. 1834 | """ 1835 | 1836 | def __init__(self, input_shape, train_givenY, hparams): 1837 | super(L_W, self).__init__(input_shape, train_givenY, hparams) 1838 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1839 | self.classifier = networks.Classifier( 1840 | self.featurizer.n_outputs, 1841 | self.num_classes) 1842 | 1843 | self.network = nn.Sequential(self.featurizer, self.classifier) 1844 | self.optimizer = torch.optim.Adam( 1845 | self.network.parameters(), 1846 | lr=self.hparams["lr"], 1847 | weight_decay=self.hparams['weight_decay'] 1848 | ) 1849 | 1850 | def update(self, minibatches): 1851 | _, x, strong_x, distill_x, partial_y, _, _ = minibatches 1852 | total_idxes, comp_labels = torch.where(partial_y == 0) 1853 | outputs = self.predict(x)[total_idxes] 1854 | K = partial_y.shape[1] 1855 | loss = self.w_loss(f=outputs, K=K, labels=comp_labels) 1856 | self.optimizer.zero_grad() 1857 | loss.backward() 1858 | self.optimizer.step() 1859 | return {'loss': loss.item()} 1860 | 1861 | def non_k_softmax_loss(self, f, K, labels): 1862 | Q_1 = 1 - F.softmax(f, 1) 1863 | Q_1 = F.softmax(Q_1, 1) 1864 | labels = labels.long() 1865 | return F.nll_loss(Q_1.log(), labels.long()) 1866 | 1867 | def w_loss(self, f, K, labels): 1868 | 1869 | loss_class = self.non_k_softmax_loss(f=f, K=K, labels=labels) 1870 | loss_w = self.w_loss_p(f=f, K=K, labels=labels) 1871 | final_loss = loss_class + loss_w 1872 | return final_loss 1873 | 1874 | def w_loss_p(self, f, K, labels): 1875 | Q_1 = 1-F.softmax(f, 1) 1876 | Q = F.softmax(Q_1, 1) 1877 | q = torch.tensor(1.0) / torch.sum(Q_1, dim=1) 1878 | q = q.view(-1, 1).repeat(1, K) 1879 | w = torch.mul(Q_1, q) # weight 1880 | w_1 = torch.mul(w, Q.log()) 1881 | return F.nll_loss(w_1, labels.long()) 1882 | 1883 | def predict(self, x): 1884 | return self.network(x) 1885 | 1886 | class OP_W(Algorithm): 1887 | """ 1888 | OP-W 1889 | Reference: Consistent Complementary-Label Learning via Order-Preserving Losses, AISTATS 2023. 1890 | """ 1891 | 1892 | def __init__(self, input_shape, train_givenY, hparams): 1893 | super(OP_W, self).__init__(input_shape, train_givenY, hparams) 1894 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1895 | self.classifier = networks.Classifier( 1896 | self.featurizer.n_outputs, 1897 | self.num_classes) 1898 | 1899 | self.network = nn.Sequential(self.featurizer, self.classifier) 1900 | self.optimizer = torch.optim.Adam( 1901 | self.network.parameters(), 1902 | lr=self.hparams["lr"], 1903 | weight_decay=self.hparams['weight_decay'] 1904 | ) 1905 | 1906 | def update(self, minibatches): 1907 | _, x, strong_x, distill_x, partial_y, _, _ = minibatches 1908 | total_idxes, comp_labels = torch.where(partial_y == 0) 1909 | outputs = self.predict(x)[total_idxes] 1910 | K = partial_y.shape[1] 1911 | loss = self.OP_W_loss(f=outputs, K=K, labels=comp_labels) 1912 | self.optimizer.zero_grad() 1913 | loss.backward() 1914 | self.optimizer.step() 1915 | return {'loss': loss.item()} 1916 | 1917 | def OP_W_loss(self, f, K, labels): 1918 | Q_1 = F.softmax(f, 1)+1e-18 1919 | Q_2 = F.softmax(-f, 1)+1e-18 1920 | w_ = torch.div(1, Q_2) 1921 | 1922 | w_ = w_ + 1 1923 | w = F.softmax(w_,1) 1924 | 1925 | w = torch.mul(Q_1,w)+1e-6 1926 | w_1 = torch.mul(w, Q_2.log()) 1927 | l2 = F.nll_loss(w_1, labels.long()) 1928 | return l2 1929 | 1930 | def predict(self, x): 1931 | return self.network(x) 1932 | 1933 | class FREDIS(Algorithm): 1934 | """ 1935 | FREDIS 1936 | Reference: FREDIS: A Fusion Framework of Refinement and Disambiguation for Unreliable Partial Label Learning, ICML 2023. 1937 | """ 1938 | 1939 | def __init__(self, input_shape, train_givenY, hparams): 1940 | super(FREDIS, self).__init__(input_shape, train_givenY, hparams) 1941 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1942 | self.classifier = networks.Classifier( 1943 | self.featurizer.n_outputs, 1944 | self.num_classes) 1945 | 1946 | self.network = nn.Sequential(self.featurizer, self.classifier) 1947 | self.optimizer = torch.optim.Adam( 1948 | self.network.parameters(), 1949 | lr=self.hparams["lr"], 1950 | weight_decay=self.hparams['weight_decay'] 1951 | ) 1952 | self.train_givenY = torch.from_numpy(train_givenY) 1953 | tempY = self.train_givenY.sum(dim=1).unsqueeze(1).repeat(1, self.train_givenY.shape[1]) 1954 | label_confidence = self.train_givenY/tempY 1955 | self.confidence = label_confidence.double() 1956 | self.posterior = torch.ones_like(self.confidence).double() 1957 | self.posterior = self.posterior / self.posterior.sum(dim=1, keepdim=True) 1958 | self.steps_per_epoch = train_givenY.shape[0] / self.hparams['batch_size'] 1959 | self.steps_update_interval = int(self.hparams['update_interval'] * self.steps_per_epoch) 1960 | self.curr_iter = 0 1961 | self.theta = self.hparams['theta'] 1962 | self.delta = self.hparams['delta'] 1963 | self.pre_correction_label_matrix = torch.zeros_like(self.train_givenY) 1964 | self.correction_label_matrix = copy.deepcopy(self.train_givenY) 1965 | 1966 | def update(self, minibatches): 1967 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 1968 | device = "cuda" if index.is_cuda else "cpu" 1969 | self.confidence = self.confidence.to(device) 1970 | self.posterior = self.posterior.to(device) 1971 | self.correction_label_matrix = self.correction_label_matrix.to(device) 1972 | self.pre_correction_label_matrix = self.pre_correction_label_matrix.to(device) 1973 | consistency_criterion = nn.KLDivLoss(reduction='batchmean').to(device) 1974 | batch_outputs = self.predict(x) 1975 | y_pred_aug0_probas_log = torch.log_softmax(batch_outputs, dim=-1) 1976 | consist_loss = consistency_criterion(y_pred_aug0_probas_log, self.confidence[index].float()) 1977 | super_loss = -torch.mean(torch.sum(torch.log(1.0000001 - F.softmax(batch_outputs, dim=1)) * (1 - partial_y), dim=1)) 1978 | loss = float(self.hparams['lam']) * consist_loss + float(self.hparams['alpha']) * super_loss 1979 | self.optimizer.zero_grad() 1980 | loss.backward() 1981 | self.optimizer.step() 1982 | self.confidence_update(x, partial_y, index) 1983 | self.posterior_update(x, index) 1984 | if self.curr_iter % self.steps_update_interval == 0: 1985 | pred, _ = torch.max(self.posterior, dim=1, keepdim=True) 1986 | tmp_diff = pred - self.posterior 1987 | self.pre_correction_label_matrix = copy.deepcopy(self.correction_label_matrix) 1988 | non_change_matrix = copy.deepcopy(self.pre_correction_label_matrix) 1989 | non_change_matrix[tmp_diff < self.theta] = 1 1990 | non_change = torch.sum(torch.not_equal(self.pre_correction_label_matrix, non_change_matrix)) 1991 | if non_change > self.hparams['change_size']: 1992 | row, col = torch.where(tmp_diff < self.theta) 1993 | idx_list = [ i for i in range(0, len(row))] 1994 | random.shuffle(idx_list) 1995 | non_row, non_col = row[idx_list[0:self.hparams['change_size']]], col[idx_list[0:self.hparams['change_size']]] 1996 | non_change = self.hparams['change_size'] 1997 | self.correction_label_matrix[non_row, non_col] = 1 1998 | else: 1999 | self.correction_label_matrix[tmp_diff < self.theta] = 1 2000 | 2001 | can_change_matrix = copy.deepcopy(self.pre_correction_label_matrix) 2002 | can_change_matrix[tmp_diff > self.delta] = 0 2003 | can_change = torch.sum(torch.not_equal(self.pre_correction_label_matrix, can_change_matrix)) 2004 | while can_change < non_change * self.hparams['times']: 2005 | self.delta = self.delta - self.hparams['dec'] 2006 | can_change_matrix = copy.deepcopy(self.pre_correction_label_matrix) 2007 | can_change_matrix[tmp_diff > self.delta] = 0 2008 | can_change = torch.sum(torch.not_equal(self.pre_correction_label_matrix, can_change_matrix)) 2009 | if can_change > self.hparams['change_size'] * self.hparams['times']: 2010 | row, col = torch.where(tmp_diff > self.delta) 2011 | idx_list = [ i for i in range(0, len(row))] 2012 | random.shuffle(idx_list) 2013 | can_row, can_col = row[idx_list[0: self.hparams['change_size'] * self.hparams['times']]], col[idx_list[0: self.hparams['change_size'] * self.hparams['times']]] 2014 | can_change = self.hparams['change_size'] * self.hparams['times'] 2015 | self.correction_label_matrix[can_row, can_col] = 0 2016 | else: 2017 | self.correction_label_matrix[tmp_diff > self.delta] = 0 2018 | for i in range(len(self.correction_label_matrix)): 2019 | if self.correction_label_matrix[i].sum == 0: 2020 | self.correction_label_matrix[i] = copy.deepcopy(self.pre_correction_label_matrix[i]) 2021 | 2022 | tmp_label_matrix = self.posterior * (self.correction_label_matrix + 1e-12) 2023 | self.confidence = tmp_label_matrix / tmp_label_matrix.sum(dim=1).repeat(tmp_label_matrix.size(1), 1).transpose(0, 1) 2024 | 2025 | change = torch.sum(torch.not_equal(self.pre_correction_label_matrix, self.correction_label_matrix)) 2026 | 2027 | if self.theta < 0.9 and self.delta > 0.1: 2028 | if change < self.hparams['change_size']: 2029 | self.theta += self.hparams['inc'] 2030 | self.delta -= self.hparams['dec'] 2031 | self.curr_iter = self.curr_iter + 1 2032 | return {'loss': loss.item()} 2033 | 2034 | def predict(self, x): 2035 | return self.network(x) 2036 | 2037 | def confidence_update(self, batchX, batchY, batch_index): 2038 | with torch.no_grad(): 2039 | batch_outputs = self.predict(batchX) 2040 | temp_un_conf = F.softmax(batch_outputs, dim=1) 2041 | self.confidence[batch_index, :] = (temp_un_conf * batchY).double() # un_confidence stores the weight of each example 2042 | base_value = self.confidence.sum(dim=1).unsqueeze(1).repeat(1, self.confidence.shape[1]) 2043 | self.confidence = self.confidence / base_value 2044 | 2045 | def posterior_update(self, batchX, batch_index): 2046 | with torch.no_grad(): 2047 | batch_outputs = self.predict(batchX) 2048 | self.posterior[batch_index, :] = torch.softmax(batch_outputs, dim=-1).double() 2049 | 2050 | class PiCO_plus(PiCO): 2051 | """ 2052 | PiCO_plus: PiCO+: Contrastive Label Disambiguation for Robust Partial Label Learning, TPAMI 2024. 2053 | """ 2054 | 2055 | class PiCO_plus_model(PiCO.PiCO_model): 2056 | def __init__(self, num_classes, input_shape, hparams, base_encoder): 2057 | super().__init__(num_classes,input_shape,hparams, base_encoder) 2058 | self.register_buffer("queue_rel", torch.zeros(hparams['moco_queue'], dtype=torch.bool)) 2059 | @torch.no_grad() 2060 | def _dequeue_and_enqueue(self, keys, labels, is_rel, hparams): 2061 | batch_size = is_rel.shape[0] 2062 | ptr = int(self.queue_ptr) 2063 | self.queue_rel[ptr:ptr + batch_size] = is_rel 2064 | super()._dequeue_and_enqueue(keys, labels, hparams) 2065 | 2066 | def forward(self, img_q, im_k=None, Y_ori=None, Y_cor=None, is_rel=None, hparams=None, is_eval=False, ): 2067 | output, q = self.encoder_q(img_q) 2068 | if is_eval: 2069 | return output 2070 | 2071 | batch_weight = is_rel.float() 2072 | with torch.no_grad(): 2073 | predicetd_scores = torch.softmax(output, dim=1) 2074 | _, within_max_cls = torch.max(predicetd_scores * Y_ori, dim=1) 2075 | _, all_max_cls = torch.max(predicetd_scores, dim=1) 2076 | pseudo_labels_b = batch_weight * within_max_cls + (1 - batch_weight) * all_max_cls 2077 | pseudo_labels_b = pseudo_labels_b.long() 2078 | prototypes = self.prototypes.clone().detach() 2079 | logits_prot = torch.mm(q, prototypes.t()) 2080 | score_prot = torch.softmax(logits_prot, dim=1) 2081 | _, within_max_cls_ori = torch.max(predicetd_scores * Y_ori, dim=1) 2082 | distance_prot = - (q * prototypes[within_max_cls_ori]).sum(dim=1) 2083 | with torch.no_grad(): 2084 | for feat, label in zip(q[is_rel], pseudo_labels_b[is_rel]): 2085 | self.prototypes[label] = self.prototypes[label] * hparams['proto_m'] + (1 - hparams['proto_m']) * feat 2086 | self.prototypes = F.normalize(self.prototypes, p=2, dim=1).detach() 2087 | self._momentum_update_key_encoder(hparams) 2088 | _, k = self.encoder_k(im_k) 2089 | features = torch.cat((q, k, self.queue.clone().detach()), dim=0) 2090 | pseudo_labels = torch.cat((pseudo_labels_b, pseudo_labels_b, self.queue_pseudo.clone().detach()), dim=0) 2091 | is_rel_queue = torch.cat((is_rel, is_rel, self.queue_rel.clone().detach()), dim=0) 2092 | self._dequeue_and_enqueue(k, pseudo_labels_b, is_rel, hparams) 2093 | return output, features, pseudo_labels, score_prot, distance_prot, is_rel_queue 2094 | 2095 | class partial_loss(nn.Module): 2096 | def __init__(self, confidence, hparams, conf_ema_m=0.99): 2097 | super().__init__() 2098 | self.confidence = confidence 2099 | self.conf_ema_m = conf_ema_m 2100 | self.num_class = confidence.shape[1] 2101 | self.conf_ema_range = [float(item) for item in hparams['conf_ema_range'].split(',')] 2102 | 2103 | def set_conf_ema_m(self, epoch, total_epochs): 2104 | start = self.conf_ema_range[0] 2105 | end = self.conf_ema_range[1] 2106 | self.conf_ema_m = 1. * epoch / total_epochs * (end - start) + start 2107 | 2108 | def forward(self, outputs, index, is_rel=None): 2109 | device = "cuda" if outputs.is_cuda else "cpu" 2110 | self.confidence = self.confidence.to(device) 2111 | confidence = self.confidence[index, :].to(device) 2112 | loss_vec, _ = self.ce_loss(outputs, confidence) 2113 | if is_rel is None: 2114 | average_loss = loss_vec.mean() 2115 | else: 2116 | average_loss = loss_vec[is_rel].mean() 2117 | return average_loss 2118 | 2119 | def ce_loss(self, outputs, targets, sel=None): 2120 | targets = targets.detach() 2121 | logsm_outputs = F.log_softmax(outputs, dim=1) 2122 | final_outputs = logsm_outputs * targets 2123 | loss_vec = - (final_outputs).sum(dim=1) 2124 | if sel is None: 2125 | average_loss = loss_vec.mean() 2126 | else: 2127 | if sel.sum()==0: 2128 | average_loss=torch.tensor(0.0).cuda() 2129 | else: 2130 | average_loss = loss_vec[sel].mean() 2131 | return loss_vec, average_loss 2132 | 2133 | def confidence_update(self, temp_un_conf, batch_index, batchY): 2134 | with torch.no_grad(): 2135 | _, prot_pred = (temp_un_conf * batchY).max(dim=1) 2136 | device = (torch.device('cuda') if temp_un_conf.is_cuda else torch.device('cpu')) 2137 | pseudo_label = F.one_hot(prot_pred, self.num_class).float().to(device).detach() 2138 | self.confidence=self.confidence.to(device) 2139 | self.confidence[batch_index, :] = self.conf_ema_m * self.confidence[batch_index, :] + (1 - self.conf_ema_m) * pseudo_label 2140 | return None 2141 | 2142 | class SupConLoss(nn.Module): 2143 | def __init__(self, temperature=0.07, base_temperature=0.07): 2144 | super().__init__() 2145 | self.temperature = temperature 2146 | self.base_temperature = base_temperature 2147 | def forward(self, features, mask=None, batch_size=-1, weights=None): 2148 | device = (torch.device('cuda') if features.is_cuda else torch.device('cpu')) 2149 | if mask is not None: 2150 | mask = mask.float().detach().to(device) 2151 | anchor_dot_contrast = torch.div( 2152 | torch.matmul(features[:batch_size], features.T), 2153 | self.temperature) 2154 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 2155 | logits = anchor_dot_contrast - logits_max.detach() 2156 | logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size).view(-1, 1).to(device), 0) 2157 | mask = mask * logits_mask 2158 | exp_logits = torch.exp(logits) * logits_mask 2159 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12) 2160 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 2161 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 2162 | if weights is None: 2163 | loss = loss.mean() 2164 | else: 2165 | weights = weights.detach() 2166 | loss = (loss * weights).mean() 2167 | else: 2168 | q = features[:batch_size] 2169 | k = features[batch_size:batch_size * 2] 2170 | queue = features[batch_size * 2:] 2171 | l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) 2172 | l_neg = torch.einsum('nc,kc->nk', [q, queue]) 2173 | logits = torch.cat([l_pos, l_neg], dim=1) 2174 | logits /= self.temperature 2175 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device) 2176 | loss = F.cross_entropy(logits, labels) 2177 | return loss 2178 | 2179 | 2180 | 2181 | def reliable_set_selection(self, hparams, epoch, sel_stats): 2182 | dist = sel_stats['dist'] 2183 | n = dist.shape[0] 2184 | device = (torch.device('cuda') if dist.is_cuda else torch.device('cpu')) 2185 | is_rel = torch.zeros(n).bool().to(device) 2186 | sorted_idx = torch.argsort(dist) 2187 | chosen_num = int(n * hparams['pure_ratio']) 2188 | is_rel[sorted_idx[:chosen_num]] = True 2189 | sel_stats['is_rel'] = is_rel 2190 | 2191 | def __init__(self,input_shape,train_givenY,hparams): 2192 | super(PiCO_plus, self).__init__(input_shape,train_givenY,hparams) 2193 | self.network=self.PiCO_plus_model(self.num_classes,input_shape,hparams=hparams,base_encoder=self.PiCONet) 2194 | self.optimizer = torch.optim.Adam( 2195 | self.network.parameters(), 2196 | lr=self.hparams["lr"], 2197 | weight_decay=self.hparams['weight_decay'] 2198 | ) 2199 | self.num_instance=train_givenY.shape[0] 2200 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 2201 | self.sel_stats = {'dist': torch.zeros(self.num_instance).to(device),'is_rel': torch.ones(self.num_instance).bool().to(device)} 2202 | 2203 | def update(self, minibatches): 2204 | if self.curr_iter >= self.hparams['prot_start']: 2205 | self.reliable_set_selection(self.hparams, self.curr_iter, self.sel_stats) 2206 | start_upd_prot = self.curr_iter >= self.hparams['prot_start'] 2207 | 2208 | _, x, strong_x, distill_x, partial_y, _, index = minibatches 2209 | device = (torch.device('cuda') if x.is_cuda else torch.device('cpu')) 2210 | is_rel = self.sel_stats['is_rel'][index] 2211 | batch_weight = is_rel.float() 2212 | cls_out, features_cont, pseudo_labels, score_prot, distance_prot, is_rel_queue = self.network(x, strong_x, partial_y, Y_cor=None, is_rel=is_rel, hparams=self.hparams) 2213 | batch_size = cls_out.shape[0] 2214 | pseudo_target_cont = pseudo_labels.contiguous().view(-1, 1) 2215 | 2216 | 2217 | if start_upd_prot: 2218 | self.loss_fn.confidence_update(temp_un_conf=score_prot, batch_index=index, batchY=partial_y) 2219 | if start_upd_prot: 2220 | mask_all = torch.eq(pseudo_target_cont[:batch_size], pseudo_target_cont.T).float().to(device) 2221 | loss_cont_all = self.loss_cont_fn(features=features_cont, mask=mask_all, batch_size=batch_size, weights=None) 2222 | mask = copy.deepcopy(mask_all).detach() 2223 | mask = batch_weight.unsqueeze(1).repeat(1, mask.shape[1]) * mask 2224 | mask = is_rel_queue.view(1, -1).repeat(mask.shape[0], 1) * mask 2225 | if self.curr_iter >= self.hparams['knn_start']: 2226 | cosine_corr = features_cont[:batch_size] @ features_cont.T 2227 | _, kNN_index = torch.topk(cosine_corr, k=self.hparams['chosen_neighbors'], dim=-1, largest=True) 2228 | mask_kNN = torch.scatter(torch.zeros(mask.shape).to(device), 1, kNN_index, 1) 2229 | mask[~is_rel] = mask_kNN[~is_rel] 2230 | mask[:, batch_size:batch_size * 2] = ((mask[:, batch_size:batch_size * 2] + torch.eye(batch_size).to(device)) > 0).float() 2231 | mask[:, :batch_size] = ((mask[:, :batch_size] + torch.eye(batch_size).to(device)) > 0).float() 2232 | if self.curr_iter >= self.hparams['knn_start']: 2233 | weights = self.hparams['loss_weight'] * batch_weight + self.hparams['ur_weight'] * (1 - batch_weight) 2234 | loss_cont_rel_knn = self.loss_cont_fn(features=features_cont, mask=mask, batch_size=batch_size,weights=weights) 2235 | else: 2236 | loss_cont_rel_knn = self.loss_cont_fn(features=features_cont, mask=mask, batch_size=batch_size, weights=None) 2237 | loss_cont = loss_cont_rel_knn + self.hparams['ur_weight'] * loss_cont_all 2238 | loss_cls = self.loss_fn(cls_out, index, is_rel) 2239 | sp_temp_scale = score_prot ** (1 / self.hparams['temperature_guess']) 2240 | targets_guess = sp_temp_scale / sp_temp_scale.sum(dim=1, keepdim=True) 2241 | _, loss_cls_ur = self.loss_fn.ce_loss(cls_out, targets_guess, sel=~is_rel) 2242 | l = np.random.beta(4, 4) 2243 | l = max(l, 1 - l) 2244 | pseudo_label = self.loss_fn.confidence[index] 2245 | pseudo_label[~is_rel] = targets_guess[~is_rel] 2246 | idx = torch.randperm(x.size(0)) 2247 | X_w_rand = x[idx] 2248 | pseudo_label_rand = pseudo_label[idx] 2249 | X_w_mix = l * x + (1 - l) * X_w_rand 2250 | pseudo_label_mix = l * pseudo_label + (1 - l) * pseudo_label_rand 2251 | logits_mix, _ = self.network.encoder_q(X_w_mix) 2252 | _, loss_mix = self.loss_fn.ce_loss(logits_mix, targets=pseudo_label_mix) 2253 | loss_cls = loss_mix + self.hparams['cls_weight'] * loss_cls + self.hparams['ur_weight'] * loss_cls_ur 2254 | loss = loss_cls + loss_cont 2255 | else: 2256 | loss_cls = self.loss_fn(cls_out, index, is_rel=None) 2257 | loss_cont = self.loss_cont_fn(features=features_cont, mask=None, batch_size=batch_size) 2258 | loss = loss_cls + self.hparams['loss_weight'] * loss_cont 2259 | 2260 | self.sel_stats['dist'][index] = copy.deepcopy(distance_prot.clone().detach()) 2261 | self.optimizer.zero_grad() 2262 | loss.backward() 2263 | self.optimizer.step() 2264 | self.loss_fn.set_conf_ema_m(self.curr_iter, self.max_steps) 2265 | return {'loss': loss.item()} 2266 | 2267 | class ALIM(PiCO): 2268 | """ 2269 | ALIM 2270 | Reference: ALIM: Adjusting Label Importance Mechanism for Noisy Partial Label Learning, NeurIPS 2023 2271 | """ 2272 | 2273 | 2274 | class partial_loss(nn.Module): 2275 | def __init__(self, confidence, hparams, conf_ema_m=0.99): 2276 | super().__init__() 2277 | self.confidence = confidence 2278 | self.init_conf = confidence.detach() 2279 | self.conf_ema_m = conf_ema_m 2280 | self.conf_ema_range = [float(item) for item in hparams['conf_ema_range'].split(',')] 2281 | def set_conf_ema_m(self, epoch, total_epochs): 2282 | start = self.conf_ema_range[0] 2283 | end = self.conf_ema_range[1] 2284 | self.conf_ema_m = 1. * epoch /total_epochs * (end - start) + start 2285 | def forward(self, outputs, index): 2286 | device = "cuda" if outputs.is_cuda else "cpu" 2287 | self.confidence=self.confidence.to(device) 2288 | logsm_outputs = F.log_softmax(outputs, dim=1) 2289 | final_outputs = logsm_outputs * self.confidence[index, :] 2290 | average_loss = - ((final_outputs).sum(dim=1)).mean() 2291 | return average_loss 2292 | def confidence_update(self, temp_un_conf, batch_index, batchY, piror): 2293 | with torch.no_grad(): 2294 | _, prot_pred = (temp_un_conf * (batchY + piror * (1 - batchY))).max(dim=1) # ALIM 2295 | pseudo_label = F.one_hot(prot_pred, batchY.shape[1]).float().cuda().detach() 2296 | self.confidence[batch_index, :] = self.conf_ema_m * self.confidence[batch_index, :] \ 2297 | + (1 - self.conf_ema_m) * pseudo_label 2298 | return None 2299 | 2300 | 2301 | def __init__(self,input_shape,train_givenY,hparams): 2302 | super(ALIM, self).__init__(input_shape,train_givenY,hparams) 2303 | self.network=self.PiCO_model(self.num_classes,input_shape,hparams=hparams,base_encoder=self.PiCONet) 2304 | self.optimizer = torch.optim.Adam( 2305 | self.network.parameters(), 2306 | lr=self.hparams["lr"], 2307 | weight_decay=self.hparams['weight_decay'] 2308 | ) 2309 | 2310 | train_givenY = torch.from_numpy(train_givenY) 2311 | tempY = train_givenY.sum(dim=1).unsqueeze(1).repeat(1, train_givenY.shape[1]) 2312 | label_confidence = train_givenY.float() / tempY 2313 | self.label_confidence = label_confidence 2314 | self.loss_fn = self.partial_loss(label_confidence.float(),self.hparams) 2315 | self.loss_cont_fn = self.SupConLoss() 2316 | self.train_givenY = train_givenY 2317 | self.piror = 0 2318 | self.curr_iter = 0 2319 | self.margin = [] 2320 | self.steps_per_epoch = train_givenY.shape[0] / self.hparams['batch_size'] 2321 | 2322 | def update(self,minibatches1, minibatches2): 2323 | if self.curr_iter % self.steps_per_epoch == 0: 2324 | if self.curr_iter >= self.hparams['start_epoch'] * self.steps_per_epoch: 2325 | self.piror = sorted(self.margin)[int(len(self.margin)*self.hparams['noise_rate'])] 2326 | self.margin = [] 2327 | 2328 | _, x, strong_x, distill_x, partial_y, _, index = minibatches1 2329 | _, x2, _, _, _, _, index2 = minibatches2 2330 | cls_out, features_cont, pseudo_target_cont, score_prot = self.network(x, strong_x, partial_y, self.hparams) 2331 | batch_size = cls_out.shape[0] 2332 | pseudo_target_cont = pseudo_target_cont.contiguous().view(-1, 1) 2333 | 2334 | start_upd_prot = self.curr_iter >= self.hparams['prot_start'] 2335 | if start_upd_prot: 2336 | self.loss_fn.confidence_update(temp_un_conf=score_prot, batch_index=index, batchY=partial_y, piror=self.piror) 2337 | if start_upd_prot: 2338 | mask = torch.eq(pseudo_target_cont[:batch_size], pseudo_target_cont.T).float().cuda() 2339 | else: 2340 | mask = None 2341 | loss_cont = self.loss_cont_fn(features=features_cont, mask=mask, batch_size=batch_size) 2342 | loss_cls = self.loss_fn(cls_out, index) 2343 | 2344 | lam = np.random.beta(self.hparams['mixup_alpha'], self.hparams['mixup_alpha']) 2345 | lam = max(lam, 1-lam) 2346 | pseudo_label_1 = self.loss_fn.confidence[index] 2347 | pseudo_label_2 = self.loss_fn.confidence[index2] 2348 | X_w_mix = lam * x + (1 - lam) * x2 2349 | pseudo_label_mix = lam * pseudo_label_1 + (1 - lam) * pseudo_label_2 2350 | logits_mix, _ = self.network.encoder_q(X_w_mix) 2351 | pred_mix = torch.softmax(logits_mix, dim=1) 2352 | loss_mixup = - ((torch.log(pred_mix) * pseudo_label_mix).sum(dim=1)).mean() 2353 | 2354 | loss = loss_cls + self.hparams['loss_weight'] * loss_cont + self.hparams['loss_weight_mixup'] * loss_mixup 2355 | self.optimizer.zero_grad() 2356 | loss.backward() 2357 | self.optimizer.step() 2358 | self.loss_fn.set_conf_ema_m(self.curr_iter, self.max_steps) 2359 | self.margin += ((torch.max(score_prot*partial_y, 1)[0])/(1e-9+torch.max(score_prot*(1-partial_y), 1)[0])).tolist() 2360 | return {'loss': loss.item()} --------------------------------------------------------------------------------