├── Graph ├── README.md ├── SOAP.py ├── config │ └── config_hiv.py ├── datasets.py ├── datasets │ └── hiv.csv ├── gine_mpnn.py ├── imbalanaced_loss.py ├── main_hiv.py ├── metric.py ├── model.py ├── predict.py ├── pretrained_models │ └── hiv_pretrained_model │ │ ├── hiv_gine_ce.ckpt │ │ ├── hiv_mlpmpnn_ce.ckpt │ │ └── hiv_mpnn_ce.ckpt ├── train_eval.py └── tran_data.py ├── Image ├── README.md ├── SOAP.py ├── cepretrainmodels │ └── tmp.md.docx ├── config_cifar.py ├── config_melanoma.py ├── data │ └── melanoma_split_inds │ │ ├── test_split.csv │ │ ├── train_split.csv │ │ └── valid_split.csv ├── data_split.py ├── imbalanced_cifar.py ├── loss.py ├── main_cifar100_resnet18.py ├── main_cifar100_resnet34.py ├── main_cifar10_resnet18.py ├── main_cifar10_resnet34.py ├── main_melanoma_resnet18.py ├── main_melanoma_resnet34.py ├── preprocess.py ├── train_eval.py ├── train_eval_melanoma.py └── utils.py └── README.md /Graph/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### Configuration 4 | Dependencies: 5 | python>=3.6.8\ 6 | torch>=1.7.0\ 7 | torch_geometric==1.6.3\ 8 | Other packages:\ 9 | Install necessary packages required in MoleculeKit gnn part, rdkit, pytorch geometrics, descriptastorus. Ensure the rdkit version is 2020.03.3, otherwise the feature extraction may be problematic.\ 10 | Referenced literature:\ 11 | https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html 12 | 13 | 14 | 15 | ### Data 16 | Use tran_data.py to get GNN features from .csv data files.\ 17 | The preprocessed data can be found in google drive: https://drive.google.com/drive/folders/13Bxt0eLeOKNEPbwbq1oEeOLNo9AhnQvr?usp=sharing \ 18 | Put the inner files and directories to the **./datasets** to run the code. 19 | 20 | ### Model 21 | GINE -- a variant of GIN model \ 22 | MPNN -- message passing neural network \ 23 | MLN -- GNN in MoleculeKit \ 24 | GINE and MPNN are implemented in gine_mpnn.py\ 25 | MLN is implemented in model.py 26 | 27 | 28 | ### Algorithm and Pretrained Models 29 | Proposed loss and the SOAP algorithm is implemented in SOAP.py \ 30 | SOAP algorithm with **squared hinge (sqh)** surrogate loss are trained from ce_pretrained models. \ 31 | The **pretrained models** for **hiv** data are provided in the \ 32 | **pretrained_models/hiv_pretrianed_model/**: 33 | - hiv_gine_ce.ckpt 34 | - hiv_mlpmpnn_ce.ckpt 35 | - hiv_mpnn_ce.ckpt 36 | 37 | 38 | The pretrained model is trained using **ce_loss** for 100 epochs using Adam with the following default configurations in **./config/config_hiv.py**: 39 | It has been used for training ce_pretrained model. 40 | 41 | 42 | The default **conf['loss_type']** = 'ce_loss' is for standard cross entropy loss for training the pretrained model. 43 | 44 | **conf['ft_mode']** = 'fc_random': Reinitializing the Fully-Connected layer for the pretrained model when starting training SOAP. 45 | **conf['pre_train']** = { None : training from scratch, 46 | 'path_of_pretrained_model': training from a pretrained model } 47 | 48 | 49 | 50 | ### The hyperparameters for SOAPLOSS: 51 | --**threshold**: the margin in squared hinge loss | **conf['loss_param']['threshold'] = 10** \ 52 | --**batch_size**: batch size | **conf['batch_size'] = 64** \ 53 | --**data_length**: length of the dataset \ 54 | --**loss_type**: squared hinge surrogate loss for SOAP | **conf['loss_param']['type'] = 'sqh'** \ 55 | --**gamma**: gamma parameter in the paper | **conf['loss_param']['mv_gamma'] = {0.99, 0.9}** \ 56 | **conf['posNum']** : Number of positive samples per batch \ 57 | We use the same conf['posNum'] = 1 both for both SOAP with sqh loss and the baselines for graph data. 58 | 59 | 60 | 61 | ### Results 62 | To replicate the SOAP results in Table 2, Run: 63 | ``` 64 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_hiv.py 65 | ``` 66 | | HIV | Network | GINE | MPNN | MLPNN | 67 | |-----|:-------:|:----------------:|:---------------:|:----------------:| 68 | | mean (std) | SOAP |0.3462 (0.0083) | 0.3406 (0.0053) | 0.3646 (0.0076) | 69 | 70 | The wrapped package can be found in https://github.com/Optimization-AI/LibAUC/ 71 | with the following installation and cases command: 72 | ```python 73 | pip install libauc 74 | >>> #import library 75 | >>> from libauc.losses import APLoss_SH 76 | >>> from libauc.optimizers import SOAP_SGD, SOAP_ADAM 77 | ... 78 | >>> #define loss 79 | >>> Loss = APLoss_SH() 80 | >>> optimizer = SOAP_ADAM() 81 | ... 82 | >>> #training 83 | >>> model.train() 84 | >>> for index, data, targets in trainloader: 85 | data, targets = data.cuda(), targets.cuda() 86 | logits = model(data) 87 | preds = torch.sigmoid(logits) 88 | loss = Loss(preds, targets, index) 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | ``` 93 | 94 | ### Tips: 95 | If some error happens with the file path, **please adjust the corresponding data file, pretrained model file to your own path**. 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /Graph/SOAP.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Qi' 2 | # Created by on 11/3/21. 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torch.utils.data.sampler import Sampler 7 | import random 8 | from imbalanced_loss import logistic_loss, sigmoid_loss 9 | 10 | 11 | class AUPRCSampler(Sampler): 12 | def __init__(self, labels, batchSize, posNum=1): 13 | # positive class: minority class 14 | # negative class: majority class 15 | 16 | self.labels = labels 17 | self.posNum = posNum 18 | self.batchSize = batchSize 19 | 20 | self.clsLabelList = np.unique(labels) 21 | self.dataDict = {} 22 | 23 | for label in self.clsLabelList: 24 | self.dataDict[str(label)] = [] 25 | 26 | for i in range(len(self.labels)): 27 | self.dataDict[str(self.labels[i])].append(i) 28 | 29 | self.ret = [] 30 | 31 | def __iter__(self): 32 | minority_data_list = self.dataDict[str(1)] 33 | majority_data_list = self.dataDict[str(0)] 34 | 35 | # print(len(minority_data_list), len(majority_data_list)) 36 | random.shuffle(minority_data_list) 37 | random.shuffle(majority_data_list) 38 | 39 | # In every iteration : sample 1(posNum) positive sample(s), and sample batchSize - 1(posNum) negative samples 40 | if len(minority_data_list) // self.posNum >= len(majority_data_list) // ( 41 | self.batchSize - self.posNum): # At this case, we go over the all positive samples in every epoch. 42 | # extend the length of majority_data_list from len(majority_data_list) to len(minority_data_list)* (batchSize-posNum) 43 | majority_data_list.extend(np.random.choice(majority_data_list, len(minority_data_list) // self.posNum * ( 44 | self.batchSize - self.posNum) - len(majority_data_list), replace=True).tolist()) 45 | 46 | for i in range(len(minority_data_list) // self.posNum): 47 | if self.posNum == 1: 48 | self.ret.append(minority_data_list[i]) 49 | else: 50 | self.ret.extend(minority_data_list[i * self.posNum:(i + 1) * self.posNum]) 51 | 52 | startIndex = i * (self.batchSize - self.posNum) 53 | endIndex = (i + 1) * (self.batchSize - self.posNum) 54 | self.ret.extend(majority_data_list[startIndex:endIndex]) 55 | 56 | else: # At this case, we go over the all negative samples in every epoch. 57 | # extend the length of minority_data_list from len(minority_data_list) to len(majority_data_list)//(batchSize-posNum) + 1 58 | 59 | minority_data_list.extend(np.random.choice(minority_data_list, len(majority_data_list) // ( 60 | self.batchSize - self.posNum) + 1 - len(minority_data_list) // self.posNum, replace=True).tolist()) 61 | for i in range(0, len(majority_data_list), self.batchSize - self.posNum): 62 | 63 | if self.posNum == 1: 64 | self.ret.append(minority_data_list[i // (self.batchSize - self.posNum)]) 65 | else: 66 | self.ret.extend(minority_data_list[i // (self.batchSize - self.posNum) * self.posNum: (i // ( 67 | self.batchSize - self.posNum) + 1) * self.posNum]) 68 | 69 | self.ret.extend(majority_data_list[i:i + self.batchSize - self.posNum]) 70 | 71 | return iter(self.ret) 72 | 73 | def __len__(self): 74 | return len(self.ret) 75 | 76 | 77 | class SOAPLOSS(nn.Module): 78 | def __init__(self, threshold, data_length, loss_type='sqh', gamma = 0.9): 79 | ''' 80 | :param threshold: margin for squred hinge loss 81 | ''' 82 | super(SOAPLOSS, self).__init__() 83 | self.u_all = 1.0 * torch.tensor([0] * data_length).view(-1, 1).cuda() 84 | self.u_pos = 1.0 * torch.tensor([0] * data_length).view(-1, 1).cuda() 85 | self.threshold = threshold 86 | self.loss_type = loss_type 87 | self.gamma = gamma 88 | print('The loss type is :', self.loss_type) 89 | 90 | def forward(self, f_ps, f_ns, index_s): 91 | f_ps = f_ps.view(-1) 92 | f_ns = f_ns.view(-1) 93 | 94 | vec_dat = torch.cat((f_ps, f_ns), 0) 95 | mat_data = vec_dat.repeat(len(f_ps), 1) 96 | 97 | f_ps = f_ps.view(-1, 1) 98 | 99 | neg_mask = torch.ones_like(mat_data) 100 | neg_mask[:, 0:f_ps.size(0)] = 0 101 | 102 | pos_mask = torch.zeros_like(mat_data) 103 | pos_mask[:, 0:f_ps.size(0)] = 1 104 | 105 | if self.loss_type == 'sqh': 106 | pos_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * pos_mask 107 | neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * neg_mask 108 | 109 | elif self.loss_type == 'lgs': 110 | 111 | neg_loss = logistic_loss(f_ps, mat_data, self.threshold) * neg_mask 112 | pos_loss = logistic_loss(f_ps, mat_data, self.threshold) * pos_mask 113 | 114 | elif self.loss_type == 'sgm': 115 | 116 | neg_loss = sigmoid_loss(f_ps, mat_data, self.threshold) * neg_mask 117 | pos_loss = sigmoid_loss(f_ps, mat_data, self.threshold) * pos_mask 118 | 119 | loss = pos_loss + neg_loss 120 | 121 | if f_ps.size(0) == 1: 122 | self.u_pos[index_s] = (1 - self.gamma) * self.u_pos[index_s] + self.gamma * (pos_loss.mean()) 123 | self.u_all[index_s] = (1 - self.gamma) * self.u_all[index_s] + self.gamma * (loss.mean()) 124 | else: 125 | self.u_all[index_s] = (1 - self.gamma) * self.u_all[index_s] + self.gamma * (loss.mean(1, keepdim=True)) 126 | self.u_pos[index_s] = (1 - self.gamma) * self.u_pos[index_s] + self.gamma * (pos_loss.mean(1, keepdim=True)) 127 | 128 | 129 | p = (self.u_pos[index_s] - (self.u_all[index_s]) * pos_mask) / (self.u_all[index_s] ** 2) 130 | 131 | p.detach_() 132 | loss = torch.sum(p * loss) 133 | 134 | return loss 135 | 136 | 137 | -------------------------------------------------------------------------------- /Graph/config/config_hiv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration file 3 | """ 4 | 5 | 6 | 7 | conf = {} 8 | 9 | 10 | 11 | ###################################################################################################################### 12 | # Settings for experimental setup 13 | ## 'model' chooses model from 'ml2' and 'ml3': ml2 is the proposed hierachical message passing model; ml3 is the model without subgraph-level representations, compared to ml2. 14 | ## 'task_type': 'regression' or 'classification'. 15 | ## 'metric' is the evaluation method: 'mae', 'rmse', 'prc', 'roc' 16 | ## 'num_tasks' denotes how many tasks we have in choosed dataset. 17 | ## 'graph_level_feature': if it is true, we combine the 200-d feature extracted by rkdit with the representation output by network together, and the use the combined representation to do property prediction. 18 | ###################################################################################################################### 19 | 20 | conf['task_type'] = 'classification' 21 | conf['metric'] = 'roc' 22 | conf['num_tasks'] = 1 23 | conf['arch'] = 'ml2' 24 | conf['graph_level_feature'] = True 25 | 26 | ###################################################################################################################### 27 | # Settings for training 28 | ## 'epochs': maximum training epochs 29 | ## 'early_stopping': patience used to stop training 30 | ## 'lr': starting learning rate 31 | ## 'lr_decay_factor': learning rate decay factor 32 | ## 'lr_decay_step_size': step size of learning rate decay 33 | ## 'dropout': dropout rate 34 | ## 'weight_decay': l2 regularizer term 35 | ## 'depth': number of layers 36 | ## 'batch_size': training batch_size 37 | ###################################################################################################################### 38 | conf['epochs'] = 150 39 | conf['early_stopping'] = 50 40 | conf['lr'] = 0.0005 41 | conf['lr_decay_factor'] = 0.5 42 | conf['lr_decay_step_size'] = 50 43 | conf['dropout'] = 0 44 | conf['weight_decay'] = 0.00005 45 | conf['depth'] = 3 46 | conf['hidden'] = 32 47 | conf['batch_size'] = 64 48 | conf['loss_type'] = 'ce_loss' 49 | conf['loss_param'] = {'threshold':10} 50 | conf['ft_mode'] = 'fc_random' 51 | conf['pre_train'] = None 52 | 53 | ###################################################################################################################### 54 | # Settings for val/test 55 | ## 'vt_batch_size': val/test batch_size 56 | ###################################################################################################################### 57 | conf['vt_batch_size'] = 1000 58 | 59 | -------------------------------------------------------------------------------- /Graph/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | from sklearn.utils import shuffle 4 | import numpy as np 5 | import random 6 | # import rdkit.Chem as Chem 7 | # from rdkit.Chem.Scaffolds import MurckoScaffold 8 | import csv 9 | import codecs 10 | import torch 11 | from torch_geometric.data import Data 12 | 13 | 14 | 15 | 16 | class JunctionTreeData(Data): 17 | def __inc__(self, key, item): 18 | if key == 'tree_edge_index': 19 | return self.x_clique.size(0) 20 | elif key == 'atom2clique_index': 21 | return torch.tensor([[self.x.size(0)], [self.x_clique.size(0)]]) 22 | else: 23 | return super(JunctionTreeData, self).__inc__(key, item) 24 | 25 | 26 | 27 | def get_dataset(dataset_path, graph_level_feature=True, subgraph_level_feature=True): 28 | data_set = torch.load(dataset_path) 29 | num_node_features = data_set[0].x.size(1) 30 | num_edge_features = data_set[-1].edge_attr.size(1) 31 | num_graph_features = None 32 | if graph_level_feature: 33 | num_graph_features = data_set[0].graph_attr.size(-1) 34 | if subgraph_level_feature: 35 | data_set = [JunctionTreeData(**{k: v for k, v in data}) for data in data_set] 36 | return data_set, num_node_features, num_edge_features, num_graph_features 37 | 38 | 39 | 40 | def split_data(dataset, split_file=None, ori_dataset_path=None, name=None, split_rule=None, seed=None, split_size=[0.8, 0.1, 0.1]): 41 | if split_file is not None: 42 | with open(split_file, 'rb') as f: 43 | inds = pickle.load(f, encoding='latin1') 44 | train_ids, val_ids, test_ids = inds[0], inds[1], inds[2] 45 | train_dataset = [dataset[i] for i in train_ids] 46 | val_dataset = [dataset[i] for i in val_ids] 47 | test_dataset = [dataset[i] for i in test_ids] 48 | 49 | return train_dataset, val_dataset, test_dataset 50 | 51 | elif split_rule == "random": 52 | print("-----Random splitting-----") 53 | dataset = shuffle(dataset, random_state=seed) 54 | assert sum(split_size) == 1 55 | train_size = int(split_size[0] * len(dataset)) 56 | train_val_size = int((split_size[0] + split_size[1]) * len(dataset)) 57 | train_dataset = dataset[:train_size] 58 | val_dataset = dataset[train_size:train_val_size] 59 | test_dataset = dataset[train_val_size:] 60 | 61 | return train_dataset, val_dataset, test_dataset 62 | 63 | elif split_rule == "scaffold": 64 | print("-----Scaffold splitting-----") 65 | assert sum(split_size) == 1 66 | smile_list = [] 67 | path = osp.join(osp.dirname(osp.realpath(__file__)), ori_dataset_path, name+'.csv') 68 | with codecs.open(path, "r", encoding="utf-8-sig") as f: 69 | reader = csv.DictReader(f) 70 | for row in reader: 71 | smiles = row['smiles'] 72 | smile_list.append(smiles) 73 | scaffolds = {} 74 | for ind, smiles in enumerate(smile_list): 75 | scaffold = generate_scaffold(smiles) 76 | if scaffold not in scaffolds: 77 | scaffolds[scaffold] = [ind] 78 | else: 79 | scaffolds[scaffold].append(ind) 80 | # Sort from largest to smallest scaffold sets 81 | scaffolds = {key: sorted(value) for key, value in scaffolds.items()} 82 | scaffold_sets = [ 83 | scaffold_set 84 | for (scaffold, scaffold_set) in sorted( 85 | scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) 86 | ] 87 | train_size = split_size[0] * len(smile_list) 88 | train_val_size = (split_size[0] + split_size[1]) * len(smile_list) 89 | train_idx, val_idx, test_idx = [], [], [] 90 | for scaffold_set in scaffold_sets: 91 | if len(train_idx) + len(scaffold_set) > train_size: 92 | if len(train_idx) + len(val_idx) + len(scaffold_set) > train_val_size: 93 | test_idx += scaffold_set 94 | else: 95 | val_idx += scaffold_set 96 | else: 97 | train_idx += scaffold_set 98 | train_dataset = [dataset[i] for i in train_idx] 99 | val_dataset = [dataset[i] for i in val_idx] 100 | test_dataset = [dataset[i] for i in test_idx] 101 | 102 | return train_dataset, val_dataset, test_dataset 103 | 104 | elif split_rule == "stratified": 105 | print("-----stratified splitting-----") 106 | assert sum(split_size) == 1 107 | np.random.seed(seed) 108 | 109 | y = [] 110 | for data in dataset: 111 | y.append(data.y) 112 | assert len(y[0]) == 1 113 | y_s = np.array(y).squeeze(axis=1) 114 | sortidx = np.argsort(y_s) 115 | 116 | split_cd = 10 117 | train_cutoff = int(np.round(split_size[0] * split_cd))#8 118 | valid_cutoff = int(np.round(split_size[1] * split_cd)) + train_cutoff#9 119 | test_cutoff = int(np.round(split_size[2] * split_cd)) + valid_cutoff#10 120 | 121 | train_idx = np.array([]) 122 | valid_idx = np.array([]) 123 | test_idx = np.array([]) 124 | 125 | while sortidx.shape[0] >= split_cd: 126 | sortidx_split, sortidx = np.split(sortidx, [split_cd]) 127 | shuffled = np.random.permutation(range(split_cd)) 128 | train_idx = np.hstack([train_idx, sortidx_split[shuffled[:train_cutoff]]]) 129 | valid_idx = np.hstack([valid_idx, sortidx_split[shuffled[train_cutoff:valid_cutoff]]]) 130 | test_idx = np.hstack([test_idx, sortidx_split[shuffled[valid_cutoff:]]]) 131 | 132 | if sortidx.shape[0] > 0: np.hstack([train_idx, sortidx]) 133 | 134 | train_dataset = [dataset[int(i)] for i in train_idx] 135 | val_dataset = [dataset[int(i)] for i in valid_idx] 136 | test_dataset = [dataset[int(i)] for i in test_idx] 137 | 138 | return train_dataset, val_dataset, test_dataset 139 | 140 | 141 | 142 | class ScaffoldGenerator(object): 143 | """ 144 | Generate molecular scaffolds. 145 | """ 146 | def __init__(self, include_chirality=False): 147 | self.include_chirality = include_chirality 148 | 149 | def get_scaffold(self, mol): 150 | return MurckoScaffold.MurckoScaffoldSmiles( 151 | mol=mol, includeChirality=self.include_chirality) 152 | 153 | 154 | 155 | def generate_scaffold(smiles, include_chirality=False): 156 | """Compute the Bemis-Murcko scaffold for a SMILES string.""" 157 | mol = Chem.MolFromSmiles(smiles) 158 | engine = ScaffoldGenerator(include_chirality=include_chirality) 159 | scaffold = engine.get_scaffold(mol) 160 | return scaffold -------------------------------------------------------------------------------- /Graph/gine_mpnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch_scatter import scatter_mean, scatter_add, scatter_max 5 | from torch_geometric.nn import MessagePassing, GCNConv, NNConv, GINEConv 6 | from torch_geometric.utils import degree 7 | 8 | 9 | 10 | ### A basic MLP 11 | class MLP(torch.nn.Module): 12 | def __init__(self, input_dim, hidden_dim, output_dim, dropout): 13 | super(MLP, self).__init__() 14 | self.lin1 = torch.nn.Linear(input_dim, hidden_dim) 15 | self.lin2 = torch.nn.Linear(hidden_dim, output_dim) 16 | self.dropout = dropout 17 | 18 | def reset_parameters(self): 19 | self.lin1.reset_parameters() 20 | self.lin2.reset_parameters() 21 | 22 | def forward(self, x): 23 | x = F.dropout(x, p=self.dropout, training=self.training) 24 | x = F.relu(self.lin1(x)) 25 | x = F.dropout(x, p=self.dropout, training=self.training) 26 | x = self.lin2(x) 27 | 28 | return x 29 | 30 | 31 | 32 | class GINENet(torch.nn.Module): 33 | def __init__(self, num_node_features, num_edge_features, hidden, dropout, num_tasks): 34 | super(GINENet, self).__init__() 35 | self.conv1 = GINEConv(torch.nn.Sequential(torch.nn.Linear(num_node_features, hidden)),eps = 0, train_eps = True) 36 | self.conv2 = GINEConv(torch.nn.Sequential(torch.nn.Linear(hidden, hidden)),eps = 0, train_eps = True) 37 | self.conv3 = GINEConv(torch.nn.Sequential(torch.nn.Linear(hidden, hidden)),eps = 0, train_eps = True) 38 | self.lin1 = torch.nn.Linear(num_edge_features, num_node_features, bias = True) 39 | self.lin2 = torch.nn.Linear(num_node_features, hidden, bias = True) 40 | self.dropout = dropout 41 | self.mlp1 = MLP(hidden, hidden, num_tasks, dropout) 42 | 43 | def reset_parameters(self): 44 | self.conv1.reset_parameters() 45 | self.conv2.reset_parameters() 46 | self.conv3.reset_parameters() 47 | self.lin1.reset_parameters() 48 | self.lin2.reset_parameters() 49 | self.mlp1.reset_parameters() 50 | 51 | def forward(self, batch_data): 52 | x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr 53 | 54 | edge_attr = self.lin1(edge_attr.float()) 55 | x = F.relu(self.conv1(x, edge_index, edge_attr)) 56 | x = F.dropout(x, p=self.dropout, training=self.training) 57 | edge_attr = self.lin2(edge_attr.float()) 58 | x = F.relu(self.conv2(x, edge_index, edge_attr)) 59 | 60 | x = self.conv3(x, edge_index, edge_attr) 61 | out = scatter_mean(x, batch_data.batch, dim=0) 62 | out = self.mlp1(out) #[batch_szie, num_classes] 63 | 64 | return out 65 | 66 | 67 | class GraphSizeNorm(torch.nn.Module): 68 | """Applies Graph Size Normalization over each individual graph in a batch 69 | of node features as described in the 70 | "Benchmarking Graph Neural Networks" 71 | """ 72 | def __init__(self): 73 | super(GraphSizeNorm, self).__init__() 74 | 75 | def forward(self, x, batch=None): 76 | """""" 77 | if batch is None: 78 | batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) 79 | 80 | inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5) 81 | return x * inv_sqrt_deg[batch].view(-1, 1) 82 | 83 | 84 | 85 | class BatchNorm(torch.nn.BatchNorm1d): 86 | def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True, 87 | track_running_stats=True): 88 | super(BatchNorm, self).__init__(in_channels, eps, momentum, affine, 89 | track_running_stats) 90 | 91 | def forward(self, x): 92 | return super(BatchNorm, self).forward(x) 93 | 94 | 95 | def __repr__(self): 96 | return ('{}({}, eps={}, momentum={}, affine={}, ' 97 | 'track_running_stats={})').format(self.__class__.__name__, 98 | self.num_features, self.eps, 99 | self.momentum, self.affine, 100 | self.track_running_stats) 101 | 102 | 103 | ### MPNN + GraphSizeNorm + BatchNorm 104 | class NNNet(torch.nn.Module): 105 | def __init__(self, num_node_features, num_edge_features, hidden, dropout, num_tasks): 106 | super(NNNet, self).__init__() 107 | self.conv1 = NNConv(num_node_features, hidden, torch.nn.Sequential(torch.nn.Linear(num_edge_features, num_node_features*hidden))) 108 | self.conv2 = NNConv(hidden, hidden, torch.nn.Sequential(torch.nn.Linear(num_edge_features, hidden*hidden))) 109 | self.conv3 = NNConv(hidden, hidden, torch.nn.Sequential(torch.nn.Linear(num_edge_features, hidden*hidden))) 110 | self.mlp1 = MLP(hidden, hidden, num_tasks, dropout) 111 | self.dropout = dropout 112 | self.norm1 = GraphSizeNorm() 113 | self.bn1 = BatchNorm(hidden) 114 | self.norm2 = GraphSizeNorm() 115 | self.bn2 = BatchNorm(hidden) 116 | 117 | def reset_parameters(self): 118 | self.conv1.reset_parameters() 119 | self.conv2.reset_parameters() 120 | self.conv3.reset_parameters() 121 | self.mlp1.reset_parameters() 122 | 123 | def forward(self, batch_data): 124 | x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr 125 | x = self.conv1(x, edge_index, edge_attr) 126 | x = self.norm1(x, batch_data.batch) 127 | x = self.bn1(x) 128 | x = F.relu(x) 129 | x = F.dropout(x, p=self.dropout, training=self.training) 130 | x = self.conv2(x, edge_index, edge_attr) 131 | x = self.norm2(x, batch_data.batch) 132 | x = self.bn2(x) 133 | x = F.relu(x) 134 | x = self.conv3(x, edge_index, edge_attr) 135 | out = scatter_mean(x, batch_data.batch, dim=0) 136 | out = self.mlp1(out) #[batch_size, num_tasks] 137 | return out 138 | 139 | 140 | -------------------------------------------------------------------------------- /Graph/imbalanaced_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | sigmoidf = nn.Sigmoid() 6 | 7 | 8 | def squared_hinge_loss(predScore, targets, b): 9 | 10 | 11 | squared_hinge = (1-targets*(predScore - b)) 12 | squared_hinge[squared_hinge <=0] = 0 13 | 14 | 15 | return squared_hinge ** 2 16 | 17 | 18 | def sigmoid_loss(pos, neg, beta=2.0): 19 | return 1.0 / (1.0 + torch.exp(beta * (pos - neg))) 20 | 21 | def logistic_loss(pos, neg, beta = 1): 22 | return -torch.log(1/(1+torch.exp(-beta * (pos - neg)))) 23 | -------------------------------------------------------------------------------- /Graph/main_hiv.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Qi' 2 | # Created by on 10/31/21. 3 | from gine_mpnn import NNNet, GINENet 4 | import os 5 | from datasets import * 6 | from train_eval import run_classification 7 | from model import * 8 | from config.config_hiv import conf 9 | 10 | 11 | def run_main(data_name = 'tox21_t0', arch_list = ['gine', 'mpnn', 'mlpmpnn'], conf = conf): 12 | conf['dataset'] = data_name 13 | if conf['dataset'] == 'hiv': 14 | data_file, split_file = 'datasets/gnn_feats/hiv.pt', 'datasets/split_inds/hivscaffold.pkl' 15 | elif conf['dataset'] == 'muv': 16 | data_file, split_file = 'datasets/gnn_feats/muv_2.pt', 'datasets/split_inds/muv_2.pkl' 17 | elif conf['dataset'] == 'tox21_t0': 18 | conf['hidden'] = 128 19 | data_file, split_file = 'datasets/gnn_feats/tox21/tox21_0.pt', 'datasets/split_inds/tox21/tox21_0.pkl' 20 | elif conf['dataset'] == 'tox21_t2': 21 | conf['hidden'] = 128 22 | data_file, split_file = 'datasets/gnn_feats/tox21/tox21_2.pt', 'datasets/split_inds/tox21/tox21_2.pkl' 23 | elif conf['dataset'] == 'toxcast_t8': 24 | conf['hidden'] = 64 25 | data_file, split_file = 'datasets/gnn_feats/toxcast/toxcast_8.pt', 'datasets/split_inds/toxcast/toxcast_8.pkl' 26 | elif conf['dataset'] == 'toxcast_t12': 27 | conf['hidden'] = 64 28 | data_file, split_file = 'datasets/gnn_feats/toxcast/toxcast_12.pt', 'datasets/split_inds/toxcast/toxcast_12.pkl' 29 | 30 | dataset, num_node_features, num_edge_features, num_graph_features = get_dataset(data_file, conf['graph_level_feature']) 31 | assert conf['num_tasks'] == dataset[0].y.shape[-1] 32 | train_dataset, val_dataset, test_dataset = split_data(dataset, split_file) 33 | 34 | 35 | 36 | conf['loss_param'] = {'K':10, 'm':5, 'gamma':1000, 'tau':3.0, 'bins':2, 'mv_gamma':0.9, 'type':'sqh', 'threshold':10} 37 | conf['ft_mode'] = 'fc_random' 38 | conf['posNum'] = 1 39 | model = None 40 | for j in range(2): 41 | for method in ['SOAP']: # ['wce', 'focal', 'ldam', 'auroc', 'smoothAP', 'fastAP', 'minmax', 'SOAP']: #['wce', 'focal', 'ldam', 'auroc', 'smoothAP', 'fastAP', 'minmax', 'SOAP']: 42 | # for j in range(1): 43 | # for method in ['ce']: #['wce', 'focal', 'ldam', 'auroc', 'smoothAP', 'fastAP', 'minmax', 'SOAP']: 44 | conf['loss_type'] = method 45 | for i in range(len(arch_list)): 46 | if arch_list[i] == 'mpnn': 47 | conf['arch'] = 'mpnn' 48 | # conf['pre_train'] = None 49 | conf['pre_train'] = './pretrained_models/' + str.split(conf['dataset'], '_')[0] +'_pretrained_model/' + '_'.join([conf['dataset'], conf['arch'], 'ce.ckpt']) 50 | model = NNNet(num_node_features, num_edge_features, conf['hidden'], conf['dropout'], conf['num_tasks']) 51 | elif arch_list[i] == 'mlpmpnn': 52 | conf['arch'] = 'mlpmpnn' 53 | # conf['pre_train'] = None 54 | conf['pre_train'] = './pretrained_models/' + str.split(conf['dataset'], '_')[0] +'_pretrained_model/' + '_'.join([conf['dataset'], conf['arch'], 'ce.ckpt']) 55 | model = MLNet2(num_node_features, num_edge_features, num_graph_features, conf['hidden'], conf['dropout'], 56 | conf['num_tasks'], conf['depth'], conf['graph_level_feature']) 57 | elif arch_list[i] == 'gine': 58 | conf['arch'] = 'gine' 59 | conf['epochs'] = 150 60 | # conf['pre_train'] = None 61 | conf['pre_train'] = './pretrained_models/' + str.split(conf['dataset'], '_')[0] +'_pretrained_model/' + '_'.join([conf['dataset'], conf['arch'], 'ce.ckpt']) 62 | model = GINENet(num_node_features, num_edge_features, conf['hidden'], conf['dropout'], conf['num_tasks']) 63 | print(model is not None) 64 | if model is not None: 65 | out_path = 'results_' + conf['loss_type'] + '/' + conf['arch'] + '/' + '_'.join( 66 | ['Re_SOAP', conf['arch'], conf['dataset'], conf['loss_type'], conf['loss_param']['type'], 'lr', str(conf['lr']), 'th', str(conf['loss_param']['threshold']), 'posNum', str(conf['posNum']), 'wd', str(conf['weight_decay']), 'repeats', str(j), 'epoch', str(conf['epochs']), 'mv_gamma', str(conf['loss_param']['mv_gamma'])]) 67 | if not os.path.exists(out_path): 68 | os.makedirs(out_path) 69 | print("conf:", conf) 70 | run_classification(train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['posNum'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path) 71 | print("conf:", conf) 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | data_list = ['hiv'] #'tox21_t0', 'tox21_t2', , 'toxcast_t12', 'tox21_t0' 78 | arch_list = ['mpnn', 'gine', 'mlpmpnn'] 79 | conf['lr'] = 5e-4 80 | for data in data_list: 81 | run_main(data_name = data, arch_list=arch_list, conf = conf) 82 | 83 | 84 | 85 | 86 | # out_path = 'results_gine_hiv_auprc2' 87 | # if not os.path.exists(out_path): 88 | # os.makedirs(out_path) 89 | # conf['loss_type'] = 'fastAP' #'smoothAP' 90 | # conf['loss_param'] = {'K':10, 'm':5, 'gamma':1000, 'tau':3.0, 'bins':3} 91 | # conf['ft_mode'] = 'fc_random' 92 | # conf['pre_train'] = 'results_gine_hiv_ce/params150.ckpt' 93 | # model = GINENet(num_node_features, num_edge_features, conf['hidden'], conf['dropout'], conf['num_tasks']) 94 | # run_classification(train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path) 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /Graph/metric.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import auc, precision_recall_curve, roc_auc_score, mean_absolute_error, mean_squared_error 2 | import numpy as np 3 | 4 | 5 | 6 | 7 | def prc_auc(targets, preds): 8 | precision, recall, _ = precision_recall_curve(targets, preds) 9 | return auc(recall, precision) 10 | 11 | 12 | 13 | 14 | def compute_cla_metric(targets, preds, num_tasks): 15 | 16 | prc_results = [] 17 | roc_results = [] 18 | for i in range(num_tasks): 19 | is_labeled = targets[:,i] == targets[:,i] ## filter some samples without groundtruth label 20 | target = targets[is_labeled,i] 21 | pred = preds[is_labeled,i] 22 | try: 23 | prc = prc_auc(target, pred) 24 | except ValueError: 25 | prc = np.nan 26 | print("In task #", i+1, " , there is only one class present in the set. PRC is not defined in this case.") 27 | try: 28 | roc = roc_auc_score(target, pred) 29 | except ValueError: 30 | roc = np.nan 31 | print("In task #", i+1, " , there is only one class present in the set. ROC is not defined in this case.") 32 | if not np.isnan(prc): 33 | prc_results.append(prc) 34 | else: 35 | print("PRC results do not consider task #", i+1) 36 | if not np.isnan(roc): 37 | roc_results.append(roc) 38 | else: 39 | print("ROC results do not consider task #", i+1) 40 | return prc_results, roc_results 41 | 42 | -------------------------------------------------------------------------------- /Graph/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch_scatter import scatter_mean, scatter_softmax, scatter_add, scatter_sum 5 | from torch_geometric.nn import MessagePassing, GCNConv, NNConv, GINConv 6 | from torch_geometric.utils import degree 7 | 8 | import numpy as np 9 | 10 | 11 | 12 | class MLP(torch.nn.Module): 13 | def __init__(self, input_dim, hidden_dim, output_dim, dropout, activation): 14 | super(MLP, self).__init__() 15 | self.lin1 = torch.nn.Linear(input_dim, hidden_dim) 16 | self.lin2 = torch.nn.Linear(hidden_dim, output_dim) 17 | self.dropout = dropout 18 | self.activation = activation 19 | 20 | def reset_parameters(self): 21 | self.lin1.reset_parameters() 22 | self.lin2.reset_parameters() 23 | 24 | def forward(self, x): 25 | x = F.dropout(x, p=self.dropout, training=self.training) 26 | x = self.activation(self.lin1(x)) 27 | x = F.dropout(x, p=self.dropout, training=self.training) 28 | x = self.lin2(x) 29 | 30 | return x 31 | 32 | 33 | 34 | class SizeNorm(torch.nn.Module): 35 | def __init__(self): 36 | super(SizeNorm, self).__init__() 37 | 38 | def forward(self, x, batch=None): 39 | """""" 40 | if batch is None: 41 | batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) 42 | 43 | inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5) 44 | return x * inv_sqrt_deg[batch].view(-1, 1) 45 | 46 | 47 | 48 | class BatchNorm(torch.nn.BatchNorm1d): 49 | def __init__(self, in_channels, eps=1e-5, momentum=0.1, affine=True, 50 | track_running_stats=True): 51 | super(BatchNorm, self).__init__(in_channels, eps, momentum, affine, 52 | track_running_stats) 53 | 54 | def forward(self, x): 55 | return super(BatchNorm, self).forward(x) 56 | 57 | 58 | def __repr__(self): 59 | return ('{}({}, eps={}, momentum={}, affine={}, ' 60 | 'track_running_stats={})').format(self.__class__.__name__, 61 | self.num_features, self.eps, 62 | self.momentum, self.affine, 63 | self.track_running_stats) 64 | 65 | 66 | ########################################### 67 | class EdgeModel_ml2(torch.nn.Module): 68 | def __init__(self, input_dim, hidden_dim, output_dim, dropout): 69 | super(EdgeModel_ml2, self).__init__() 70 | self.edge_mlp = MLP(input_dim, hidden_dim, output_dim, dropout, F.relu) 71 | def forward(self, x, edge_index, edge_attr, u, batch): 72 | row, col = edge_index 73 | out = torch.cat([x[row], x[col], edge_attr, u[batch[row]]], 1) 74 | return self.edge_mlp(out) ### Step 1 75 | 76 | class NodeModel_ml2(torch.nn.Module): 77 | def __init__(self, input_dim1, hidden_dim1, output_dim1, input_dim2, hidden_dim2, output_dim2, dropout): 78 | super(NodeModel_ml2, self).__init__() 79 | self.node_mlp_1 = MLP(input_dim1, hidden_dim1, output_dim1, dropout, F.relu) 80 | self.node_mlp_2 = MLP(input_dim2, hidden_dim2, output_dim2, dropout, F.relu) 81 | 82 | def forward(self, x, edge_index, edge_attr, u, batch): 83 | row, col = edge_index 84 | out = torch.cat([x[row], edge_attr], dim=1) 85 | out = scatter_sum(out, col, dim=0, dim_size=x.size(0)) 86 | out = self.node_mlp_1(out) ### Step 2 87 | out = torch.cat([x, out, u[batch]], dim=1) 88 | return self.node_mlp_2(out) ### Step 3 89 | 90 | class SubgraphModel_ml2(torch.nn.Module): 91 | def __init__(self, input_dim, hidden_dim, output_dim, dropout): 92 | super(SubgraphModel_ml2, self).__init__() 93 | self.mlp1 = MLP(hidden_dim, hidden_dim, hidden_dim, dropout, F.relu) 94 | self.mlp2 = MLP(hidden_dim, hidden_dim, hidden_dim, dropout, F.relu) 95 | self.subgraph_mlp = MLP(input_dim, hidden_dim, output_dim, dropout, F.relu) 96 | 97 | 98 | def forward(self, x, x_clique, tree_edge_index, atom2clique_index, u, tree_batch): 99 | row, col = tree_edge_index 100 | out = scatter_sum(x_clique[row], col, dim=0, dim_size=x_clique.size(0)) 101 | out = self.mlp1(out) 102 | row_assign, col_assign = atom2clique_index 103 | node_info = scatter_sum(x[row_assign], col_assign, dim=0, dim_size=x_clique.size(0)) 104 | node_info = self.mlp2(node_info) ### Step 4 105 | out = torch.cat([node_info, x_clique, out, u[tree_batch]], dim=1) 106 | return self.subgraph_mlp(out) ### Step 5 107 | 108 | 109 | 110 | class GlobalModel_ml2(torch.nn.Module): 111 | def __init__(self, input_dim, hidden_dim, output_dim, dropout): 112 | super(GlobalModel_ml2, self).__init__() 113 | self.global_mlp = MLP(input_dim, hidden_dim, output_dim, dropout, F.relu) 114 | 115 | 116 | def forward(self, x, edge_index, edge_attr, x_clique, u, batch, tree_batch): 117 | row, col = edge_index 118 | edge_info = scatter_mean(edge_attr, batch[row], dim=0, dim_size=u.size(0)) ### Step 6 119 | node_info = scatter_mean(x, batch, dim=0, dim_size=u.size(0)) ### Step 7 120 | subgraph_info = scatter_mean(x_clique, tree_batch, dim=0, dim_size=u.size(0)) ### Step 8 121 | out = torch.cat([u, node_info, edge_info, subgraph_info], dim=1) 122 | return self.global_mlp(out) ### Step 9 123 | 124 | 125 | 126 | class MetaLayer_ml2(torch.nn.Module): 127 | def __init__(self, input_node_rep_dim, input_edge_rep_dim, input_subgraph_rep_dim, input_global_rep_dim, output_node_rep_dim, output_edge_rep_dim, output_subgraph_rep_dim, output_global_rep_dim, hidden, dropout): 128 | super(MetaLayer_ml2, self).__init__() 129 | self.edge_model = EdgeModel_ml2(2*input_node_rep_dim+input_edge_rep_dim+input_global_rep_dim, hidden, output_edge_rep_dim, dropout) 130 | self.node_model = NodeModel_ml2(input_node_rep_dim+output_edge_rep_dim, hidden, hidden, hidden+input_node_rep_dim+input_global_rep_dim, hidden, output_node_rep_dim, dropout) 131 | self.subgraph_model = SubgraphModel_ml2(2*input_subgraph_rep_dim+output_node_rep_dim+input_global_rep_dim, hidden, output_subgraph_rep_dim, dropout) 132 | self.global_model = GlobalModel_ml2(input_global_rep_dim+output_node_rep_dim+output_edge_rep_dim+output_subgraph_rep_dim, hidden, output_global_rep_dim, dropout) 133 | 134 | 135 | def forward(self, x, edge_index, edge_attr, u, tree_edge_index, atom2clique_index, x_clique, ori_batch, tree_batch): 136 | 137 | edge_attr = self.edge_model(x, edge_index, edge_attr, u, ori_batch) 138 | x = self.node_model(x, edge_index, edge_attr, u, ori_batch) 139 | x_clique = self.subgraph_model(x, x_clique, tree_edge_index, atom2clique_index, u, tree_batch) 140 | u = self.global_model(x, edge_index, edge_attr, x_clique, u, ori_batch, tree_batch) 141 | 142 | return x, edge_attr, x_clique, u 143 | 144 | 145 | 146 | class MLNet2(torch.nn.Module): 147 | def __init__(self, num_node_features, num_edge_features, num_global_features, hidden, dropout, num_tasks, depth, graph_level_feature): 148 | super(MLNet2, self).__init__() 149 | self.mlp_node = MLP(num_node_features, hidden, hidden, dropout, F.relu) 150 | self.mlp_edge = MLP(num_edge_features, hidden, hidden, dropout, F.relu) 151 | self.emb_subgraph = torch.nn.Embedding(4, hidden) 152 | self.graph_level_feature = graph_level_feature 153 | if self.graph_level_feature: 154 | self.mlp_global = MLP(num_edge_features, hidden, hidden, dropout, F.relu) 155 | self.mlp1 = MLP(num_global_features+hidden, hidden, num_tasks, dropout, F.relu) 156 | else: 157 | self.mlp_global = MLP(num_edge_features, hidden, hidden, dropout, F.relu) 158 | self.mlp1 = MLP(hidden, hidden, num_tasks, dropout, F.relu) 159 | 160 | 161 | self.dropout = dropout 162 | self.num_global_features = num_global_features 163 | 164 | self.depth = depth 165 | 166 | self.gn = torch.nn.ModuleList([MetaLayer_ml2(hidden, hidden, hidden, hidden, hidden, hidden, hidden, hidden, hidden, dropout) for i in range(self.depth)]) 167 | self.norm_node = torch.nn.ModuleList([SizeNorm() for i in range(self.depth+1)]) 168 | self.norm_edge = torch.nn.ModuleList([SizeNorm() for i in range(self.depth+1)]) 169 | self.norm_subgraph = torch.nn.ModuleList([SizeNorm() for i in range(self.depth+1)]) 170 | self.bn_node = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 171 | self.bn_edge = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 172 | self.bn_subgraph = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 173 | self.bn_global = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 174 | 175 | 176 | 177 | def forward(self, batch_data): 178 | 179 | if self.graph_level_feature: ### Use rdkit_2d_normalized_features 180 | x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr 181 | row, col = edge_index 182 | u = scatter_mean(edge_attr, batch_data.batch[row], dim=0, dim_size=max(batch_data.batch)+1) 183 | tree_edge_index, atom2clique_index, num_cliques, x_clique = batch_data.tree_edge_index, batch_data.atom2clique_index, batch_data.num_cliques, batch_data.x_clique 184 | aug_feat = batch_data.graph_attr 185 | if len(aug_feat.shape) != 2: 186 | aug_feat = torch.reshape(aug_feat, (-1, self.num_global_features)) 187 | else: 188 | x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr 189 | row, col = edge_index 190 | u = scatter_mean(edge_attr, batch_data.batch[row], dim=0, dim_size=max(batch_data.batch)+1) 191 | tree_edge_index, atom2clique_index, num_cliques, x_clique = batch_data.tree_edge_index, batch_data.atom2clique_index, batch_data.num_cliques, batch_data.x_clique 192 | 193 | x = self.mlp_node(x) 194 | edge_attr = self.mlp_edge(edge_attr) 195 | x_clique = self.emb_subgraph(x_clique) 196 | u = self.mlp_global(u) 197 | 198 | row, col = edge_index 199 | 200 | ori_batch = batch_data.batch 201 | tree_batch = torch.repeat_interleave(num_cliques) 202 | 203 | x = self.norm_node[-1](x, ori_batch) 204 | edge_attr = self.norm_edge[-1](edge_attr, ori_batch[row]) 205 | x_clique = self.norm_subgraph[-1](x_clique, tree_batch) 206 | x = self.bn_node[-1](x) 207 | edge_attr = self.bn_edge[-1](edge_attr) 208 | x_clique = self.bn_subgraph[-1](x_clique) 209 | u = self.bn_global[-1](u) 210 | 211 | for i in range(self.depth): 212 | 213 | x, edge_attr, x_clique, u = self.gn[i](x, edge_index, edge_attr, u, tree_edge_index, atom2clique_index, x_clique, ori_batch, tree_batch) 214 | 215 | x = self.norm_node[i](x, batch_data.batch) 216 | edge_attr = self.norm_edge[i](edge_attr, batch_data.batch[row]) 217 | x_clique = self.norm_subgraph[i](x_clique, tree_batch) 218 | x = self.bn_node[i](x) 219 | edge_attr = self.bn_edge[i](edge_attr) 220 | x_clique = self.bn_subgraph[i](x_clique) 221 | u = self.bn_global[i](u) 222 | 223 | if self.graph_level_feature: 224 | u = torch.cat([u,aug_feat], dim=1) 225 | out = self.mlp1(u) 226 | 227 | 228 | return out 229 | 230 | 231 | 232 | 233 | ############## Ablation studey: w/o subgraph-level 234 | 235 | class GlobalModel_ml3(torch.nn.Module): 236 | def __init__(self, input_dim, hidden_dim, output_dim, dropout): 237 | super(GlobalModel_ml3, self).__init__() 238 | self.global_mlp = MLP(input_dim, hidden_dim, output_dim, dropout, F.relu) 239 | 240 | 241 | def forward(self, x, edge_index, edge_attr, u, batch): 242 | row, col = edge_index 243 | edge_info = scatter_mean(edge_attr, batch[row], dim=0, dim_size=u.size(0)) 244 | node_info = scatter_mean(x, batch, dim=0, dim_size=u.size(0)) 245 | out = torch.cat([u, node_info, edge_info], dim=1) 246 | return self.global_mlp(out) 247 | 248 | class MetaLayer_ml3(torch.nn.Module): 249 | def __init__(self, input_node_rep_dim, input_edge_rep_dim, input_global_rep_dim, output_node_rep_dim, output_edge_rep_dim, output_global_rep_dim, hidden, dropout): 250 | super(MetaLayer_ml3, self).__init__() 251 | self.edge_model = EdgeModel_ml2(2*input_node_rep_dim+input_edge_rep_dim+input_global_rep_dim, hidden, output_edge_rep_dim, dropout) 252 | self.node_model = NodeModel_ml2(input_node_rep_dim+output_edge_rep_dim, hidden, hidden, hidden+input_node_rep_dim+input_global_rep_dim, hidden, output_node_rep_dim, dropout) 253 | self.global_model = GlobalModel_ml3(input_global_rep_dim+output_node_rep_dim+output_edge_rep_dim, hidden, output_global_rep_dim, dropout) 254 | 255 | 256 | def forward(self, x, edge_index, edge_attr, u, ori_batch): 257 | 258 | edge_attr = self.edge_model(x, edge_index, edge_attr, u, ori_batch) 259 | x = self.node_model(x, edge_index, edge_attr, u, ori_batch) 260 | u = self.global_model(x, edge_index, edge_attr, u, ori_batch) 261 | 262 | return x, edge_attr, u 263 | 264 | 265 | 266 | class MLNet3(torch.nn.Module): 267 | def __init__(self, num_node_features, num_edge_features, num_global_features, hidden, dropout, num_tasks, depth, graph_level_feature): 268 | super(MLNet3, self).__init__() 269 | self.mlp_node = MLP(num_node_features, hidden, hidden, dropout, F.relu) 270 | self.mlp_edge = MLP(num_edge_features, hidden, hidden, dropout, F.relu) 271 | self.graph_level_feature = graph_level_feature 272 | if self.graph_level_feature: 273 | self.mlp_global = MLP(num_edge_features, hidden, hidden, dropout, F.relu) 274 | self.mlp1 = MLP(num_global_features+hidden, hidden, num_tasks, dropout, F.relu) 275 | else: 276 | self.mlp_global = MLP(num_edge_features, hidden, hidden, dropout, F.relu) 277 | self.mlp1 = MLP(hidden, hidden, num_tasks, dropout, F.relu) 278 | 279 | 280 | self.dropout = dropout 281 | self.num_global_features = num_global_features 282 | 283 | self.depth = depth 284 | 285 | self.gn = torch.nn.ModuleList([MetaLayer_ml3(hidden, hidden, hidden, hidden, hidden, hidden, hidden, dropout) for i in range(self.depth)]) 286 | self.norm_node = torch.nn.ModuleList([SizeNorm() for i in range(self.depth+1)]) 287 | self.norm_edge = torch.nn.ModuleList([SizeNorm() for i in range(self.depth+1)]) 288 | self.norm_subgraph = torch.nn.ModuleList([SizeNorm() for i in range(self.depth+1)]) 289 | self.bn_node = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 290 | self.bn_edge = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 291 | self.bn_subgraph = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 292 | self.bn_global = torch.nn.ModuleList([BatchNorm(hidden) for i in range(self.depth+1)]) 293 | 294 | 295 | 296 | def forward(self, batch_data): 297 | 298 | if self.graph_level_feature: ### Use rdkit_2d_normalized_features as input graph-level feature 299 | x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr 300 | row, col = edge_index 301 | u = scatter_mean(edge_attr, batch_data.batch[row], dim=0, dim_size=max(batch_data.batch)+1) 302 | aug_feat = batch_data.graph_attr 303 | if len(aug_feat.shape) != 2: 304 | aug_feat = torch.reshape(aug_feat, (-1, self.num_global_features)) 305 | else: 306 | x, edge_index, edge_attr = batch_data.x, batch_data.edge_index, batch_data.edge_attr 307 | row, col = edge_index 308 | u = scatter_mean(edge_attr, batch_data.batch[row], dim=0, dim_size=max(batch_data.batch)+1) 309 | 310 | x = self.mlp_node(x) 311 | edge_attr = self.mlp_edge(edge_attr) 312 | u = self.mlp_global(u) 313 | 314 | row, col = edge_index 315 | 316 | ori_batch = batch_data.batch 317 | 318 | x = self.norm_node[-1](x, ori_batch) 319 | edge_attr = self.norm_edge[-1](edge_attr, ori_batch[row]) 320 | x = self.bn_node[-1](x) 321 | edge_attr = self.bn_edge[-1](edge_attr) 322 | u = self.bn_global[-1](u) 323 | 324 | for i in range(self.depth): 325 | 326 | x, edge_attr, u = self.gn[i](x, edge_index, edge_attr, u, ori_batch) 327 | 328 | x = self.norm_node[i](x, batch_data.batch) 329 | edge_attr = self.norm_edge[i](edge_attr, batch_data.batch[row]) 330 | x = self.bn_node[i](x) 331 | edge_attr = self.bn_edge[i](edge_attr) 332 | u = self.bn_global[i](u) 333 | 334 | if self.graph_level_feature: 335 | u = torch.cat([u,aug_feat], dim=1) 336 | out = self.mlp1(u) 337 | 338 | 339 | return out 340 | 341 | 342 | -------------------------------------------------------------------------------- /Graph/predict.py: -------------------------------------------------------------------------------- 1 | from datasets import * 2 | from model import * 3 | 4 | import os 5 | import torch 6 | 7 | from torch_geometric.data import DataLoader 8 | from metric import compute_cla_metric 9 | import numpy as np 10 | 11 | 12 | def get_data_files(name, seed=122): 13 | if name in ['hiv']: 14 | data_file = 'datasets/gnn_feats/{}.pt'.format(name) 15 | split_files = ['datasets/split_inds/{}scaffold{}.pkl'.format(name, x) for x in [122, 123, 124]] 16 | split_file = split_files[seed-122] 17 | elif name in ['pcba', 'muv', 'tox21', 'toxcast']: 18 | data_file = 'datasets/gnn_feats/{}.pt'.format(name) 19 | split_files = ['datasets/split_inds/{}random{}.pkl'.format(name, x) for x in [122, 123, 124]] 20 | split_file = split_files[seed-122] 21 | return data_file, split_file 22 | 23 | 24 | def test_classification(model, test_loader, num_tasks, device, save_pred=False, out_path=None): 25 | model.eval() 26 | 27 | preds = torch.Tensor([]).to(device) 28 | targets = torch.Tensor([]).to(device) 29 | for batch_data in test_loader: 30 | batch_data = batch_data.to(device) 31 | with torch.no_grad(): 32 | out = model(batch_data) 33 | if len(batch_data.y.shape) != 2: 34 | batch_data.y = torch.reshape(batch_data.y, (-1, num_tasks)) 35 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 36 | preds = torch.cat([preds,pred], dim=0) 37 | targets = torch.cat([targets, batch_data.y], dim=0) 38 | 39 | if torch.cuda.is_available(): 40 | if save_pred: 41 | np.save(out_path+'/'+'pred.npy', preds.cpu().detach().numpy()) 42 | prc_results, roc_results = compute_cla_metric(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), num_tasks) 43 | else: 44 | if save_pred: 45 | np.save(out_path+'/'+'pred.npy', preds) 46 | prc_results, roc_results = compute_cla_metric(targets, preds, num_tasks) 47 | 48 | return prc_results, roc_results 49 | 50 | 51 | from config.config_hiv import conf 52 | out_path = 'results_hiv' 53 | if not os.path.exists(out_path): 54 | os.makedirs(out_path) 55 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 57 | 58 | data_file, split_file = 'datasets/gnn_feats/hiv.pt','datasets/split_inds/hivscaffold.pkl' 59 | dataset, num_node_features, num_edge_features, num_graph_features = get_dataset(data_file, conf['graph_level_feature']) 60 | assert conf['num_tasks'] == dataset[0].y.shape[-1] 61 | train_dataset, val_dataset, test_dataset = split_data(dataset, split_file) 62 | 63 | test_loader = DataLoader(test_dataset, conf['batch_size'], shuffle=False) 64 | model = MLNet2(num_node_features, num_edge_features, num_graph_features, conf['hidden'], conf['dropout'], conf['num_tasks'], conf['depth'], conf['graph_level_feature']) 65 | model = model.to(device) 66 | print('======================') 67 | print('Loading trained medel and testing...') 68 | model_dir = 'bce_models/ml2features_hiv' 69 | model_dir = os.path.join(model_dir, 'params.ckpt') 70 | model.load_state_dict(torch.load(model_dir)) 71 | num_tasks = conf['num_tasks'] 72 | 73 | test_prc_results, test_roc_results = test_classification(model, test_loader, num_tasks, device, out_path=out_path) 74 | print('======================') 75 | print('Test PRC (avg over multitasks): {:.4f}'.format(np.mean(test_prc_results))) 76 | 77 | -------------------------------------------------------------------------------- /Graph/pretrained_models/hiv_pretrained_model/hiv_gine_ce.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Optimization-AI/NeurIPS2021_SOAP/c120d8793cbffe702f960527bd83d41f00065b5f/Graph/pretrained_models/hiv_pretrained_model/hiv_gine_ce.ckpt -------------------------------------------------------------------------------- /Graph/pretrained_models/hiv_pretrained_model/hiv_mlpmpnn_ce.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Optimization-AI/NeurIPS2021_SOAP/c120d8793cbffe702f960527bd83d41f00065b5f/Graph/pretrained_models/hiv_pretrained_model/hiv_mlpmpnn_ce.ckpt -------------------------------------------------------------------------------- /Graph/pretrained_models/hiv_pretrained_model/hiv_mpnn_ce.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Optimization-AI/NeurIPS2021_SOAP/c120d8793cbffe702f960527bd83d41f00065b5f/Graph/pretrained_models/hiv_pretrained_model/hiv_mpnn_ce.ckpt -------------------------------------------------------------------------------- /Graph/train_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from numpy.lib import ufunclike 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.optim import Adam 6 | from torch_geometric.data import DataLoader 7 | from metric import compute_cla_metric 8 | import numpy as np 9 | from SOAP import AUPRCSampler, SOAPLOSS 10 | # from imbalanced_loss import * 11 | # from auprc_hinge import * 12 | import time 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | 17 | ### This is run function for classification tasks 18 | def run_classification(train_dataset, val_dataset, test_dataset, model, num_tasks, epochs, batch_size, vt_batch_size, 19 | lr, lr_decay_factor, lr_decay_step_size, weight_decay, posNum = 1, loss_type='ce', loss_param={}, 20 | ft_mode='fc_random', pre_train=None, save_dir=None, repeats = 0): 21 | model = model.to(device) 22 | if pre_train is not None: 23 | model.load_state_dict(torch.load(pre_train)) 24 | if ft_mode == 'fc_random': 25 | model.mlp1.reset_parameters() 26 | 27 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 28 | global u, a, b, m, alpha, lamda 29 | if loss_type == 'ce': 30 | criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 31 | elif loss_type in ['SOAP']: 32 | labels = [int(data.y.item()) for data in train_dataset] 33 | criterion = SOAPLOSS(loss_param['threshold'], batch_size, 34 | len(train_dataset) + len(val_dataset) + len(test_dataset), loss_param['type'], loss_param['mv_gamma']) 35 | elif loss_type in ['sum']: 36 | criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 37 | labels = [int(data.y.item()) for data in train_dataset] 38 | u = torch.zeros([len(train_dataset) + len(val_dataset) + len(test_dataset)]) 39 | elif loss_type in ['wce', 'focal', 'ldam']: 40 | labels = [int(data.y.item()) for data in train_dataset] 41 | n_pos = sum(labels) 42 | n_neg = len(labels) - n_pos 43 | cls_num_list = [n_neg, n_pos] 44 | if loss_type == 'wce': 45 | criterion = WeightedBCEWithLogitsLoss(cls_num_list=cls_num_list) 46 | elif loss_type == 'focal': 47 | criterion = FocalLoss(cls_num_list=cls_num_list) 48 | elif loss_type == 'ldam': 49 | criterion = BINARY_LDAMLoss(cls_num_list=cls_num_list) 50 | elif loss_type in ['auroc']: 51 | criterion = None 52 | a, b, alpha, m = float(1), float(0), float(1), loss_param['m'] 53 | labels = [int(data.y.item()) for data in train_dataset] 54 | loss_param['pos_ratio'] = sum(labels) / len(labels) 55 | elif loss_type in ['minmax']: 56 | criterion = AUCPRHingeLoss() 57 | elif loss_type in ['smoothAP']: 58 | criterion = smoothAP 59 | elif loss_type in ['fastAP']: 60 | criterion = fastAP 61 | 62 | train_loader_for_prc = DataLoader(train_dataset, vt_batch_size, shuffle=False) 63 | val_loader = DataLoader(val_dataset, vt_batch_size, shuffle=False) 64 | test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False) 65 | 66 | best_val_metric = 0 67 | best_test_metric = 0 68 | epoch_bvl = 0 69 | epoch_test = 0 70 | crs_test_metric = 0 71 | 72 | # save the training records 73 | save_file = os.path.join(save_dir, 'record' + '_' + str(repeats) + '.txt') 74 | labels = [int(data.y.item()) for data in train_dataset] 75 | for epoch in range(1, epochs + 1): 76 | # if loss_type in ['ce', 'wce', 'focal', 'ldam', 'auroc', 'auroc2', 'minmax', 'smoothAP', 'fastAP']: 77 | # train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True) 78 | # elif loss_type in ['auprc1', 'auprc2', 'SOAP', 'sum']: 79 | 80 | if loss_type == 'ce': 81 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) 82 | else: 83 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=AUPRCSampler(labels, batch_size, posNum=posNum)) 84 | 85 | avg_train_loss = train_classification(model, optimizer, train_loader, num_tasks, device, epoch, lr, posNum, criterion, 86 | loss_type, loss_param) 87 | train_prc_results, train_roc_results = test_classification(model, train_loader_for_prc, num_tasks, device) 88 | val_prc_results, val_roc_results = test_classification(model, val_loader, num_tasks, device) 89 | test_prc_results, test_roc_results = test_classification(model, test_loader, num_tasks, device) 90 | 91 | print('Epoch: {:03d}, Training Loss: {:.6f}, Val PRC (avg over multitasks): {:.4f}, Test PRC (avg over multitasks): {:.4f}'.format(epoch, avg_train_loss, np.mean(val_prc_results), np.mean(test_prc_results))) 92 | 93 | if epoch % lr_decay_step_size == 0: 94 | for param_group in optimizer.param_groups: 95 | param_group['lr'] = lr_decay_factor * param_group['lr'] 96 | 97 | if np.mean(val_prc_results) > best_val_metric: 98 | epoch_bvl = epoch 99 | best_val_metric, crs_test_metric = np.mean(val_prc_results), np.mean(test_prc_results) 100 | torch.save(model.state_dict(), os.path.join(save_dir, 'best_eval.ckpt')) 101 | 102 | 103 | if np.mean(test_prc_results) > best_test_metric: 104 | epoch_test = epoch 105 | best_test_metric = np.mean(test_prc_results) 106 | torch.save(model.state_dict(), os.path.join(save_dir, 'best_test.ckpt')) 107 | 108 | 109 | if epoch == 1: 110 | while os.path.exists(save_file): 111 | repeats += 1 112 | save_file = os.path.join(save_dir, 'record' + '_' + str(repeats) + '.txt') 113 | 114 | if save_file is not None: 115 | if epoch == epochs: 116 | torch.save(model.state_dict(), os.path.join(save_dir, 'params{}.ckpt'.format(epoch))) 117 | fp = open(save_file, 'a') 118 | fp.write( 119 | 'Epoch: {:03d}, Train avg loss: {:.4f}, Train PRC: {:.4f}, Val PRC: {:.4f}, Test PRC: {:.4f}\n'.format( 120 | epoch, avg_train_loss, np.mean(train_prc_results), np.mean(val_prc_results), 121 | np.mean(test_prc_results))) 122 | fp.close() 123 | 124 | 125 | if epoch - epoch_bvl >= 50: 126 | break 127 | 128 | fp = open(save_file, 'a') 129 | fp.write( 130 | 'Best val metric is: {:.4f}, Best val metric achieves at epoch: {:03d}, Corresponding test metric: {:04f}\n'.format(best_val_metric, epoch_bvl, crs_test_metric)) 131 | fp.write('Best test metric is: {:.4f}, Best test metric achieves at epoch: {:03d}\n'.format(best_test_metric, 132 | epoch_test)) 133 | fp.close() 134 | 135 | print('Best val metric is: {:.4f}, Best val metric achieves at epoch: {:03d}, Corresponding test metric: {:04f}\n'.format(best_val_metric, epoch_bvl, crs_test_metric)) 136 | print('Best test metric is: {:.4f}, Best test metric achieves at epoch: {:03d}\n'.format(best_test_metric, epoch_test)) 137 | return crs_test_metric, best_test_metric 138 | 139 | def train_classification(model, optimizer, train_loader, num_tasks, device, epoch, lr, posNum, criterion=None, loss_type=None, 140 | loss_param={}): 141 | model.train() 142 | 143 | global a, b, m, alpha 144 | if loss_type == 'auroc' and epoch % 10 == 1: 145 | # Periordically update w_{ref}, a_{ref}, b_{ref} 146 | global state, a_0, b_0 147 | a_0, b_0 = a, b 148 | state = [] 149 | for name, param in model.named_parameters(): 150 | state.append(param.data) 151 | 152 | losses = [] 153 | for i, batch_data in enumerate(train_loader): 154 | optimizer.zero_grad() 155 | batch_data = batch_data.to(device) 156 | out = model(batch_data) 157 | 158 | if loss_type == 'ce': 159 | if len(batch_data.y.shape) != 2: 160 | batch_data.y = torch.reshape(batch_data.y, (-1, num_tasks)) 161 | mask = torch.Tensor([[not np.isnan(x) for x in tb] for tb in 162 | batch_data.y.cpu()]) # Skip those without targets (in PCBA, MUV, Tox21, ToxCast) 163 | mask = mask.to(device) 164 | target = torch.Tensor([[0 if np.isnan(x) else x for x in tb] for tb in batch_data.y.cpu()]) 165 | target = target.to(device) 166 | loss = criterion(out, target) * mask 167 | loss = loss.sum() 168 | loss.backward() 169 | optimizer.step() 170 | 171 | elif loss_type == 'SOAP': 172 | target = batch_data.y 173 | predScore = torch.nn.Sigmoid()(out) 174 | loss = criterion(predScore[:posNum], predScore[posNum:], batch_data.idx.view(-1, 1).long()) 175 | loss.backward() 176 | optimizer.step() 177 | elif loss_type == 'sum': 178 | target = batch_data.y 179 | if len(target.shape) != 2: 180 | target = torch.reshape(target, (-1, num_tasks)) 181 | loss1 = criterion(out, target) 182 | loss1 = loss1.sum() 183 | predScore = torch.nn.Sigmoid()(out) 184 | g = pairLossAlg2(10, predScore[0], predScore[1:]) 185 | p = calculateP(g, u, batch_data.idx[0], 1) 186 | loss2 = surrLoss(g, p) 187 | loss = loss1 + loss2 188 | loss.backward() 189 | optimizer.step() 190 | elif loss_type in ['wce', 'focal', 'ldam']: 191 | target = batch_data.y 192 | loss = criterion(out, target, epoch) 193 | loss.backward() 194 | optimizer.step() 195 | elif loss_type in ['auroc']: 196 | target = batch_data.y 197 | predScore = torch.nn.Sigmoid()(out) 198 | loss = AUROC_loss(predScore, target, a, b, m, alpha, loss_param['pos_ratio']) 199 | curRegularizer = calculateRegularizerWeights(lr, model, state, loss_param['gamma']) 200 | loss.backward() 201 | optimizer.step() 202 | regularizeUpdate(model, curRegularizer) 203 | a, b, alpha = PESG_update_a_b_alpha_2(lr, a, a_0, b, b_0, alpha, m, predScore, target, 204 | loss_param['pos_ratio'], loss_param['gamma']) 205 | elif loss_type in ['minmax']: 206 | target = batch_data.y 207 | loss = criterion(out, target) 208 | loss.backward() 209 | optimizer.step() 210 | elif loss_type in ['smoothAP']: 211 | target = batch_data.y 212 | predScore = torch.sigmoid(out) 213 | loss = criterion(predScore, target, tau=loss_param['tau']) 214 | loss.backward() 215 | optimizer.step() 216 | elif loss_type in ['fastAP']: 217 | target = batch_data.y 218 | predScore = torch.sigmoid(out) 219 | loss = criterion(predScore, target, bins=loss_param['bins']) 220 | loss.backward() 221 | optimizer.step() 222 | 223 | # print('Iter {} | Loss {}'.format(i, loss.cpu().item())) 224 | losses.append(loss) 225 | return sum(losses).item() / len(losses) 226 | 227 | 228 | def test_classification(model, test_loader, num_tasks, device): 229 | model.eval() 230 | 231 | preds = torch.Tensor([]).to(device) 232 | targets = torch.Tensor([]).to(device) 233 | for batch_data in test_loader: 234 | batch_data = batch_data.to(device) 235 | with torch.no_grad(): 236 | out = model(batch_data) 237 | if len(batch_data.y.shape) != 2: 238 | batch_data.y = torch.reshape(batch_data.y, (-1, num_tasks)) 239 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 240 | preds = torch.cat([preds, pred], dim=0) 241 | targets = torch.cat([targets, batch_data.y], dim=0) 242 | prc_results, roc_results = compute_cla_metric(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), 243 | num_tasks) 244 | 245 | return prc_results, roc_results 246 | -------------------------------------------------------------------------------- /Graph/tran_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import codecs 4 | import numpy as np 5 | import networkx as nx 6 | import pickle 7 | import torch 8 | 9 | from rdkit import Chem 10 | from typing import List, Tuple, Union 11 | from torch_geometric.utils import from_networkx, tree_decomposition 12 | 13 | ### Adaptively adjust from https://github.com/chemprop/chemprop 14 | 15 | BOND_FDIM = 14 16 | 17 | MAX_ATOMIC_NUM = 100 18 | ATOM_FEATURES = { 19 | 'atomic_num': list(range(MAX_ATOMIC_NUM)), 20 | 'degree': [0, 1, 2, 3, 4, 5], 21 | 'formal_charge': [-1, -2, 1, 2, 0], 22 | 'chiral_tag': [0, 1, 2, 3], 23 | 'num_Hs': [0, 1, 2, 3, 4], 24 | 'hybridization': [ 25 | Chem.rdchem.HybridizationType.SP, 26 | Chem.rdchem.HybridizationType.SP2, 27 | Chem.rdchem.HybridizationType.SP3, 28 | Chem.rdchem.HybridizationType.SP3D, 29 | Chem.rdchem.HybridizationType.SP3D2 30 | ], 31 | } 32 | 33 | # Distance feature sizes 34 | PATH_DISTANCE_BINS = list(range(10)) 35 | THREE_D_DISTANCE_MAX = 20 36 | THREE_D_DISTANCE_STEP = 1 37 | THREE_D_DISTANCE_BINS = list(range(0, THREE_D_DISTANCE_MAX + 1, THREE_D_DISTANCE_STEP)) 38 | 39 | 40 | def onek_encoding_unk(value: int, choices: List[int]) -> List[int]: 41 | """ 42 | Creates a one-hot encoding. 43 | :param value: The value for which the encoding should be one. 44 | :param choices: A list of possible values. 45 | :return: A one-hot encoding of the value in a list of length len(choices) + 1. 46 | If value is not in the list of choices, then the final element in the encoding is 1. 47 | """ 48 | encoding = [0] * (len(choices) + 1) 49 | index = choices.index(value) if value in choices else -1 50 | encoding[index] = 1 51 | 52 | return encoding 53 | 54 | 55 | def atom_features(atom: Chem.rdchem.Atom, functional_groups: List[int] = None) -> List[Union[bool, int, float]]: 56 | """ 57 | Builds a feature vector for an atom. 58 | :param atom: An RDKit atom. 59 | :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to. 60 | :return: A list containing the atom features. 61 | """ 62 | features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \ 63 | onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \ 64 | onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \ 65 | onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \ 66 | onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \ 67 | onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \ 68 | [1 if atom.GetIsAromatic() else 0] + \ 69 | [atom.GetMass() * 0.01] # scaled to about the same range as other features 70 | 71 | if functional_groups is not None: 72 | features += functional_groups 73 | return features 74 | 75 | 76 | def bond_features(bond: Chem.rdchem.Bond) -> List[Union[bool, int, float]]: 77 | """ 78 | Builds a feature vector for a bond. 79 | :param bond: A RDKit bond. 80 | :return: A list containing the bond features. 81 | """ 82 | if bond is None: 83 | fbond = [1] + [0] * (BOND_FDIM - 1) 84 | else: 85 | bt = bond.GetBondType() 86 | fbond = [ 87 | 0, # bond is not None 88 | bt == Chem.rdchem.BondType.SINGLE, 89 | bt == Chem.rdchem.BondType.DOUBLE, 90 | bt == Chem.rdchem.BondType.TRIPLE, 91 | bt == Chem.rdchem.BondType.AROMATIC, 92 | (bond.GetIsConjugated() if bt is not None else 0), 93 | (bond.IsInRing() if bt is not None else 0) 94 | ] 95 | fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6))) 96 | 97 | 98 | return fbond 99 | 100 | from descriptastorus.descriptors import rdNormalizedDescriptors 101 | 102 | 103 | def rdkit_2d_normalized_features_generator(mol): 104 | smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol 105 | generator = rdNormalizedDescriptors.RDKit2DNormalized() 106 | features = generator.process(smiles)[1:] 107 | 108 | 109 | return features 110 | 111 | 112 | def smile_to_graph(smile): 113 | mol = Chem.MolFromSmiles(smile) 114 | atoms = [atom_features(atom) for atom in mol.GetAtoms()] 115 | graph_feature = rdkit_2d_normalized_features_generator(smile) 116 | G = nx.Graph(graph_attr=graph_feature) 117 | 118 | G.graph['tree_edge_index']=1 119 | 120 | for i in range(len(atoms)): 121 | G.add_node(i) 122 | G.nodes[i]['x'] = atoms[i] 123 | 124 | for i in range(len(atoms)): 125 | for j in range(i, len(atoms)): 126 | bond = mol.GetBondBetweenAtoms(i, j) 127 | if bond: 128 | G.add_edge(i, j) 129 | G.edges[i, j]['edge_attr'] = bond_features(bond) 130 | 131 | out = tree_decomposition(mol, return_vocab=True) 132 | tree_edge_index, atom2clique_index, num_cliques, x_clique = out 133 | G.graph['tree_edge_index'] = tree_edge_index 134 | G.graph['atom2clique_index'] = atom2clique_index 135 | G.graph['num_cliques'] = num_cliques 136 | G.graph['x_clique'] = x_clique 137 | 138 | return G 139 | 140 | 141 | def mol_from_data(data): 142 | mol = Chem.RWMol() 143 | 144 | x = data.x if data.x.dim() == 1 else data.x[:, 0] 145 | for z in x.tolist(): 146 | mol.AddAtom(Chem.Atom(z)) 147 | 148 | row, col = data.edge_index 149 | mask = row < col 150 | row, col = row[mask].tolist(), col[mask].tolist() 151 | 152 | bond_type = data.edge_attr 153 | bond_type = bond_type if bond_type.dim() == 1 else bond_type[:, 0] 154 | bond_type = bond_type[mask].tolist() 155 | 156 | for i, j, bond in zip(row, col, bond_type): 157 | assert bond >= 1 and bond <= 4 158 | mol.AddBond(i, j, bonds[bond - 1]) 159 | 160 | return mol.GetMol() 161 | 162 | 163 | def data_reader(file_name): 164 | inputs = [] 165 | labels = [] 166 | with codecs.open(file_name, "r", encoding="utf-8-sig") as f: 167 | reader = csv.DictReader(f) 168 | for row in reader: 169 | smiles = row['smiles'] 170 | inputs.append(smiles) 171 | labels.append( 172 | [float(row[y]) if row[y] != '' else np.nan for y in row.keys() if y != 'smiles' and y != 'mol_id']) 173 | return inputs, np.array(labels) 174 | 175 | 176 | 177 | 178 | def save_to_data(networks, labels): 179 | dataset = [] 180 | for idx, nx in enumerate(networks): 181 | data = from_networkx(nx) 182 | 183 | ### data format 184 | # x: torch.float32 185 | # edge_index: torch.int64 186 | # edge_attr: torch.float32 187 | # y: torch.float32 188 | # graph_attr: torch.float32 189 | # tree_edge_index: torch.int64 190 | # atom2clique_index: torch.int64 191 | # num_cliques: torch.int64 192 | # x_clique: torch.int64 193 | 194 | data['y'] = torch.tensor(labels[idx], dtype=torch.float32) 195 | ### If there is no edge, we should give an empty tensor (size: [0,BOND_FDIM]) to edge_attr 196 | if data['edge_index'].shape[-1] == 0: 197 | print("There is no edge, we should give an empty tensor to edge_attr. Molecule is in row:", (idx+2)) 198 | data['edge_attr'] = torch.empty([0, BOND_FDIM]) 199 | data['x'] = data['x'].to(dtype=torch.float32) 200 | data['edge_index'] = data['edge_index'].to(dtype=torch.int64) 201 | data['edge_attr'] = data['edge_attr'].to(dtype=torch.float32) 202 | data['idx'] = torch.tensor(idx, dtype=torch.int32) 203 | 204 | graph_attr = torch.tensor(nx.graph['graph_attr'], dtype=torch.float32) 205 | is_nan = ~(graph_attr == graph_attr) 206 | graph_attr[is_nan] = 0 ### replace nan with 0 207 | data.graph_attr = torch.reshape(graph_attr, (1,200)) 208 | 209 | data.tree_edge_index = nx.graph['tree_edge_index'].to(dtype=torch.int64) 210 | data.atom2clique_index = nx.graph['atom2clique_index'].to(dtype=torch.int64) 211 | data.num_cliques = torch.tensor([nx.graph['num_cliques']], dtype=torch.int64) 212 | data.x_clique = nx.graph['x_clique'].to(dtype=torch.int64) 213 | 214 | # print(data) 215 | 216 | dataset.append(data) 217 | 218 | return dataset 219 | 220 | 221 | smiles, labels = data_reader('datasets/hiv.csv') 222 | 223 | networks = [smile_to_graph(smile) for smile in smiles] 224 | 225 | graph_set = save_to_data(networks, labels) 226 | 227 | torch.save(graph_set, 'datasets/gnn_feats/hiv.pt') -------------------------------------------------------------------------------- /Image/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### Configuration 4 | Dependencies: \ 5 | python>=3.6.8 \ 6 | torch>=1.7.0 7 | 8 | 9 | ### Data 10 | 11 | ##### [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html), [CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html) 12 | For each dataset we manually take the first half of classes as positive class and last half of classes as negative class. \ 13 | **imb_ratio** = the number samples in positive classes / the number of samples in negative classes \ 14 | The imb_ratio is 0.02 for both datasets. To construct highly imbalanced data, we remove 98% of the positive images from the training data and keep the test data unchanged.\ 15 | The indexes can be found in **imbalaced_cifar.py** 16 | ##### [Melanoma](https://www.kaggle.com/c/siim-isic-melanoma-classification/data) 17 | The Melanoma dataset is from a medical image Kaggle competition, which serves as a natural real imbalanced image dataset. It contains 33,126 labeled medical images, \ 18 | among which 584 images are related to malignant melanoma and labelled as positive samples. \ 19 | We manually split the training data into train/validation/test set at 80%/10%/10% ratio and report the achieved AUPRC on test results. 20 | The split csv are provided in: **./data/melanoma/{train, test, valid}_split.csv** 21 | 22 | 23 | 24 | 25 | ### Model and Optimizer 26 | Arch: ResNet18, ResNet34 for all three datasets. \ 27 | Optimizer: SGD with Momentum 0.9 28 | 29 | 30 | 31 | ### Algorithm and Pretrained Models 32 | Proposed loss and the SOAP algorithm is implemented in SOAP.py \ 33 | SOAP algorithm with **squared hinge (sqh)** surrogate loss are trained from ce_pretrained models. \ 34 | The **pretrained models** for **CIFAR10, CIFAR100, Melanoma** data are provided in https://drive.google.com/drive/folders/13Bxt0eLeOKNEPbwbq1oEeOLNo9AhnQvr?usp=sharing \ 35 | Unzip the downloaded models to: \ 36 | **cepretainmodels/**: 37 | - cifar10_resnet18_002.ckpt 38 | - cifar10_resnet34_002.ckpt 39 | - cifar100_resnet18_002.ckpt 40 | - cifar100_resnet34_002.ckpt 41 | - melanoma_ce_pretrain_resnet18.pth 42 | - melanoma_ce_pretrain_resnet34.pth 43 | 44 | 45 | **config_cifar.py**: The hyperparameter configurations for CIFAR10, CIFAR100 \ 46 | **config_melanoma.py**: The hyperparameter configuration for Melanoma \ 47 | 48 | 49 | 50 | ### The hyperparameters for SOAPLOSS: 51 | --**threshold**: the margin in squared hinge loss | **conf['loss_param']['threshold'] = 0.6** \ 52 | --**batch_size**: batch size | **conf['batch_size'] = 64** \ 53 | --**data_length**: length of the dataset \ 54 | --**loss_type**: squared hinge surrogate loss for SOAP | **conf['surr_loss'] = 'sqh'** \ 55 | --**gamma**: gamma parameter in the paper | mv_gamma = {0.9, 0.99} 56 | 57 | **conf['ft_mode']** = 'fc_random': Reinitializing the Fully-Connected layer for the pretrained model when starting training SOAP. 58 | **conf['pre_train']** = { None : training from scratch, 59 | 'path_of_pretrained_model': training from a pretrained model } 60 | **conf['posNum']** : Number of positive samples per batch, \{1,2,3,4,5\} 61 | 62 | 63 | 64 | 65 | ### Results 66 | To replicate the SOAP results for CIFAR10, CIFAR100, Melanoma 67 | ```python 68 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_cifar10_resnet18.py # ResNet18, CIFAR10 69 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_cifar10_resnet34.py # ResNet34, CIFAR10 70 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_cifar100_resnet18.py # ResNet18, CIFAR100 71 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_cifar100_resnet34.py # ResNet34, CIFAR100 72 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_melanoma_resnet18.py # ResNet18, Melanoma 73 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore main_melanoma_resnet34.py # ResNet34, Melanoma 74 | ``` 75 | 76 | The wrapped package can be found in https://github.com/Optimization-AI/LibAUC/ 77 | with the following installation and cases command: 78 | ```python 79 | pip install libauc 80 | >>> #import library 81 | >>> from libauc.losses import APLoss_SH 82 | >>> from libauc.optimizers import SOAP_SGD, SOAP_ADAM 83 | ... 84 | >>> #define loss 85 | >>> Loss = APLoss_SH() 86 | >>> optimizer = SOAP_ADAM() 87 | ... 88 | >>> #training 89 | >>> model.train() 90 | >>> for index, data, targets in trainloader: 91 | data, targets = data.cuda(), targets.cuda() 92 | logits = model(data) 93 | preds = torch.sigmoid(logits) 94 | loss = Loss(preds, targets, index) 95 | optimizer.zero_grad() 96 | loss.backward() 97 | optimizer.step() 98 | ``` 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /Image/SOAP.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Qi' 2 | # Created by on 11/3/21. 3 | import torch 4 | import numpy as np 5 | from torch.utils.data.sampler import Sampler 6 | import random 7 | import torch.nn as nn 8 | from loss import logistic_loss, sigmoid_loss 9 | 10 | class AUPRCSampler(Sampler): 11 | 12 | def __init__(self, labels, batchSize, posNum=1): 13 | # positive class: minority class 14 | # negative class: majority class 15 | 16 | self.labels = labels 17 | self.posNum = posNum 18 | self.batchSize = batchSize 19 | 20 | self.clsLabelList = np.unique(labels) 21 | self.dataDict = {} 22 | 23 | for label in self.clsLabelList: 24 | self.dataDict[str(label)] = [] 25 | 26 | for i in range(len(self.labels)): 27 | self.dataDict[str(self.labels[i])].append(i) 28 | 29 | self.ret = [] 30 | 31 | 32 | def __iter__(self): 33 | minority_data_list = self.dataDict[str(1)] 34 | majority_data_list = self.dataDict[str(0)] 35 | 36 | # print(len(minority_data_list), len(majority_data_list)) 37 | np.random.shuffle(minority_data_list) 38 | np.random.shuffle(majority_data_list) 39 | 40 | # In every iteration : sample 1(posNum) positive sample(s), and sample batchSize - 1(posNum) negative samples 41 | if len(minority_data_list) // self.posNum > len(majority_data_list)//(self.batchSize - self.posNum): # At this case, we go over the all positive samples in every epoch. 42 | # extend the length of majority_data_list from len(majority_data_list) to len(minority_data_list)* (batchSize-posNum) 43 | majority_data_list.extend(np.random.choice(majority_data_list, len(minority_data_list) // self.posNum * (self.batchSize - self.posNum) - len(majority_data_list), replace=True).tolist()) 44 | 45 | elif len(minority_data_list) // self.posNum < len(majority_data_list)//(self.batchSize - self.posNum): # At this case, we go over the all negative samples in every epoch. 46 | # extend the length of minority_data_list from len(minority_data_list) to len(majority_data_list)//(batchSize-posNum) + 1 47 | 48 | minority_data_list.extend(np.random.choice(minority_data_list, len(majority_data_list) // (self.batchSize - self.posNum)*self.posNum - len(minority_data_list), replace=True).tolist()) 49 | 50 | self.ret = [] 51 | for i in range(len(minority_data_list) // self.posNum): 52 | self.ret.extend(minority_data_list[i*self.posNum:(i+1)*self.posNum]) 53 | startIndex = i*(self.batchSize - self.posNum) 54 | endIndex = (i+1)*(self.batchSize - self.posNum) 55 | self.ret.extend(majority_data_list[startIndex:endIndex]) 56 | 57 | return iter(self.ret) 58 | 59 | 60 | def __len__ (self): 61 | return len(self.ret) 62 | 63 | class SOAPLOSS(nn.Module): 64 | def __init__(self, threshold, data_length, loss_type = 'sqh', gamma = 0.9): 65 | ''' 66 | threshold: margin for squred hinge loss, e.g. 0.6 67 | data_length: number of samples in the dataset for moving avearage variable updates 68 | loss_type(str): type of surrogate losses, including, square hinge loss (sqh), logistic loss (lgs), and sigmoid loss (sgm) 69 | gamma (Tensor float): algorithm momentum parameter, by default 0.9 70 | ''' 71 | super(SOAPLOSS, self).__init__() 72 | self.u_all = torch.tensor([0.0]*data_length).view(-1, 1).cuda() 73 | self.u_pos = torch.tensor([0.0]*data_length).view(-1, 1).cuda() 74 | self.threshold = threshold 75 | self.loss_type = loss_type 76 | self.gamma = gamma 77 | print('The loss type is :', self.loss_type) 78 | 79 | 80 | def forward(self,f_ps, f_ns, index_s): 81 | ''' 82 | Params: 83 | f_ps (Tensor array): positive prediction scores 84 | f_ns (Tensor array): negative prediction scores 85 | index_s (Tensor array): positive sample indexes 86 | Return: 87 | Mean Average Precision (AP) loss. 88 | ''' 89 | f_ps = f_ps.view(-1) 90 | f_ns = f_ns.view(-1) 91 | 92 | vec_dat = torch.cat((f_ps, f_ns), 0) 93 | mat_data = vec_dat.repeat(len(f_ps), 1) 94 | 95 | # print(mat_data.shape) 96 | 97 | f_ps = f_ps.view(-1, 1) 98 | 99 | neg_mask = torch.ones_like(mat_data) 100 | neg_mask[:, 0:f_ps.size(0)] = 0 101 | 102 | pos_mask = torch.zeros_like(mat_data) 103 | pos_mask[:, 0:f_ps.size(0)] = 1 104 | 105 | # test_tmp = f_ps- mat_data 106 | # print(f_ps.size(), mat_data.size(), test_tmp.size()) 107 | 108 | # 3*1 - 3*64 ==> 3*64 109 | if self.loss_type == 'sqh': 110 | 111 | neg_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * neg_mask 112 | pos_loss = torch.max(self.threshold - (f_ps - mat_data), torch.zeros_like(mat_data)) ** 2 * pos_mask 113 | 114 | elif self.loss_type == 'lgs': 115 | 116 | neg_loss = logistic_loss(f_ps, mat_data, self.threshold) * neg_mask 117 | pos_loss = logistic_loss(f_ps, mat_data, self.threshold) * pos_mask 118 | 119 | elif self.loss_type == 'sgm': 120 | neg_loss = sigmoid_loss(f_ps, mat_data, self.threshold) * neg_mask 121 | pos_loss = sigmoid_loss(f_ps, mat_data, self.threshold) * pos_mask 122 | 123 | 124 | loss = pos_loss + neg_loss 125 | 126 | 127 | if f_ps.size(0) == 1: 128 | 129 | self.u_pos[index_s] = (1 - self.gamma) * self.u_pos[index_s] + self.gamma * (pos_loss.mean()) 130 | self.u_all[index_s] = (1 - self.gamma) * self.u_all[index_s] + self.gamma * (loss.mean()) 131 | else: 132 | # print(self.u_all[index_s], loss.size(), loss.sum(1, keepdim = 1)) 133 | self.u_all[index_s] = (1 - self.gamma) * self.u_all[index_s] + self.gamma * (loss.mean(1, keepdim=True)) 134 | self.u_pos[index_s] = (1 - self.gamma) * self.u_pos[index_s] + self.gamma * (pos_loss.mean(1, keepdim=True)) 135 | 136 | 137 | 138 | p = (self.u_pos[index_s] - (self.u_all[index_s]) * pos_mask) / (self.u_all[index_s] ** 2) 139 | 140 | 141 | p.detach_() 142 | loss = torch.mean(p * loss) 143 | 144 | return loss 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /Image/cepretrainmodels/tmp.md.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Optimization-AI/NeurIPS2021_SOAP/c120d8793cbffe702f960527bd83d41f00065b5f/Image/cepretrainmodels/tmp.md.docx -------------------------------------------------------------------------------- /Image/config_cifar.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Qi' 2 | # Created by on 9/7/21. 3 | """ 4 | Configuration file 5 | """ 6 | 7 | 8 | 9 | conf = {} 10 | conf['input_size'] = 384 11 | conf['num_tasks'] = 1 12 | 13 | 14 | ###################################################################################################################### 15 | # Settings for training 16 | ## 'epochs': maximum training epochs 17 | ## 'early_stopping': patience used to stop training 18 | ## 'lr': starting learning rate 19 | ## 'lr_decay_factor': learning rate decay factor 20 | ## 'lr_decay_step_size': step size of learning rate decay 21 | ## 'dropout': dropout rate 22 | ## 'weight_decay': l2 regularizer term 23 | ## 'depth': number of layers 24 | ## 'batch_size': training batch_size 25 | ###################################################################################################################### 26 | 27 | 28 | 29 | # cifar10 30 | conf['epochs'] = 64 31 | conf['early_stopping'] = 50 32 | conf['lr'] = 0.001 33 | conf['lr_decay_factor'] = 0.1 34 | conf['lr_decay_step_size'] = 32 35 | conf['dropout'] = 0 36 | conf['weight_decay'] = 2e-4 37 | conf['batch_size'] = 64 38 | conf['loss_type'] = 'auprc2' 39 | conf['loss_param'] = None 40 | conf['ft_mode'] = None 41 | conf['pre_train'] = None 42 | 43 | 44 | 45 | ###################################################################################################################### 46 | # Settings for val/test 47 | ## 'vt_batch_size': val/test batch_size 48 | ###################################################################################################################### 49 | conf['vt_batch_size'] = 64 -------------------------------------------------------------------------------- /Image/config_melanoma.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration file 3 | """ 4 | 5 | 6 | 7 | conf = {} 8 | conf['input_size'] = 384 9 | conf['num_tasks'] = 1 10 | 11 | 12 | ###################################################################################################################### 13 | # Settings for training 14 | ## 'epochs': maximum training epochs 15 | ## 'early_stopping': patience used to stop training 16 | ## 'lr': starting learning rate 17 | ## 'lr_decay_factor': learning rate decay factor 18 | ## 'lr_decay_step_size': step size of learning rate decay 19 | ## 'dropout': dropout rate 20 | ## 'weight_decay': l2 regularizer term 21 | ## 'depth': number of layers 22 | ## 'batch_size': training batch_size 23 | ###################################################################################################################### 24 | 25 | # melanoma 26 | conf['epochs'] = 100 27 | conf['early_stopping'] = 50 28 | conf['lr'] = 0.0001 29 | conf['lr_decay_factor'] = 0.5 30 | conf['lr_decay_step_size'] = 100 31 | conf['dropout'] = 0 32 | conf['weight_decay'] = 1e-5 33 | conf['batch_size'] = 64 34 | conf['loss_type'] = 'ce' 35 | conf['loss_param'] = None 36 | conf['ft_mode'] = None 37 | conf['pre_train'] = None 38 | 39 | 40 | ###################################################################################################################### 41 | # Settings for val/test 42 | ## 'vt_batch_size': val/test batch_size 43 | ###################################################################################################################### 44 | conf['vt_batch_size'] = 64 45 | -------------------------------------------------------------------------------- /Image/data_split.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Qi' 2 | # Created by on 8/1/21. 3 | 4 | 5 | import pandas as pd 6 | import os 7 | import shutil 8 | 9 | train_dat = pd.read_csv("./data/melanoma_split_inds/train_split.csv") 10 | test_dat = pd.read_csv("./data/melanoma_split_inds/test_split.csv") 11 | valid_dat = pd.read_csv("./data/melanoma_split_inds/valid_split.csv") 12 | 13 | print(test_dat.shape[0] + valid_dat.shape[0]) 14 | 15 | 16 | j = 0 17 | for img in os.listdir('/data/qiuzh/melanoma/jpeg/mytrain/0'): 18 | for i in range(test_dat.shape[0]): 19 | if test_dat['image_name'][i] in img: 20 | j+=1 21 | shutil.move(os.path.join("/data/qiuzh/melanoma/jpeg/mytrain/0", img), os.path.join("/data/qiuzh/melanoma/jpeg/mytest/0", img)) 22 | for i in range(valid_dat.shape[0]): 23 | if valid_dat['image_name'][i] in img: 24 | j+=1 25 | shutil.move(os.path.join("/data/qiuzh/melanoma/jpeg/mytrain/0", img), os.path.join("/data/qiuzh/melanoma/jpeg/myval/0", img)) 26 | print("Moved Nagative Samples:", j) 27 | 28 | j=0 29 | for img in os.listdir('/data/qiuzh/melanoma/jpeg/mytrain/1'): 30 | for i in range(test_dat.shape[0]): 31 | if test_dat['image_name'][i] in img: 32 | j += 1 33 | shutil.move(os.path.join("/data/qiuzh/melanoma/jpeg/mytrain/1", img), os.path.join("/data/qiuzh/melanoma/jpeg/mytest/1", img)) 34 | for i in range(valid_dat.shape[0]): 35 | if valid_dat['image_name'][i] in img: 36 | j += 1 37 | shutil.move(os.path.join("/data/qiuzh/melanoma/jpeg/mytrain/1", img), os.path.join("/data/qiuzh/melanoma/jpeg/myval/1", img)) 38 | print("Moved Positive Samples:", j) 39 | 40 | -------------------------------------------------------------------------------- /Image/imbalanced_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from collections import Counter 8 | 9 | # __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 10 | # 'std': [0.229, 0.224, 0.225]} 11 | # 12 | # __tiny_imagenet_stats = {'mean': [0.4802, 0.4481, 0.3975], 13 | # 'std': [0.2302, 0.2265, 0.2262]} 14 | # 15 | # __imagenet_pca = { 16 | # 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 17 | # 'eigvec': torch.Tensor([ 18 | # [-0.5675, 0.7192, 0.4009], 19 | # [-0.5808, -0.0045, -0.8140], 20 | # [-0.5836, -0.6948, 0.4203], 21 | # ]) 22 | # } 23 | # 24 | # __cifar10_stats = {'mean': [0.4914, 0.4822, 0.4465], 25 | # 'std': [0.2023, 0.1994, 0.2010]} 26 | # 27 | # __cifar100_stats = {'mean': [0.5071, 0.4867, 0.4408], 28 | # 'std': [0.2675, 0.2565, 0.2761]} 29 | 30 | 31 | 32 | transform_train = transforms.Compose([ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 37 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 38 | ]) 39 | 40 | transform_val = transforms.Compose([ 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 43 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | 46 | 47 | 48 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 49 | cls_num = 10 50 | 51 | def __init__(self, root, imb_type='step', imb_factor=0.02, rand_number=0, train=True, 52 | transform=None, target_transform=None, 53 | download=False, val = False): 54 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download) 55 | np.random.seed(rand_number) 56 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 57 | self.gen_imbalanced_data(img_num_list, val, imb_factor) 58 | 59 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 60 | img_max = len(self.data) / cls_num 61 | 62 | img_num_per_cls = [] 63 | if imb_type == 'exp': 64 | for cls_idx in range(cls_num): 65 | num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0))) 66 | img_num_per_cls.append(int(num)) 67 | elif imb_type == 'step': 68 | for cls_idx in range(cls_num // 2): 69 | img_num_per_cls.append(int(img_max)) 70 | for cls_idx in range(cls_num // 2): 71 | img_num_per_cls.append(int(img_max * imb_factor)) 72 | else: 73 | img_num_per_cls.extend([int(img_max)] * cls_num) 74 | 75 | return img_num_per_cls 76 | 77 | def gen_imbalanced_data(self, img_num_per_cls, val, imb_factor): 78 | new_data = [] 79 | new_targets = [] 80 | targets_np = np.array(self.targets, dtype=np.int64) 81 | classes = np.unique(targets_np) 82 | # np.random.shuffle(classes) 83 | self.num_per_cls_dict = dict() 84 | for the_class, the_img_num in zip(classes, img_num_per_cls): 85 | self.num_per_cls_dict[the_class] = the_img_num 86 | idx = np.where(targets_np == the_class)[0] 87 | np.random.shuffle(idx) 88 | selec_idx = idx[:the_img_num] 89 | new_data.append(self.data[selec_idx, ...]) 90 | new_targets.extend([the_class, ] * the_img_num) 91 | new_data = np.vstack(new_data) 92 | 93 | new_targets[0:25000] =[0]*25000 94 | new_targets[25000:] = [1]*int(25000*imb_factor) 95 | print(type(new_data), type(new_targets)) 96 | train_index, val_index = self.train_val_split(imb_factor) 97 | # print('length of train data : {}, length of val data : {} and length of new_data : {}'.format(len(train_index), len(val_index), len(new_data))) 98 | 99 | if not val: 100 | self.data = new_data[train_index] 101 | self.targets = np.array(new_targets)[train_index].tolist() 102 | print('train self.data:', len(self.data)) 103 | else: 104 | self.data = new_data[val_index] 105 | self.targets = np.array(new_targets)[val_index].tolist() 106 | print('val self.data:', len(self.data)) 107 | 108 | def get_cls_num_list(self): 109 | cls_num_list = [] 110 | for i in range(self.cls_num): 111 | cls_num_list.append(self.num_per_cls_dict[i]) 112 | return cls_num_list 113 | 114 | def __getitem__(self, index: int): 115 | """ 116 | Args: 117 | index (int): Index 118 | 119 | Returns: 120 | tuple: (image, target) where target is index of the target class. 121 | """ 122 | img, target = super().__getitem__(index) 123 | return index, img, target 124 | 125 | 126 | # Data Splits 127 | def train_val_split(self, imb_factor): 128 | 129 | x = list(range(25000)) 130 | y = list(range(25000, 25000 + int(25000 * imb_factor))) 131 | neg_val_index = [21213, 14463, 23845, 14944, 2689, 17624, 14107, 1405, 1080, 14571, 4736, 5306, 5868, 19834, 132 | 24703, 2227, 6772, 10587, 17733, 8773, 20457, 18566, 15481, 7499, 21949, 5058, 22630, 6100, 133 | 6736, 21827, 16088, 4477, 17764, 19173, 16863, 23057, 18482, 15327, 16082, 15862, 12906, 1929, 134 | 6506, 10115, 17442, 4997, 7181, 14347, 12015, 17355, 23572, 16726, 24332, 23843, 13433, 294, 135 | 11934, 2301, 13822, 12883, 23335, 21670, 5018, 4847, 2212, 4484, 24912, 21024, 578, 11709, 136 | 24488, 16212, 980, 18485, 21198, 17718, 123, 11490, 5777, 24968, 18821, 4058, 17112, 14801, 137 | 15763, 3525, 8457, 15070, 10125, 3534, 4588, 5767, 17926, 3946, 3442, 14125, 9635, 20801, 138 | 13138, 2328, 15619, 4328, 20052, 873, 13204, 23368, 2857, 22528, 1165, 5524, 23547, 13178, 139 | 4546, 24799, 6675, 23945, 10981, 19914, 2506, 1275, 6211, 11482, 16066, 6018, 15219, 4478, 140 | 24672, 7007, 9295, 14118, 19998, 12858, 21751, 3593, 11944, 15714, 3321, 2859, 921, 18916, 141 | 2225, 4068, 4451, 22481, 1597, 2889, 3191, 16380, 10762, 20742, 12539, 10541, 13025, 21265, 142 | 6147, 10798, 21911, 10602, 8111, 8241, 14482, 3395, 4356, 23442, 11376, 21676, 14211, 16183, 143 | 14630, 796, 8896, 23784, 441, 2467, 18443, 17087, 12717, 5721, 9673, 19487, 21891, 19068, 144 | 16444, 22367, 10561, 24913, 20760, 2491, 7494, 23918, 11638, 8329, 10710, 4063, 5033, 12109, 145 | 22786, 222, 9477, 24546, 6523, 10887, 9850, 3420, 20551, 19793, 9000, 383, 23708, 18533, 2529, 146 | 17125, 14587, 15813, 10715, 18008, 7356, 18605, 19001, 2235, 11715, 1569, 23519, 19482, 8727, 147 | 11171, 16499, 17939, 13740, 15377, 8706, 10879, 16811, 17820, 10085, 11054, 20270, 2627, 13478, 148 | 20979, 9264, 18557, 9525, 23895, 18135, 1188, 18453, 1834, 22416, 21170, 13986, 6123, 19709, 149 | 8364, 16994, 21015, 3746, 5409, 15425, 10950, 22882, 5449, 16879, 17030, 961, 2968, 13891, 150 | 7704, 12047, 451, 5171, 14874, 22064, 5202, 2279, 11616, 22618, 18421, 17085, 21417, 13154, 151 | 540, 13517, 8711, 22137, 22311, 17037, 24379, 24028, 9693, 2693, 13001, 5863, 12388, 23897, 152 | 21014, 7734, 22916, 24240, 3242, 7314, 8052, 663, 12385, 11182, 6778, 8316, 12568, 10540, 489, 153 | 16776, 6677, 13035, 627, 16440, 11333, 7618, 14269, 7172, 4412, 15171, 16803, 14025, 22668, 154 | 1746, 9292, 13171, 20968, 13382, 1499, 5589, 2598, 20005, 8527, 24668, 16387, 6061, 19597, 155 | 14068, 3908, 22010, 13581, 8380, 9031, 9824, 11713, 3295, 18154, 12868, 14145, 15289, 1546, 156 | 5239, 14247, 21003, 5652, 11237, 1956, 10411, 7188, 872, 22539, 19972, 3719, 18805, 1119, 8459, 157 | 24214, 24258, 13296, 17122, 2218, 21838, 21618, 13857, 22803, 13659, 1702, 24882, 17135, 15543, 158 | 12716, 3572, 24748, 18344, 4037, 828, 11737, 17231, 5007, 5632, 23994, 17653, 9652, 23325, 159 | 3780, 20958, 3690, 1560, 14742, 18054, 10127, 14259, 22544, 9362, 2709, 11205, 13158, 13604, 160 | 11884, 21337, 8632, 14945, 9507, 11297, 10815, 15351, 12028, 18445, 14218, 6257, 925, 3166, 161 | 1897, 21376, 212, 1675, 2267, 20037, 16735, 7528, 12584, 13222, 274, 2745, 14553, 22833, 9437, 162 | 21632, 6089, 1318, 5441, 2349, 21508, 1763, 20064, 24178, 7959, 9271, 18543, 20317, 14999, 163 | 4826, 13837, 21534, 13498, 10663, 15111, 19920, 24126, 551, 8349, 14624, 15692, 11288, 19830, 164 | 552, 21136, 22458, 8184, 1685, 14077, 5046, 21954, 8385, 18446, 8775, 15096, 5484, 14865, 7769, 165 | 24752, 20101, 14028, 2403, 3716, 20124, 11825, 14556, 19400, 9565, 2722, 2569, 368, 7487, 9793, 166 | 12478, 6010, 2567, 19, 20309, 12864, 23848, 21870, 11492, 19286, 11531, 5497, 22926, 2861, 167 | 6554, 23571, 17840, 10218, 21877, 3428, 15294, 6833, 21496, 10311, 5698, 17987, 12246, 14005, 168 | 3970, 20151, 5057, 17289, 19611, 16682, 18936, 4434, 11898, 11475, 10372, 19867, 17720, 13225, 169 | 12133, 2752, 5078, 7397, 15404, 3150, 9709, 22181, 295, 9981, 23063, 24815, 5625, 2725, 24702, 170 | 23967, 18846, 7722, 9944, 2622, 12669, 14626, 8799, 19168, 16568, 2323, 13587, 1879, 11488, 171 | 3874, 21861, 20649, 9503, 22346, 10698, 4877, 8424, 7536, 9584, 23699, 18789, 22006, 18820, 172 | 9305, 12035, 16290, 22608, 19599, 2714, 20020, 22688, 22880, 3464, 24727, 2332, 23374, 18273, 173 | 13876, 1500, 14639, 10491, 17437, 22384, 23751, 22958, 13860, 20907, 23054, 15762, 10721, 4466, 174 | 17675, 24826, 9066, 7113, 1677, 14109, 6845, 1037, 15880, 9882, 15680, 10207, 3678, 16256, 175 | 19303, 18165, 16329, 4072, 4036, 15909, 13941, 3726, 9684, 20647, 15659, 781, 24345, 20785, 176 | 18278, 22401, 10799, 14635, 4617, 21652, 7237, 10650, 23405, 21809, 20618, 12436, 10862, 17911, 177 | 23801, 494, 16427, 18400, 23371, 270, 12369, 10469, 18607, 7399, 18632, 19384, 14007, 10101, 178 | 3995, 7918, 8231, 20631, 307, 7874, 20088, 20787, 20452, 17215, 9531, 22784, 4300, 8410, 13370, 179 | 16840, 6455, 22447, 3813, 1900, 3704, 11994, 3061, 9809, 14090, 18502, 20666, 15609, 13698, 180 | 17713, 14207, 16999, 16579, 7416, 9806, 2141, 633, 9528, 18257, 12781, 19997, 20864, 3812, 181 | 11139, 12707, 2823, 13569, 18700, 15941, 14910, 23878, 13800, 17810, 24674, 22809, 301, 17273, 182 | 11831, 17072, 12337, 14732, 17791, 7694, 17839, 19413, 5955, 4213, 12295, 265, 15959, 13325, 183 | 22890, 16582, 1489, 21553, 10513, 4457, 3945, 12757, 16000, 6487, 19145, 2366, 2250, 1608, 661, 184 | 19399, 5690, 1983, 11696, 20868, 2383, 1796, 10539, 24903, 19220, 7290, 2955, 7888, 17022, 185 | 14837, 1334, 10757, 5147, 10275, 3003, 7651, 6175, 3859, 5577, 12996, 21308, 14552, 12181, 186 | 21464, 3576, 16225, 7774, 2308, 19906, 24256, 9659, 280, 4816, 4685, 38, 17530, 18989, 17421, 187 | 5772, 17237, 19278, 10907, 7257, 2794, 14346, 4724, 13464, 9510, 9459, 19751, 4835, 12668, 188 | 17036, 24943, 17063, 12039, 11566, 17367, 19886, 7673, 20447, 13114, 16976, 9077, 1804, 3415, 189 | 20482, 4330, 5062, 18196, 8536, 16327, 13291, 15669, 2476, 2915, 13000, 23143, 9682, 11938, 190 | 180, 7636, 4907, 19999, 18162, 18182, 17485, 12676, 15246, 16882, 10811, 14254, 10874, 17726, 191 | 20041, 1353, 19006, 7793, 3518, 15654, 1414, 10505, 24698, 9081, 9221, 8107, 13883, 11094, 192 | 23785, 15822, 14175, 13010, 19962, 8793, 11461, 23814, 15561, 24120, 9022, 443, 23350, 2664, 193 | 21525, 24244, 4257, 10916, 2186, 16122, 2732, 13018, 23694, 2878, 8704, 16746, 12979, 20798, 194 | 22501, 16142, 18934, 15821, 4692, 19869, 6435, 11308, 14103, 5518, 14208, 20143, 17050, 11300, 195 | 4537, 11290, 24683, 487, 21375, 14810, 19350, 10806, 20736, 24549, 24095, 3040, 11677, 24775, 196 | 1409, 4906, 9648, 16182, 16941, 548, 12680, 21918, 12797, 21012, 10480, 1864, 5689, 2652, 7691, 197 | 14505, 1976, 310, 7013, 6206, 17742, 1760, 13313, 18192, 20826, 6082, 23852, 5336, 23719, 1239, 198 | 9121, 6658, 7019, 9160, 1327, 23633, 14681, 3819, 16094, 12682, 11483, 5943, 23658, 23649, 747, 199 | 8430, 13366, 5304, 2958, 3162, 3741, 6831, 9136, 679, 16944, 16770, 19098, 371, 5481, 19982, 200 | 9676, 14878, 19040, 1085, 14802, 21643, 2851, 23574, 1783, 19526, 8858, 12912, 15145, 6484, 201 | 24174, 4469, 4757, 18643, 10998, 13642, 3333, 24904, 11538, 16683, 17494, 14543, 2069, 3796, 202 | 5264, 21450, 22494, 21888, 4499, 7616, 3580, 14091, 23485, 7958, 12268, 17693, 17246, 5750, 203 | 21762, 20290, 12164, 14828, 12454, 9304, 19800, 14870, 22930, 22861, 9139, 6606, 19525, 4306, 204 | 8501, 136, 14036, 10109, 18464, 7479, 4251, 7809, 838, 22124, 16513, 12826, 7561, 23505, 23806, 205 | 5778, 20822, 14362, 9450, 19199, 2886, 768, 19877, 11028, 20000, 9971, 9831, 7321, 14984, 206 | 20858, 20369, 597, 2655, 11034, 2255, 17168, 7256, 5924, 12499, 10447, 24494, 20080, 13535, 207 | 22513, 23677, 17613, 20514, 17248, 14968, 12763, 5118, 9675, 19043, 20684, 160, 8600, 20072, 208 | 6415, 22713, 17838, 2191, 21270, 16587, 4455, 17573, 12784, 13861, 22869, 23628, 23454, 5925, 209 | 12334, 19853, 1594, 5946, 10708, 20187, 18742, 3143, 22103, 12522, 23598, 9497, 24122, 10033, 210 | 24902, 12683, 11817, 13978, 23116, 11858, 4891, 9003, 7560, 8900, 1634, 3140, 8644, 16917, 211 | 24044, 15061, 23742, 2884, 19769, 20013, 12945, 23062, 1894, 8041, 19836, 22981, 17401, 17795, 212 | 11002, 8664, 24657, 4493, 18023, 23959, 8236, 11802, 444, 24490, 10897, 6260, 11132, 16360, 213 | 24280, 24837, 10985, 7988, 11158, 9597, 9953, 22756, 158, 5931, 22680, 2245, 4467, 2247, 13667, 214 | 2843, 24331, 10972, 583, 10571, 4773, 7749, 13873, 8875, 14486, 11891, 21956, 23697, 6192, 215 | 5349, 9132, 20856, 14469, 14775, 19519, 17641, 23577, 15184, 22708, 13147, 5159, 9017, 7583, 216 | 21740, 3481, 23990, 24368, 4456, 18736, 1008, 12701, 15944, 19747, 23379, 2240, 12141, 1002, 217 | 22258, 8161, 20904, 23769, 21224, 16149, 13966, 4549, 20573, 14749, 12767, 10448, 16868, 2995, 218 | 24687, 997, 18513, 15236, 13490, 5020, 19775, 3639, 3587, 20536, 3274, 16968, 14444, 19658, 219 | 7656, 23120, 9328, 21994, 1978, 15794, 14061, 8619, 24232, 5870, 20955, 10253, 10695, 19221, 220 | 18642, 18025, 24117, 9088, 11265, 4425, 1052, 12887, 23973, 2766, 15656, 16939, 9246, 14663, 221 | 6060, 3279, 14129, 16320, 12403, 6174, 24356, 11152, 17160, 4220, 22402, 23079, 9352, 14414, 222 | 9724, 16912, 508, 8035, 10769, 21560, 6161, 20816, 14540, 13056, 16430, 22097, 14012, 21322, 223 | 12467, 9726, 6485, 15371, 23146, 20335, 11211, 8432, 22033, 22507, 2785, 7292, 22963, 7757, 224 | 20126, 401, 13492, 2885, 3740, 8066, 13152, 7052, 7480, 23624, 2099, 1704, 1392, 12741, 21130, 225 | 18401, 22548, 20871, 14631, 22818, 15915, 21396, 10424, 14141, 7946, 19062, 21665, 22245, 226 | 12405, 8238, 2775, 10982, 24326, 20278, 22428, 16424, 19172, 8383, 3565, 15956, 18127, 36, 227 | 17956, 5109, 13440, 24722, 19825, 23924, 8787, 10552, 1867, 9644, 18481, 3709, 10584, 21118, 228 | 22225, 22333, 11561, 12182, 14345, 17358, 22876, 9338, 11486, 5957, 1045, 2110, 8575, 6185, 229 | 20569, 5357, 7167, 22389, 17188, 3654, 23554, 19921, 13161, 9491, 17781, 12393, 16455, 13180, 230 | 20947, 6819, 23520, 8744, 10040, 1340, 9055, 13794, 1184, 20597, 10623, 5536, 17096, 3604, 231 | 4578, 1010, 18384, 10348, 13965, 10246, 12132, 21806, 20067, 13727, 17435, 387, 15785, 5647, 232 | 4510, 18268, 7289, 1833, 14198, 6416, 17723, 24533, 21075, 8837, 20561, 1794, 14363, 23408, 233 | 23864, 2231, 5080, 1536, 1928, 12050, 14933, 18998, 6757, 22712, 8273, 8334, 15716, 1265, 3933, 234 | 24563, 20695, 19294, 20829, 14617, 20123, 8068, 18094, 16422, 23910, 24465, 13954, 1821, 2597, 235 | 8801, 15382, 12770, 1070, 2224, 21682, 6872, 16135, 12544, 8915, 3410, 19305, 22799, 23319, 236 | 22765, 11827, 15877, 17913, 20096, 16865, 6682, 14295, 10243, 19986, 245, 21970, 12322, 9406, 237 | 11109, 9552, 5312, 19152, 2907, 13933, 19653, 14339, 18743, 9300, 22646, 10899, 7493, 24014, 238 | 18352, 14009, 4088, 3952, 5155, 13389, 2044, 9413, 11509, 17754, 6442, 9807, 7302, 19695, 239 | 24121, 21237, 19087, 16335, 4382, 5528, 6413, 2108, 18723, 18773, 20782, 3966, 24538, 2090, 240 | 17849, 525, 2449, 14806, 8211, 14709, 1053, 19727, 20608, 19958, 24873, 8482, 10607, 22391, 241 | 6864, 7827, 20960, 16862, 21421, 9561, 17748, 13589, 21944, 19300, 12662, 14172, 17232, 13421, 242 | 3679, 3008, 20986, 12491, 12293, 5627, 12276, 14468, 6463, 4704, 19898, 12348, 12727, 23983, 243 | 24529, 3141, 23595, 5782, 5570, 6239, 15752, 20299, 8581, 11190, 12292, 12558, 24961, 22215, 244 | 4627, 22162, 9519, 16716, 13078, 22845, 13674, 13214, 5439, 17695, 17480, 9327, 2684, 14135, 245 | 10314, 17227, 20423, 6182, 19858, 9064, 4696, 1899, 16695, 7224, 18115, 13512, 12774, 19985, 246 | 14270, 16376, 3246, 16567, 6570, 19104, 9396, 20308, 12115, 5766, 24879, 23559, 14840, 20180, 247 | 7274, 10285, 24493, 17996, 17803, 14929, 21634, 591, 3895, 3222, 6093, 15168, 4239, 13824, 248 | 24476, 8919, 1339, 6759, 10431, 14039, 12250, 11960, 7366, 18928, 6466, 12365, 2791, 14827, 249 | 15053, 22198, 14765, 21435, 7587, 6250, 3569, 7574, 16613, 3826, 6602, 3056, 17003, 18022, 250 | 15887, 10948, 4528, 4814, 17250, 5235, 12442, 8794, 3721, 23862, 12252, 16488, 12943, 8638, 251 | 2973, 12020, 23133, 16878, 7668, 23131, 14675, 9270, 665, 8243, 7437, 21171, 5739, 12375, 252 | 23322, 12492, 17592, 4821, 21547, 21360, 763, 8544, 8982, 22039, 8991, 16247, 20324, 743, 253 | 18092, 7627, 5391, 3029, 23517, 23722, 21041, 4684, 21320, 13967, 3129, 6821, 8254, 1622, 254 | 11220, 6287, 5054, 8859, 16510, 22034, 5937, 3633, 18114, 17035, 13946, 13487, 6335, 13618, 255 | 825, 7934, 15279, 16365, 24654, 13051, 23568, 14, 16364, 21894, 20718, 21084, 15521, 17646, 256 | 9656, 20511, 7109, 6655, 5310, 2043, 6749, 22488, 1510, 17729, 758, 12821, 19505, 16569, 9432, 257 | 17049, 4550, 8005, 13499, 2169, 4953, 20437, 21610, 21901, 6614, 17553, 21172, 12617, 4108, 258 | 19407, 21105, 5889, 13579, 5498, 23842, 3067, 7619, 24478, 12034, 10092, 14989, 21430, 1258, 259 | 1176, 869, 17971, 15652, 18873, 10058, 19474, 21529, 14328, 22643, 23386, 12742, 15931, 15440, 260 | 6264, 24819, 3215, 1320, 21968, 24774, 832, 1980, 24264, 3896, 7319, 10387, 22122, 16233, 1359, 261 | 7457, 24175, 22059, 5002, 984, 12342, 21628, 19621, 19203, 5621, 3344, 15336, 11533, 15340, 262 | 5094, 15323, 8209, 11257, 23552, 3991, 19073, 24761, 23863, 2193, 20145, 8767, 6035, 4218, 263 | 15033, 22597, 1690, 18167, 7803, 11654, 977, 15544, 2585, 5216, 8164, 17964, 15946, 22692, 264 | 16338, 2649, 10136, 10520, 24608, 23636, 21292, 20746, 10399, 22523, 13014, 8473, 17499, 24960, 265 | 22927, 2, 183, 6843, 10420, 2089, 20556, 11921, 918, 8734, 8956, 8248, 5874, 9889, 21863, 266 | 24093, 24097, 5656, 908, 6492, 12886, 10412, 18191, 6694, 24519, 399, 17580, 4784, 16572, 267 | 20646, 12505, 16006, 20734, 2209, 21457, 5989, 17478, 16824, 780, 21107, 9774, 14317, 3318, 268 | 3783, 14517, 7431, 4190, 7768, 19406, 24699, 24579, 17543, 1105, 10822, 1761, 5694, 11775, 269 | 6092, 24343, 18310, 13675, 19521, 1850, 3937, 13142, 17220, 21373, 5362, 3685, 16213, 20031, 270 | 6968, 16995, 20208, 9407, 21306, 24836, 1524, 17456, 1247, 17164, 13089, 9490, 11471, 76, 271 | 13869, 5701, 21847, 20330, 21977, 6476, 1502, 14479, 6812, 24285, 21392, 1649, 24154, 21813, 272 | 3676, 24442, 2641, 16795, 17465, 14029, 2373, 9567, 22102, 3011, 30, 16880, 18249, 8102, 23437, 273 | 14603, 6380, 958, 9649, 23601, 5134, 16821, 8854, 21541, 20095, 7615, 9629, 6921, 2626, 18689, 274 | 14293, 19423, 5950, 2151, 2441, 17638, 23463, 3413, 6178, 17191, 10567, 14950, 24024, 10707, 275 | 24440, 16370, 12724, 18866, 14192, 16983, 21200, 17277, 8753, 11328, 10276, 1739, 23002, 7151, 276 | 18422, 861, 12753, 16781, 1376, 12803, 20819, 8352, 20844, 14643, 15789, 3347, 4914, 9978, 277 | 1168, 23969, 15913, 16121, 2855, 13718, 18213, 21314, 5079, 13818, 23850, 13554, 3714, 12719, 278 | 23361, 20138, 17745, 7002, 17780, 10857, 945, 20368, 23710, 13336, 12974, 10459, 7277, 24269, 279 | 1543, 536, 4145, 16715, 1951, 19394, 6932, 24979, 19476, 12202, 2804, 10930, 13310, 16136, 280 | 11495, 5523, 24951, 11814, 11586, 20572, 810, 23066, 12807, 1054, 3698, 6166, 5529, 4521, 9369, 281 | 17738, 24791, 11159, 24357, 10403, 24945, 20872, 14241, 3507, 16336, 11107, 20453, 2199, 6965, 282 | 21733, 20524, 11618, 3622, 5373, 3641, 9892, 21924, 20585, 2502, 16588, 13877, 22144, 9043, 283 | 19261, 3731, 17153, 24544, 23227, 23977, 19157, 9950, 1256, 22223, 4302, 1298, 6000, 20792, 284 | 20710, 24949, 24360, 15783, 16703, 6517, 17193, 6840, 12286, 10321, 515, 9732, 1038, 12721, 285 | 5049, 8151, 304, 1497, 19714, 10915, 13337, 10405, 18593, 9819, 4417, 15176, 2371, 5238, 17214, 286 | 23776, 18194, 5231, 2064, 18823, 17585, 7828, 11907, 12415, 16130, 1315, 18795, 12555, 20683, 287 | 16234, 18185, 20922, 6227, 4656, 9332, 20797, 7444, 6462, 16684, 261, 15694, 24002, 2142, 721, 288 | 3907, 19837, 14163, 13416, 6318, 22211, 11666, 17252, 24682, 20032, 484, 5935, 10047, 13643, 289 | 21379, 17892, 2461, 24704, 16661, 6408, 19524, 11272, 24835, 8015, 22315, 16808, 15031, 740, 290 | 20392, 21439, 17402, 7451, 22788, 14913, 14634, 10466, 5012, 5142, 24321, 13208, 2934, 19548, 291 | 10523, 23229, 10310, 75, 6665, 23869, 16410, 17924, 8556, 23412, 4157, 22655, 11329, 22698, 292 | 2046, 16820, 4989, 10269, 8997, 19538, 19460, 6427, 1355, 10941, 2302, 20719, 2831, 20667, 293 | 5742, 11803, 8114, 4570, 8564, 7105, 18850, 6954, 8637, 6842, 9513, 2707, 20894, 16799, 15623, 294 | 3476, 2025, 10345, 6031, 10939, 13930, 11102, 2147, 5300, 19002, 1863, 5124, 12459, 17912, 295 | 4084, 23473, 14770, 14456, 12714, 13079, 12726, 2896, 4253, 8308, 23611, 1554, 5564, 8036, 296 | 4202, 19742, 1099, 17155, 11577, 15363, 11545, 16630, 23980, 6133, 11748, 22633, 2076, 965, 297 | 1587, 23803, 17737, 2509, 11657, 10168, 4711, 947, 18922, 7340, 12042, 3091, 3800, 1605, 102, 298 | 4815, 3751, 20913, 18523, 16696, 13902, 24928, 19195, 24151, 5585, 21300, 12882, 3514, 23048, 299 | 21986, 6942, 1866, 15924, 5072, 10210, 4955, 1260, 3182, 15013, 11234, 20087, 9604, 12786, 300 | 21402, 16402, 10766, 251, 8909, 17312, 5265, 895, 3663, 11312, 3085, 2264, 24713, 9384, 1940, 301 | 11829, 15305, 7563, 3076, 15256, 15087, 12848, 8755, 13515, 2599, 9626, 17953, 13437, 17488, 302 | 17527, 10022, 5964, 12531, 105, 7147, 16479, 15912, 19673, 20283, 19458, 11786, 16894, 6888, 303 | 7629, 9714, 19271, 651, 21321, 4051, 21331, 6891, 11161, 11073, 16461, 8551, 15787, 22185, 304 | 23279, 8576, 7031, 23482, 13469, 11723, 14526, 5713, 715, 1952, 19923, 3522, 17989, 2377, 305 | 10301, 21401, 9848, 6477, 17025, 15265, 14313, 22365, 15546, 18395, 21971, 18799, 12323, 19639, 306 | 2495, 20192, 24034, 6194, 2135, 6958, 9666, 2747, 3573, 8635, 19428, 21635, 11887, 9070, 22068, 307 | 3203, 1853, 22044, 6091, 8517, 8405, 2103, 7853, 13524, 8968, 23534, 20403, 3493, 21046, 380, 308 | 13137, 8890, 14844, 12598, 24623, 13741, 22161, 12397, 18788, 20906, 18657, 18848, 18051, 309 | 17651, 9260, 8074, 18245, 5240, 3275, 7490, 19017, 21395, 23922, 7194, 3650, 12228, 19740, 310 | 8343, 7012, 11668, 17365, 10337, 4922, 4947, 15157, 23157, 1046, 1706, 12465, 14586, 16890, 311 | 12937, 12183, 2472, 2152, 19773, 4121, 7108, 12139, 4567, 22658, 20865, 23351, 12510, 22078, 312 | 9190, 15961, 850, 7380, 10792, 2183, 24412, 10169, 3694, 23587, 23888, 20376, 21842, 9589, 313 | 22023, 19194, 22069, 10538, 13213, 22678, 10557, 17270, 16205, 12176, 2527, 4520, 19027, 11597, 314 | 18929, 4807, 6903, 11289, 5815, 6680, 23266, 13060, 22693, 20548, 19235, 23976, 5263, 9504, 315 | 18430, 1791, 3948, 21205, 1167, 21223, 8634, 16409, 2392, 17823, 3253, 19371, 6347, 19942, 316 | 4775, 24205, 4294, 24816, 1387, 22512, 22582, 22957, 6199, 4928, 8453, 17481, 22878, 1911, 317 | 13775, 17307, 21122, 20778, 15499, 19077, 12157, 170, 22647, 6309, 9111, 14079, 11313, 2687, 318 | 17242, 10892, 5678, 23781, 24771, 24785, 5035, 22046, 9994, 24536, 14264, 10313, 8227, 10159, 319 | 18520, 3276, 16221, 4984, 8415, 20643, 13764, 5592, 1044, 3083, 17642, 19052, 24021, 11222, 320 | 24625, 22383, 20352, 9193, 8537, 7501, 13759, 19564, 6715, 24817, 9033, 13235, 10995, 3976, 321 | 19466, 18676, 6045, 24166, 23549, 11138, 20758, 90, 20181, 8764, 14326, 18026, 8508, 24537, 322 | 10057, 11232, 9378, 10008, 8725, 5507, 10691, 13545, 6125, 13050, 1505, 3727, 5554, 6098, 323 | 12547, 22960, 10694, 5013, 17751, 6652, 3300, 15874, 9402, 24222, 12900, 20807, 6593, 20166, 324 | 4733, 1145, 23415, 332, 3809, 223, 11630, 3081, 15032, 10673, 8670, 11442, 14383, 20571, 12258, 325 | 6306, 2994, 1104, 15235, 4683, 21354, 4720, 3646, 12927, 18658, 3552, 24604, 23484, 23663, 326 | 24985, 9326, 23284, 22100, 21312, 20338, 7578, 19757, 7459, 13820, 20600, 19587, 23294, 20051, 327 | 23504, 17083, 1171, 7523, 7677, 21408, 10994, 11007, 15529, 22865, 22008, 13929, 22282, 7488, 328 | 12343, 19342, 24661, 4804, 9781, 7497, 12005, 8026, 22683, 8244, 11359, 21731, 24000, 2695, 329 | 12904, 6028, 16388, 4851, 9493, 1079, 14206, 1927, 4608, 15345, 18528, 1292, 15886, 8714, 330 | 21935, 10241, 8375, 20897, 6962, 19852, 9556, 9029, 21527, 17561, 21586, 6786, 12251, 1386, 331 | 24049, 18071, 20744, 730, 23389, 4897, 2048, 15768, 22171, 18198, 11640, 15432, 16184, 22323, 332 | 9697, 12315, 2686, 2944, 7345, 3923, 8294, 6418, 13084, 4150, 24784, 6957, 20134, 5316, 22490, 333 | 19205, 9638, 23867, 7370, 2298, 1009, 24430, 9919, 12404, 3915, 20795, 19379, 23318, 7879, 334 | 20456, 315, 13593, 10081, 17311, 15179, 11823, 6436, 9672, 20786, 989, 6376, 9920, 4290, 13236, 335 | 17302, 8004, 8431, 6708, 11233, 21896, 2657, 23007, 13029, 23501, 15943, 12586, 12341, 21217, 336 | 20817, 23182, 4025, 20432, 23492, 10002, 4824, 7889, 15038, 15402, 49, 13125, 22899, 5590, 337 | 9385, 13155, 4514, 12647, 23808, 1735, 20296, 8172, 21301, 15064, 13298, 22538, 3190, 18662, 338 | 7141, 3447, 17950, 23640, 11570, 2413, 3736, 24719, 14911, 7904, 13383, 12787, 13365, 19711, 339 | 18169, 12898, 7530, 859, 23805, 6160, 9357, 6527, 6706, 23159, 22943, 23044, 16024, 20366, 340 | 11772, 9061, 19885, 19720, 11348, 8384, 10277, 139, 852, 6657, 16926, 23252, 3172, 1057, 870, 341 | 10905, 1274, 678, 14797, 6448, 17505, 3250, 10884, 8803, 12131, 7862, 17189, 9746, 13947, 342 | 19188, 14779, 2683, 4772, 5130, 298, 16732, 22886, 22032, 11061, 4612, 23849, 10514, 23652, 343 | 19086, 14019, 20893, 12986, 12731, 22871, 16589, 3360, 5613, 15585, 185, 24560, 10067, 23477, 344 | 23605, 19414, 8006, 8180, 10938, 19527, 1670, 19897, 13710, 550, 21821, 24545, 24975, 16801, 345 | 3111, 15494, 11195, 20119, 8013, 15338, 22240, 4407, 20502, 6765, 18715, 22901, 17679, 6790, 346 | 24425, 16934, 11030, 23508, 10031, 21126, 10558, 11230, 12798, 11989, 22013, 15531, 10598, 347 | 4791, 2578, 15869, 18378, 1992, 19882, 7781, 8633, 22339, 7040, 3149, 3942, 6272, 15043, 4311, 348 | 19318, 7906, 4307, 23684, 13686, 4476, 20763, 20026, 8662, 8776, 5933, 22133, 9559, 1220, 349 | 20723, 8718, 20910, 22922, 4004, 15234, 2097, 16159, 4776, 22504, 14580, 11409, 188, 8198, 350 | 22457, 22995, 22028, 20316, 19273, 23241, 7845, 20363, 2835, 812, 16368, 8296, 7427, 19030, 351 | 12264, 14189, 1681, 15101, 12275, 7386, 11714, 12825, 13449, 292, 21564, 6262, 8279, 19383, 352 | 11219, 10789, 9442, 9109, 845, 2078, 21425, 23914, 13095, 4303, 5541, 23593, 8145, 22816, 6073, 353 | 2203, 21979, 18222, 12401, 5380, 18345, 11543, 21963, 18768, 589, 18663, 1798, 13530, 20851, 354 | 4492, 8826, 6176, 15322, 14656, 19966, 6971, 14763, 3086, 4849, 2682, 16323, 16549, 21521, 355 | 22464, 4274, 21796, 4981, 1755, 22154, 4916, 5017, 16405, 21216, 23052, 9148, 24513, 22761, 356 | 13897, 2737, 7553, 21072, 15464, 22083, 1211, 7525, 10479, 9517, 17306, 4834, 12696, 21900, 357 | 13989, 23528, 17194, 16639, 11968, 3540, 7117, 20239, 23375, 8946, 15606, 20743, 23487, 10184, 358 | 5794, 17297, 18549, 18398, 20874, 24304, 18153, 12642, 15066, 6057, 15228, 1474, 7588, 8128, 359 | 24621, 6232, 9575, 22043, 7948, 3857, 12995, 21071, 23544, 17292, 9788, 8322, 13720, 8147, 881, 360 | 21981, 14488, 4216, 3137, 6430, 23540, 18733, 13730, 10906, 8806, 23152, 8686, 4186, 23196, 361 | 22011, 13328, 338, 2483, 15169, 8989, 10595, 1299, 12105, 14343, 13400, 407, 20670, 11810, 362 | 13373, 8940, 16051, 13157, 8317, 23632, 12808, 9045, 5139, 18903, 5679, 8595, 6580, 24965, 363 | 1159, 882, 5157, 23392, 8604, 20387, 16264, 13536, 22677, 5526, 510, 7917, 13425, 18038, 253, 364 | 10724, 9541, 10118, 482, 2104, 9474, 18828, 8171, 9374, 3478, 14834, 6731, 15243, 17811, 14549, 365 | 5010, 13836, 14336, 22690, 1801, 22422, 4663, 10103, 55, 24559, 2172, 7645, 17706, 5186, 1959, 366 | 15581, 16699, 308, 17479, 24685, 9191, 7018, 3441, 889, 2619, 10201, 8656, 19572, 1410, 20103, 367 | 10864, 7489, 23416, 15558, 1936, 15570, 6214, 7119, 13935, 11568, 21315, 10501, 23214, 23023, 368 | 5607, 9849, 5068, 23137, 8703, 18889, 15472, 11033, 22151, 2633, 14599, 4110, 21372, 6039, 369 | 20044, 15107, 11236, 5389, 16007, 9602, 13091, 15125, 16743, 17697, 3330, 21293, 11215, 4266, 370 | 15712, 22856, 8874, 16535, 15889, 24263, 18759, 16686, 21625, 23686, 21389, 12560, 16202, 764, 371 | 17688, 8813, 5173, 957, 19010, 8299, 15430, 101, 1116, 854, 3427, 3592, 21353, 16933, 16548, 372 | 1508, 12722, 11055, 9153, 364, 18616, 9301, 23012, 12062, 11875, 9269, 8818, 17615, 15306, 373 | 6378, 15122, 15147, 18888, 19296, 3438, 14703, 4288, 7697, 24590, 9344, 17982, 14905, 16279, 374 | 20362, 19431, 8008, 23277, 20673, 19192, 23896, 9008, 7739, 6662, 19767, 4539, 23427, 21157, 375 | 12827, 23226, 18692, 13385, 6693, 6265, 14756, 7409, 86, 19343, 2837, 13342, 3662, 3305, 15350, 376 | 23756, 24432, 2374, 14780, 7968, 3628, 9703, 5560, 24531, 14255, 992, 21097, 11052, 19675, 377 | 18011, 17123, 13270, 10603, 7455, 24605, 13398, 7025, 18762, 7610, 15879, 2922, 2156, 8897, 378 | 3346, 5835, 14406, 2828, 23573, 7419, 19500, 15755, 9918, 9297, 4089, 11271, 1030, 3737, 15550, 379 | 5788, 12509, 5229, 11952, 1541, 7772, 3612, 19310, 10055, 9608, 14310, 14102, 4351, 19083, 380 | 22272, 21504, 2455, 18396, 19297, 24381, 11136, 2981, 12981, 15650, 21626, 16041, 14839, 513, 381 | 3109, 4679, 21238, 18612, 12635, 16539, 1656, 9526, 17539, 11975, 17902, 18488, 6933, 14516, 382 | 6565, 333, 10720, 1487, 18449, 23857, 14509, 1759, 16982, 10341, 1519, 6953, 24521, 13302, 383 | 20973, 24482, 794, 5254, 5247, 16957, 19327, 15238, 506, 19702, 8510, 7653, 8805, 4179, 24539, 384 | 8542, 20173, 4191, 4668, 14113, 877, 2723, 12921, 18909, 23329, 18073, 23695, 16103, 21820, 385 | 1628, 22694, 23099, 2592, 13470, 21318, 18357, 692, 3913, 15538, 17755, 14685, 16485, 7100, 386 | 7469, 8571, 17217, 5364, 2380, 4018, 13423, 7931, 22009, 17923, 13144, 3505, 10731, 24302, 387 | 14758, 12713, 10404, 3343, 21100, 16288, 1317, 7492, 1503, 14501, 8210, 277, 21154, 22545, 388 | 7787, 5671, 20688, 3835, 8275, 22183, 3837, 12587, 1154, 17327, 11544, 13955, 18209, 12281, 389 | 12750, 5031, 20288, 9020, 21279, 4095, 10947, 17909, 16261, 20759, 15174, 8917, 862, 4694, 390 | 17405, 17637, 6064, 23303, 8849, 14448, 446, 13919, 4459, 19964, 12022, 14759, 20355, 24769, 391 | 13259, 620, 6767, 19368, 9351, 9713, 10077, 8295, 7568, 6345, 8908, 1518, 10188, 11724, 3910, 392 | 18790, 19584, 13798, 24681, 23887, 1584, 14306, 11766, 5574, 10203, 13611, 20883, 6203, 17829, 393 | 10896, 13431, 21550, 22063, 9115, 8150, 11774, 14520, 12366, 9864, 16924, 1621, 21483, 13544, 394 | 22707, 9393, 9244, 4169, 23180, 17524, 11242, 18842, 16348, 1590, 8303, 8642, 8769, 5972, 395 | 18265, 5646, 8702, 17946, 8, 21186, 21240, 1016, 3473, 11140, 10492, 13216, 19020, 1205, 18480, 396 | 19392, 23032, 10705, 20411, 7313, 18407, 18785, 10943, 13563, 21694, 19107, 13263, 23255, 397 | 15492, 18113, 22150, 4346, 2423, 16867, 5582, 20803, 23166, 21490, 12896, 4557, 13493, 18343, 398 | 16848, 4077, 24277, 5406, 19812, 3989, 24073, 18132, 12473, 17406, 17027, 7429, 16248, 10619, 399 | 16841, 22373, 24453, 24718, 20902, 23378, 20513, 13960, 7044, 24414, 20901, 15938, 23565, 1912, 400 | 6261, 1402, 9990, 7387, 11791, 20090, 15906, 2260, 23435, 10091, 24759, 14169, 19293, 5634, 401 | 13427, 15586, 5454, 2830, 23965, 13273, 6219, 21362, 9646, 2030, 15888, 18594, 24210, 9603, 402 | 13778, 21001, 18779, 6139, 20301, 13991, 9487, 12735, 12136, 7965, 20470, 21416, 11244, 750, 403 | 9467, 12390, 17806, 17382, 13907, 9375, 1417, 430, 24757, 9720, 16893, 6052, 13169, 19928, 404 | 19125, 4224, 15973, 7692, 15063, 23858, 12959, 4446, 2375, 19325, 3118, 18483, 20745, 17568, 405 | 1222, 20773, 17210, 19382, 16355, 2278, 14846, 6381, 19707, 4126, 956, 6827, 12689, 7414, 406 | 13543, 5534, 8051, 2770, 16680, 24435, 936, 2814, 8222, 16031, 13132, 10524, 1015, 13093, 407 | 22187, 23072, 18359, 7220, 1108, 5791, 18259, 5558, 4836, 17844, 10580, 21451, 21845, 11872, 408 | 10814, 19893, 14112, 16935, 19871, 4164, 10402, 16975, 10627, 18455, 7079, 16681, 11425, 15647, 409 | 23635, 3231, 14637, 11327, 4971, 3969, 16271, 652, 6344, 15580, 6566, 2583, 23882, 13046, 1509, 410 | 15426, 1051, 3508, 14356, 4750, 13731, 16473, 16996, 20920, 13528, 11314, 13503, 7986, 9742, 411 | 10430, 15894, 18731, 13033, 10976, 18570, 16884, 15079, 23662, 12149, 3599, 17132, 13739, 412 | 15917, 5479, 23396, 5075, 11014, 3139, 16046, 6428, 5302, 17663, 3711, 6459, 5255, 2948, 4886, 413 | 18707, 22279, 3963, 10418, 3591, 5824, 9478, 19287, 5132, 16421, 5493, 12329, 9333, 23564, 414 | 19307, 7837, 23782, 16886, 10617, 4777, 6383, 8014, 8406, 11259, 22585, 23627, 22111, 4196, 415 | 6895, 20755, 11494, 10342, 13446, 22126, 8173, 15761, 22862, 3669, 1352, 19595, 2669, 5954, 416 | 11194, 21128, 531, 12759, 3070, 7201, 3322, 732, 7154, 13387, 8532, 15308, 16546, 21507, 19955, 417 | 16447, 7685, 3867, 557, 9777, 15876, 3041, 22807, 20029, 16372, 2788, 23989, 14282, 18005, 418 | 12118, 7829, 20203, 6372, 5740, 14550, 14508, 9371, 13689, 10527, 20770, 17410, 16296, 2471, 419 | 6707, 13443, 14992, 14219, 1542, 6369, 3026, 18337, 4350, 4644, 20721, 13664, 14726, 15720, 420 | 7532, 10439, 5696, 17887, 7511, 23955, 4818, 17074, 4595, 20227, 19409, 19916, 2672, 6839, 421 | 15490, 20637, 16034, 7045, 4321, 12592, 19351, 22562, 20555, 10742, 2681, 24206, 790, 24924, 422 | 1250, 8533, 8827, 6146, 12854, 23254, 18524, 5218, 24458, 20421, 16111, 13901, 1903, 9515, 423 | 24587, 12053, 12799, 11270, 12903, 9660, 18282, 4911, 3237, 642, 4728, 14457, 6898, 22615, 424 | 17687, 5790, 7817, 10973, 17691, 20056, 16120, 23024, 19803, 21289, 15721, 20880, 12284, 4614, 425 | 8460, 10041, 359, 9530, 22075, 9103, 13206, 2116, 18883, 9535, 13074, 13319, 21802, 20523, 426 | 5873, 24104, 4365, 7039, 19945, 15039, 10626, 23637, 21497, 11520, 18329, 3693, 15569, 2061, 427 | 19744, 16063, 14872, 19374, 8116, 2616, 40, 4227, 12004, 21474, 6761, 6426, 7026, 12030, 5991, 428 | 9488, 21512, 9048, 3890, 4286, 7392, 11368, 15983, 1564, 12304, 14949, 7050, 18091, 11379, 429 | 6714, 13167, 15089, 14745, 21794, 11038, 12386, 571, 3390, 10187, 18469, 2006, 6811, 4460, 430 | 6307, 6575, 6019, 22768, 24637, 15864, 849, 22522, 11902, 12331, 19838, 191, 5833, 2935, 1604, 431 | 9078, 15261, 14197, 3458, 9223, 19075, 20961, 19334, 11699, 4230, 10088, 17020, 16672, 24352, 432 | 22325, 8387, 23421, 16561, 4263, 564, 7799, 7973, 10263, 19492, 24894, 18564, 19268, 18971, 433 | 19749, 11437, 12085, 6209, 15282, 12788, 23556, 10202, 4359, 6333, 23019, 6451, 11593, 14095, 434 | 3268, 12695, 1701, 11596, 2139, 9999, 8505, 21287, 2548, 21203, 23103, 1215, 22747, 7907, 435 | 17722, 14753, 5061, 8416, 17834, 9404, 6717, 17889, 21328, 15624, 14491, 22917, 8905, 726, 436 | 8086, 12367, 21250, 4333, 8469, 13842, 10361, 1319, 15139, 24084, 3001, 11332, 11278, 5374, 437 | 20016, 4743, 10569, 15646, 11084, 12014, 20280, 15092, 24091, 14484, 12287, 2291, 15229, 20474, 438 | 18049, 3049, 5125, 11427, 21221, 7832, 4819, 12745, 3657, 19939, 13632, 4200, 7281, 8115, 6516, 439 | 11248, 5489, 565, 22254, 18992, 19177, 21479, 6948, 4078, 353, 488, 6403, 11623, 15722, 22626, 440 | 19590, 198, 723, 20866, 13053, 24755, 14117, 5878, 20228, 11119, 5898, 1271, 20890, 6202, 441 | 15416, 23671, 24436, 2414, 10741, 7406, 23265, 24315, 21103, 21648, 9501, 19298, 2629, 6331, 442 | 17790, 1756, 5352, 6862, 21363, 12493, 18251, 7280, 14815, 24558, 985, 9211, 22755, 19719, 443 | 16899, 18217, 9704, 22045, 13290, 7067, 41, 7541, 11849, 6149, 6251, 18500, 9585, 3524, 20245, 444 | 24732, 5221, 10239, 6314, 10070, 23366, 13804, 13556, 9289, 3849, 24072, 19233, 20436, 4254, 445 | 22018, 23280, 12515, 14529, 20218, 23717, 12159, 11862, 16842, 200, 16851, 24027, 11264, 18830, 446 | 13210, 1881, 17336, 16442, 89, 24995, 24394, 4686, 19462, 2000, 11624, 5533, 18263, 2800, 447 | 14380, 19335, 17928, 1208, 24353, 19651, 6163, 16476, 15588, 10452, 23224, 6604, 19357, 2513, 448 | 138, 6775, 16769, 17862, 12206, 5583, 17567, 12978, 10931, 13408, 898, 7460, 3586, 20161, 449 | 10771, 14736, 10477, 21867, 12448, 12239, 20927, 7156, 23893, 8265, 18427, 2254, 18813, 1217, 450 | 11610, 15791, 5146, 13098, 24377, 11208, 3754, 972, 21969, 13987, 4367, 19139, 12381, 12363, 451 | 9926, 22636, 24294, 10530, 8977, 11363, 22049, 18, 17498, 17268, 2452, 15760, 3645, 7249, 452 | 19522, 8058, 21818, 4414, 2634, 11686, 11019, 4378, 13695, 13910, 18728, 24973, 840, 21214, 453 | 9025, 23712, 10859, 5456, 2568, 3692, 19069, 12220, 6235, 3509, 9651, 996, 17804, 17507, 20027, 454 | 9480, 8931, 21675, 13694, 2541, 8582, 9053, 22863, 16662, 6595, 9034, 1277, 4090, 10318, 10180, 455 | 9973, 24709, 18141, 22815, 11558, 9817, 17387, 24399, 4993, 22997, 21499, 10042, 3381, 17452, 456 | 3176, 10444, 10727, 21853, 2136, 11365, 19004, 15077, 6738, 24778, 16439, 8028, 17888, 11380, 457 | 5119, 1576, 18868, 724, 816, 5091, 8427, 5920, 2586, 14547, 18103, 3862, 17677, 20373, 6946, 458 | 24424, 16570, 21095, 18995, 5967, 9316, 7355, 6818, 1960, 5019, 10630, 63, 21683, 13514, 5557, 459 | 10693, 24144, 163, 11198, 9826, 8495, 18540, 4500, 22139, 23846, 19247, 7871, 24127, 21771, 460 | 19158, 20939, 21543, 22604, 18865, 6945, 11400, 13653, 14761, 2504, 3794, 13758, 13227, 4861, 461 | 22193, 11167, 23045, 14953, 1979, 273, 7269, 7807, 15895, 24856, 1300, 23618, 6489, 6395, 462 | 12008, 8692, 12800, 13688, 21767, 6490, 5162, 14881, 17323, 4601, 21220, 4597, 6224, 18706, 463 | 8880, 8371, 7403, 10747, 17140, 10227, 17082, 3511, 1362, 20329, 13459, 24361, 20035, 24196, 464 | 8204, 24223, 1289, 8089, 24328, 8652, 5847, 1563, 6909, 14799, 13790, 1404, 11223, 24358, 465 | 22425, 15123, 16350, 4007, 10932, 1637, 5226, 4124, 20994, 1772, 20209, 4758, 1114, 16307, 466 | 4687, 22631, 17783, 20384, 4779, 24341, 20788, 19516, 16250, 4319, 18237, 23327, 13015, 9216, 467 | 6412, 20089, 14969, 7966, 23140, 18861, 11704, 2118, 19309, 22881, 14551, 3195, 13165, 14942, 468 | 3025, 10378, 15801, 2125, 17873, 23261, 31, 15297, 19634, 11712, 4970, 4133, 1358, 18618, 469 | 13101, 24134, 17404, 10802, 18582, 20714, 6519, 6164, 21766, 10749, 24908, 9313, 3090, 351, 470 | 15288, 3601, 4825, 7671, 7413, 24058, 17551, 16412, 24607, 4237, 17236, 20996, 206, 22262, 471 | 21915, 3715, 5520, 10390, 5260, 13276, 5024, 18148, 14629, 9157, 23373, 5397, 18006, 10107, 472 | 17489, 6643, 3451, 11369, 20451, 24603, 11384, 9868, 554, 14073, 22775, 2660, 8955, 10312, 473 | 18377, 5968, 6619, 2305, 19544, 413, 18704, 16622, 22411, 514, 4268, 16831, 15507, 7761, 21248, 474 | 8555, 5757, 8048, 6346, 11117, 5726, 19164, 12751, 23205, 16792, 23273, 8338, 9618, 22939, 475 | 15390, 20969, 17084, 22837, 9989, 9388, 3443, 15140, 3096, 3651, 2716, 12776, 7938, 21768, 476 | 5549, 10634, 13522, 5251, 5746, 7097, 6504, 7543, 23364, 3206, 5695, 8216, 11835, 17660, 19502, 477 | 17848, 24141, 8306, 10370, 11214, 16484, 5143, 16881, 21226, 16306, 18001, 17969, 15293, 10548, 478 | 5040, 20104, 7032, 20953, 15272, 3652, 11912, 19918, 6179, 5104, 15809, 12870, 4128, 3409, 479 | 1870, 11097, 4242, 13104, 9058, 799, 22590, 12335, 18599, 15266, 20428, 6611, 21732, 4774, 480 | 11980, 3365, 22573, 11325, 13957, 11594, 1367, 1842, 12525, 17886, 21992, 17033, 24612, 13292, 481 | 15527, 9356, 24948, 17281, 19115, 3024, 10666, 18678, 24236, 11480, 5165, 19037, 12439, 9107, 482 | 15482, 14150, 1799, 24301, 23413, 11013, 14119, 8779, 15841, 22956, 7951, 14616, 218, 21893, 483 | 23013, 2017, 3483, 18804, 5572, 4419, 15291, 22259, 4261, 5747, 10608, 22310, 14897, 22601, 484 | 13781] 485 | if imb_factor == 0.02: 486 | pos_val_index = [25273, 25375, 25169, 25433, 25260, 25140, 25446, 25135, 25233, 25380, 25459, 25399, 25125, 487 | 25283, 25391, 25308, 25189, 25317, 25148, 25496, 25200, 25276, 25296, 25067, 25461, 25094, 488 | 25173, 25054, 25055, 25010, 25357, 25381, 25383, 25156, 25366, 25343, 25101, 25424, 25150, 489 | 25307, 25076, 25028, 25334, 25342, 25398, 25415, 25295, 25111, 25466, 25093, 25217, 25099, 490 | 25390, 25179, 25001, 25032, 25474, 25197, 25147, 25323, 25485, 25404, 25105, 25473, 25498, 491 | 25185, 25171, 25182, 25329, 25113, 25312, 25340, 25305, 25228, 25102, 25163, 25151, 25281, 492 | 25298, 25145, 25355, 25059, 25423, 25419, 25053, 25186, 25412, 25368, 25349, 25351, 25049, 493 | 25254, 25370, 25063, 25253, 25071, 25002, 25223, 25096, 25241] 494 | 495 | elif imb_factor == 0.2: 496 | 497 | pos_val_index = [28265, 26111, 26380, 25713, 26329, 29556, 28543, 25658, 25725, 25121, 29996, 29917, 27708, 26068, 25246, 26344, 498 | 27077, 28079, 29389, 28622, 28596, 27750, 27062, 29548, 29992, 27675, 25184, 29177, 28495, 27619, 27860, 28757, 499 | 29235, 26250, 25285, 29112, 29684, 28723, 27448, 25824, 29298, 25909, 27834, 29995, 29073, 27180, 28695, 27120, 500 | 26211, 29557, 29188, 29151, 26712, 26481, 26562, 28445, 28538, 27774, 25909, 29179, 29868, 28598, 28553, 25546, 501 | 29670, 28978, 29220, 25266, 28191, 27698, 28952, 28146, 27518, 25586, 26011, 28359, 29381, 27716, 25475, 28424, 502 | 29130, 29189, 25341, 29108, 29855, 28241, 25497, 28489, 27055, 26521, 28073, 28131, 26457, 26100, 26090, 27844, 503 | 28618, 28158, 28764, 28531, 28399, 27226, 28986, 26192, 27723, 27425, 29495, 25863, 28530, 27060, 26301, 28337, 504 | 25722, 25282, 29781, 28511, 27393, 27130, 25115, 25912, 27136, 27483, 27862, 27038, 29128, 26547, 25804, 26145, 505 | 27394, 28025, 28565, 28853, 28205, 26975, 26597, 29791, 28805, 26015, 25505, 27621, 29498, 29083, 25185, 25298, 506 | 28625, 28131, 25492, 27811, 26508, 29567, 28858, 28652, 27347, 27996, 25285, 26315, 27164, 25237, 26685, 27164, 507 | 26632, 27662, 29459, 28742, 26591, 27792, 28704, 27336, 25055, 29845, 27362, 26691, 27473, 29049, 26079, 29608, 508 | 28676, 25757, 26365, 29997, 28131, 26353, 27769, 25074, 27292, 25610, 26451, 28128, 25394, 29335, 26449, 26396, 509 | 27409, 26818, 25449, 28014, 28248, 26563, 27643, 25759, 26274, 25715, 26016, 25376, 26933, 29155, 28532, 25978, 510 | 27434, 25592, 27951, 27145, 27979, 29660, 29050, 29894, 28110, 28907, 27617, 27141, 25449, 29754, 27600, 28999, 511 | 29523, 26728, 28308, 25993, 26747, 25942, 25344, 29481, 26460, 25201, 27075, 28321, 27517, 27203, 28584, 27243, 512 | 27849, 26108, 29908, 28457, 25309, 26607, 27666, 25529, 29648, 25626, 27442, 29659, 27167, 26687, 25927, 29411, 513 | 28562, 27141, 29753, 28765, 26116, 25217, 29395, 29307, 27419, 26365, 29575, 27822, 25246, 26885, 27096, 25950, 514 | 28454, 26125, 26404, 29055, 29116, 27008, 28223, 26643, 25279, 28220, 28947, 28600, 29796, 28815, 27392, 28117, 515 | 26926, 28742, 28408, 26643, 29049, 25726, 29290, 25121, 25624, 28804, 25327, 28503, 25843, 28659, 25170, 29288, 516 | 27127, 26698, 27867, 27579, 29871, 26952, 26185, 26108, 28365, 25753, 29169, 27098, 28229, 28738, 28255, 28625, 517 | 26859, 28649, 29384, 28975, 27193, 25853, 26753, 25018, 28715, 26255, 29832, 25797, 28837, 26694, 27621, 28610, 518 | 25414, 29334, 26682, 26672, 29158, 29232, 26841, 29203, 28576, 28970, 25231, 28605, 26029, 27395, 29093, 25722, 519 | 26615, 26198, 29879, 25979, 26245, 27872, 29614, 29852, 29058, 27360, 27788, 25509, 27679, 26605, 26475, 28220, 520 | 27146, 27031, 28085, 28985, 26747, 27492, 29181, 29654, 28895, 28755, 26904, 25607, 26740, 26875, 26979, 26905, 521 | 25523, 26947, 26496, 25986, 25072, 26612, 25298, 28382, 26895, 29189, 28072, 26841, 26876, 27867, 27731, 25112, 522 | 27430, 26419, 29409, 29899, 26203, 25578, 28907, 28385, 29591, 28418, 25619, 28219, 27668, 29705, 26904, 29429, 523 | 29542, 25209, 25419, 25177, 26792, 25338, 29960, 25417, 27057, 26633, 25219, 28848, 25967, 28330, 25966, 28141, 524 | 25586, 25884, 28810, 29299, 27927, 28162, 29875, 28696, 27154, 29303, 26239, 27587, 27796, 25560, 27943, 25499, 525 | 25548, 26869, 29616, 29819, 28552, 25345, 26517, 27280, 29217, 28631, 27795, 26931, 29536, 29482, 25047, 29126, 526 | 29713, 26127, 26189, 29016, 27925, 26578, 29720, 29057, 29791, 28379, 25016, 25117, 26801, 27929, 27085, 26603, 527 | 26989, 29719, 25937, 26048, 26198, 28635, 25580, 26324, 27080, 27036, 28656, 29715, 27463, 25729, 29385, 28660, 528 | 28746, 27778, 29283, 28438, 27524, 28043, 29018, 25433, 26282, 29711, 27135, 25044, 26395, 25656, 29487, 28502, 529 | 27745, 27878, 27967, 25432, 29746, 26940, 29861, 27240, 29589, 27213, 28150, 27485, 25815, 26010, 29743, 27267, 530 | 27445, 25077, 29535, 29134, 27197, 27864, 26832, 29744, 26778, 29019, 28576, 25066, 28918, 29355, 26436, 29954, 531 | 29363, 27204, 29809, 25989, 26337, 27736, 26843, 25320, 26794, 28999, 29946, 28206, 26277, 29770, 26381, 26155, 532 | 26120, 27437, 28806, 29030, 26305, 26266, 28888, 28384, 25600, 26612, 25128, 28134, 25297, 26122, 26810, 25248, 533 | 25246, 28291, 25598, 29971, 28698, 28319, 29195, 27282, 29390, 28098, 26978, 26861, 25544, 26092, 29142, 28568, 534 | 25967, 26598, 29404, 28195, 26927, 28108, 27969, 27480, 29913, 25897, 29853, 29587, 25507, 29736, 26132, 25097, 535 | 28362, 28816, 25318, 27784, 29495, 26938, 29295, 25657, 25755, 28313, 26516, 28146, 28659, 28939, 28009, 26299, 536 | 28430, 28108, 28959, 27662, 28856, 26459, 25716, 26051, 25715, 26969, 27427, 26207, 25264, 27009, 29348, 29037, 537 | 28293, 25449, 29548, 25682, 29116, 27642, 27461, 29933, 28799, 26977, 25776, 26375, 27126, 25614, 27858, 26414, 538 | 26634, 28033, 28470, 29757, 25558, 26230, 26783, 27089, 25931, 25578, 29434, 26288, 29616, 25438, 29904, 27982, 539 | 29231, 28727, 28345, 25595, 27040, 26870, 26688, 26261, 26438, 27105, 29542, 25125, 25504, 29879, 27876, 25609, 540 | 29673, 25428, 29683, 28027, 26514, 26006, 25155, 25583, 25494, 25236, 28129, 29016, 25112, 29187, 27633, 26321, 541 | 26718, 26688, 28921, 28166, 28885, 28483, 26899, 25505, 29192, 27648, 28806, 26515, 28162, 27944, 26668, 25517, 542 | 28005, 26823, 25425, 27753, 28386, 29001, 27228, 27393, 26889, 27477, 26024, 29418, 26568, 28414, 29378, 26237, 543 | 27247, 27522, 27439, 28917, 25128, 28274, 25604, 27741, 26476, 26928, 27063, 27371, 29470, 25519, 25879, 27073, 544 | 27878, 29903, 26261, 26460, 26677, 28412, 26069, 27245, 27360, 26436, 27884, 28359, 29696, 26198, 27584, 26221, 545 | 25277, 28502, 26153, 25039, 26192, 27762, 29567, 26392, 27079, 27131, 26593, 29932, 25809, 29215, 29711, 25717, 546 | 27353, 27775, 29316, 26539, 28866, 26838, 28764, 29971, 28951, 27365, 26870, 26185, 26052, 28736, 25647, 25370, 547 | 25896, 26733, 27877, 28686, 25376, 26289, 26683, 25741, 29546, 27328, 28359, 27280, 28027, 29443, 27164, 27194, 548 | 29034, 25443, 26212, 29084, 25284, 26273, 26741, 25890, 29214, 29783, 29533, 27769, 28831, 26706, 25028, 29328, 549 | 29235, 29829, 26814, 29181, 26630, 29991, 27557, 29558, 29171, 29728, 27261, 27215, 26466, 28454, 29631, 27114, 550 | 27706, 26803, 28461, 29645, 26158, 25104, 27777, 27003, 27292, 25936, 29566, 28918, 27688, 27573, 26635, 25639, 551 | 26591, 29010, 25935, 27518, 28421, 27065, 25203, 25105, 27467, 27177, 29955, 27614, 25351, 26577, 28134, 27689, 552 | 25510, 29021, 28517, 27122, 29832, 25526, 26264, 29168, 27878, 26214, 25037, 28548, 26535, 28676, 29012, 29521, 553 | 25816, 26949, 27427, 28851, 29043, 29868, 27364, 29640, 25145, 28293, 29176, 25126, 27139, 29224, 27593, 29291, 554 | 28692, 25794, 28040, 26709, 26158, 27195, 25285, 27501, 26145, 27942, 29758, 29734, 29355, 29319, 29718, 25377, 555 | 26963, 29750, 28057, 26421, 26018, 26559, 26789, 29365, 28550, 27710, 25076, 29960, 29813, 27505, 27435, 28094, 556 | 28858, 28332, 28770, 26033, 28365, 29034, 28236, 26272, 28619, 26767, 29603, 25199, 28265, 29023, 25845, 25937, 557 | 28889, 28943, 29967, 26860, 28270, 26616, 29659, 25338, 25927, 27249, 29603, 26582, 27704, 28289, 25390, 28704, 558 | 27754, 28324, 28409, 25531, 27831, 26301, 29653, 28636, 26863, 26958, 25709, 29313, 25088, 29044, 26873, 29393, 559 | 27625, 27433, 27983, 29564, 29862, 28733, 26683, 28698] 560 | 561 | neg_train_index = list(set(x) ^ set(neg_val_index)) 562 | pos_train_index = list(set(y) ^ set(pos_val_index)) 563 | 564 | 565 | train_index = neg_train_index 566 | train_index.extend(pos_train_index) 567 | 568 | val_index = neg_val_index 569 | val_index.extend(pos_val_index) 570 | 571 | return train_index, val_index 572 | 573 | 574 | 575 | 576 | class IMBALANCECIFAR100(IMBALANCECIFAR10): 577 | """`CIFAR100 `_ Dataset. 578 | This is a subclass of the `CIFAR10` Dataset. 579 | """ 580 | base_folder = 'cifar-100-python' 581 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 582 | filename = "cifar-100-python.tar.gz" 583 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 584 | train_list = [ 585 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 586 | ] 587 | 588 | test_list = [ 589 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 590 | ] 591 | meta = { 592 | 'filename': 'meta', 593 | 'key': 'fine_label_names', 594 | 'md5': '7973b15100ade9c7d40fb424638fde48', 595 | } 596 | cls_num = 100 597 | 598 | def __init__(self, root, imb_type='step', imb_factor=0.02, rand_number=0, train=True, 599 | transform=None, target_transform=None, 600 | download=False, val=False): 601 | IMBALANCECIFAR10.__init__(self, root, imb_type= imb_type , imb_factor= imb_factor, rand_number=rand_number, train=train, 602 | transform=transform, target_transform=target_transform, 603 | download=download, val=val) 604 | 605 | 606 | # root, 607 | # transform=None, 608 | # target_transform=None, 609 | # loader=, 610 | # is_valid_file = None 611 | # 612 | 613 | class TINYIMAGENET(torchvision.datasets.ImageFolder): 614 | cls_num = 200 615 | 616 | def __init__(self, root, imb_type='exp', imb_factor=0.02, rand_number=0, 617 | transform=None, target_transform=None, is_valid_file=None): 618 | super(TINYIMAGENET, self).__init__(root, transform, target_transform, is_valid_file) 619 | np.random.seed(rand_number) 620 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 621 | self.gen_imbalanced_data(img_num_list) 622 | 623 | # print("ImageFolder:", self.imgs[0:100]) 624 | 625 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 626 | img_max = len(self.imgs) / cls_num 627 | img_num_per_cls = [] 628 | if imb_type == 'exp': 629 | for cls_idx in range(cls_num): 630 | num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0))) 631 | img_num_per_cls.append(int(num)) 632 | elif imb_type == 'step': 633 | for cls_idx in range(cls_num // 2): 634 | img_num_per_cls.append(int(img_max)) 635 | for cls_idx in range(cls_num // 2): 636 | img_num_per_cls.append(int(img_max * imb_factor)) 637 | else: 638 | img_num_per_cls.extend([int(img_max)] * cls_num) 639 | return img_num_per_cls 640 | 641 | def gen_imbalanced_data(self, img_num_per_cls): 642 | new_data = [] 643 | new_targets = [] 644 | 645 | # print("img_num_per_cls:", img_num_per_cls) 646 | # self.data = np.array([x[0] for x in self.imgs]) 647 | self.targets = [x[1] for x in self.imgs] 648 | self.imgs = np.array(self.imgs) 649 | # print(self.data[0:10]) 650 | targets_np = np.array(self.targets, dtype=np.int64) 651 | classes = np.unique(targets_np) 652 | # np.random.shuffle(classes) 653 | self.num_per_cls_dict = dict() 654 | for the_class, the_img_num in zip(classes, img_num_per_cls): 655 | self.num_per_cls_dict[the_class] = the_img_num 656 | idx = np.where(targets_np == the_class)[0] 657 | np.random.shuffle(idx) 658 | selec_idx = idx[:the_img_num] 659 | new_data.extend(self.imgs[selec_idx, ...]) 660 | new_targets.extend([the_class, ] * the_img_num) 661 | 662 | self.samples = new_data 663 | self.targets = new_targets 664 | 665 | def get_cls_num_list(self): 666 | cls_num_list = [] 667 | for i in range(self.cls_num): 668 | cls_num_list.append(self.num_per_cls_dict[i]) 669 | return cls_num_list 670 | 671 | def __getitem__(self, index): 672 | """ 673 | Args: 674 | index (int): Index 675 | 676 | Returns: 677 | tuple: (sample, target) where target is class_index of the target class. 678 | """ 679 | 680 | path, target = self.samples[index] 681 | target = int(target) 682 | assert os.path.isfile(path) == True, "File not exists" 683 | img = Image.open(path) 684 | sample = img.convert('RGB') 685 | 686 | if self.transform is not None: 687 | sample = self.transform(sample) 688 | if self.target_transform is not None: 689 | target = self.target_transform(target) 690 | 691 | return sample, target 692 | 693 | 694 | if __name__ == '__main__': 695 | transform = transforms.Compose( 696 | [transforms.ToTensor(), 697 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 698 | trainset = IMBALANCECIFAR100(root='./data', train=True, 699 | download=True, transform=transform) 700 | trainloader = iter(trainset) 701 | data, label = next(trainloader) 702 | import pdb; 703 | 704 | pdb.set_trace() 705 | 706 | -------------------------------------------------------------------------------- /Image/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | sigmoidf = nn.Sigmoid() 6 | 7 | def squared_hinge_loss(predScore, targets, b): 8 | 9 | 10 | squared_hinge = (1-targets*(predScore - b)) 11 | squared_hinge[squared_hinge <=0] = 0 12 | 13 | 14 | return squared_hinge ** 2 15 | 16 | 17 | def sigmoid_loss(pos, neg, beta=2.0): 18 | return 1.0 / (1.0 + torch.exp(beta * (pos - neg))) 19 | 20 | def logistic_loss(pos, neg, beta = 1): 21 | 22 | return -torch.log(1/(1+torch.exp(-beta * (pos - neg)))) 23 | -------------------------------------------------------------------------------- /Image/main_cifar100_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets.folder import ImageFolder 4 | import numpy as np 5 | from torchvision.models import resnet 6 | from preprocess import * 7 | from utils import * 8 | from config_cifar import conf 9 | from train_eval import * 10 | from imbalanced_cifar import * 11 | from torchvision import datasets 12 | # train_dataset = ('/mnt/dive/shared/lyz/auprc_img/datasets/train', get_transform(input_size=conf['input_size'], augment=True)) 13 | # val_dataset = ImageFolder('/mnt/dive/shared/lyz/auprc_img/datasets/valid', get_transform(input_size=conf['input_size'], augment=False)) 14 | # test_dataset = ImageFolder('/mnt/dive/shared/lyz/auprc_img/datasets/test', get_transform(input_size=conf['input_size'], augment=False)) 15 | 16 | 17 | imb_ratio = 0.02 18 | train_dataset = IMBALANCECIFAR100(root='./data', download=True, transform = transform_train, imb_factor=imb_ratio ) 19 | val_dataset = IMBALANCECIFAR100(root='./data', download=True, transform=transform_train, imb_factor= imb_ratio, val = True) 20 | test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val) 21 | 22 | 23 | 24 | model = resnet18() 25 | model_name = 'resnet18' 26 | 27 | 28 | if imb_ratio == 0.02: 29 | apdix = '002' 30 | elif imb_ratio == 0.2: 31 | apdix = '020' 32 | 33 | 34 | for conf['batch_size'] in [64]: 35 | for i in range(3): 36 | for loss_type in ['SOAP']: 37 | conf['epochs'] = 64 38 | conf['ft_mode'] = 'fc_random' 39 | conf['lr'] = 1e-4 40 | conf['pre_train'] = './cepretrainmodels/cifar100_' + model_name + '_' + apdix + '.ckpt' # imb_factor 0.02 41 | conf['surr_loss'] = 'sqh' 42 | tau = 1 43 | conf['posNum'] = 2 44 | th = 0.6 45 | print(conf) 46 | print(i) 47 | bins = 2 48 | mv_gamma = 0.999 49 | out_path = './Released_results/{}/cifar100/SGD_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_tau_{}_posNum_{}_threshold_{}_repeats_{}_imb_{}_surrloss_{}_gamma_{}'.format(model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'],tau, conf['posNum'], th, i, imb_ratio, conf['surr_loss'], str(mv_gamma)) 50 | if not os.path.exists(out_path): 51 | os.makedirs(out_path) 52 | conf['loss_type'] = loss_type 53 | conf['loss_param'] = {'threshold': th, 'm':5, 'gamma':1000} 54 | 55 | run_classification(i, train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path, 56 | bins = bins, tau = tau, posNum = conf['posNum'], surr_loss= conf['surr_loss'], dataset = 'cifar100', mv_gamma = mv_gamma) 57 | print(mv_gamma, conf) 58 | print(i) 59 | -------------------------------------------------------------------------------- /Image/main_cifar100_resnet34.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets.folder import ImageFolder 4 | import numpy as np 5 | from torchvision.models import resnet 6 | from preprocess import * 7 | from utils import * 8 | from config_cifar import conf 9 | from train_eval import * 10 | from imbalanced_cifar import * 11 | from torchvision import datasets 12 | # train_dataset = ('/mnt/dive/shared/lyz/auprc_img/datasets/train', get_transform(input_size=conf['input_size'], augment=True)) 13 | # val_dataset = ImageFolder('/mnt/dive/shared/lyz/auprc_img/datasets/valid', get_transform(input_size=conf['input_size'], augment=False)) 14 | # test_dataset = ImageFolder('/mnt/dive/shared/lyz/auprc_img/datasets/test', get_transform(input_size=conf['input_size'], augment=False)) 15 | 16 | 17 | imb_ratio = 0.02 18 | train_dataset = IMBALANCECIFAR100(root='./data', download=True, transform = transform_train, imb_factor=imb_ratio ) 19 | val_dataset = IMBALANCECIFAR100(root='./data', download=True, transform=transform_train, imb_factor= imb_ratio, val = True) 20 | test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_val) 21 | 22 | 23 | 24 | model = resnet34() 25 | model_name = 'resnet34' 26 | 27 | if imb_ratio == 0.02: 28 | apdix = '002' 29 | elif imb_ratio == 0.2: 30 | apdix = '020' 31 | 32 | np.random.seed(0) 33 | for conf['batch_size'] in [64]: 34 | for i in range(5): 35 | for loss_type in ['SOAP']:# , 'focal', 'auprc_lang', 'auroc2' 36 | conf['epochs'] = 64 37 | conf['ft_mode'] = 'fc_random' 38 | conf['lr'] = 1e-3 39 | # rebuttal 40 | conf['pre_train'] = './cepretrainmodels/cifar100_' + model_name + '_' + apdix + '.ckpt' # imb_factor 0.02 41 | conf['surr_loss'] = 'sqh' 42 | tau = 1 43 | conf['posNum'] = 3 44 | th = 0.6 45 | out_path = './Released_results/{}/cifar100/Pretrained_SGD_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_tau_{}_posNum_{}_threshold_{}_repeats_{}_imb_{}_surrloss_{}'.format(model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'],tau, conf['posNum'], th, i, imb_ratio, conf['surr_loss']) 46 | if not os.path.exists(out_path): 47 | os.makedirs(out_path) 48 | conf['loss_type'] = loss_type 49 | conf['loss_param'] = {'threshold': th, 'm':5, 'gamma':1000} 50 | 51 | print(i, ":", conf) 52 | bins = 2 53 | mv_gamma = 0.999 54 | run_classification(i, train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path, 55 | bins = bins, tau = tau, posNum = conf['posNum'], surr_loss= conf['surr_loss'], dataset = 'cifar100', mv_gamma = mv_gamma) 56 | print(mv_gamma, conf) 57 | print(i) 58 | 59 | -------------------------------------------------------------------------------- /Image/main_cifar10_resnet18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets.folder import ImageFolder 4 | import numpy as np 5 | from torchvision.models import resnet 6 | from preprocess import * 7 | from utils import * 8 | from config_cifar import conf 9 | from train_eval import * 10 | from imbalanced_cifar import * 11 | from torchvision import datasets 12 | 13 | 14 | imb_ratio = 0.02 15 | train_dataset = IMBALANCECIFAR10(root='./data', download=True, transform = transform_train, imb_factor=imb_ratio ) 16 | val_dataset = IMBALANCECIFAR10(root='./data', download=True, transform=transform_train, imb_factor= imb_ratio, val = True) 17 | test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val) 18 | 19 | 20 | 21 | print(len(train_dataset), len(val_dataset), len(test_dataset)) 22 | model = resnet18() 23 | model_name = 'resnet18' 24 | 25 | if imb_ratio == 0.02: 26 | apdix = '002' 27 | elif imb_ratio == 0.2: 28 | apdix = '020' 29 | 30 | 31 | for conf['batch_size'] in [64]: 32 | for i in range(3): 33 | for loss_type in ['SOAP']: 34 | conf['epochs'] = 64 35 | conf['ft_mode'] = 'fc_random' 36 | conf['lr'] = 1e-4 37 | conf['pre_train'] = './cepretrainmodels/cifar10_' + model_name + '_' + apdix + '.ckpt' # imb_factor 0.02 38 | conf['surr_loss'] = 'sqh' 39 | tau = 1 40 | conf['posNum'] = 2 41 | th = 0.6 42 | out_path = './Released_results/{}/cifar10/SGD_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_tau_{}_posNum_{}_threshold_{}_repeats_{}_imb_{}_surrloss_{}'.\ 43 | format(model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'],tau, conf['posNum'], th, i, imb_ratio, conf['surr_loss']) 44 | if not os.path.exists(out_path): 45 | os.makedirs(out_path) 46 | conf['loss_type'] = loss_type 47 | conf['loss_param'] = {'threshold': th, 'm':5, 'gamma':1000} 48 | print(conf) 49 | print(i) 50 | bins = 2 51 | mv_gamma = 0.999 52 | run_classification(i, train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path, 53 | bins = bins, tau = tau, posNum = conf['posNum'], surr_loss= conf['surr_loss'], dataset = 'cifar10', mv_gamma = mv_gamma) 54 | print(mv_gamma, conf) 55 | -------------------------------------------------------------------------------- /Image/main_cifar10_resnet34.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.datasets.folder import ImageFolder 4 | import numpy as np 5 | from torchvision.models import resnet 6 | from preprocess import * 7 | from utils import * 8 | from config_cifar import conf 9 | from train_eval import * 10 | from imbalanced_cifar import * 11 | from torchvision import datasets 12 | # train_dataset = ('/mnt/dive/shared/lyz/auprc_img/datasets/train', get_transform(input_size=conf['input_size'], augment=True)) 13 | # val_dataset = ImageFolder('/mnt/dive/shared/lyz/auprc_img/datasets/valid', get_transform(input_size=conf['input_size'], augment=False)) 14 | # test_dataset = ImageFolder('/mnt/dive/shared/lyz/auprc_img/datasets/test', get_transform(input_size=conf['input_size'], augment=False)) 15 | 16 | 17 | imb_ratio = 0.02 18 | train_dataset = IMBALANCECIFAR10(root='./data', download=True, transform = transform_train, imb_factor=imb_ratio ) 19 | val_dataset = IMBALANCECIFAR10(root='./data', download=True, transform=transform_train, imb_factor= imb_ratio, val = True) 20 | test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val) 21 | 22 | 23 | 24 | model = resnet34() 25 | model_name = 'resnet34' 26 | 27 | 28 | if imb_ratio == 0.02: 29 | apdix = '002' 30 | elif imb_ratio == 0.2: 31 | apdix = '020' 32 | 33 | 34 | np.random.seed(0) 35 | for conf['batch_size'] in [64]: 36 | for i in range(3): 37 | for loss_type in ['SOAP']: #, 'auroc2', 'auprc_lang', 'wce', 'ldam', 'focal', 'smoothAP', 'fastAP' 38 | conf['epochs'] = 64 39 | conf['ft_mode'] = 'fc_random' 40 | conf['lr'] = 1e-2 41 | conf['pre_train'] = './cepretrainmodels/cifar10_' + model_name + '_' + apdix + '.ckpt' # imb_factor 0.02 42 | conf['surr_loss'] = 'sqh' 43 | tau = 1 44 | posNum = 3 45 | th = 0.5 46 | out_path = './Released_results/{}/cifar10/SGD_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_tau_{}_posNum_{}_threshold_{}_repeats_{}_imb_{}_surrloss_{}'.format(model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'],tau, posNum, th, i, imb_ratio, conf['surr_loss']) 47 | if not os.path.exists(out_path): 48 | os.makedirs(out_path) 49 | conf['loss_type'] = loss_type 50 | conf['loss_param'] = {'threshold': th, 'm':5, 'gamma':1000} 51 | print(conf) 52 | print(i, posNum) 53 | bins = 2 54 | mv_gamma = 0.99 55 | run_classification(i, train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path, 56 | bins = bins, tau = tau, posNum = posNum, surr_loss= conf['surr_loss'], dataset = 'cifar10', mv_gamma = mv_gamma) 57 | 58 | # gamma = 1 1-gamma = 0 59 | # gamma = 0.1 1- gamma = 0.9, 1- 0.05 = 0.95 60 | # gamma = 0.01 1- gamma = 0.99 61 | -------------------------------------------------------------------------------- /Image/main_melanoma_resnet18.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*-coding=utf-8 -*- 3 | __author__ = 'Qi' 4 | # Created by on 7/3/21. 5 | 6 | from torchvision.models import resnet 7 | from preprocess import * 8 | from utils import * 9 | from config_melanoma import conf 10 | from train_eval_melanoma import * 11 | import os 12 | 13 | 14 | if 'amax' in os.uname()[1]: 15 | data_path = '/data/qiuzh/melanoma/jpeg/' 16 | elif 'optimus' in os.uname()[1]: 17 | data_path = '/optimus_data/backed_up/qqi7/melanoma/data/melanoma/' 18 | 19 | 20 | 21 | train_dataset = MyImageFolder(data_path + 'mytrain', get_transform(input_size=conf['input_size'], augment=True)) 22 | val_dataset = ImageFolder(data_path + 'myval', get_transform(input_size=conf['input_size'], augment=False)) 23 | test_dataset = ImageFolder( data_path + 'mytest', get_transform(input_size=conf['input_size'], augment=False)) 24 | 25 | model = resnet18() 26 | model_name = 'resnet18' 27 | 28 | for conf['batch_size'] in [64]: 29 | for i in range(1,3): 30 | print(i) 31 | for loss_type in ['smoothAP', 'fastAP']: # 'auprc_lang', 'ldam', 'wce','focal', 'auroc2', 'SOAP' # 32 | conf['ft_mode'] = 'fc_random' 33 | 34 | if 'amax' in os.uname()[1]: 35 | conf['pre_train'] = './cepretrainmodels/melanoma_ce_pretrain_' + model_name + '.pth' 36 | elif 'optimus' in os.uname()[1]: 37 | conf['pre_train'] = './cepretrainmodels/melanoma_ce_pretrain_' + model_name + '.pth' # last.ckpt 38 | 39 | conf['lr'] = 1e-4 40 | conf['epochs'] = 100 41 | tau = 1 42 | conf['posNum'] = 1 43 | th = 5 44 | 45 | if 'amax' in os.uname()[1]: 46 | out_path = '/data/qiuzh/qiqi_res/{}/melanoma/SGD_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_th_{}_tau_{}_posNum_{}'.format( 47 | model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'], th, tau, 48 | conf['posNum']) 49 | elif 'optimus' in os.uname()[1]: 50 | out_path = './Released_results/{}/melanoma/SGD_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_th_{}_tau_{}_posNum_{}'.format( 51 | model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'], th, tau, 52 | conf['posNum']) 53 | 54 | 55 | if not os.path.exists(out_path): 56 | os.makedirs(out_path) 57 | conf['loss_type'] = loss_type 58 | conf['loss_param'] = {'threshold':th, 'm':5, 'gamma':1000} 59 | 60 | print(conf) 61 | print("posNum: ", conf['posNum'], ' option: ', 1) 62 | mv_gamma = 0.999 63 | bins = 5 64 | run_classification(i, train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path, 65 | bins = bins, tau = tau, posNum = conf['posNum'], dataset = 'melanoma', mv_gamma = mv_gamma) 66 | 67 | print(mv_gamma, conf) 68 | print("i: ", i) 69 | -------------------------------------------------------------------------------- /Image/main_melanoma_resnet34.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*-coding=utf-8 -*- 3 | __author__ = 'Qi' 4 | # Created by on 7/3/21. 5 | 6 | from torchvision.models import resnet 7 | from preprocess import * 8 | from utils import * 9 | from config_melanoma import conf 10 | from train_eval_melanoma import * 11 | import os 12 | 13 | 14 | if 'amax' in os.uname()[1]: 15 | data_path = '/data/qiuzh/melanoma/jpeg/' 16 | elif 'optimus' in os.uname()[1]: 17 | data_path = '/optimus_data/backed_up/qqi7/melanoma/data/melanoma/' 18 | 19 | 20 | 21 | train_dataset = MyImageFolder(data_path + 'mytrain', get_transform(input_size=conf['input_size'], augment=True)) 22 | val_dataset = ImageFolder(data_path + 'myval', get_transform(input_size=conf['input_size'], augment=False)) 23 | test_dataset = ImageFolder( data_path + 'mytest', get_transform(input_size=conf['input_size'], augment=False)) 24 | 25 | model = resnet34() 26 | model_name = 'resnet34' 27 | 28 | for conf['batch_size'] in [64]: 29 | for i in range(1,3): 30 | for loss_type in ['auprc_lang', 'wce']: # 'ldam', 'wce','focal', 'auroc2', # ['focal', 'smoothAP', 'fastAP'] 31 | 32 | print(i) 33 | conf['ft_mode'] = 'fc_random' 34 | 35 | if 'amax' in os.uname()[1]: 36 | conf['pre_train'] = './cepretrainmodels/YZ_melanoma_ce_pretrain_' + model_name + '.pth' 37 | elif 'optimus' in os.uname()[1]: 38 | conf['pre_train'] = './cepretrainmodels/YZ_melanoma_ce_pretrain_' + model_name + '.pth' # last.ckpt 39 | 40 | conf['lr'] = 1e-4 41 | conf['epochs'] = 100 42 | tau = 1 43 | conf['posNum'] = 1 44 | conf['th'] = 5 45 | 46 | if 'amax' in os.uname()[1]: 47 | out_path = '/data/qiuzh/qiqi_res/{}/melanoma/YZAdam_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_th_{}_tau_{}_posNum_{}'.format( 48 | model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'], conf['th'], tau, 49 | conf['posNum']) 50 | elif 'optimus' in os.uname()[1]: 51 | out_path = './Released_results/{}/melanoma/YZAdam_results_{}_bth_{}_epoch_{}_lr_{}_ft_mode_{}_th_{}_tau_{}_posNum_{}'.format( 52 | model_name, loss_type, conf['batch_size'], conf['epochs'], conf['lr'], conf['ft_mode'], conf['th'], tau, 53 | conf['posNum']) 54 | 55 | 56 | if not os.path.exists(out_path): 57 | os.makedirs(out_path) 58 | conf['loss_type'] = loss_type 59 | conf['loss_param'] = {'threshold': conf['th'], 'm':5, 'gamma':1000} 60 | 61 | print(conf) 62 | print("posNum: ", conf['posNum'], ' option: ', 1) 63 | mv_gamma = 0.999 64 | bins = 5 65 | run_classification(i, train_dataset, val_dataset, test_dataset, model, conf['num_tasks'], conf['epochs'], conf['batch_size'], conf['vt_batch_size'], conf['lr'], conf['lr_decay_factor'], conf['lr_decay_step_size'], conf['weight_decay'], conf['loss_type'], conf['loss_param'], conf['ft_mode'], conf['pre_train'], out_path, 66 | bins = bins, tau = tau, posNum = conf['posNum'], dataset = 'melanoma', mv_gamma = mv_gamma) 67 | print(mv_gamma, conf) 68 | -------------------------------------------------------------------------------- /Image/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import cv2 4 | import random 5 | import numpy as np 6 | import os 7 | 8 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 9 | 'std': [0.229, 0.224, 0.225]} 10 | 11 | 12 | 13 | class Microscope: 14 | """ 15 | Cutting out the edges around the center circle of the image 16 | Imitating a picture, taken through the microscope 17 | Args: 18 | p (float): probability of applying an augmentation 19 | 20 | """ 21 | 22 | def __init__(self, p: float = 0.5): 23 | self.p = p 24 | 25 | def __call__(self, img): 26 | """ 27 | 28 | Args: 29 | 30 | img (PIL Image): Image to apply transformation to. 31 | 32 | Returns: 33 | 34 | PIL Image: Image with transformation. 35 | 36 | """ 37 | 38 | img = np.array(img) 39 | if random.random() < self.p: 40 | circle = cv2.circle((np.ones(img.shape) * 255).astype(np.uint8), # image placeholder 41 | 42 | (img.shape[0] // 2, img.shape[1] // 2), # center point of circle 43 | 44 | random.randint(img.shape[0] // 2 - 3, img.shape[0] // 2 + 15), # radius 45 | 46 | (0, 0, 0), # color 47 | 48 | -1) 49 | mask = circle - 255 50 | 51 | img = np.multiply(img, mask) 52 | 53 | return img 54 | 55 | def __repr__(self): 56 | return f'{self.__class__.__name__}(p={self.p})' 57 | 58 | def random_crop_and_filp(input_size, normalize_1=None): 59 | 60 | # VERSION: 1 61 | # t_list = [ 62 | # transforms.RandomResizedCrop(input_size, scale=(0.7, 1)), 63 | # transforms.RandomHorizontalFlip(), 64 | # transforms.RandomVerticalFlip(), 65 | # transforms.ColorJitter(brightness=32. / 255., saturation=0.5), 66 | # Microscope(p=0.6), 67 | # transforms.ToTensor(), 68 | # normalize_1 69 | # # normalize_2 70 | # ] 71 | # VERSION: 2 72 | t_list = [ 73 | transforms.Resize(size=(input_size, input_size)), 74 | transforms.RandomHorizontalFlip(), 75 | transforms.RandomVerticalFlip(), 76 | transforms.ColorJitter(brightness=32. / 255., contrast=0.2, saturation=0.3), 77 | Microscope(p=0.6), 78 | transforms.ToTensor(), 79 | normalize_1 80 | ] 81 | 82 | 83 | 84 | return transforms.Compose(t_list) 85 | 86 | 87 | def scale_and_center_crop(input_size, normalize_1=None): 88 | t_list = [ 89 | transforms.Resize(size = (input_size, input_size)), # center crop 90 | #transforms.Resize(size=input_size), 91 | transforms.ToTensor(), 92 | normalize_1 # Normalization Data 93 | # normalize_2 94 | ] 95 | # if scale_size != input_size: 96 | # t_list = [transforms.Resize(scale_size)] + t_list 97 | 98 | return transforms.Compose(t_list) 99 | 100 | 101 | def get_transform(input_size=None, aug_type=1, augment=True): 102 | 103 | input_size = input_size or 256 104 | 105 | normalize_1 = transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]) 106 | 107 | if aug_type == 1: 108 | if augment: 109 | return random_crop_and_filp(input_size=input_size, normalize_1=normalize_1) 110 | else: 111 | return scale_and_center_crop(input_size=input_size, normalize_1=normalize_1) 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /Image/train_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.optim import Adam, SGD 5 | from utils import compute_cla_metric, ave_prc, global_surrogate_loss_with_sqh 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from SOAP import SOAPLOSS, AUPRCSampler 9 | # from loss import * 10 | # from auprc_hinge import * 11 | # import wandb 12 | from sklearn.metrics import precision_recall_curve 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | 20 | 21 | 22 | 23 | ### This is run function for classification tasks 24 | def run_classification(i, train_dataset, val_dataset, test_dataset, model, num_tasks, epochs, batch_size, vt_batch_size, lr, 25 | lr_decay_factor, lr_decay_step_size, weight_decay, loss_type='ce', loss_param={}, ft_mode='fc_random', pre_train=None, save_dir=None, bins = 5, tau = 1, posNum = 1, surr_loss = 'sqh', dataset = 'cifar10', mv_gamma = 0.999, imb_factor = 0.02): 26 | 27 | if dataset == "melanoma": 28 | n_train = 26500 29 | n_train_pos = 467 30 | else: 31 | if imb_factor == 0.02: 32 | n_train = 20400 33 | n_train_pos = 400 34 | elif imb_factor == 0.2: 35 | n_train = 24000 36 | n_train_pos = 4000 37 | 38 | model = model.to(device) 39 | if pre_train is not None: 40 | print('we are loading pretrain model') 41 | state_key = torch.load(pre_train) 42 | print('pretrain model is loaded from {} epoch'.format(state_key['epoch'])) 43 | filtered = {k:v for k,v in state_key['model'].items() if 'fc' not in k} 44 | model.load_state_dict(filtered, False) 45 | if ft_mode == 'frozen': 46 | for key,param in model.named_parameters(): 47 | if 'fc' in key and 'gn' not in key: 48 | param.requires_grad = True 49 | else: 50 | param.requires_grad = False 51 | elif ft_mode == 'fc_random': 52 | model.fc.reset_parameters() 53 | 54 | optimizer = SGD(model.parameters(), lr=lr, weight_decay=weight_decay) 55 | global u, a, b, m, alpha 56 | #_1, u_2, 57 | labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 58 | 59 | if loss_type == 'ce': 60 | criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 61 | elif loss_type in ['auprc2']: 62 | labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 63 | criterion = None 64 | u = torch.zeros([len(train_dataset)]) 65 | elif loss_type in ['wce','focal','ldam']: 66 | n_pos = n_train_pos 67 | n_neg = n_train - n_train_pos 68 | cls_num_list = [n_neg, n_pos] 69 | if loss_type == 'wce': 70 | criterion = WeightedBCEWithLogitsLoss(cls_num_list=cls_num_list) 71 | elif loss_type == 'focal': 72 | criterion = FocalLoss(cls_num_list=cls_num_list) 73 | elif loss_type == 'ldam': 74 | criterion = BINARY_LDAMLoss(cls_num_list=cls_num_list) 75 | elif loss_type in ['auroc2']: 76 | criterion = None 77 | a, b, alpha, m = float(1), float(0), float(1), loss_param['m'] 78 | loss_param['pos_ratio'] = n_train_pos / n_train 79 | elif loss_type in ['auprc_lang']: 80 | criterion = AUCPRHingeLoss() 81 | elif loss_type in ['fastAP']: 82 | criterion = fastAP 83 | elif loss_type in ['smoothAP']: 84 | criterion = smoothAP 85 | elif loss_type in ['expAP']: 86 | criterion = expAP 87 | elif loss_type in ['SOAP']: 88 | labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 89 | criterion = SOAPLOSS(threshold=loss_param['threshold'], data_length = len(train_dataset) + len(val_dataset), loss_type=surr_loss, gamma = mv_gamma) 90 | # elif loss_type in ['SOAPINDI']: 91 | # labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 92 | # criterion = SOAPLOSSINDICATOR(threshold=loss_param['threshold'], batch_size=batch_size, data_length = len(train_dataset) + len(val_dataset), loss_type=surr_loss) 93 | # elif loss_type in ['GENERALSOAPLOSS']: 94 | # labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 95 | # criterion = GENERALSOAPLOSS(threshold=loss_param['threshold'], batch_size = batch_size) 96 | 97 | 98 | val_loader = DataLoader(val_dataset, vt_batch_size, shuffle=False, num_workers=16, pin_memory=True) 99 | test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False, num_workers=16, pin_memory=True) 100 | 101 | best_auprc_score = 0 102 | final_auprc = 0 103 | best_test_auprc_score = 0 104 | 105 | 106 | for epoch in range(1, epochs + 1): 107 | 108 | # if loss_type in ['ce', 'wce', 'focal', 'ldam', 'auroc2', 'auprc_lang','fastAP', 'smoothAP', 'expAP']: 109 | # train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True, num_workers=16, pin_memory=True) 110 | # elif loss_type in ['auprc2', 'SOAP', 'GENERALSOAPLOSS']: 111 | if loss_type == 'ce': 112 | train_loader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True, num_workers=16, 113 | pin_memory=True) 114 | else: 115 | train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=AUPRCSampler(labels, batch_size, posNum=posNum), num_workers=16, pin_memory=True) 116 | # 117 | avg_train_loss = train_classification(model, optimizer, train_loader, lr_decay_step_size, num_tasks, device, epoch, lr, criterion, loss_type, loss_param, bins, tau, posNum, mv_gamma = mv_gamma) 118 | train_auprc, train_roc, train_ap, train_surr_loss = val_train_classification(model, train_loader, num_tasks, device, loss_param) 119 | val_auprc, val_roc, val_ap, val_surr_loss = val_train_classification(model, val_loader, num_tasks, device, loss_param) 120 | test_auprc, test_roc, test_ap, test_surr_loss = test_classification(model, test_loader, num_tasks, device, loss_param, dataset = dataset) 121 | 122 | 123 | if best_auprc_score <= np.mean(val_auprc): 124 | best_auprc_score = np.mean(val_auprc) 125 | final_auprc = np.mean(test_auprc) 126 | if save_dir is not None: 127 | torch.save({'model':model.state_dict(), 'epoch':epoch}, os.path.join(save_dir, str(i) + '_best.ckpt')) 128 | 129 | 130 | if best_test_auprc_score <= np.mean(test_auprc): 131 | best_test_auprc_score = np.mean(test_auprc) 132 | 133 | print('Epoch: {:03d}, Training Loss: {:.6f}, Train AUPRC: {:.4f}, Val AUPRC (avg over multitasks): {:.4f}, Best AUPRC: {:.4f}, Test AUPRC: {:.4f} Final AUPRC: {:.4f} Best AUPRC: {:.4f}' 134 | .format(epoch, avg_train_loss, np.mean(train_auprc), np.mean(val_auprc), best_auprc_score, np.mean(test_auprc), final_auprc, best_test_auprc_score)) 135 | print('Train AP {:.4f}, Val AP: {}, Test AP: {:.4f}\n'.format(train_ap, val_ap, test_ap)) 136 | print('Train Surr Loss {:.4f}, Val Surr Loss: {}, Test Surr Loss: {:.4f}\n'.format(train_surr_loss, val_surr_loss, test_surr_loss)) 137 | 138 | if epoch % lr_decay_step_size == 0: 139 | for param_group in optimizer.param_groups: 140 | param_group['lr'] = lr_decay_factor * param_group['lr'] 141 | 142 | if save_dir is not None: 143 | fp = open(os.path.join(save_dir, str(i)+'_res_auprc_avepre.txt'), 'a') 144 | fp.write('Train AUPRC {:.4f}, Val AUPRC: {}, Test AUPRC: {:.4f}, Final AUPRC: {:.4f}, Train avg loss: {:.4f}\n'.format(np.mean(train_auprc), np.mean(val_auprc), np.mean(test_auprc), final_auprc, avg_train_loss)) 145 | fp.close() 146 | fp = open(os.path.join(save_dir, str(i)+'_res_auroc.txt'), 'a') 147 | fp.write('Train avg loss: {:.4f}, Val AUROC: {:.4f} Test AUCROC: {:.4f}\n'.format(avg_train_loss, np.mean(val_roc), np.mean(test_roc))) 148 | fp.close() 149 | fp = open(os.path.join(save_dir, str(i)+'_ap.txt'), 'a') 150 | fp.write( 151 | 'Train AP {:.4f}, Val AP: {}, Test AP: {:.4f}\n'.format( 152 | train_ap, val_ap, test_ap)) 153 | fp.close() 154 | fp = open(os.path.join(save_dir, str(i) + '_surr_loss.txt'), 'a') 155 | fp.write( 156 | 'Train Surr Loss {:.4f}, Val Surr Loss: {}, Test Surr Loss: {:.4f}\n'.format( 157 | train_surr_loss, val_surr_loss, test_surr_loss)) 158 | fp.close() 159 | 160 | 161 | 162 | 163 | if save_dir is not None: 164 | torch.save({'model':model.state_dict(), 'epoch':epochs}, os.path.join(save_dir, str(i) + '_last.ckpt')) 165 | 166 | 167 | 168 | def train_classification(model, optimizer, train_loader, lr_decay_step_size, num_tasks, device, epoch, lr, criterion=None, loss_type=None, loss_param={}, bins = 5, tau = 1.0, posNum = 1, mv_gamma=0.999): 169 | model.train() 170 | 171 | global a, b, m, alpha 172 | if loss_type == 'auroc2' and epoch % 10 == 1: 173 | # Periordically update w_{ref}, a_{ref}, b_{ref} 174 | global state, a_0, b_0 175 | a_0, b_0 = a, b 176 | state = [] 177 | for name, param in model.named_parameters(): 178 | state.append(param.data) 179 | losses = [] 180 | for i, (index, inputs, target) in enumerate(train_loader): 181 | 182 | if i%50 == 0: 183 | print(epoch, " : ", i, "/", len(train_loader)) 184 | # warmup_learning_rate(epoch, i, lr, len(train_loader), optimizer) 185 | # print(index, target) 186 | optimizer.zero_grad() 187 | inputs = inputs.to(device) 188 | target = target.to(device).float() 189 | out = model(inputs) 190 | 191 | if loss_type == 'ce': 192 | if len(target.shape) != 2: 193 | target = torch.reshape(target, (-1, num_tasks)) 194 | loss = criterion(out, target) 195 | loss = loss.sum() 196 | loss.backward() 197 | optimizer.step() 198 | elif loss_type in ['wce','focal','ldam']: 199 | loss = criterion(out, target, epoch) 200 | loss.backward() 201 | optimizer.step() 202 | elif loss_type in ['auroc2']: 203 | predScore = torch.nn.Sigmoid()(out) 204 | loss = AUROC_loss(predScore, target, a, b, m, alpha, loss_param['pos_ratio']) 205 | curRegularizer = calculateRegularizerWeights(lr, model, state, loss_param['gamma']) 206 | loss.backward() 207 | optimizer.step() 208 | regularizeUpdate(model, curRegularizer) 209 | a, b, alpha = PESG_update_a_b_alpha_2(lr, a, a_0, b, b_0, alpha, m, predScore, target, loss_param['pos_ratio'], loss_param['gamma']) 210 | elif loss_type in ['auprc_lang']: 211 | loss = criterion(out, target) 212 | loss.backward() 213 | optimizer.step() 214 | elif loss_type in ['smoothAP']: 215 | predScore = torch.sigmoid(out) 216 | loss = criterion(predScore, target, tau = tau) 217 | loss.backward() 218 | optimizer.step() 219 | elif loss_type in ['fastAP']: 220 | predScore = torch.sigmoid(out) 221 | # predScore = out/torch.norm(out) 222 | loss = criterion(predScore, target, bins = bins) 223 | loss.backward() 224 | optimizer.step() 225 | elif loss_type in ['expAP']: 226 | # predScore = out / torch.norm(out) 227 | predScore = torch.sigmoid(out) 228 | loss = criterion(predScore, target, tau = tau) 229 | loss.backward() 230 | optimizer.step() 231 | elif loss_type in ['SOAP']: 232 | predScore = torch.nn.Sigmoid()(out) 233 | loss = criterion(f_ps=predScore[0:posNum], f_ns=predScore[posNum:], index_s=index[0:posNum]) 234 | 235 | loss.backward() 236 | optimizer.step() 237 | 238 | losses.append(loss) 239 | return sum(losses).item() / len(losses) 240 | 241 | 242 | def val_train_classification(model, test_loader, num_tasks, device, loss_param): 243 | model.eval() 244 | preds = torch.Tensor([]).to(device) 245 | targets = torch.Tensor([]).to(device) 246 | 247 | for (index, inputs, target)in test_loader: 248 | 249 | inputs = inputs.to(device) 250 | target = target.to(device).float() 251 | with torch.no_grad(): 252 | out = model(inputs) 253 | if len(target.shape) != 2: 254 | target = torch.reshape(target, (-1, num_tasks)) 255 | if out.shape[1] == 1: 256 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 257 | else: 258 | pred = torch.softmax(out, dim=-1)[:, 1:2] 259 | preds = torch.cat([preds, pred], dim=0) 260 | targets = torch.cat([targets, target], dim=0) 261 | 262 | 263 | auprc, auroc = compute_cla_metric(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), num_tasks) 264 | ap = ave_prc(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) 265 | 266 | surro_loss = global_surrogate_loss_with_sqh(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), loss_param['threshold']) 267 | 268 | 269 | 270 | return auprc, auroc, ap, surro_loss 271 | 272 | 273 | def test_classification(model, test_loader, num_tasks, device, loss_param, dataset = 'cifar10'): 274 | model.eval() 275 | 276 | preds = torch.Tensor([]).to(device) 277 | targets = torch.Tensor([]).to(device) 278 | 279 | for (inputs, target) in test_loader: 280 | inputs = inputs.to(device) 281 | target = target.to(device).float() 282 | if dataset == 'cifar10': 283 | target[target <= 4] = 0 284 | target[target > 4] = 1 285 | elif dataset == 'cifar100': 286 | target[target <= 49] = 0 287 | target[target > 49] = 1 288 | 289 | 290 | with torch.no_grad(): 291 | out = model(inputs) 292 | if len(target.shape) != 2: 293 | target = torch.reshape(target, (-1, num_tasks)) 294 | 295 | if out.shape[1] == 1: 296 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 297 | else: 298 | pred = torch.softmax(out,dim=-1)[:,1:2] 299 | preds = torch.cat([preds,pred], dim=0) 300 | # print(preds) 301 | targets = torch.cat([targets, target], dim=0) 302 | # print(targets) 303 | auprc, auroc = compute_cla_metric(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), num_tasks) 304 | ap = ave_prc(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) 305 | 306 | surro_loss = global_surrogate_loss_with_sqh(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), loss_param['threshold']) 307 | 308 | 309 | return auprc, auroc, ap, surro_loss 310 | 311 | 312 | def warmup_learning_rate(epoch, batch_id, lr, total_batches, optimizer): 313 | if epoch <= 5: 314 | p = (batch_id + (epoch - 1) * total_batches) / \ 315 | (5 * total_batches) 316 | lr = 0.01 + p * (lr - 0.01) 317 | 318 | for param_group in optimizer.param_groups: 319 | param_group['lr'] = lr 320 | 321 | 322 | def plot_precision_recall_curve(model, vt_batch_size, test_dataset, saved_model, method, dataset = 'cifar10'): 323 | 324 | test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False, num_workers=16, pin_memory=True) 325 | 326 | model = model.to(device) 327 | # wandb.watch(model) 328 | 329 | 330 | state_key = torch.load(saved_model) 331 | print('pretrain model is loaded from {} epoch'.format(state_key['epoch'])) 332 | model.load_state_dict(state_key['model']) 333 | 334 | model.eval() 335 | preds = torch.Tensor([]).to(device) 336 | targets = torch.Tensor([]).to(device) 337 | 338 | for (inputs, target) in test_loader: 339 | inputs = inputs.to(device) 340 | target = target.to(device).float() 341 | if dataset == 'cifar10': 342 | target[target <= 4] = 0 343 | target[target > 4] = 1 344 | elif dataset == 'cifar100': 345 | target[target <= 49] = 0 346 | target[target > 49] = 1 347 | 348 | with torch.no_grad(): 349 | out = model(inputs) 350 | 351 | if out.shape[1] == 1: 352 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 353 | else: 354 | pred = torch.softmax(out, dim=-1)[:, 1:2] 355 | preds = torch.cat([preds, pred], dim=0) 356 | 357 | # print(preds) 358 | targets = torch.cat([targets, target], dim=0) 359 | precision, recall, _ = precision_recall_curve(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) 360 | # disp = PrecisionRecallDisplay(precision=precision, recall=recall) 361 | # precision_recall_plt = disp.plot() 362 | 363 | 364 | plt.plot(recall, precision, label = method, linewidth = 2) 365 | if dataset == 'cifar10': 366 | plt.title('CIFAR-10', fontsize = 25) 367 | else: 368 | plt.title('CIFAR-100', fontsize = 25) 369 | plt.xlabel('Recall', fontsize = 20) 370 | plt.ylabel('Precision', fontsize = 20) 371 | plt.hlines(0.5, -0.03, 1.03, colors='gray', linestyles = '--', linewidth = 2) 372 | plt.ylim(0.45,1) 373 | plt.legend(fontsize = 13) 374 | plt.savefig(os.path.join('results', dataset, dataset +'_'+method+'_precision_recall_curve.png')) 375 | -------------------------------------------------------------------------------- /Image/train_eval_melanoma.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Qi' 2 | # Created by on 7/19/21. 3 | import os 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.optim import Adam, SGD 7 | from utils import compute_cla_metric, ave_prc 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | from SOAP import SOAPLOSS, AUPRCSampler 11 | # from loss import * 12 | # from auprc_hinge import * 13 | # import wandb 14 | from sklearn.metrics import precision_recall_curve 15 | import matplotlib.pyplot as plt 16 | 17 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 18 | 19 | # n_train = 20400 20 | # n_train_pos = 400 21 | 22 | 23 | n_train = 26500 24 | n_train_pos = 467 25 | ### This is run function for classification tasks 26 | def run_classification(i, train_dataset, val_dataset, test_dataset, model, num_tasks, epochs, batch_size, vt_batch_size, 27 | lr, lr_decay_factor, lr_decay_step_size, weight_decay, loss_type='ce', loss_param={}, 28 | ft_mode='fc_random', pre_train=None, save_dir=None, bins=5, tau=1, posNum=1, dataset='cifar10', 29 | mv_gamma=0.999): 30 | # 1. Start a W&B run 31 | # wandb.init(project='auprc', entity='qiqi-helloworld') 32 | # 2. Save model inputs and hyperparameters 33 | model = model.to(device) 34 | # wandb.watch(model) 35 | if pre_train is not None: 36 | print('we are loading pretrain model') 37 | state_key = torch.load(pre_train) 38 | print(pre_train) 39 | filtered = {k: v for k, v in state_key.items() if 'fc' not in k} 40 | model.load_state_dict(filtered, False) 41 | if ft_mode == 'frozen': 42 | for key, param in model.named_parameters(): 43 | if 'fc' in key and 'gn' not in key: 44 | param.requires_grad = True 45 | else: 46 | param.requires_grad = False 47 | elif ft_mode == 'fc_random': 48 | model.fc.reset_parameters() 49 | 50 | # optimizer = SGD(model.parameters(), lr=lr, weight_decay=weight_decay) 51 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 52 | global u, a, b, m, alpha 53 | # _1, u_2, 54 | labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 55 | if loss_type == 'ce': 56 | criterion = torch.nn.BCEWithLogitsLoss(reduction='none') 57 | elif loss_type in ['auprc2']: 58 | labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 59 | criterion = None 60 | u = torch.zeros([len(train_dataset)]) 61 | elif loss_type in ['wce', 'focal', 'ldam']: 62 | n_pos = n_train_pos 63 | n_neg = n_train - n_train_pos 64 | cls_num_list = [n_neg, n_pos] 65 | if loss_type == 'wce': 66 | criterion = WeightedBCEWithLogitsLoss(cls_num_list=cls_num_list) 67 | elif loss_type == 'focal': 68 | criterion = FocalLoss(cls_num_list=cls_num_list) 69 | elif loss_type == 'ldam': 70 | criterion = BINARY_LDAMLoss(cls_num_list=cls_num_list) 71 | elif loss_type in ['auroc2']: 72 | criterion = None 73 | a, b, alpha, m = float(1), float(0), float(1), loss_param['m'] 74 | loss_param['pos_ratio'] = n_train_pos / n_train 75 | elif loss_type in ['auprc_lang']: 76 | criterion = AUCPRHingeLoss() 77 | elif loss_type in ['fastAP']: 78 | criterion = fastAP 79 | elif loss_type in ['smoothAP']: 80 | criterion = smoothAP 81 | elif loss_type in ['expAP']: 82 | criterion = expAP 83 | elif loss_type in ['SOAP']: 84 | labels = [0] * (n_train - n_train_pos) + [1] * n_train_pos 85 | criterion = SOAPLOSS(threshold=loss_param['threshold'], data_length = len(train_dataset) + len(val_dataset), gamma = mv_gamma) 86 | 87 | val_loader = DataLoader(val_dataset, vt_batch_size, shuffle=False, num_workers=16, pin_memory=True) 88 | test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False, num_workers=16, pin_memory=True) 89 | 90 | best_auprc = 0 91 | final_auprc = 0 92 | final_ap = 0 93 | 94 | for epoch in range(1, epochs + 1): 95 | 96 | 97 | train_loader = DataLoader(train_dataset, batch_size=batch_size, 98 | sampler=AUPRCSampler(labels, batch_size, posNum=posNum), num_workers=16, 99 | pin_memory=True) 100 | 101 | avg_train_loss = train_classification(model, optimizer, train_loader, lr_decay_step_size, num_tasks, device, 102 | epoch, lr, criterion, loss_type, loss_param, bins, tau, posNum) 103 | 104 | 105 | val_auprc, val_auc, val_ap = test_classification(model, val_loader, num_tasks, device, dataset = dataset) 106 | test_auprc, test_auc, test_ap = test_classification(model, test_loader, num_tasks, 107 | device, dataset=dataset) 108 | 109 | if best_auprc <= np.mean(val_auprc): 110 | best_auprc = np.mean(val_auprc) 111 | final_auprc = np.mean(test_auprc) 112 | final_ap = test_ap 113 | if save_dir is not None: 114 | torch.save({'model': model.state_dict(), 'epoch': epoch}, os.path.join(save_dir, str(i) + '_best.ckpt')) 115 | 116 | 117 | print('Epoch: {:03d}, Training Loss: {:.4f}, Val AUPRC: {:.4f}, Best AUPRC: {:.4f}, Test AUPRC: {:.4f} Final AUPRC: {:.4f}' 118 | .format(epoch, avg_train_loss, np.mean(val_auprc), best_auprc, np.mean(test_auprc), final_auprc)) # np.mean(train_auprc), Train AUPRC: {:.4f}, 119 | 120 | print('Epoch: {:03d}, Val AP: {:.4f}, Test AP: {:.4f} Final AP: {:.4f}'.format(epoch, val_ap, test_ap, final_ap)) # train_ap, Train AP: {:.4f}, 121 | 122 | if epoch % lr_decay_step_size == 0: 123 | for param_group in optimizer.param_groups: 124 | param_group['lr'] = lr_decay_factor * param_group['lr'] 125 | 126 | if save_dir is not None: 127 | fp = open(os.path.join(save_dir, str(i) + '_res_auprc.txt'), 'a') 128 | fp.write( 129 | 'Val AUPRC: {}, Test AUPRC: {:.4f}, Train avg loss: {:.4f}, Val AP: {:.4f}, Test AP: {:.4f}, Final AUPRC: {:.4f}, Final AP: {:.4f}\n'.format( 130 | np.mean(val_auprc), np.mean(test_auprc), 131 | avg_train_loss, val_ap, test_ap, final_auprc, final_ap)) # np.mean(train_auprc), Train AUPRC {:.4f}, 132 | fp.close() 133 | 134 | fp = open(os.path.join(save_dir, str(i) + '_res_auroc.txt'), 'a') 135 | fp.write('Train avg loss: {:.4f}, Val AUROC: {:.4f} Test AUCROC: {:.4f}\n'.format(avg_train_loss, 136 | np.mean(val_auc), 137 | np.mean(test_auc))) 138 | 139 | fp.close() 140 | 141 | fp = open(os.path.join(save_dir, str(i) + '_ap.txt'), 'a') 142 | fp.write('Val AP: {:.4f} Test AP: {:.4f}\n'.format(val_ap, test_ap)) 143 | fp.close() 144 | 145 | if save_dir is not None: 146 | torch.save({'model': model.state_dict(), 'epoch': epochs}, os.path.join(save_dir, str(i) + '_last.ckpt')) 147 | 148 | 149 | def train_classification(model, optimizer, train_loader, lr_decay_step_size, num_tasks, device, epoch, lr, 150 | criterion=None, loss_type=None, loss_param={}, bins=5, tau=1.0, posNum=1, mv_gamma = 0.999): 151 | model.train() 152 | 153 | global a, b, m, alpha 154 | if loss_type == 'auroc2' and epoch % 10 == 1: 155 | # Periordically update w_{ref}, a_{ref}, b_{ref} 156 | global state, a_0, b_0 157 | a_0, b_0 = a, b 158 | state = [] 159 | for name, param in model.named_parameters(): 160 | state.append(param.data) 161 | losses = [] 162 | 163 | for i, (index, inputs, target) in enumerate(train_loader): 164 | 165 | if i % 50 == 0: 166 | print(epoch, " : ", i, "/", len(train_loader)) 167 | # warmup_learning_rate(epoch, i, lr, len(train_loader), optimizer) 168 | # print(index, target) 169 | optimizer.zero_grad() 170 | inputs = inputs.to(device) 171 | target = target.to(device).float() 172 | out = model(inputs) 173 | 174 | if loss_type == 'ce': 175 | if len(target.shape) != 2: 176 | target = torch.reshape(target, (-1, num_tasks)) 177 | loss = criterion(out, target) 178 | loss = loss.sum() 179 | loss.backward() 180 | optimizer.step() 181 | elif loss_type in ['wce', 'focal', 'ldam']: 182 | loss = criterion(out, target, epoch) 183 | loss.backward() 184 | optimizer.step() 185 | elif loss_type in ['auroc2']: 186 | predScore = torch.nn.Sigmoid()(out) 187 | loss = AUROC_loss(predScore, target, a, b, m, alpha, loss_param['pos_ratio']) 188 | curRegularizer = calculateRegularizerWeights(lr, model, state, loss_param['gamma']) 189 | loss.backward() 190 | optimizer.step() 191 | regularizeUpdate(model, curRegularizer) 192 | a, b, alpha = PESG_update_a_b_alpha_2(lr, a, a_0, b, b_0, alpha, m, predScore, target, 193 | loss_param['pos_ratio'], loss_param['gamma']) 194 | elif loss_type in ['auprc_lang']: 195 | loss = criterion(out, target) 196 | loss.backward() 197 | optimizer.step() 198 | elif loss_type in ['smoothAP']: 199 | predScore = torch.sigmoid(out) 200 | loss = criterion(predScore, target, tau=tau) 201 | loss.backward() 202 | optimizer.step() 203 | elif loss_type in ['fastAP']: 204 | predScore = torch.sigmoid(out) 205 | # predScore = out/torch.norm(out) 206 | loss = criterion(predScore, target, bins=bins) 207 | loss.backward() 208 | optimizer.step() 209 | elif loss_type in ['expAP']: 210 | # predScore = out / torch.norm(out) 211 | predScore = torch.sigmoid(out) 212 | loss = criterion(predScore, target, tau=tau) 213 | loss.backward() 214 | optimizer.step() 215 | elif loss_type in ['SOAP']: 216 | predScore = torch.nn.Sigmoid()(out) 217 | loss = criterion(f_ps=predScore[0:posNum], f_ns=predScore[posNum:], index_s=index[0:posNum]) 218 | loss.backward() 219 | optimizer.step() 220 | 221 | losses.append(loss) 222 | return sum(losses).item() / len(losses) 223 | 224 | 225 | def val_train_classification(model, test_loader, num_tasks, device): 226 | model.eval() 227 | preds = torch.Tensor([]).to(device) 228 | targets = torch.Tensor([]).to(device) 229 | 230 | for (index, inputs, target) in test_loader: 231 | 232 | inputs = inputs.to(device) 233 | target = target.to(device).float() 234 | with torch.no_grad(): 235 | out = model(inputs) 236 | if len(target.shape) != 2: 237 | target = torch.reshape(target, (-1, num_tasks)) 238 | if out.shape[1] == 1: 239 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 240 | else: 241 | pred = torch.softmax(out, dim=-1)[:, 1:2] 242 | preds = torch.cat([preds, pred], dim=0) 243 | targets = torch.cat([targets, target], dim=0) 244 | 245 | prc_results, roc_results = compute_cla_metric(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), 246 | num_tasks) 247 | ap = ave_prc(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) 248 | 249 | return prc_results, roc_results, ap 250 | 251 | 252 | def test_classification(model, test_loader, num_tasks, device, dataset='cifar10'): 253 | model.eval() 254 | 255 | preds = torch.Tensor([]).to(device) 256 | targets = torch.Tensor([]).to(device) 257 | 258 | for (inputs, target) in test_loader: 259 | inputs = inputs.to(device) 260 | target = target.to(device).float() 261 | if dataset == 'cifar10': 262 | target[target <= 4] = 0 263 | target[target > 4] = 1 264 | elif dataset == 'cifar100': 265 | target[target <= 49] = 0 266 | target[target > 49] = 1 267 | 268 | with torch.no_grad(): 269 | out = model(inputs) 270 | if len(target.shape) != 2: 271 | target = torch.reshape(target, (-1, num_tasks)) 272 | 273 | if out.shape[1] == 1: 274 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 275 | else: 276 | pred = torch.softmax(out, dim=-1)[:, 1:2] 277 | preds = torch.cat([preds, pred], dim=0) 278 | # print(preds) 279 | targets = torch.cat([targets, target], dim=0) 280 | # print(targets) 281 | prc_results, roc_results = compute_cla_metric(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), 282 | num_tasks) 283 | ap = ave_prc(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) 284 | 285 | return prc_results, roc_results, ap 286 | 287 | 288 | def warmup_learning_rate(epoch, batch_id, lr, total_batches, optimizer): 289 | if epoch <= 5: 290 | p = (batch_id + (epoch - 1) * total_batches) / \ 291 | (5 * total_batches) 292 | lr = 0.01 + p * (lr - 0.01) 293 | 294 | for param_group in optimizer.param_groups: 295 | param_group['lr'] = lr 296 | 297 | 298 | def plot_precision_recall_curve(model, vt_batch_size, test_dataset, saved_model, method, dataset='cifar10'): 299 | test_loader = DataLoader(test_dataset, vt_batch_size, shuffle=False, num_workers=16, pin_memory=True) 300 | 301 | model = model.to(device) 302 | # wandb.watch(model) 303 | state_key = torch.load(saved_model) 304 | # print('pretrain model is loaded from {} epoch'.format(state_key['epoch'])) 305 | model.load_state_dict(state_key) 306 | 307 | model.eval() 308 | preds = torch.Tensor([]).to(device) 309 | targets = torch.Tensor([]).to(device) 310 | 311 | for (inputs, target) in test_loader: 312 | inputs = inputs.to(device) 313 | target = target.to(device).float() 314 | if dataset == 'cifar10': 315 | target[target <= 4] = 0 316 | target[target > 4] = 1 317 | elif dataset == 'cifar100': 318 | target[target <= 49] = 0 319 | target[target > 49] = 1 320 | 321 | with torch.no_grad(): 322 | out = model(inputs) 323 | 324 | if out.shape[1] == 1: 325 | pred = torch.sigmoid(out) ### prediction real number between (0,1) 326 | else: 327 | pred = torch.softmax(out, dim=-1)[:, 1:2] 328 | preds = torch.cat([preds, pred], dim=0) 329 | 330 | # print(preds) 331 | targets = torch.cat([targets, target], dim=0) 332 | precision, recall, _ = precision_recall_curve(targets.cpu().detach().numpy(), preds.cpu().detach().numpy()) 333 | 334 | plt.plot(recall, precision, label=method, linewidth=2) 335 | if dataset == 'cifar10': 336 | plt.title('CIFAR-10', fontsize=25) 337 | else: 338 | plt.title('CIFAR-100', fontsize=25) 339 | plt.xlabel('Recall', fontsize=20) 340 | plt.ylabel('Precision', fontsize=20) 341 | plt.hlines(0.5, -0.03, 1.03, colors='gray', linestyles='--', linewidth=2) 342 | plt.ylim(0.45, 1) 343 | plt.legend(fontsize=13) 344 | plt.savefig(os.path.join('results', dataset, dataset + '_' + method + '_precision_recall_curve.png')) 345 | -------------------------------------------------------------------------------- /Image/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock 3 | from torchvision.datasets.folder import ImageFolder 4 | import numpy as np 5 | from sklearn.metrics import auc, roc_auc_score, average_precision_score 6 | from sklearn.metrics import precision_recall_curve 7 | 8 | def resnet18(): 9 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=1) 10 | return model 11 | 12 | def resnet34(): 13 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=1) 14 | return model 15 | 16 | def resnet50(): 17 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=1) 18 | return model 19 | 20 | 21 | 22 | class MyImageFolder(ImageFolder): 23 | def __init__(self, root, transform): 24 | super().__init__(root, transform=transform) 25 | 26 | def __getitem__(self, index): 27 | sample, target = super().__getitem__(index) 28 | return index, sample, target 29 | 30 | 31 | def prc_auc(targets, preds): 32 | precision, recall, _ = precision_recall_curve(targets, preds) 33 | # disp = PrecisionRecallDisplay(precision=precision, recall=recall) 34 | return auc(recall, precision) 35 | 36 | 37 | def prc_recall_curve(targets, preds): 38 | precision, recall, _ = precision_recall_curve(targets, preds) 39 | 40 | 41 | def ave_prc(targets, preds): 42 | return average_precision_score(targets, preds) 43 | 44 | def compute_cla_metric(targets, preds, num_tasks): 45 | 46 | prc_results = [] 47 | roc_results = [] 48 | for i in range(num_tasks): 49 | is_labeled = targets[:,i] == targets[:,i] ## filter some samples without groundtruth label 50 | target = targets[is_labeled,i] 51 | pred = preds[is_labeled,i] 52 | try: 53 | prc = prc_auc(target, pred) 54 | except ValueError: 55 | prc = np.nan 56 | print("In task #", i+1, " , there is only one class present in the set. PRC is not defined in this case.") 57 | try: 58 | roc = roc_auc_score(target, pred) 59 | except ValueError: 60 | roc = np.nan 61 | print("In task #", i+1, " , there is only one class present in the set. ROC is not defined in this case.") 62 | if not np.isnan(prc): 63 | prc_results.append(prc) 64 | else: 65 | print("PRC results do not consider task #", i+1) 66 | if not np.isnan(roc): 67 | roc_results.append(roc) 68 | else: 69 | print("ROC results do not consider task #", i+1) 70 | return prc_results, roc_results 71 | 72 | 73 | def global_surrogate_loss_with_sqh(target, pred, threshold): 74 | 75 | 76 | posNum = np.sum(target) 77 | target, pred = target.reshape(-1), pred.reshape(-1) 78 | # print(target, pred) 79 | # print(posNum) 80 | loss = 0 81 | for t in range(len(target)): 82 | if target[t] == 1: 83 | # print(t) 84 | all_surr_loss = np.maximum(threshold - (pred[t] - pred), np.array([0]*len(target)))**2 85 | num = np.sum(all_surr_loss * (target == 1)) 86 | dem = np.sum(all_surr_loss) 87 | 88 | loss += -num/dem 89 | 90 | 91 | return loss/posNum 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Optimization of Areas Under Precision-Recall Curves with Provable Convergence [![pdf](https://img.shields.io/badge/Arxiv-pdf-orange.svg?style=flat)](https://arxiv.org/pdf/2104.08736.pdf) 2 | This is the official implementation of the paper "**Stochastic Optimization of Areas Under Precision-Recall Curves with Provable Convergence**" published on **Neurips2021**. 3 | 4 | 5 | Benchmark Datasets 6 | --------- 7 | **Image**: [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html), [CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html), [Melanoma](https://www.kaggle.com/c/siim-isic-melanoma-classification/data) \ 8 | **Graph**: HIV, MUV, AICures 9 | 10 | Package 11 | ---------- 12 | The main algorithm **SOAP** has been implemented in [LibAUC](https://github.com/Optimization-AI/LibAUC/), with 13 | ```python 14 | >>> from libauc.optimizers import SOAP_SGD, SOAP_ADAM 15 | ``` 16 | You can design your own loss. The following is a usecase: 17 | ```python 18 | pip install libauc 19 | >>> #import library 20 | >>> from libauc.losses import APLoss_SH 21 | >>> from libauc.optimizers import SOAP_SGD, SOAP_ADAM 22 | ... 23 | >>> #define loss 24 | >>> Loss = APLoss_SH() 25 | >>> optimizer = SOAP_ADAM() 26 | ... 27 | >>> #training 28 | >>> model.train() 29 | >>> for index, data, targets in trainloader: 30 | data, targets = data.cuda(), targets.cuda() 31 | logits = model(data) 32 | preds = torch.sigmoid(logits) 33 | loss = Loss(preds, targets, index) 34 | optimizer.zero_grad() 35 | loss.backward() 36 | optimizer.step() 37 | ``` 38 | 39 | Reminder 40 | ---------- 41 | **If you want to download the code that reproducing the reported table results for the Neurips 2021 paper, please go to the Graph/Image subdirectories and refer the inside README.md.** 42 | 43 | 44 | 45 | Citation 46 | --------- 47 | If you find this repo helpful, please cite the following paper: 48 | ``` 49 | @article{qi2021stochastic, 50 | title={Stochastic Optimization of Area Under Precision-Recall Curve for Deep Learning with Provable Convergence}, 51 | author={Qi, Qi and Luo, Youzhi and Xu, Zhao and Ji, Shuiwang and Yang, Tianbao}, 52 | journal={arXiv preprint arXiv:2104.08736}, 53 | year={2021} 54 | } 55 | ``` 56 | 57 | Contact 58 | ---------- 59 | If you have any questions, please contact us @ [Qi Qi](https://qiqi-helloworld.github.io/) [qi-qi@uiowa.edu] , and [Tianbao Yang](https://homepage.cs.uiowa.edu/~tyng/) [tianbao-yang@uiowa.edu] or please open a new issue in the Github. 60 | --------------------------------------------------------------------------------