├── .gitignore ├── LICENSE ├── README.md ├── classification.py ├── imgs ├── framework.png ├── lanl_result.png └── optc_result.png ├── libauc ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cat_vs_dog.py │ ├── chexpert.py │ ├── cifar.py │ ├── melanoma.py │ ├── movielens.py │ └── stl10.py ├── losses │ ├── __init__.py │ ├── auc.py │ ├── constrastive.py │ ├── losses.py │ ├── losses_v1.py │ ├── perf_at_top.py │ ├── ranking.py │ └── surrogate.py ├── metrics │ ├── __init__.py │ └── metrics.py ├── models │ ├── __init__.py │ ├── densenet.py │ ├── gcn.py │ ├── neumf.py │ ├── perceptron.py │ ├── resnet.py │ └── resnet_cifar.py ├── optimizers │ ├── __init__.py │ ├── adam.py │ ├── pdsca.py │ ├── pesg.py │ ├── sgd.py │ ├── soap.py │ ├── song.py │ ├── sopa.py │ ├── sopa_s.py │ └── sota_s.py ├── sampler │ ├── __init__.py │ ├── ranking.py │ └── sampler.py └── utils │ ├── __init__.py │ ├── generator.py │ └── helper.py ├── loaders ├── load_lanl.py ├── load_optc.py ├── load_utils.py ├── split_lanl.py ├── split_optc.py └── tdata.py ├── main.py ├── models ├── argus.py └── recurrent.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | **.csv 131 | **.pkl 132 | Exps 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jiacen Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Argus 2 | This is the repo for the paper "Understanding and Bridging the Gap Between Unsupervised Network Representation Learning and Security Analytics" which is accepted in IEEE Security & Privacy 2024. 3 | There is a [blog](https://c0ldstudy.github.io/posts/GSA/) summarizing the main idea of the paper or you can check the [paper](https://www.computer.org/csdl/proceedings-article/sp/2024/313000a012/1RjE9Q5gQrm) directly. 4 | 5 | ## Setup 6 | 7 | #### Python Environment 8 | Deploy a python environment and download related python packages: 9 | ```bash 10 | # Generate a virtual python environment 11 | conda create -n argus python==3.9 12 | # Activate the python environment 13 | conda activate argus 14 | # Install pytorch, pytorch-geometric, and related packages 15 | pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 -f https://download.pytorch.org/whl/cu111/torch_stable.html 16 | pip install -r requirements.txt 17 | pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.10.1+cu111.html --no-index 18 | ``` 19 | #### Dataset 20 | For [LANL Dataset](https://csr.lanl.gov/data/cyber1/), we use `auth.txt.gz`, `redteam.txt.gz` and `flows.txt.gz`. 21 | 22 | For [OpTC Dataset](https://github.com/FiveDirections/OpTC-data), we use the "START" events related to the "FLOW" objects (i.e., network flows), and the statistics after filtering following the [paper](https://ieeexplore.ieee.org/abstract/document/9789921). The dataset is available in the [link](https://drive.google.com/drive/folders/1pTU-ZcyJbzoB1FuvujXe-ynaUy8O-PVD?usp=sharing). 23 | 24 | The datasets need to be preprocessed by the files `./loaders/split_lanl.py` and `split_optc.py` after setting the dataset paths at the beginning of each file. 25 | 26 | ```bash 27 | # revise the Line 6-9 of ./loaders/split_lanl.py to store preprocessed LANL dataset 28 | RED = '' # Location of redteam.txt 29 | SRC = '' # Location of auth.txt 30 | DST = '' # Directory to save output files to 31 | SRC_DIR = '' # Directory of flows.txt, auth.txt 32 | 33 | cd loaders 34 | python split_lanl.py 35 | 36 | # revise the Line 20 in ./loaders/loal_lanl.py to add the DST path in ./loaders/split_lanl.py 37 | LANL_FOLDER = '' 38 | 39 | 40 | # revise the Line 7-9 of ./loaders/split_optc.py to store preprocessed OpTC dataset 41 | RED = '' # Location of redteam.txt 42 | SRC = '' # Location of auth.txt 43 | DST = '' # Directory to save output files to 44 | 45 | cd loaders 46 | python split_optc.py 47 | 48 | # revise the Line 19 in ./loaders/loal_optc.py to add the DST path in ./loaders/split_optc.py 49 | OPTC_FOLDER = '' 50 | 51 | ``` 52 | 53 | ## System Structure 54 | ![Framework](./imgs/framework.png) 55 | 56 | 57 | 58 | ## Experiments 59 | 60 | **Note:** Argus models use APLoss by default, which is memory-intensive. If you encounter OOM errors during training, try adding `--loss bce` to use BCE loss instead. This may reduce AP performance. 61 | 62 | `python main.py --dataset LANL --delta 1 --lr 0.01` 63 | ![LANL](./imgs/lanl_result.png) 64 | 65 | `python main.py --dataset OPTC --delta 0.1 --lr 0.005 --patience 10` 66 | ![LANL](./imgs/optc_result.png) 67 | 68 | 69 | Thanks for the supporting from [Euler](https://github.com/iHeartGraph/Euler) and [LibAUC](https://github.com/Optimization-AI/LibAUC). 70 | 71 | 72 | ### Citation 73 | ```bibtex 74 | @inproceedings{xu2023understanding, 75 | title={Understanding and Bridging the Gap Between Unsupervised Network Representation Learning and Security Analytics}, 76 | author={Xu, Jiacen and Shu, Xiaokui and Li, Zhou}, 77 | booktitle={2024 IEEE Symposium on Security and Privacy (SP)}, 78 | pages={12--12}, 79 | year={2023}, 80 | organization={IEEE Computer Society} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | import time 5 | import pandas as pd 6 | import numpy as np 7 | from sklearn.metrics import roc_auc_score as auc_score, f1_score, average_precision_score as ap_score, precision_recall_curve, confusion_matrix 8 | from libauc.optimizers import SOAP 9 | import torch 10 | from torch.optim import Adam, Adadelta 11 | from loaders.tdata import TData 12 | from loaders.load_optc import load_optc_dist 13 | from models.argus import Argus, DetectorEncoder 14 | from utils import get_score, get_optimal_cutoff 15 | 16 | TMP_FILE = 'tmp.dat' 17 | 18 | def classification(args, rnn_args, worker_args, OUTPATH, device): 19 | if args.val_times is None: 20 | val = max((args.tr_end - args.tr_start) // 20, args.delta*2) 21 | args.val_start = args.tr_end-val 22 | args.val_end = args.tr_end 23 | args.tr_end = args.val_start 24 | else: 25 | args.val_start = args.val_times[0] 26 | args.val_end = args.val_times[1] 27 | 28 | times = { 29 | 'tr_start': args.tr_start, 30 | 'tr_end': args.tr_end, 31 | 'val_start': args.val_start, 32 | 'val_end': args.val_end, 33 | 'te_times': args.te_times, 34 | 'delta': args.delta 35 | } 36 | global LOAD_FN 37 | LOAD_FN = args.loader 38 | 39 | # Evaluating a pre-trained model, so no need to train 40 | if args.load: 41 | kwargs = { 42 | 'start': None, 43 | 'end': None, 44 | 'use_flows': args.flows, 45 | 'device': device 46 | } 47 | rrefs = args.encoder(LOAD_FN, kwargs, *worker_args) 48 | rnn = args.rnn(*rnn_args) 49 | model = Argus(rnn, rrefs, args.loss, device) 50 | 51 | states = pickle.load(open('./Exps/model_save_'+args.dataset+'.pkl', 'rb')) 52 | model.load_states(*states['states']) 53 | h0 = states['h0'] 54 | tpe = 0 55 | tr_time = 0 56 | 57 | # Building and training a fresh model 58 | else: 59 | kwargs = { 60 | 'start': times['tr_start'], 61 | 'end': times['tr_end'], 62 | 'delta': times['delta'], 63 | 'is_test': False, 64 | 'use_flows': args.flows, 65 | 'device': device} 66 | rrefs = args.encoder(LOAD_FN, kwargs, *worker_args) 67 | tmp = time.time() 68 | model, h0, tpe = train(rrefs, args, rnn_args, device) 69 | tr_time = time.time() - tmp 70 | model = model.to(device) 71 | h0, zs = get_cutoff(model, h0, times, args, args.fpweight, args.flows, device) 72 | stats = [] 73 | 74 | for te_start,te_end in times['te_times']: 75 | test_times = { 76 | 'te_start': te_start, 77 | 'te_end': te_end, 78 | 'delta': times['delta'] 79 | } 80 | st = test(model, h0, test_times, rrefs, args.flows, device, args) 81 | for s in st: 82 | s['TPE'] = tpe 83 | stats += st 84 | 85 | pickle.dump(stats, open(OUTPATH+TMP_FILE, 'wb+'), protocol=pickle.HIGHEST_PROTOCOL) 86 | 87 | # Retrieve stats, and cleanup temp file 88 | stats = pickle.load(open(OUTPATH+TMP_FILE, 'rb')) 89 | return stats 90 | 91 | 92 | 93 | def train(rrefs, args, rnn_args, device): 94 | rnn_constructor = args.rnn 95 | dataset = args.dataset 96 | rnn = rnn_constructor(*rnn_args) 97 | model = Argus(rnn, rrefs, args.loss, device) 98 | model = model.to(device) 99 | # opt = torch.optim.Adam(model.parameters(), lr=args.lr) 100 | opt = SOAP(model.parameters(), lr=args.lr, mode='adam', weight_decay=0.0) 101 | times = [] 102 | best = (model.save_states(), 0) 103 | no_progress = 0 104 | for e in range(args.epochs): 105 | # Get loss and send backward 106 | model.train() 107 | st = time.time() 108 | zs = model.forward(TData.TRAIN) 109 | loss = model.loss_fn(zs, TData.TRAIN, nratio=args.nratio, device=device, encoder_name=args.encoder_name) 110 | loss.backward() 111 | opt.step() 112 | elapsed = time.time()-st 113 | times.append(elapsed) 114 | l = loss.sum() 115 | print('[%d] Loss %0.4f %0.2fs' % (e, l.item(), elapsed)) 116 | 117 | # Get validation info to prevent overfitting 118 | model.eval() 119 | with torch.no_grad(): 120 | zs = model.forward(TData.TRAIN, no_grad=True) 121 | p,n = model.score_edges(zs, TData.VAL) 122 | auc,ap = get_score(p,n) 123 | print("\tValidation: AP: %0.4f AUC: %0.4f" % (ap, auc), end='') 124 | 125 | # Either incriment or update early stopping criteria 126 | tot = auc+ap 127 | if tot > best[1]: 128 | print('*\n') 129 | best = (model.save_states(), tot) 130 | no_progress = 0 131 | else: 132 | print('\n') 133 | if e >= 1: 134 | no_progress += 1 135 | if no_progress == args.patience: 136 | print("Early stopping!") 137 | break 138 | 139 | model.load_states(*best[0]) 140 | 141 | # Get the best possible h0 to eval with 142 | zs, h0 = model(TData.TEST, include_h=True) 143 | states = {'states': best[0], 'h0': h0} 144 | f = open('./Exps/model_save_'+dataset+'.pkl', 'wb+') 145 | pickle.dump(states, f, protocol=pickle.HIGHEST_PROTOCOL) 146 | tpe = sum(times)/len(times) 147 | print("Exiting train loop") 148 | print("Avg TPE: %0.4fs" % tpe) 149 | return model, h0, tpe 150 | 151 | 152 | def get_cutoff(model, h0, times, args, lambda_param, use_flows, device): 153 | Encoder = DetectorEncoder 154 | ld_args = { 155 | 'start': times['val_start'], 156 | 'end': times['val_end'], 157 | 'delta': times['delta'], 158 | 'is_test': False, 159 | 'use_flows': use_flows 160 | } 161 | 162 | Encoder.load_new_data(model.gcns, LOAD_FN, ld_args) 163 | # Then generate GCN embeds 164 | model.eval() 165 | 166 | zs = Encoder.forward(model.gcns.module, TData.ALL, True).to(device) 167 | # Finally, generate actual embeds 168 | with torch.no_grad(): 169 | zs, h0 = model.rnn(zs, h0, include_h=True) 170 | 171 | # Then score them 172 | p, n = Encoder.score_edges(model.gcns, zs, TData.ALL, args.nratio) 173 | # Finally, figure out the optimal cutoff score 174 | p = p.cpu() 175 | n = n.cpu() 176 | model.cutoff = get_optimal_cutoff(p,n,fw=lambda_param) 177 | return h0, zs[-1] 178 | 179 | def test(model, h0, times, rrefs, use_flows, device, args): 180 | Encoder = DetectorEncoder 181 | # Load train data into workers 182 | ld_args = {'start': times['te_start'], 183 | 'end': times['te_end'], 184 | 'delta': times['delta'], 185 | 'is_test': True, 186 | 'use_flows': use_flows} 187 | 188 | print("Loading test data") 189 | 190 | Encoder.load_new_data(rrefs, LOAD_FN, ld_args) 191 | stats = [] 192 | model = model.to(device) 193 | print("Embedding Test Data...") 194 | test_tmp = time.time() 195 | with torch.no_grad(): 196 | model.eval() 197 | s = time.time() 198 | zs = model.forward(TData.TEST, h0=h0, no_grad=True) 199 | ctime = time.time()-s 200 | # Scores all edges and matches them with name/timestamp 201 | scores, labels, weights = model.score_all(zs) 202 | test_time = time.time() - test_tmp 203 | 204 | stats.append(score_stats(args,scores, labels, weights, model.cutoff, ctime)) 205 | return stats 206 | 207 | def score_stats(args, scores, labels, weights, cutoff, ctime): 208 | scores = np.concatenate(scores, axis=0) 209 | labels = np.concatenate(labels, axis=0).clip(max=1) 210 | weights = np.concatenate(weights, axis=0) 211 | 212 | # Classify using cutoff from earlier 213 | classified = np.zeros(labels.shape) 214 | classified[scores <= cutoff] = 1 215 | 216 | # Calculate TPR 217 | p = classified[labels==1] 218 | tpr = p.mean() 219 | tp = p.sum() 220 | del p 221 | 222 | # Calculate FPR 223 | f = classified[labels==0] 224 | fp = f.sum() 225 | fpr = f.mean() 226 | del f 227 | 228 | cm = confusion_matrix(labels, classified, labels=[0,1]) 229 | tn, fp, fn, tp = cm.ravel() 230 | print("tn, fp, fn, tp: ", tn, fp, fn, tp) 231 | scores = 1-scores 232 | 233 | 234 | # Get metrics 235 | auc = auc_score(labels, scores) 236 | ap = ap_score(labels, scores) 237 | f1 = f1_score(labels, classified) 238 | 239 | print("Learned Cutoff %0.4f" % cutoff) 240 | print("TPR: %0.4f, FPR: %0.4f" % (tpr, fpr)) 241 | print("TP: %d FP: %d" % (tp, fp)) 242 | print("F1: %0.8f" % f1) 243 | print("AUC: %0.4f AP: %0.4f\n" % (auc,ap)) 244 | print("FwdTime", ctime, ) 245 | title = "test" 246 | return { 247 | 'Model': title, 248 | 'TPR':tpr.item(), 249 | 'FPR':fpr.item(), 250 | 'TP':tp.item(), 251 | 'FP':fp.item(), 252 | 'F1':f1, 253 | 'AUC':auc, 254 | 'AP': ap, 255 | 'FwdTime':ctime, 256 | 'tn': tn, 257 | 'fp': fp, 258 | 'fn': fn, 259 | 'tp': tp 260 | } 261 | 262 | 263 | -------------------------------------------------------------------------------- /imgs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/C0ldstudy/Argus/0bd70e6f486fe0c787a40db17a0530b2c96732c9/imgs/framework.png -------------------------------------------------------------------------------- /imgs/lanl_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/C0ldstudy/Argus/0bd70e6f486fe0c787a40db17a0530b2c96732c9/imgs/lanl_result.png -------------------------------------------------------------------------------- /imgs/optc_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/C0ldstudy/Argus/0bd70e6f486fe0c787a40db17a0530b2c96732c9/imgs/optc_result.png -------------------------------------------------------------------------------- /libauc/__init__.py: -------------------------------------------------------------------------------- 1 | __name__ = "libauc" 2 | __author__ = 'Zhuoning Yuan' 3 | __contact__ = 'yzhuoning@gmail.com' 4 | __version__ = '1.2.0' 5 | -------------------------------------------------------------------------------- /libauc/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import * 2 | from .cifar import * 3 | from .stl10 import * 4 | from .cat_vs_dog import * 5 | from .chexpert import * 6 | from .melanoma import * 7 | from .movielens import * 8 | -------------------------------------------------------------------------------- /libauc/datasets/cat_vs_dog.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 5 | # reference: https://pytorch.org/vision/0.8/_modules/torchvision/datasets/cifar.html#CIFAR10 6 | # Dataset credit goes to https://www.microsoft.com/en-us/download/details.aspx?id=54765 7 | 8 | def _check_integrity(root, train_list, test_list, base_folder): 9 | for fentry in (train_list + test_list): 10 | filename, md5 = fentry[0], fentry[1] 11 | fpath = os.path.join(root, base_folder, filename) 12 | if not check_integrity(fpath, md5): 13 | return False 14 | print('Files already downloaded and verified') 15 | return True 16 | 17 | def load_data(data_path, label_path): 18 | data = np.load(data_path) 19 | targets = np.load(label_path) 20 | return data, targets 21 | 22 | def CAT_VS_DOG(root='./data/', train=True): 23 | base_folder = "cat_vs_dog" 24 | url = 'https://homepage.divms.uiowa.edu/~zhuoning/datasets/cat_vs_dog.tar.gz' 25 | filename = "cat_vs_dog.tar.gz" 26 | train_list = [ 27 | ['cat_vs_dog_data.npy', None], 28 | ['cat_vs_dog_label.npy', None], 29 | ] 30 | test_list = [] 31 | 32 | # download dataset 33 | if not _check_integrity(root, train_list, test_list, base_folder): 34 | download_and_extract_archive(url=url, download_root=root, filename=filename) 35 | 36 | # train or test set 37 | if train: 38 | data_path = os.path.join(root, base_folder, train_list[0][0]) 39 | label_path = os.path.join(root, base_folder, train_list[1][0]) 40 | data, targets = load_data(data_path, label_path) 41 | data = data[:-5000] 42 | targets = targets[:-5000] 43 | else: 44 | data_path = os.path.join(root, base_folder, train_list[0][0]) 45 | label_path = os.path.join(root, base_folder, train_list[1][0]) 46 | data, targets = load_data(data_path, label_path) 47 | data = data[-5000:] 48 | targets = targets[-5000:] 49 | 50 | return data, targets 51 | 52 | if __name__ == '__main__': 53 | data, targets = CAT_VS_DOG('./data/', train=True) 54 | print (data.shape, targets.shape) 55 | data, targets = CAT_VS_DOG('./data/', train=False) 56 | print (data.shape, targets.shape) 57 | -------------------------------------------------------------------------------- /libauc/datasets/chexpert.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import torchvision.transforms as tfs 5 | import cv2 6 | from PIL import Image 7 | import pandas as pd 8 | 9 | class CheXpert(Dataset): 10 | ''' 11 | Reference: 12 | Large-scale Robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification 13 | Zhuoning Yuan, Yan Yan, Milan Sonka, Tianbao Yang 14 | International Conference on Computer Vision (ICCV 2021) 15 | ''' 16 | def __init__(self, 17 | csv_path, 18 | image_root_path='', 19 | image_size=320, 20 | class_index=0, 21 | use_frontal=True, 22 | use_upsampling=True, 23 | flip_label=False, 24 | shuffle=True, 25 | seed=123, 26 | verbose=True, 27 | transforms=None, 28 | upsampling_cols=['Cardiomegaly', 'Consolidation'], 29 | train_cols=['Cardiomegaly', 'Edema', 'Consolidation', 'Atelectasis', 'Pleural Effusion'], 30 | mode='train'): 31 | 32 | 33 | # load data from csv 34 | self.df = pd.read_csv(csv_path) 35 | self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0-small/', '', regex=True) 36 | self.df['Path'] = self.df['Path'].str.replace('CheXpert-v1.0/', '', regex=True) 37 | if use_frontal: 38 | self.df = self.df[self.df['Frontal/Lateral'] == 'Frontal'] 39 | 40 | # upsample selected cols 41 | if use_upsampling: 42 | assert isinstance(upsampling_cols, list), 'Input should be list!' 43 | sampled_df_list = [] 44 | for col in upsampling_cols: 45 | print ('Upsampling %s...'%col) 46 | sampled_df_list.append(self.df[self.df[col] == 1]) 47 | self.df = pd.concat([self.df] + sampled_df_list, axis=0) 48 | 49 | 50 | # impute missing values 51 | for col in train_cols: 52 | if col in ['Edema', 'Atelectasis']: 53 | self.df[col].replace(-1, 1, inplace=True) 54 | self.df[col].fillna(0, inplace=True) 55 | elif col in ['Cardiomegaly','Consolidation', 'Pleural Effusion']: 56 | self.df[col].replace(-1, 0, inplace=True) 57 | self.df[col].fillna(0, inplace=True) 58 | elif col in ['No Finding', 'Enlarged Cardiomediastinum', 'Lung Opacity', 'Lung Lesion', 'Pneumonia', 'Pneumothorax', 'Pleural Other','Fracture','Support Devices']: # other labels 59 | self.df[col].replace(-1, 0, inplace=True) 60 | self.df[col].fillna(0, inplace=True) 61 | else: 62 | self.df[col].fillna(0, inplace=True) 63 | 64 | self._num_images = len(self.df) 65 | 66 | # 0 --> -1 67 | if flip_label and class_index != -1: # In multi-class mode we disable this option! 68 | self.df.replace(0, -1, inplace=True) 69 | 70 | # shuffle data 71 | if shuffle: 72 | data_index = list(range(self._num_images)) 73 | np.random.seed(seed) 74 | np.random.shuffle(data_index) 75 | self.df = self.df.iloc[data_index] 76 | 77 | 78 | #assert class_index in [-1, 0, 1, 2, 3, 4], 'Out of selection!' 79 | assert image_root_path != '', 'You need to pass the correct location for the dataset!' 80 | 81 | if class_index == -1: # 5 classes 82 | if verbose: 83 | print ('Multi-label mode: True, Number of classes: [%d]'%len(train_cols)) 84 | print ('-'*30) 85 | self.select_cols = train_cols 86 | self.value_counts_dict = {} 87 | for class_key, select_col in enumerate(train_cols): 88 | class_value_counts_dict = self.df[select_col].value_counts().to_dict() 89 | self.value_counts_dict[class_key] = class_value_counts_dict 90 | else: 91 | self.select_cols = [train_cols[class_index]] # this var determines the number of classes 92 | self.value_counts_dict = self.df[self.select_cols[0]].value_counts().to_dict() 93 | 94 | self.mode = mode 95 | self.class_index = class_index 96 | self.image_size = image_size 97 | self.transforms = transforms 98 | 99 | self._images_list = [image_root_path+path for path in self.df['Path'].tolist()] 100 | if class_index != -1: 101 | self.targets = self.df[train_cols].values[:, class_index].tolist() 102 | else: 103 | self.targets = self.df[train_cols].values.tolist() 104 | 105 | if True: 106 | if class_index != -1: 107 | if flip_label: 108 | self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[-1]+self.value_counts_dict[1]) 109 | if verbose: 110 | print ('-'*30) 111 | print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[-1] )) 112 | print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio )) 113 | print ('-'*30) 114 | else: 115 | self.imratio = self.value_counts_dict[1]/(self.value_counts_dict[0]+self.value_counts_dict[1]) 116 | if verbose: 117 | print ('-'*30) 118 | print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0] )) 119 | print ('%s(C%s): imbalance ratio is %.4f'%(self.select_cols[0], class_index, self.imratio )) 120 | print ('-'*30) 121 | else: 122 | imratio_list = [] 123 | for class_key, select_col in enumerate(train_cols): 124 | try: 125 | imratio = self.value_counts_dict[class_key][1]/(self.value_counts_dict[class_key][0]+self.value_counts_dict[class_key][1]) 126 | except: 127 | if len(self.value_counts_dict[class_key]) == 1 : 128 | only_key = list(self.value_counts_dict[class_key].keys())[0] 129 | if only_key == 0: 130 | self.value_counts_dict[class_key][1] = 0 131 | imratio = 0 # no postive samples 132 | else: 133 | self.value_counts_dict[class_key][1] = 0 134 | imratio = 1 # no negative samples 135 | 136 | imratio_list.append(imratio) 137 | if verbose: 138 | #print ('-'*30) 139 | print('Found %s images in total, %s positive images, %s negative images'%(self._num_images, self.value_counts_dict[class_key][1], self.value_counts_dict[class_key][0] )) 140 | print ('%s(C%s): imbalance ratio is %.4f'%(select_col, class_key, imratio )) 141 | print () 142 | #print ('-'*30) 143 | self.imratio = np.mean(imratio_list) 144 | self.imratio_list = imratio_list 145 | 146 | 147 | @property 148 | def class_counts(self): 149 | return self.value_counts_dict 150 | 151 | @property 152 | def imbalance_ratio(self): 153 | return self.imratio 154 | 155 | @property 156 | def num_classes(self): 157 | return len(self.select_cols) 158 | 159 | @property 160 | def data_size(self): 161 | return self._num_images 162 | 163 | def image_augmentation(self, image): 164 | img_aug = tfs.Compose([tfs.RandomAffine(degrees=(-15, 15), translate=(0.05, 0.05), scale=(0.95, 1.05), fill=128)]) # pytorch 3.7: fillcolor --> fill 165 | image = img_aug(image) 166 | return image 167 | 168 | def __len__(self): 169 | return self._num_images 170 | 171 | def __getitem__(self, idx): 172 | 173 | image = cv2.imread(self._images_list[idx], 0) 174 | image = Image.fromarray(image) 175 | if self.mode == 'train' : 176 | if self.transforms is None: 177 | image = self.image_augmentation(image) 178 | else: 179 | image = self.transforms(image) 180 | image = np.array(image) 181 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 182 | 183 | # resize and normalize; e.g., ToTensor() 184 | image = cv2.resize(image, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR) 185 | image = image/255.0 186 | __mean__ = np.array([[[0.485, 0.456, 0.406]]]) 187 | __std__ = np.array([[[0.229, 0.224, 0.225] ]]) 188 | image = (image-__mean__)/__std__ 189 | image = image.transpose((2, 0, 1)).astype(np.float32) 190 | if self.class_index != -1: # multi-class mode 191 | label = np.array(self.targets[idx]).reshape(-1).astype(np.float32) 192 | else: 193 | label = np.array(self.targets[idx]).reshape(-1).astype(np.float32) 194 | return image, label 195 | 196 | 197 | if __name__ == '__main__': 198 | root = '../chexpert/dataset/CheXpert-v1.0-small/' 199 | traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=True, use_frontal=True, image_size=320, mode='train', class_index=0) 200 | testSet = CheXpert(csv_path=root+'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=320, mode='valid', class_index=0) 201 | trainloader = torch.utils.data.DataLoader(traindSet, batch_size=32, num_workers=2, drop_last=True, shuffle=True) 202 | testloader = torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=2, drop_last=False, shuffle=False) 203 | 204 | 205 | -------------------------------------------------------------------------------- /libauc/datasets/cifar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import pickle 4 | import numpy as np 5 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive 6 | # reference: https://pytorch.org/vision/0.8/_modules/torchvision/datasets/cifar.html#CIFAR10 7 | 8 | def _check_integrity(root, train_list, test_list, base_folder): 9 | for fentry in (train_list + test_list): 10 | filename, md5 = fentry[0], fentry[1] 11 | fpath = os.path.join(root, base_folder, filename) 12 | if not check_integrity(fpath, md5): 13 | return False 14 | print('Files already downloaded and verified') 15 | return True 16 | 17 | def CIFAR10(root='./data', train=True): 18 | base_folder = "cifar-10-batches-py" 19 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 20 | filename = "cifar-10-python.tar.gz" 21 | train_list = [ 22 | ["data_batch_1", "c99cafc152244af753f735de768cd75f"], 23 | ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], 24 | ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], 25 | ["data_batch_4", "634d18415352ddfa80567beed471001a"], 26 | ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], 27 | ] 28 | test_list = [ 29 | ["test_batch", "40351d587109b95175f43aff81a1287e"], 30 | ] 31 | # download dataset 32 | if not _check_integrity(root, train_list, test_list, base_folder): 33 | download_and_extract_archive(url=url, download_root=root, filename=filename) 34 | 35 | # train or test set 36 | if train: 37 | downloaded_list = train_list 38 | else: 39 | downloaded_list = test_list 40 | 41 | data = [] 42 | targets = [] 43 | 44 | # now load the picked numpy arrays 45 | for file_name, checksum in downloaded_list: 46 | file_path = os.path.join(root, base_folder, file_name) 47 | with open(file_path, "rb") as f: 48 | entry = pickle.load(f, encoding="latin1") 49 | data.append(entry["data"]) 50 | if "labels" in entry: 51 | targets.extend(entry["labels"]) 52 | else: 53 | targets.extend(entry["fine_labels"]) 54 | 55 | # reshape data and targets 56 | data = np.vstack(data).reshape(-1, 3, 32, 32) 57 | data = data.transpose((0, 2, 3, 1)) 58 | targets = np.array(targets).astype(np.int32) 59 | return data, targets 60 | 61 | 62 | def CIFAR100(root='./data', train=True): 63 | base_folder = "cifar-100-python" 64 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 65 | filename = "cifar-100-python.tar.gz" 66 | tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" 67 | train_list = [ 68 | ["train", "16019d7e3df5f24257cddd939b257f8d"], 69 | ] 70 | 71 | test_list = [ 72 | ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], 73 | ] 74 | 75 | # download dataset 76 | if not _check_integrity(root, train_list, test_list, base_folder): 77 | download_and_extract_archive(url=url, download_root=root, filename=filename) 78 | 79 | # train or test set 80 | if train: 81 | downloaded_list = train_list 82 | else: 83 | downloaded_list = test_list 84 | 85 | data = [] 86 | targets = [] 87 | 88 | # now load the picked numpy arrays 89 | for file_name, checksum in downloaded_list: 90 | file_path = os.path.join(root, base_folder, file_name) 91 | with open(file_path, "rb") as f: 92 | entry = pickle.load(f, encoding="latin1") 93 | data.append(entry["data"]) 94 | if "labels" in entry: 95 | targets.extend(entry["labels"]) 96 | else: 97 | targets.extend(entry["fine_labels"]) 98 | 99 | # reshape data and targets 100 | data = np.vstack(data).reshape(-1, 3, 32, 32) 101 | data = data.transpose((0, 2, 3, 1)) 102 | targets = np.array(targets).astype(np.int32) 103 | return data, targets 104 | 105 | 106 | if __name__ == '__main__': 107 | # return numpy array 108 | data, targets = CIFAR10(root='./data', train=True) 109 | data, targets = CIFAR100(root='./data', train=True) 110 | 111 | -------------------------------------------------------------------------------- /libauc/datasets/melanoma.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from PIL import Image 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch 6 | import os 7 | 8 | 9 | def get_augmentations_v1(image_size=256, is_test=True): 10 | import albumentations as A 11 | from albumentations.pytorch.transforms import ToTensor 12 | ''' 13 | https://www.kaggle.com/vishnus/a-simple-pytorch-starter-code-single-fold-93 14 | ''' 15 | imagenet_stats = {'mean':[0.485, 0.456, 0.406], 'std':[0.229, 0.224, 0.225]} 16 | train_tfms = A.Compose([ 17 | A.Cutout(p=0.5), 18 | A.RandomRotate90(p=0.5), 19 | A.Flip(p=0.5), 20 | A.OneOf([ 21 | A.RandomBrightnessContrast(brightness_limit=0.2, 22 | contrast_limit=0.2, 23 | ), 24 | A.HueSaturationValue( 25 | hue_shift_limit=20, 26 | sat_shift_limit=50, 27 | val_shift_limit=50) 28 | ], p=0.5), 29 | A.OneOf([ 30 | A.IAAAdditiveGaussianNoise(), 31 | A.GaussNoise(), 32 | ], p=0.5), 33 | A.OneOf([ 34 | A.MotionBlur(p=0.2), 35 | A.MedianBlur(blur_limit=3, p=0.1), 36 | A.Blur(blur_limit=3, p=0.1), 37 | ], p=0.5), 38 | A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.5), 39 | A.OneOf([ 40 | A.OpticalDistortion(p=0.3), 41 | A.GridDistortion(p=0.1), 42 | A.IAAPiecewiseAffine(p=0.3), 43 | ], p=0.5), 44 | ToTensor(normalize=imagenet_stats) 45 | ]) 46 | 47 | test_tfms = A.Compose([ 48 | ToTensor(normalize=imagenet_stats) 49 | ]) 50 | if is_test: 51 | return test_tfms 52 | else: 53 | return train_tfms 54 | 55 | class Melanoma(Dataset): 56 | ''' 57 | Reference: 58 | - Large-scale Robust deep auc maximization: A new surrogate loss and empirical studies on medical image classification 59 | Zhuoning Yuan, Yan Yan, Milan Sonka, Tianbao Yang 60 | International Conference on Computer Vision (ICCV 2021) 61 | - https://www.kaggle.com/cdeotte/jpeg-melanoma-256x256 62 | - https://www.kaggle.com/vishnus/a-simple-pytorch-starter-code-single-fold-93 63 | - https://www.kaggle.com/haqishen/1st-place-soluiton-code-small-ver 64 | ''' 65 | def __init__(self, root, test_size=0.2, is_test=False, transforms=None): 66 | assert os.path.isfile(root + '/train.csv'), 'There is no train.csv in %s!'%root 67 | self.data = pd.read_csv(root + '/train.csv') 68 | self.train_df, self.test_df = self.get_train_val_split(self.data, test_size=test_size) 69 | self.is_test = is_test 70 | 71 | if is_test: 72 | self.df = self.test_df.copy() 73 | else: 74 | self.df = self.train_df.copy() 75 | 76 | self._num_images = len(self.df) 77 | self.value_counts_dict = self.df.target.value_counts().to_dict() 78 | self.imratio = self.value_counts_dict[1]/self.value_counts_dict[0] 79 | print ('Found %s image in total, %s postive images, %s negative images.'%(self._num_images, self.value_counts_dict[1], self.value_counts_dict[0])) 80 | 81 | # get path 82 | dir_name = 'train' 83 | self._images_list = [f"{root}/{dir_name}/{img}.jpg" for img in self.df.image_name] 84 | self._labels_list = self.df.target.values.tolist() 85 | if not transforms: 86 | self.transforms = get_augmentations_v1(is_test=is_test) 87 | else: 88 | self.transforms = transforms(is_test=is_test) 89 | 90 | @property 91 | def class_counts(self): 92 | return self.value_counts_dict 93 | 94 | @property 95 | def imbalance_ratio(self): 96 | return self.imratio 97 | 98 | @property 99 | def num_classes(self): 100 | return 1 101 | 102 | def get_train_val_split(self, df, test_size=0.2): 103 | print ('test set split is %s'%test_size) 104 | #Remove Duplicates 105 | df = df[df.tfrecord != -1].reset_index(drop=True) 106 | #We are splitting data based on triple stratified kernel provided here https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/165526 107 | num_tfrecords = len(df.tfrecord.unique()) 108 | train_tf_records = list(range(len(df.tfrecord.unique())))[:-int(num_tfrecords*test_size)] 109 | split_cond = df.tfrecord.apply(lambda x: x in train_tf_records) 110 | train_df = df[split_cond].reset_index() 111 | valid_df = df[~split_cond].reset_index() 112 | return train_df, valid_df 113 | 114 | def __len__(self): 115 | return self.df.shape[0] 116 | 117 | def __getitem__(self,idx): 118 | img_path = self._images_list[idx] 119 | image = Image.open(img_path) 120 | image = self.transforms(**{"image": np.array(image)})["image"] 121 | target = torch.tensor([self._labels_list[idx]],dtype=torch.float32) 122 | return image, target 123 | 124 | if __name__ == '__main__': 125 | trainSet = Melanoma(root='./datasets/256x256/', is_test=False, test_size=0.2) 126 | testSet = Melanoma(root='./datasets/256x256/', is_test=True, test_size=0.2) 127 | bs = 128 128 | train_dl = DataLoader(dataset=trainSet,batch_size=bs,shuffle=True, num_workers=0) 129 | 130 | 131 | -------------------------------------------------------------------------------- /libauc/datasets/stl10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import numpy as np 4 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive, verify_str_arg 5 | # reference: https://pytorch.org/vision/0.8/_modules/torchvision/datasets/stl10.html#STL10 6 | 7 | def load_file(data_file, labels_file=None): 8 | labels = None 9 | if labels_file: 10 | with open(labels_file, 'rb') as f: 11 | labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based 12 | with open(data_file, 'rb') as f: 13 | # read whole file in uint8 chunks 14 | everything = np.fromfile(f, dtype=np.uint8) 15 | images = np.reshape(everything, (-1, 3, 96, 96)) 16 | images = np.transpose(images, (0, 1, 3, 2)) 17 | 18 | return images, labels 19 | 20 | def _check_integrity(root, train_list, test_list, base_folder): 21 | for fentry in (train_list + test_list): 22 | filename, md5 = fentry[0], fentry[1] 23 | fpath = os.path.join(root, base_folder, filename) 24 | if not check_integrity(fpath, md5): 25 | return False 26 | print('Files already downloaded and verified') 27 | return True 28 | 29 | def STL10(root='./data/', split='train'): 30 | base_folder = 'stl10_binary' 31 | url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" 32 | filename = "stl10_binary.tar.gz" 33 | class_names_file = 'class_names.txt' 34 | folds_list_file = 'fold_indices.txt' 35 | train_list = [ 36 | ['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'], 37 | ['train_y.bin', '5a34089d4802c674881badbb80307741'], 38 | ['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4'] 39 | ] 40 | 41 | test_list = [ 42 | ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], 43 | ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] 44 | ] 45 | splits = ('train', 'train+unlabeled', 'unlabeled', 'test') 46 | 47 | # download dataset 48 | fpath = os.path.join(root, base_folder, filename) 49 | if not _check_integrity(root, train_list, test_list, base_folder): 50 | download_and_extract_archive(url=url, download_root=root, filename=filename) 51 | 52 | # choose which set to load 53 | if split=='train': 54 | path_to_data = os.path.join(root, base_folder, train_list[0][0]) 55 | path_to_labels = os.path.join(root, base_folder, train_list[1][0]) 56 | data, targets = load_file(path_to_data, path_to_labels) 57 | elif split == 'unlabeled': 58 | path_to_data = os.path.join(root, base_folder, train_list[2][0]) 59 | data, _ = load_file(path_to_data) 60 | targets = np.asarray([-1] * data.shape[0]) 61 | elif split == 'test': 62 | path_to_data = os.path.join(root, base_folder, test_list[0][0]) 63 | path_to_labels = os.path.join(root, base_folder, test_list[1][0]) 64 | data, targets = load_file(path_to_data, path_to_labels) 65 | else: 66 | raise ValueError('Out of option!') 67 | 68 | return data, targets 69 | 70 | 71 | 72 | if __name__ == '__main__': 73 | data, targets = STL10(root='./data/', split='test') # return numpy array 74 | print (data.shape, targets.shape) 75 | -------------------------------------------------------------------------------- /libauc/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .auc import * 2 | from .ranking import * 3 | from .constrastive import * 4 | from .losses import * 5 | from .losses_v1 import * 6 | from .perf_at_top import * 7 | -------------------------------------------------------------------------------- /libauc/losses/constrastive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import distributed as dist 4 | from torch.nn import functional as F 5 | 6 | 7 | class GlobalContrastiveLoss(nn.Module): 8 | """For MoCov3 9 | """ 10 | def __init__(self, N=1.2e6, T=1.0): 11 | super(GlobalContrastiveLoss, self).__init__() 12 | self.u = torch.zeros(N).reshape(-1, 1) 13 | self.T = T 14 | 15 | def forward(self, q, k, index, gamma): 16 | # normalize 17 | q = nn.functional.normalize(q, dim=1) 18 | k = nn.functional.normalize(k, dim=1) 19 | # gather all targets 20 | k = concat_all_gather(k) 21 | N_lagre = k.shape[0] # batch size of total GPUs 22 | # Einstein sum is more intuitive 23 | logits = torch.einsum('nc,mc->nm', [q, k]) 24 | N = logits.shape[0] # batch size per GPU 25 | labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() 26 | 27 | # compute negative weights 28 | labels_one_hot = F.one_hot(labels, N_lagre) 29 | neg_mask = 1-labels_one_hot 30 | neg_logits = torch.exp(logits/self.T)*neg_mask 31 | u = (1 - gamma) * self.u[index].cuda() + gamma * torch.sum(neg_logits, dim=1, keepdim=True)/(N_lagre-1) 32 | p_neg_weights = (neg_logits/u).detach_() 33 | 34 | # gather all u & index from all machines 35 | u_all = concat_all_gather(u) 36 | index_all = concat_all_gather(index) 37 | self.u[index_all] = u_all.cpu() 38 | 39 | # compute loss 40 | expsum_neg_logits = torch.sum(p_neg_weights*logits, dim=1, keepdim=True)/(N_lagre-1) 41 | normalized_logits = logits - expsum_neg_logits 42 | loss = -torch.sum(labels_one_hot * normalized_logits, dim=1) 43 | 44 | return loss.mean()* (2 * self.T) 45 | 46 | 47 | @torch.no_grad() 48 | def concat_all_gather(tensor): 49 | """ 50 | Performs all_gather operation on the provided tensors. 51 | *** Warning ***: torch.distributed.all_gather has no gradient. 52 | """ 53 | tensors_gather = [torch.ones_like(tensor) 54 | for _ in range(torch.distributed.get_world_size())] 55 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 56 | 57 | output = torch.cat(tensors_gather, dim=0) 58 | return output 59 | 60 | 61 | # alias 62 | GCLoss = GlobalContrastiveLoss 63 | -------------------------------------------------------------------------------- /libauc/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class CrossEntropyLoss(torch.nn.Module): 6 | """ 7 | Cross Entropy Loss with Sigmoid Function 8 | Reference: 9 | https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html 10 | """ 11 | def __init__(self): 12 | super(CrossEntropyLoss, self).__init__() 13 | self.criterion = F.binary_cross_entropy_with_logits # with sigmoid 14 | 15 | def forward(self, y_pred, y_true): # TODO: handle the tensor shapes 16 | if len(y_pred.shape) == 1: 17 | y_pred = y_pred.reshape(-1, 1) 18 | if len(y_true.shape) == 1: 19 | y_true = y_true.reshape(-1, 1) 20 | return self.criterion(y_pred, y_true) 21 | 22 | class FocalLoss(torch.nn.Module): 23 | """ 24 | Focal Loss 25 | Reference: 26 | https://amaarora.github.io/2020/06/29/FocalLoss.html 27 | """ 28 | def __init__(self, alpha=.25, gamma=2): 29 | super(FocalLoss, self).__init__() 30 | self.alpha = torch.tensor([alpha, 1-alpha]).cuda() 31 | self.gamma = gamma 32 | 33 | def forward(self, inputs, targets): 34 | if len(y_pred.shape) == 1: 35 | y_pred = y_pred.reshape(-1, 1) 36 | if len(y_true.shape) == 1: 37 | y_true = y_true.reshape(-1, 1) 38 | 39 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 40 | targets = targets.type(torch.long) 41 | at = self.alpha.gather(0, targets.data.view(-1)) 42 | pt = torch.exp(-BCE_loss) 43 | F_loss = at*(1-pt)**self.gamma * BCE_loss 44 | return F_loss.mean() 45 | 46 | 47 | -------------------------------------------------------------------------------- /libauc/losses/losses_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class AUCMLoss_V1(torch.nn.Module): 6 | """ 7 | AUCM Loss with squared-hinge function: a novel loss function to directly optimize AUROC 8 | 9 | inputs: 10 | margin: margin term for AUCM loss, e.g., m in [0, 1] 11 | imratio: imbalance ratio, i.e., the ratio of number of postive samples to number of total samples 12 | outputs: 13 | loss value 14 | 15 | Reference: 16 | Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 17 | Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 18 | International Conference on Computer Vision (ICCV 2021) 19 | Link: 20 | https://arxiv.org/abs/2012.03173 21 | """ 22 | def __init__(self, margin=None, imratio=None, device=None): 23 | super(AUCMLoss_V1, self).__init__() 24 | if not device: 25 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 26 | else: 27 | self.device = device 28 | self.margin = margin 29 | self.p = imratio 30 | self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 31 | self.b = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 32 | self.alpha = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 33 | 34 | def forward(self, y_pred, y_true, auto=True): 35 | if auto: 36 | self.p = (y_true==1).sum()/y_true.shape[0] 37 | if len(y_pred.shape) == 1: 38 | y_pred = y_pred.reshape(-1, 1) 39 | if len(y_true.shape) == 1: 40 | y_true = y_true.reshape(-1, 1) 41 | loss = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \ 42 | self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float()) + \ 43 | 2*self.alpha*(self.p*(1-self.p) + \ 44 | torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \ 45 | self.p*(1-self.p)*self.alpha**2 46 | return loss 47 | 48 | 49 | class AUCM_MultiLabel_V1(torch.nn.Module): 50 | """ 51 | Reference: 52 | Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 53 | Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 54 | International Conference on Computer Vision (ICCV 2021) 55 | Link: 56 | https://arxiv.org/abs/2012.03173 57 | """ 58 | def __init__(self, margin=1.0, imratio=None, num_classes=10, device=None): 59 | super(AUCM_MultiLabel_V1, self).__init__() 60 | if not device: 61 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 62 | else: 63 | self.device = device 64 | self.margin = margin 65 | self.p = imratio 66 | self.num_classes = num_classes 67 | if self.p: 68 | assert len(imratio)==num_classes, 'Length of imratio needs to be same as num_classes!' 69 | else: 70 | self.p = [0.0]*num_classes 71 | self.a = torch.zeros(num_classes, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 72 | self.b = torch.zeros(num_classes, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 73 | self.alpha = torch.zeros(num_classes, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 74 | 75 | @property 76 | def get_a(self): 77 | return self.a.mean() 78 | @property 79 | def get_b(self): 80 | return self.b.mean() 81 | @property 82 | def get_alpha(self): 83 | return self.alpha.mean() 84 | 85 | def forward(self, y_pred, y_true, auto=True): 86 | total_loss = 0 87 | for idx in range(self.num_classes): 88 | if len(y_pred[:, idx].shape) == 1: 89 | y_pred_i = y_pred[:, idx].reshape(-1, 1) 90 | if len(y_true[:, idx].shape) == 1: 91 | y_true_i = y_true[:, idx].reshape(-1, 1) 92 | if auto: 93 | self.p[idx] = (y_true_i==1).sum()/y_true_i.shape[0] 94 | loss = (1-self.p[idx])*torch.mean((y_pred_i - self.a[idx])**2*(1==y_true_i).float()) + \ 95 | self.p[idx]*torch.mean((y_pred_i - self.b[idx])**2*(0==y_true_i).float()) + \ 96 | 2*self.alpha[idx]*(self.p[idx]*(1-self.p[idx]) + \ 97 | torch.mean((self.p[idx]*y_pred_i*(0==y_true_i).float() - (1-self.p[idx])*y_pred_i*(1==y_true_i).float())) )- \ 98 | self.p[idx]*(1-self.p[idx])*self.alpha[idx]**2 99 | total_loss += loss 100 | return total_loss 101 | 102 | class CompositionalAUCLoss_V1(torch.nn.Module): 103 | """ 104 | Compositional AUC Loss: a novel loss function to directly optimize AUROC 105 | inputs: 106 | margin: margin term for AUCM loss, e.g., m in [0, 1] 107 | imratio: imbalance ratio, i.e., the ratio of number of postive samples to number of total samples 108 | outputs: 109 | loss 110 | Reference: 111 | @inproceedings{ 112 | yuan2022compositional, 113 | title={Compositional Training for End-to-End Deep AUC Maximization}, 114 | author={Zhuoning Yuan and Zhishuai Guo and Nitesh Chawla and Tianbao Yang}, 115 | booktitle={International Conference on Learning Representations}, 116 | year={2022}, 117 | url={https://openreview.net/forum?id=gPvB4pdu_Z} 118 | } 119 | """ 120 | def __init__(self, imratio=None, margin=1, backend='ce', last_activation='sigmoid', device=None): 121 | super(CompositionalAUCLoss_V1, self).__init__() 122 | if not device: 123 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 124 | else: 125 | self.device = device 126 | self.margin = margin 127 | self.p = imratio 128 | self.a = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 129 | self.b = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 130 | self.alpha = torch.zeros(1, dtype=torch.float32, device=self.device, requires_grad=True).to(self.device) 131 | self.L_AVG = F.binary_cross_entropy_with_logits # include sigmoid 132 | self.backend = 'ce' 133 | self.last_activation = last_activation # whether to use output activation for computing AUC loss 134 | 135 | def forward(self, y_pred, y_true, auto=True): 136 | if len(y_pred.shape) == 1: 137 | y_pred = y_pred.reshape(-1, 1) 138 | if len(y_true.shape) == 1: 139 | y_true = y_true.reshape(-1, 1) 140 | if self.backend == 'ce': 141 | self.backend = 'auc' 142 | return self.L_AVG(y_pred, y_true) 143 | else: 144 | self.backend = 'ce' 145 | if auto: 146 | self.p = (y_true==1).sum()/y_true.shape[0] 147 | if self.last_activation == 'sigmoid': 148 | y_pred = torch.sigmoid(y_pred) 149 | elif self.last_activation == 'l2': 150 | y_pred = F.normalize(y_pred, p=2, dim=0) 151 | self.L_AUC = (1-self.p)*torch.mean((y_pred - self.a)**2*(1==y_true).float()) + \ 152 | self.p*torch.mean((y_pred - self.b)**2*(0==y_true).float()) + \ 153 | 2*self.alpha*(self.p*(1-self.p) + \ 154 | torch.mean((self.p*y_pred*(0==y_true).float() - (1-self.p)*y_pred*(1==y_true).float())) )- \ 155 | self.p*(1-self.p)*self.alpha**2 156 | return self.L_AUC 157 | 158 | # alias 159 | AUCMLoss = AUCMLoss_V1 160 | AUCM_MultiLabel = AUCM_MultiLabel_V1 161 | CompositionalAUCLoss = CompositionalAUCLoss_V1 162 | -------------------------------------------------------------------------------- /libauc/losses/perf_at_top.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .surrogate import squared_loss, squared_hinge_loss, logistic_loss 4 | 5 | def _check_tensor_shape(inputs, shape=(-1, 1)): 6 | input_shape = inputs.shape 7 | target_shape = shape 8 | if len(input_shape) != len(target_shape): 9 | inputs = inputs.reshape(target_shape) 10 | return inputs 11 | 12 | def _get_surrogate_loss(backend='squared_hinge'): 13 | if backend == 'squared_hinge': 14 | surr_loss = squared_hinge_loss 15 | elif backend == 'squared': 16 | surr_loss = squared_loss 17 | elif backend == 'logistic': 18 | surr_loss = logistic_loss 19 | else: 20 | raise ValueError('Out of options!') 21 | return surr_loss 22 | 23 | class TopPush_Loss(torch.nn.Module): 24 | """Partial AUC Loss: a stochastic one-way partial AUC based on DRO-CVaR (Top Push Loss) 25 | 26 | Args: 27 | pos_length: number of positive examples for the training data 28 | num_neg: number of negative samples for each mini-batch 29 | threshold: margin for basic AUC loss 30 | beta: FPR upper bound for pAUC used for SOTA 31 | eta: stepsize for CVaR regularization term 32 | loss type: basic AUC loss to apply. 33 | 34 | Return: 35 | loss value (scalar) 36 | 37 | Reference: 38 | Zhu, D., Li, G., Wang, B., Wu, X., and Yang, T., 2022. 39 | When AUC meets DRO: Optimizing Partial AUC for Deep Learning with Non-Convex Convergence Guarantee. 40 | arXiv preprint arXiv:2203.00176. 41 | 42 | """ 43 | def __init__(self, 44 | pos_length, 45 | num_neg, 46 | margin=1.0, 47 | alpha=0, 48 | beta=0.2, 49 | surrogate_loss='squared_hinge', 50 | top_push=False): 51 | 52 | super(pAUC_CVaR_Loss, self).__init__() 53 | self.beta = 1/num_neg # choose hardest negative samples in mini-batch 54 | self.alpha = alpha 55 | self.eta = 1.0 56 | self.num_neg = num_neg 57 | self.pos_length = pos_length 58 | self.u_pos = torch.tensor([0.0]*pos_length).reshape(-1, 1).cuda() 59 | self.margin = margin 60 | self.surrogate_loss = _get_surrogate_loss(surrogate_loss) 61 | 62 | def set_coef(self, eta): 63 | self.eta = eta 64 | 65 | def update_coef(self, decay_factor): 66 | self.eta = self.eta/decay_factor 67 | 68 | @property 69 | def get_coef(self): 70 | return self.eta 71 | 72 | def forward(self, y_pred, y_true, index_p): 73 | y_pred = _check_tensor_shape(y_pred, (-1, 1)) 74 | y_true = _check_tensor_shape(y_true, (-1, 1)) 75 | index_p = _check_tensor_shape(index_p, (-1,)) 76 | 77 | f_ps = y_pred[y_true == 1].reshape(-1, 1) 78 | f_ns = y_pred[y_true == 0].reshape(-1, 1) 79 | f_ps = f_ps.repeat(1, len(f_ns)) 80 | f_ns = f_ns.repeat(1, len(f_ps)) 81 | index_p = index_p[index_p>=0] 82 | 83 | loss = self.surrogate_loss(self.margin, f_ps - f_ns.transpose(0,1)) # return element-wsie loss 84 | p = loss > self.u_pos[index_p] 85 | self.u_pos[index_p] = self.u_pos[index_p]-self.eta/self.pos_length*(1 - p.sum(dim=1, keepdim=True)/(self.beta*self.num_neg)) 86 | 87 | p.detach_() 88 | loss = torch.mean(p * loss) / self.beta 89 | return loss 90 | -------------------------------------------------------------------------------- /libauc/losses/ranking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy.sparse import dok_matrix 6 | from .surrogate import squared_loss, squared_hinge_loss, logistic_loss 7 | 8 | def _get_surrogate_loss(backend='squared_hinge'): 9 | if backend == 'squared_hinge': 10 | surr_loss = squared_hinge_loss 11 | elif backend == 'squared': 12 | surr_loss = squared_loss 13 | elif backend == 'logistic': 14 | surr_loss = logistic_loss 15 | else: 16 | raise ValueError('Out of options!') 17 | return surr_loss 18 | 19 | class ListwiseCE_Loss(torch.nn.Module): 20 | """ 21 | Stochastic Optimization of Listwise CE loss: a novel listwise cross-entropy loss that 22 | computes the cross-entropy between predicted and ground truth top-one probability distribution 23 | 24 | Inputs: 25 | id_mapper (scipy.sparse.dok_matrix): map 2d index (user_id, item_id) to 1d index 26 | total_relevant_pairs (int): number of all relevant pairs 27 | num_pos (int): the number of positive items sampled for each user 28 | gamma0 (float): the factor for moving average, i.e., \gamma_0 in our paper, in range (0.0, 1.0) 29 | this hyper-parameter can be tuned for better performance 30 | eps (float, optional): a small value to avoid divide-zero error 31 | Outputs: 32 | loss value 33 | Reference: 34 | Qiu, Z., Hu, Q., Zhong, Y., Zhang, L. and Yang, T. 35 | Large-scale Stochastic Optimization of NDCG Surrogates for Deep Learning with Provable Convergence 36 | https://arxiv.org/abs/2202.12183 37 | """ 38 | def __init__(self, 39 | id_mapper, 40 | total_relevant_pairs, 41 | num_pos, 42 | gamma0, 43 | eps=1e-10, 44 | device=None): 45 | super(ListwiseCE_Loss, self).__init__() 46 | if not device: 47 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 48 | else: 49 | self.device = device 50 | self.id_mapper = id_mapper 51 | self.num_pos = num_pos 52 | self.gamma0 = gamma0 53 | self.eps = eps 54 | self.u = torch.zeros(total_relevant_pairs).to(self.device) 55 | 56 | def forward(self, predictions, batch): 57 | """ 58 | Args: 59 | predictions: predicted socres from the model, shape: [batch_size, num_pos + num_neg] 60 | batch: a dict that contains two keys: user_id and item_id 61 | """ 62 | batch_size = predictions.size(0) 63 | neg_pred = torch.repeat_interleave(predictions[:, self.num_pos:], self.num_pos, dim=0) # [batch_size * num_pos, num_neg] 64 | pos_pred = torch.cat(torch.chunk(predictions[:, :self.num_pos], batch_size, dim=0), dim=1).permute(1,0) # [batch_size * num_pos, 1] 65 | 66 | margin = neg_pred - pos_pred 67 | exp_margin = torch.exp(margin - torch.max(margin)).detach_() 68 | 69 | user_ids_tsfd = batch['user_id'].repeat_interleave(self.num_pos) 70 | pos_item_ids_tsfd = torch.cat(torch.chunk(batch['item_id'][:, :self.num_pos] , batch_size, dim=0), dim=1).squeeze() 71 | 72 | user_item_ids = self.id_mapper[user_ids_tsfd.tolist(), pos_item_ids_tsfd.tolist()].toarray().squeeze() 73 | self.u[user_item_ids] = (1-self.gamma0) * self.u[user_item_ids] + self.gamma0 * torch.mean(exp_margin, dim=1) 74 | 75 | exp_margin_softmax = exp_margin / (self.u[user_item_ids][:, None] + self.eps) 76 | 77 | loss = torch.sum(margin * exp_margin_softmax) 78 | loss /= batch_size 79 | 80 | return loss 81 | 82 | 83 | class NDCG_Loss(torch.nn.Module): 84 | """ 85 | Stochastic Optimization of NDCG (SONG) and top-K NDCG (K-SONG) 86 | 87 | Inputs: 88 | id_mapper (scipy.sparse.dok_matrix): map 2d index (user_id, item_id) to 1d index 89 | total_relevant_pairs (int): number of all relevant pairs 90 | num_user (int): the number of users in the dataset 91 | num_item (int): the number of items in the dataset 92 | num_pos (int): the number of positive items sampled for each user 93 | gamma0 (float): the moving average factor of u_{q,i}, i.e., \beta_0 in our paper, in range (0.0, 1.0) 94 | this hyper-parameter can be tuned for better performance 95 | gamma1 (float, optional): the moving average factor of s_{q} and v_{q} 96 | eta0 (float, optional): step size of \lambda 97 | margin (float, optional): margin for squared hinge loss 98 | topk (int, optional): NDCG@k optimization is activated if topk > 0; topk=-1 represents SONG 99 | topk_version (string, optional): 'theo' or 'prac' 100 | tau_1 (float, optional): \tau_1 in Eq. (6), \tau_1 << 1 101 | tau_2 (float, optional): \tau_2 in Eq. (6), \tau_2 << 1 102 | psi_func (str, optional): can be 'sigmoid' or 'hinge' 103 | hinge_margin (float, optional): a hyperparameter for hinge function, psi(x) = max(x + hinge_margin, 0) 104 | sigmoid_alpha (float, optional): a hyperparameter for sigmoid function, psi(x) = sigmoid(x * sigmoid_alpha) 105 | Outputs: 106 | loss value 107 | Reference: 108 | Qiu, Z., Hu, Q., Zhong, Y., Zhang, L. and Yang, T. 109 | Large-scale Stochastic Optimization of NDCG Surrogates for Deep Learning with Provable Convergence 110 | https://arxiv.org/abs/2202.12183 111 | """ 112 | def __init__(self, 113 | id_mapper, 114 | total_relevant_pairs, 115 | num_user, 116 | num_item, 117 | num_pos, 118 | gamma0, 119 | gamma1=0.9, 120 | eta0=0.01, 121 | margin=1.0, 122 | topk=-1, 123 | topk_version='theo', 124 | tau_1=0.01, 125 | tau_2=0.0001, 126 | psi_func='sigmoid', 127 | topk_margin=2.0, 128 | sigmoid_alpha=2.0, 129 | surrogate_loss='squared_hinge', 130 | device=None): 131 | super(NDCG_Loss, self).__init__() 132 | if not device: 133 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 134 | else: 135 | self.device = device 136 | self.id_mapper = id_mapper 137 | self.num_pos = num_pos 138 | self.margin = margin 139 | self.gamma0 = gamma0 140 | self.topk = topk 141 | self.lambda_q = torch.zeros(num_user+1).to(self.device) # learnable thresholds for all querys (users) 142 | self.v_q = torch.zeros(num_user+1).to(self.device) # moving average estimator for \nabla_{\lambda} L_q 143 | self.gamma1 = gamma1 144 | self.tau_1 = tau_1 145 | self.tau_2 = tau_2 146 | self.eta0 = eta0 147 | self.num_item = num_item 148 | self.topk_version = topk_version 149 | self.s_q = torch.zeros(num_user+1).to(self.device) # moving average estimator for \nabla_{\lambda}^2 L_q 150 | self.psi_func = psi_func 151 | self.topk_margin = topk_margin 152 | self.sigmoid_alpha = sigmoid_alpha 153 | self.u = torch.zeros(total_relevant_pairs).to(self.device) 154 | self.surrogate_loss = _get_surrogate_loss(surrogate_loss) 155 | 156 | def forward(self, predictions, batch): 157 | """ 158 | Args: 159 | predictions: predicted socres from the model, shape: [batch_size, num_pos + num_neg] 160 | batch: a dict that contains the following keys: user_id, item_id, rating, num_pos_items, ideal_dcg 161 | """ 162 | device = predictions.device 163 | ratings = batch['rating'][:, :self.num_pos] # [batch_size, num_pos] 164 | batch_size = ratings.size()[0] 165 | predictions_expand = torch.repeat_interleave(predictions, self.num_pos, dim=0) # [batch_size*num_pos, num_pos+num_neg] 166 | predictions_pos = torch.cat(torch.chunk(predictions[:, :self.num_pos], batch_size, dim=0), dim=1).permute(1,0) # [batch_suze*num_pos, 1] 167 | 168 | num_pos_items = batch['num_pos_items'].float() # [batch_size], the number of positive items for each user 169 | ideal_dcg = batch['ideal_dcg'].float() # [batch_size], the ideal dcg for each user 170 | 171 | g = torch.mean(self.surrogate_loss(self.margin, predictions_pos-predictions_expand), dim=-1) # [batch_size*num_pos] 172 | g = g.reshape(batch_size, self.num_pos) # [batch_size, num_pos], line 5 in Algo 2. 173 | 174 | G = (2.0 ** ratings - 1).float() 175 | 176 | user_ids = batch['user_id'] 177 | pos_item_ids = batch['item_id'][:, :self.num_pos] # [batch_size, num_pos] 178 | 179 | pos_item_ids = torch.cat(torch.chunk(pos_item_ids, batch_size, dim=0), dim=1).squeeze() 180 | user_ids_repeat = user_ids.repeat_interleave(self.num_pos) 181 | 182 | user_item_ids = self.id_mapper[user_ids_repeat.tolist(), pos_item_ids.tolist()].toarray().squeeze() 183 | self.u[user_item_ids] = (1-self.gamma0) * self.u[user_item_ids] + self.gamma0 * g.clone().detach_().reshape(-1) 184 | g_u = self.u[user_item_ids].reshape(batch_size, self.num_pos) 185 | 186 | nabla_f_g = (G * self.num_item) / ((torch.log2(1 + self.num_item*g_u))**2 * (1 + self.num_item*g_u) * np.log(2)) # \nabla f(g) 187 | 188 | if self.topk > 0: 189 | user_ids = user_ids.long() 190 | pos_preds_lambda_diffs = predictions[:, :self.num_pos].clone().detach_() - self.lambda_q[user_ids][:, None].to(device) 191 | preds_lambda_diffs = predictions.clone().detach_() - self.lambda_q[user_ids][:, None].to(device) 192 | 193 | # the gradient of lambda 194 | grad_lambda_q = self.topk/self.num_item + self.tau_2*self.lambda_q[user_ids] - torch.mean(torch.sigmoid(preds_lambda_diffs.to(device) / self.tau_1), dim=-1) 195 | self.v_q[user_ids] = self.gamma1 * grad_lambda_q + (1-self.gamma1) * self.v_q[user_ids] 196 | self.lambda_q[user_ids] = self.lambda_q[user_ids] - self.eta0 * self.v_q[user_ids] 197 | 198 | if self.topk_version == 'prac': 199 | nabla_f_g *= torch.sigmoid(pos_preds_lambda_diffs * self.sigmoid_alpha) 200 | 201 | elif self.topk_version == 'theo': 202 | nabla_f_g *= torch.sigmoid(pos_preds_lambda_diffs * self.sigmoid_alpha) 203 | d_psi = torch.sigmoid(pos_preds_lambda_diffs * self.sigmoid_alpha) * (1 - torch.sigmoid(pos_preds_lambda_diffs * self.sigmoid_alpha)) 204 | 205 | temp_term = torch.sigmoid(preds_lambda_diffs / self.tau_1) * (1 - torch.sigmoid(preds_lambda_diffs / self.tau_1)) / self.tau_1 206 | L_lambda_hessian = self.tau_2 + torch.mean(temp_term, dim=1) # \nabla_{\lambda}^2 L_q in Eq. (7) in the paper 207 | self.s_q[user_ids] = self.gamma1 * L_lambda_hessian.to(device) + (1-self.gamma1) * self.s_q[user_ids] # line 10 in Algorithm 2 in the paper 208 | hessian_term = torch.mean(temp_term * predictions, dim=1) / self.s_q[user_ids].to(device) # \nabla_{\lambda,w}^2 L_q * s_q in Eq. (7) in the paper 209 | f_g_u = -G / torch.log2(1 + self.num_item*g_u) 210 | loss = (num_pos_items * torch.mean(nabla_f_g * g + d_psi * f_g_u * (predictions[:, :self.num_pos] - hessian_term[:, None]), dim=-1) / ideal_dcg).mean() 211 | return loss 212 | 213 | loss = (num_pos_items * torch.mean(nabla_f_g * g, dim=-1) / ideal_dcg).mean() 214 | return loss 215 | 216 | # alias 217 | ListwiseCELoss = ListwiseCE_Loss 218 | NDCGLoss = NDCG_Loss 219 | -------------------------------------------------------------------------------- /libauc/losses/surrogate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def squared_loss(margin, t): 4 | return (margin - t)** 2 5 | 6 | def squared_hinge_loss(margin, t): 7 | return torch.max(margin - t, torch.zeros_like(t)) ** 2 8 | 9 | def logistic_loss(margin, t): 10 | return torch.log(1+torch.log(-margin*t)) 11 | -------------------------------------------------------------------------------- /libauc/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * 2 | -------------------------------------------------------------------------------- /libauc/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import roc_auc_score 2 | from sklearn.metrics import average_precision_score 3 | import numpy as np 4 | 5 | def auroc(labels, scores, **kwargs): 6 | if isinstance(labels, list): 7 | labels = np.array(labels) 8 | if isinstance(scores, list): 9 | scores = np.array(scores) 10 | if len(labels.shape) == 1: 11 | labels = labels.reshape(-1, 1) 12 | if len(scores.shape) == 1: 13 | scores = scores.reshape(-1, 1) 14 | if scores.shape[-1] != 1 and len(scores.shape)>1: 15 | class_auc_list = [] 16 | for i in range(scores.shape[-1]): 17 | try: 18 | local_auc = roc_auc_score(labels[:, i], scores[:, i], **kwargs) 19 | class_auc_list.append(local_auc) 20 | except: 21 | class_auc_list.append(0.0) # if only one class 22 | return class_auc_list 23 | return roc_auc_score(labels, scores, **kwargs) 24 | 25 | def auprc(labels, scores, **kwargs): 26 | if isinstance(labels, list): 27 | labels = np.array(labels) 28 | if isinstance(scores, list): 29 | scores = np.array(scores) 30 | if len(labels.shape) == 1: 31 | labels = labels.reshape(-1, 1) 32 | if len(scores.shape) == 1: 33 | scores = scores.reshape(-1, 1) 34 | if scores.shape[-1] != 1 and len(scores.shape)>1: 35 | class_auc_list = [] 36 | for i in range(scores.shape[-1]): 37 | try: 38 | local_auc = average_precision_score(labels[:, i], scores[:, i]) 39 | class_auc_list.append(local_auc) 40 | except: 41 | class_auc_list.append(0.0) 42 | return class_auc_list 43 | return average_precision_score(labels, scores) 44 | 45 | def pauc(labels, scores, max_fpr=1.0, min_tpr=0.0, **kwargs): 46 | # multi-task support: TBD 47 | if isinstance(labels, list): 48 | labels = np.array(labels) 49 | if isinstance(scores, list): 50 | scores = np.array(scores) 51 | labels = labels.reshape(-1) 52 | scores = scores.reshape(-1) 53 | if min_tpr == 0: 54 | # one-way partial AUC 55 | return roc_auc_score(labels, scores, max_fpr=max_fpr) 56 | # two-way partial AUC 57 | pos_idx = np.where(labels == 1)[0] 58 | neg_idx = np.where(labels != 1)[0] 59 | num_pos = round(len(pos_idx)*(1-min_tpr)) 60 | num_neg = round(len(neg_idx)*max_fpr) 61 | num_pos = 1 if num_pos < 1 else num_pos 62 | num_neg = 1 if num_neg < 1 else num_neg 63 | if len(pos_idx)==1: 64 | selected_pos = [0] 65 | else: 66 | selected_pos = np.argpartition(scores[pos_idx], num_pos)[:num_pos] 67 | if len(neg_idx)==1: 68 | selected_neg = [0] 69 | else: 70 | selected_neg = np.argpartition(-scores[neg_idx], num_neg)[:num_neg] 71 | selected_target = np.concatenate((labels[pos_idx][selected_pos], labels[neg_idx][selected_neg])) 72 | selected_pred = np.concatenate((scores[pos_idx][selected_pos], scores[neg_idx][selected_neg])) 73 | return roc_auc_score(selected_target, selected_pred) 74 | 75 | # TODO: make individual function 76 | def map_at_k(hit, gt_rank): 77 | ap_list = [] 78 | hit_gt_rank = (hit * gt_rank).astype(float) 79 | sorted_hit_gt_rank = np.sort(hit_gt_rank) 80 | for idx, row in enumerate(sorted_hit_gt_rank): 81 | precision_list = [] 82 | counter = 1 83 | for item in row: 84 | if item > 0: 85 | precision_list.append(counter / item) 86 | counter += 1 87 | ap = np.sum(precision_list) / np.sum(hit[idx]) if np.sum(hit[idx]) > 0 else 0 88 | ap_list.append(ap) 89 | return np.mean(ap_list) 90 | 91 | # TODO: make individual function 92 | def ndcg_at_k(ratings, normalizer_mat, hit, gt_rank, k): 93 | # calculate the normalizer first 94 | normalizer = np.sum(normalizer_mat[:, :k], axis=1) 95 | # calculate DCG 96 | DCG = np.sum(((np.exp2(ratings) - 1) / np.log2(gt_rank+1)) * hit.astype(float), axis=1) 97 | return np.mean(DCG / normalizer) 98 | 99 | # alias 100 | auc_roc_score = auroc 101 | auc_prc_score = auprc 102 | pauc_roc_score = pauc 103 | 104 | if __name__ == '__main__': 105 | # import numpy as np 106 | preds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 107 | labels = [1, 1, 1, 0, 0, 0, 1, 1, 1, 0] 108 | # print (preds.shape, labels.shape) 109 | print (auprc(labels, preds)) 110 | print (auroc(labels, preds)) 111 | 112 | print (roc_auc_score(labels, preds)) 113 | print (average_precision_score(labels, preds)) 114 | 115 | 116 | -------------------------------------------------------------------------------- /libauc/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import * 2 | from .resnet import * 3 | from .resnet_cifar import * 4 | from .neumf import * 5 | from .perceptron import * 6 | -------------------------------------------------------------------------------- /libauc/models/gcn.py: -------------------------------------------------------------------------------- 1 | # TODO 2 | -------------------------------------------------------------------------------- /libauc/models/neumf.py: -------------------------------------------------------------------------------- 1 | # This implementation is from https://github.com/hexiangnan/neural_collaborative_filtering 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | 7 | class NeuMF(nn.Module): 8 | """ 9 | NeuMF is a widely-used model for recommender systems 10 | 11 | Inputs: 12 | user_num (int): the number of users in the dataset 13 | item_num (int): the number of items in the dataset 14 | dropout (float, optional): dropout ratio for the model 15 | emb_size (int, optional): embedding size of the model 16 | layers (string, optional): describe the layer information of the model 17 | Outputs: 18 | predicted score for each input user-item pair 19 | Reference: 20 | He, X., Liao, L., Zhang, H., Nie, L., Hu, X., and Chua, T. 21 | Neural Collaborative Filtering 22 | https://arxiv.org/abs/1708.05031 23 | """ 24 | def __init__(self, user_num: int, item_num: int, dropout: float=0.2, emb_size: int=64, layers: str='[64]'): 25 | super(NeuMF, self).__init__() 26 | self.user_num = user_num 27 | self.item_num = item_num 28 | self.emb_size = emb_size 29 | self.dropout = dropout 30 | self.layers = eval(layers) 31 | 32 | self.mf_u_embeddings = nn.Embedding(self.user_num, self.emb_size) 33 | self.mf_i_embeddings = nn.Embedding(self.item_num, self.emb_size) 34 | self.mlp_u_embeddings = nn.Embedding(self.user_num, self.emb_size) 35 | self.mlp_i_embeddings = nn.Embedding(self.item_num, self.emb_size) 36 | 37 | self.mlp = nn.ModuleList([]) 38 | pre_size = 2 * self.emb_size 39 | for i, layer_size in enumerate(self.layers): 40 | self.mlp.append(nn.Linear(pre_size, layer_size)) 41 | pre_size = layer_size 42 | self.dropout_layer = nn.Dropout(p=self.dropout) 43 | self.prediction = nn.Linear(pre_size + self.emb_size, 1, bias=False) 44 | 45 | def reset_last_layer(self): 46 | self.prediction.reset_parameters() 47 | 48 | @staticmethod 49 | def init_weights(m): 50 | if 'Linear' in str(type(m)): 51 | nn.init.normal_(m.weight, mean=0.0, std=0.01) 52 | if m.bias is not None: 53 | nn.init.normal_(m.bias, mean=0.0, std=0.01) 54 | elif 'Embedding' in str(type(m)): 55 | nn.init.normal_(m.weight, mean=0.0, std=0.01) 56 | 57 | def save_model(self, model_path=None): 58 | if model_path is None: 59 | model_path = self.model_path 60 | torch.save(self.state_dict(), model_path) 61 | 62 | def load_model(self, model_path=None): 63 | if model_path is None: 64 | model_path = self.model_path 65 | self.load_state_dict(torch.load(model_path)) 66 | logging.info('Load model from ' + model_path) 67 | 68 | def forward(self, feed_dict): 69 | u_ids = feed_dict['user_id'].long() # [batch_size] 70 | i_ids = feed_dict['item_id'].long() # [batch_size, -1] 71 | 72 | u_ids = u_ids.unsqueeze(-1).repeat((1, i_ids.shape[1])) # [batch_size, -1] 73 | 74 | mf_u_vectors = self.mf_u_embeddings(u_ids) 75 | mf_i_vectors = self.mf_i_embeddings(i_ids) 76 | mlp_u_vectors = self.mlp_u_embeddings(u_ids) 77 | mlp_i_vectors = self.mlp_i_embeddings(i_ids) 78 | 79 | mf_vector = mf_u_vectors * mf_i_vectors 80 | mlp_vector = torch.cat([mlp_u_vectors, mlp_i_vectors], dim=-1) 81 | for layer in self.mlp: 82 | mlp_vector = layer(mlp_vector).relu() 83 | mlp_vector = self.dropout_layer(mlp_vector) 84 | 85 | output_vector = torch.cat([mf_vector, mlp_vector], dim=-1) 86 | prediction = self.prediction(output_vector) 87 | return {'prediction': prediction.view(feed_dict['batch_size'], -1)} 88 | -------------------------------------------------------------------------------- /libauc/models/perceptron.py: -------------------------------------------------------------------------------- 1 | # credit to LibAUC 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | # Multilayer Perceptron 7 | class MLP(torch.nn.Module): 8 | def __init__(self, input_dim=29, hidden_sizes=(16,), activation='relu', num_classes=1): 9 | super().__init__() 10 | self.inputs = torch.nn.Linear(input_dim, hidden_sizes[0]) 11 | layers = [] 12 | for i in range(len(hidden_sizes)-1): 13 | layers.append(torch.nn.Linear(hidden_sizes[i], hidden_sizes[i+1])) 14 | if activation=='relu': 15 | print ('relu') 16 | layers.append(nn.ReLU()) 17 | elif activation=='elu': 18 | layers.append(nn.ELU()) 19 | else: 20 | pass 21 | self.layers = nn.Sequential(*layers) 22 | self.classifer = torch.nn.Linear(hidden_sizes[-1], num_classes) 23 | 24 | def forward(self, x): 25 | """forward pass""" 26 | x = self.inputs(x) 27 | x = self.layers(x) 28 | return self.classifer(x) 29 | 30 | -------------------------------------------------------------------------------- /libauc/models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import torch.nn.init as init 28 | 29 | from torch.autograd import Variable 30 | 31 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 32 | 33 | 34 | def _weights_init(m): 35 | classname = m.__class__.__name__ 36 | #print(classname) 37 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 38 | #init.kaiming_normal_(m.weight) 39 | init.xavier_normal_(m.weight) 40 | 41 | class LambdaLayer(nn.Module): 42 | def __init__(self, lambd): 43 | super(LambdaLayer, self).__init__() 44 | self.lambd = lambd 45 | 46 | def forward(self, x): 47 | return self.lambd(x) 48 | 49 | from torch.nn import Parameter 50 | class NormedLinear(nn.Module): 51 | 52 | def __init__(self, in_features, out_features): 53 | super(NormedLinear, self).__init__() 54 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 55 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 56 | 57 | def forward(self, x): 58 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 59 | return out 60 | 61 | class BasicBlock(nn.Module): 62 | expansion = 1 63 | 64 | def __init__(self, in_planes, planes, stride=1, option='A'): 65 | super(BasicBlock, self).__init__() 66 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | 71 | self.shortcut = nn.Sequential() 72 | if stride != 1 or in_planes != planes: 73 | if option == 'A': 74 | """ 75 | For CIFAR10 ResNet paper uses option A. 76 | """ 77 | self.shortcut = LambdaLayer(lambda x: 78 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 79 | elif option == 'B': 80 | self.shortcut = nn.Sequential( 81 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 82 | nn.BatchNorm2d(self.expansion * planes) 83 | ) 84 | 85 | def forward(self, x): 86 | out = activation_func(self.bn1(self.conv1(x))) 87 | out = self.bn2(self.conv2(out)) 88 | out += self.shortcut(x) 89 | out = activation_func(out) 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | def __init__(self, block, num_blocks, num_classes=1, last_activation='sigmoid', pretrained=False): 95 | super(ResNet, self).__init__() 96 | self.in_planes = 16 97 | 98 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(16) 100 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 101 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 102 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 103 | self.linear = nn.Linear(64, num_classes) 104 | 105 | self.apply(_weights_init) 106 | 107 | self.sigmoid = nn.Sigmoid() 108 | self.last_activation = last_activation 109 | 110 | def _make_layer(self, block, planes, num_blocks, stride): 111 | strides = [stride] + [1]*(num_blocks-1) 112 | layers = [] 113 | for stride in strides: 114 | layers.append(block(self.in_planes, planes, stride)) 115 | self.in_planes = planes * block.expansion 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | out = activation_func(self.bn1(self.conv1(x))) 121 | out = self.layer1(out) 122 | out = self.layer2(out) 123 | out = self.layer3(out) 124 | 125 | out = F.avg_pool2d(out, out.size()[3]) 126 | out = out.view(out.size(0), -1) 127 | out = self.linear(out) 128 | if self.last_activation == 'sigmoid': 129 | out = self.sigmoid(out) 130 | elif self.last_activation == 'none' or self.last_activation==None: 131 | out = out 132 | elif self.last_activation == 'l2': 133 | out= F.normalize(out,dim=0,p=2) 134 | else: 135 | out = self.sigmoid(out) 136 | return out 137 | 138 | 139 | def resnet20(pretrained=False, activations='relu', last_activation=None, **kwargs): 140 | global activation_func 141 | activation_func = F.relu if activations=='relu' else F.elu 142 | # print (activation_func) 143 | return ResNet(BasicBlock, [3, 3, 3], last_activation=last_activation, **kwargs) 144 | 145 | 146 | def resnet32(pretrained=False, activations='relu', last_activation=None, **kwargs): 147 | global activation_func 148 | activation_func = F.relu if activations=='relu' else F.elu 149 | # print (activation_func) 150 | return ResNet(BasicBlock, [5, 5, 5], last_activation=last_activation, **kwargs) 151 | 152 | 153 | def resnet44(pretrained=False, activations='relu', last_activation=None, **kwargs): 154 | global activation_func 155 | activation_func = F.relu if activations=='relu' else F.elu 156 | # print (activation_func) 157 | return ResNet(BasicBlock, [7, 7, 7], last_activation=last_activation, **kwargs) 158 | 159 | 160 | def resnet56(pretrained=False, activations='relu', last_activation=None, **kwargs): 161 | global activation_func 162 | activation_func = F.relu if activations=='relu' else F.elu 163 | # print (activation_func) 164 | return ResNet(BasicBlock, [9, 9, 9], last_activation=last_activation, **kwargs) 165 | 166 | 167 | def resnet110(pretrained=False, activations='relu', last_activation=None, **kwargs): 168 | global activation_func 169 | activation_func = F.relu if activations=='relu' else F.elu 170 | # print (activation_func) 171 | return ResNet(BasicBlock, [18, 18, 18], last_activation=last_activation, **kwargs) 172 | 173 | 174 | def resnet1202(pretrained=False, activations='relu', last_activation=None, **kwargs): 175 | global activation_func 176 | activation_func = F.relu if activations=='relu' else F.elu 177 | # print (activation_func) 178 | return ResNet(BasicBlock, [200, 200, 200], last_activation=last_activation, **kwargs) 179 | 180 | 181 | def test(net): 182 | import numpy as np 183 | total_params = 0 184 | 185 | for x in filter(lambda p: p.requires_grad, net.parameters()): 186 | total_params += np.prod(x.data.numpy().shape) 187 | print("Total number of params", total_params) 188 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 189 | 190 | 191 | # alias 192 | ResNet20 = resnet20 193 | ResNet32 = resnet32 194 | ResNet44 = resnet44 195 | ResNet56 = resnet56 196 | ResNet110 = resnet110 197 | ResNet1202 = resnet1202 198 | 199 | if __name__ == "__main__": 200 | for net_name in __all__: 201 | if net_name.startswith('resnet'): 202 | print(net_name) 203 | test(globals()[net_name]()) 204 | print() 205 | -------------------------------------------------------------------------------- /libauc/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # LibAUC optimizers 2 | from .pesg import * 3 | from .pdsca import * 4 | from .soap import * 5 | from .sopa import * 6 | from .sopa_s import * 7 | from .sota_s import * 8 | from .song import * 9 | 10 | # PyTorch optimizers (reference) 11 | from .sgd import * 12 | from .adam import * 13 | -------------------------------------------------------------------------------- /libauc/optimizers/adam.py: -------------------------------------------------------------------------------- 1 | # Code is modified based on PyTorch implementation: https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py 2 | 3 | import math 4 | import torch 5 | 6 | class Adam(torch.optim.Optimizer): 7 | r"""Implements Adam algorithm. 8 | 9 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 10 | The implementation of the L2 penalty follows changes proposed in 11 | `Decoupled Weight Decay Regularization`_. 12 | 13 | Arguments: 14 | params (iterable): iterable of parameters to optimize or dicts defining 15 | parameter groups 16 | lr (float, optional): learning rate (default: 1e-3) 17 | betas (Tuple[float, float], optional): coefficients used for computing 18 | running averages of gradient and its square (default: (0.9, 0.999)) 19 | eps (float, optional): term added to the denominator to improve 20 | numerical stability (default: 1e-8) 21 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 22 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 23 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 24 | (default: False) 25 | 26 | .. _Adam\: A Method for Stochastic Optimization: 27 | https://arxiv.org/abs/1412.6980 28 | .. _Decoupled Weight Decay Regularization: 29 | https://arxiv.org/abs/1711.05101 30 | .. _On the Convergence of Adam and Beyond: 31 | https://openreview.net/forum?id=ryQu7f-RZ 32 | """ 33 | 34 | def __init__(self, 35 | params, 36 | lr=1e-3, 37 | betas=(0.9, 0.999), 38 | eps=1e-8, 39 | weight_decay=0, 40 | amsgrad=False, 41 | verbose=True): 42 | if not 0.0 <= lr: 43 | raise ValueError("Invalid learning rate: {}".format(lr)) 44 | if not 0.0 <= eps: 45 | raise ValueError("Invalid epsilon value: {}".format(eps)) 46 | if not 0.0 <= betas[0] < 1.0: 47 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 48 | if not 0.0 <= betas[1] < 1.0: 49 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 50 | if not 0.0 <= weight_decay: 51 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 52 | 53 | try: 54 | params = params.parameters() 55 | except: 56 | params = params 57 | 58 | self.lr = lr 59 | self.verbose = verbose 60 | 61 | defaults = dict(lr=lr, betas=betas, eps=eps, 62 | weight_decay=weight_decay, amsgrad=amsgrad) 63 | super(Adam, self).__init__(params, defaults) 64 | 65 | def __setstate__(self, state): 66 | super(Adam, self).__setstate__(state) 67 | for group in self.param_groups: 68 | group.setdefault('amsgrad', False) 69 | 70 | @torch.no_grad() 71 | def step(self, closure=None): 72 | """Performs a single optimization step. 73 | 74 | Arguments: 75 | closure (callable, optional): A closure that reevaluates the model 76 | and returns the loss. 77 | """ 78 | loss = None 79 | if closure is not None: 80 | with torch.enable_grad(): 81 | loss = closure() 82 | 83 | for group in self.param_groups: 84 | self.lr = group['lr'] 85 | for p in group['params']: 86 | if p.grad is None: 87 | continue 88 | grad = p.grad 89 | if grad.is_sparse: 90 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 91 | amsgrad = group['amsgrad'] 92 | 93 | state = self.state[p] 94 | 95 | # State initialization 96 | if len(state) == 0: 97 | state['step'] = 0 98 | # Exponential moving average of gradient values 99 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 100 | # Exponential moving average of squared gradient values 101 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 102 | if amsgrad: 103 | # Maintains max of all exp. moving avg. of sq. grad. values 104 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 105 | 106 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 107 | if amsgrad: 108 | max_exp_avg_sq = state['max_exp_avg_sq'] 109 | beta1, beta2 = group['betas'] 110 | 111 | state['step'] += 1 112 | bias_correction1 = 1 - beta1 ** state['step'] 113 | bias_correction2 = 1 - beta2 ** state['step'] 114 | 115 | if group['weight_decay'] != 0: 116 | grad = grad.add(p, alpha=group['weight_decay']) 117 | 118 | # Decay the first and second moment running average coefficient 119 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 120 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 121 | if amsgrad: 122 | # Maintains the maximum of all 2nd moment running avg. till now 123 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 124 | # Use the max. for normalizing running avg. of gradient 125 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 126 | else: 127 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 128 | 129 | step_size = group['lr'] / bias_correction1 130 | 131 | p.addcdiv_(exp_avg, denom, value=-step_size) 132 | 133 | return loss 134 | 135 | def update_lr(self, decay_factor=None): 136 | if decay_factor != None: 137 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 138 | if self.verbose: 139 | print ('Reducing learning rate to %.5f !'%(self.param_groups[0]['lr'])) 140 | 141 | -------------------------------------------------------------------------------- /libauc/optimizers/pdsca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | 4 | class PDSCA(torch.optim.Optimizer): 5 | """ 6 | Reference: 7 | @inproceedings{ 8 | yuan2022compositional, 9 | title={Compositional Training for End-to-End Deep AUC Maximization}, 10 | author={Zhuoning Yuan and Zhishuai Guo and Nitesh Chawla and Tianbao Yang}, 11 | booktitle={International Conference on Learning Representations}, 12 | year={2022}, 13 | url={https://openreview.net/forum?id=gPvB4pdu_Z} 14 | } 15 | """ 16 | def __init__(self, 17 | model, 18 | loss_fn=None, 19 | a=None, # to be deprecated 20 | b=None, # to be deprecated 21 | alpha=None, # to be deprecated 22 | margin=1.0, 23 | lr=0.1, 24 | lr0=None, 25 | gamma=None, # to be deprecated 26 | beta1=0.99, 27 | beta2=0.999, 28 | clip_value=1.0, 29 | weight_decay=1e-5, 30 | epoch_decay=2e-3, #gamma=500 31 | verbose=True, 32 | device='cuda', 33 | **kwargs): 34 | 35 | if not device: 36 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | else: 38 | self.device = device 39 | assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!' 40 | if gamma is not None: 41 | assert gamma > 0 42 | epoch_decay = 1/gamma 43 | 44 | self.margin = margin 45 | self.model = model 46 | if lr0 is None: 47 | lr0 = lr 48 | self.lr = lr 49 | self.lr0 = lr0 50 | self.gamma = gamma 51 | self.clip_value = clip_value 52 | self.weight_decay = weight_decay 53 | self.epoch_decay = epoch_decay 54 | 55 | self.beta1 = beta1 56 | self.beta2 = beta2 57 | 58 | self.loss_fn = loss_fn 59 | if loss_fn != None: 60 | try: 61 | self.a = loss_fn.a 62 | self.b = loss_fn.b 63 | self.alpha = loss_fn.alpha 64 | except: 65 | print('AUCLoss is not found!') 66 | else: 67 | self.a = a 68 | self.b = b 69 | self.alpha = alpha 70 | 71 | self.model_ref = self.init_model_ref() 72 | self.model_acc = self.init_model_acc() 73 | 74 | self.T = 0 # for epoch_decay 75 | self.steps = 0 # total optim steps 76 | self.verbose = verbose # print updates for lr/regularizer 77 | 78 | def get_parameters(params): 79 | for p in params: 80 | yield p 81 | if self.a is not None or self.b is not None: 82 | self.params = get_parameters(list(model.parameters())+[self.a, self.b]) 83 | else: 84 | self.params = get_parameters(list(model.parameters())) 85 | self.defaults = dict(lr=self.lr, 86 | lr0=self.lr0, 87 | margin=margin, 88 | a=self.a, 89 | b=self.b, 90 | alpha=self.alpha, 91 | clip_value=self.clip_value, 92 | weight_decay=self.weight_decay, 93 | epoch_decay=self.epoch_decay, 94 | beta1=self.beta1, 95 | beta2=self.beta2, 96 | model_ref=self.model_ref, 97 | model_acc=self.model_acc) 98 | 99 | super(PDSCA, self).__init__(self.params, self.defaults) 100 | 101 | def __setstate__(self, state): 102 | super(PDSCA, self).__setstate__(state) 103 | for group in self.param_groups: 104 | group.setdefault('nesterov', False) 105 | 106 | def init_model_ref(self): 107 | self.model_ref = [] 108 | for var in list(self.model.parameters())+[self.a, self.b]: 109 | if var is not None: 110 | self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device)) 111 | return self.model_ref 112 | 113 | def init_model_acc(self): 114 | self.model_acc = [] 115 | for var in list(self.model.parameters())+[self.a, self.b]: 116 | if var is not None: 117 | self.model_acc.append(torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device)) 118 | return self.model_acc 119 | 120 | @property 121 | def optim_steps(self): 122 | return self.steps 123 | 124 | @property 125 | def get_params(self): 126 | return list(self.model.parameters()) 127 | 128 | @torch.no_grad() 129 | def step(self, closure=None): 130 | """Performs a single optimization step. 131 | """ 132 | loss = None 133 | if closure is not None: 134 | with torch.enable_grad(): 135 | loss = closure() 136 | 137 | for group in self.param_groups: 138 | weight_decay = group['weight_decay'] 139 | clip_value = group['clip_value'] 140 | self.lr = group['lr'] 141 | self.lr0 = group['lr0'] 142 | 143 | epoch_decay = group['epoch_decay'] 144 | beta1 = group['beta1'] 145 | beta2 = group['beta2'] 146 | model_ref = group['model_ref'] 147 | model_acc = group['model_acc'] 148 | 149 | m = group['margin'] 150 | a = group['a'] 151 | b = group['b'] 152 | alpha = group['alpha'] 153 | 154 | for i, p in enumerate(group['params']): 155 | if p.grad is None: 156 | continue 157 | d_p = torch.clamp(p.grad.data , -clip_value, clip_value) + epoch_decay*(p.data - model_ref[i].data) + weight_decay*p.data 158 | if alpha.grad is None: # sgd + moving p. # TODO: alpha=None mode 159 | p.data = p.data - group['lr0']*d_p 160 | if beta1!= 0: 161 | param_state = self.state[p] 162 | if 'weight_buffer' not in param_state: 163 | buf = param_state['weight_buffer'] = torch.clone(p).detach() 164 | else: 165 | buf = param_state['weight_buffer'] 166 | buf.mul_(1-beta1).add_(p, alpha=beta1) 167 | p.data = buf.data # Note: use buf(s) to compute the gradients w.r.t AUC loss can lead to a slight worse performance 168 | elif alpha.grad is not None: # auc + moving g. # TODO: alpha=None mode 169 | if beta2!= 0: 170 | param_state = self.state[p] 171 | if 'momentum_buffer' not in param_state: 172 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 173 | else: 174 | buf = param_state['momentum_buffer'] 175 | buf.mul_(1-beta2).add_(d_p, alpha=beta2) 176 | d_p = buf 177 | p.data = p.data - group['lr']*d_p 178 | else: 179 | NotImplementedError 180 | model_acc[i].data = model_acc[i].data + p.data 181 | 182 | if alpha is not None: 183 | if alpha.grad is not None: 184 | alpha.data = alpha.data + group['lr']*(2*(m + b.data - a.data)-2*alpha.data) 185 | alpha.data = torch.clamp(alpha.data, 0, 999) 186 | 187 | self.T += 1 188 | self.steps += 1 189 | return loss 190 | 191 | def zero_grad(self): 192 | self.model.zero_grad() 193 | if self.a is not None and self.b is not None: 194 | self.a.grad = None 195 | self.b.grad = None 196 | if self.alpha is not None: 197 | self.alpha.grad = None 198 | 199 | def update_lr(self, decay_factor=None, decay_factor0=None): 200 | if decay_factor != None: 201 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 202 | if self.verbose: 203 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 204 | if decay_factor0 != None: 205 | self.param_groups[0]['lr0'] = self.param_groups[0]['lr0']/decay_factor0 206 | if self.verbose: 207 | print ('Reducing learning rate (inner) to %.5f @ T=%s!'%(self.param_groups[0]['lr0'], self.steps)) 208 | 209 | def update_regularizer(self, decay_factor=None, decay_factor0=None): 210 | if decay_factor != None: 211 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 212 | if self.verbose: 213 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 214 | if decay_factor0 != None: 215 | self.param_groups[0]['lr0'] = self.param_groups[0]['lr0']/decay_factor0 216 | if self.verbose: 217 | print ('Reducing learning rate (inner) to %.5f @ T=%s!'%(self.param_groups[0]['lr0'], self.steps)) 218 | if self.verbose: 219 | print ('Updating regularizer @ T=%s!'%(self.steps)) 220 | for i, param in enumerate(self.model_ref): 221 | self.model_ref[i].data = self.model_acc[i].data/self.T 222 | for i, param in enumerate(self.model_acc): 223 | self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device) 224 | self.T = 0 225 | 226 | 227 | -------------------------------------------------------------------------------- /libauc/optimizers/pesg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class PESG(torch.optim.Optimizer): 4 | """Proximal Epoch Stochastic Gradient (PESG) 5 | 6 | Reference: 7 | Yuan, Z., Yan, Y., Sonka, M. and Yang, T., 8 | Large-scale Robust Deep AUC Maximization: A New Surrogate Loss and Empirical Studies on Medical Image Classification. 9 | International Conference on Computer Vision (ICCV 2021) 10 | Link: 11 | https://arxiv.org/abs/2012.03173 12 | """ 13 | def __init__(self, 14 | model, 15 | loss_fn=None, 16 | a=None, # to be deprecated 17 | b=None, # to be deprecated 18 | alpha=None, # to be deprecated 19 | margin=1.0, 20 | lr=0.1, 21 | gamma=None, # to be deprecated 22 | clip_value=1.0, 23 | weight_decay=1e-5, 24 | epoch_decay=2e-3, # default: gamma=500 25 | momentum=0, 26 | verbose=True, 27 | device=None, 28 | **kwargs): 29 | 30 | if not device: 31 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | else: 33 | self.device = device 34 | assert (gamma is None) or (epoch_decay is None), 'You can only use one of gamma and epoch_decay!' 35 | if gamma is not None: 36 | assert gamma > 0 37 | epoch_decay = 1/gamma 38 | 39 | self.margin = margin 40 | self.model = model 41 | self.lr = lr 42 | self.gamma = gamma # to be deprecated 43 | self.clip_value = clip_value 44 | self.momentum = momentum 45 | self.weight_decay = weight_decay 46 | self.epoch_decay = epoch_decay 47 | 48 | self.loss_fn = loss_fn 49 | if loss_fn != None: 50 | try: 51 | self.a = loss_fn.a 52 | self.b = loss_fn.b 53 | self.alpha = loss_fn.alpha 54 | except: 55 | print('AUCLoss is not found!') 56 | else: 57 | self.a = a 58 | self.b = b 59 | self.alpha = alpha 60 | 61 | self.model_ref = self.init_model_ref() 62 | self.model_acc = self.init_model_acc() 63 | self.T = 0 # for epoch_decay 64 | self.steps = 0 # total optim steps 65 | self.verbose = verbose # print updates for lr/regularizer 66 | 67 | def get_parameters(params): 68 | for p in params: 69 | yield p 70 | if self.a is not None and self.b is not None: 71 | self.params = get_parameters(list(model.parameters())+[self.a, self.b]) 72 | else: 73 | self.params = get_parameters(list(model.parameters())) 74 | self.defaults = dict(lr=self.lr, 75 | margin=margin, 76 | a=self.a, 77 | b=self.b, 78 | alpha=self.alpha, 79 | clip_value=clip_value, 80 | momentum=momentum, 81 | weight_decay=weight_decay, 82 | epoch_decay=epoch_decay, 83 | model_ref=self.model_ref, 84 | model_acc=self.model_acc 85 | ) 86 | 87 | super(PESG, self).__init__(self.params, self.defaults) 88 | 89 | def __setstate__(self, state): 90 | super(PESG, self).__setstate__(state) 91 | for group in self.param_groups: 92 | group.setdefault('nesterov', False) 93 | 94 | def init_model_ref(self): 95 | self.model_ref = [] 96 | for var in list(self.model.parameters())+[self.a, self.b]: 97 | if var is not None: 98 | self.model_ref.append(torch.empty(var.shape).normal_(mean=0, std=0.01).to(self.device)) 99 | return self.model_ref 100 | 101 | def init_model_acc(self): 102 | self.model_acc = [] 103 | for var in list(self.model.parameters())+[self.a, self.b]: 104 | if var is not None: 105 | self.model_acc.append(torch.zeros(var.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device)) 106 | return self.model_acc 107 | 108 | @property 109 | def optim_steps(self): 110 | return self.steps 111 | 112 | @property 113 | def get_params(self): 114 | return list(self.model.parameters()) 115 | 116 | @torch.no_grad() 117 | def step(self, closure=None): 118 | """Performs a single optimization step. 119 | """ 120 | loss = None 121 | if closure is not None: 122 | with torch.enable_grad(): 123 | loss = closure() 124 | 125 | for group in self.param_groups: 126 | weight_decay = group['weight_decay'] 127 | clip_value = group['clip_value'] 128 | momentum = group['momentum'] 129 | self.lr = group['lr'] 130 | 131 | epoch_decay = group['epoch_decay'] 132 | model_ref = group['model_ref'] 133 | model_acc = group['model_acc'] 134 | 135 | m = group['margin'] 136 | a = group['a'] 137 | b = group['b'] 138 | alpha = group['alpha'] 139 | 140 | # updates 141 | for i, p in enumerate(group['params']): 142 | if p.grad is None: 143 | continue 144 | d_p = torch.clamp(p.grad.data , -clip_value, clip_value) + epoch_decay*(p.data - model_ref[i].data) + weight_decay*p.data 145 | if momentum != 0: 146 | param_state = self.state[p] 147 | if 'momentum_buffer' not in param_state: 148 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 149 | else: 150 | buf = param_state['momentum_buffer'] 151 | buf.mul_(1-momentum).add_(d_p, alpha=momentum) 152 | d_p = buf 153 | p.data = p.data - group['lr']*d_p 154 | model_acc[i].data = model_acc[i].data + p.data 155 | 156 | if alpha is not None: 157 | if alpha.grad is not None: 158 | alpha.data = alpha.data + group['lr']*(2*(m + b.data - a.data)-2*alpha.data) 159 | alpha.data = torch.clamp(alpha.data, 0, 999) 160 | 161 | self.T += 1 162 | self.steps += 1 163 | return loss 164 | 165 | def zero_grad(self): 166 | self.model.zero_grad() 167 | if self.a is not None and self.b is not None: 168 | self.a.grad = None 169 | self.b.grad = None 170 | if self.alpha is not None: 171 | self.alpha.grad = None 172 | 173 | def update_lr(self, decay_factor=None): 174 | if decay_factor != None: 175 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 176 | if self.verbose: 177 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 178 | 179 | def update_regularizer(self, decay_factor=None): 180 | if decay_factor != None: 181 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 182 | if self.verbose: 183 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 184 | if self.verbose: 185 | print ('Updating regularizer @ T=%s!'%(self.steps)) 186 | for i, param in enumerate(self.model_ref): 187 | self.model_ref[i].data = self.model_acc[i].data/self.T 188 | for i, param in enumerate(self.model_acc): 189 | self.model_acc[i].data = torch.zeros(param.shape, dtype=torch.float32, device=self.device, requires_grad=False).to(self.device) 190 | self.T = 0 191 | 192 | 193 | -------------------------------------------------------------------------------- /libauc/optimizers/sgd.py: -------------------------------------------------------------------------------- 1 | # Code is modified based on PyTorch implementation: https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer, required 5 | 6 | class SGD(torch.optim.Optimizer): 7 | r"""Implements stochastic gradient descent (optionally with momentum). 8 | 9 | Nesterov momentum is based on the formula from 10 | `On the importance of initialization and momentum in deep learning`__. 11 | 12 | Args: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float): learning rate 16 | momentum (float, optional): momentum factor (default: 0) 17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 18 | dampening (float, optional): dampening for momentum (default: 0) 19 | nesterov (bool, optional): enables Nesterov momentum (default: False) 20 | 21 | Example: 22 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 23 | >>> optimizer.zero_grad() 24 | >>> loss_fn(model(input), target).backward() 25 | >>> optimizer.step() 26 | 27 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 28 | 29 | .. note:: 30 | The implementation of SGD with Momentum/Nesterov subtly differs from 31 | Sutskever et. al. and implementations in some other frameworks. 32 | 33 | Considering the specific case of Momentum, the update can be written as 34 | 35 | .. math:: 36 | \begin{aligned} 37 | v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ 38 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 39 | \end{aligned} 40 | 41 | where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 42 | parameters, gradient, velocity, and momentum respectively. 43 | 44 | This is in contrast to Sutskever et. al. and 45 | other frameworks which employ an update of the form 46 | 47 | .. math:: 48 | \begin{aligned} 49 | v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ 50 | p_{t+1} & = p_{t} - v_{t+1}. 51 | \end{aligned} 52 | 53 | The Nesterov version is analogously modified. 54 | """ 55 | 56 | def __init__(self, 57 | params, 58 | lr=required, 59 | momentum=0, 60 | dampening=0, 61 | weight_decay=0, 62 | nesterov=False, 63 | verbose=True): 64 | 65 | if lr is not required and lr < 0.0: 66 | raise ValueError("Invalid learning rate: {}".format(lr)) 67 | if momentum < 0.0: 68 | raise ValueError("Invalid momentum value: {}".format(momentum)) 69 | if weight_decay < 0.0: 70 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 71 | 72 | try: 73 | params = params.parameters() 74 | except: 75 | params = params 76 | 77 | self.lr = lr 78 | self.verbose = verbose 79 | 80 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 81 | weight_decay=weight_decay, nesterov=nesterov) 82 | if nesterov and (momentum <= 0 or dampening != 0): 83 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 84 | super(SGD, self).__init__(params, defaults) 85 | 86 | 87 | 88 | def __setstate__(self, state): 89 | super(SGD, self).__setstate__(state) 90 | for group in self.param_groups: 91 | group.setdefault('nesterov', False) 92 | 93 | @torch.no_grad() 94 | def step(self, closure=None): 95 | """Performs a single optimization step. 96 | 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | loss = None 102 | if closure is not None: 103 | with torch.enable_grad(): 104 | loss = closure() 105 | 106 | for group in self.param_groups: 107 | weight_decay = group['weight_decay'] 108 | momentum = group['momentum'] 109 | dampening = group['dampening'] 110 | nesterov = group['nesterov'] 111 | self.lr = group['lr'] 112 | 113 | for p in group['params']: 114 | if p.grad is None: 115 | continue 116 | d_p = p.grad 117 | if weight_decay != 0: 118 | d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy) 119 | if momentum != 0: 120 | param_state = self.state[p] 121 | if 'momentum_buffer' not in param_state: 122 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 123 | else: 124 | buf = param_state['momentum_buffer'] 125 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p 126 | if nesterov: 127 | d_p = d_p.add(buf, alpha=momentum) 128 | else: 129 | d_p = buf 130 | 131 | p.add_(d_p, alpha=-group['lr']) 132 | 133 | return loss 134 | 135 | def update_lr(self, decay_factor=None): 136 | if decay_factor != None: 137 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 138 | if self.verbose: 139 | print ('Reducing learning rate to %.5f !'%(self.param_groups[0]['lr'])) 140 | -------------------------------------------------------------------------------- /libauc/optimizers/soap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class SOAP(torch.optim.Optimizer): 6 | r""" 7 | # This is a wrapper of SOAP_ADAM and SOAP_SGD 8 | """ 9 | def __init__(self, params, 10 | lr=required, weight_decay=0, 11 | mode='sgd', 12 | momentum=0, dampening=0, nesterov=False, # sgd 13 | betas=(0.9, 0.999), eps=1e-8, amsgrad=False, # adam 14 | ): 15 | if lr is not required and lr < 0.0: 16 | raise ValueError("Invalid learning rate: {}".format(lr)) 17 | if momentum < 0.0: 18 | raise ValueError("Invalid momentum value: {}".format(momentum)) 19 | if not 0.0 <= eps: 20 | raise ValueError("Invalid epsilon value: {}".format(eps)) 21 | if not 0.0 <= betas[0] < 1.0: 22 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 23 | if not 0.0 <= betas[1] < 1.0: 24 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 25 | if not 0.0 <= weight_decay: 26 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 27 | if not isinstance(mode, str): 28 | raise ValueError("Invalid mode type: {}".format(mode)) 29 | 30 | try: 31 | params = params.parameters() 32 | except: 33 | params = params 34 | 35 | self.lr = lr 36 | self.steps = 0 37 | self.mode = mode.lower() 38 | defaults = dict(lr=lr, weight_decay=weight_decay, 39 | momentum=momentum, dampening=dampening, nesterov=nesterov, 40 | betas=betas, eps=eps, amsgrad=amsgrad) 41 | if nesterov and (momentum <= 0 or dampening != 0): 42 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 43 | 44 | super(SOAP, self).__init__(params, defaults) 45 | 46 | def __setstate__(self, state): 47 | super(SOAP, self).__setstate__(state) 48 | for group in self.param_groups: 49 | if self.mode == 'sgd': 50 | group.setdefault('nesterov', False) 51 | elif self.mode == 'adam': 52 | group.setdefault('amsgrad', False) 53 | else: 54 | NotImplementedError 55 | 56 | 57 | @torch.no_grad() 58 | def step(self, closure=None): 59 | """Performs a single optimization step. 60 | 61 | Arguments: 62 | closure (callable, optional): A closure that reevaluates the model 63 | and returns the loss. 64 | """ 65 | loss = None 66 | if closure is not None: 67 | with torch.enable_grad(): 68 | loss = closure() 69 | 70 | for group in self.param_groups: 71 | if self.mode == 'sgd': 72 | weight_decay = group['weight_decay'] 73 | momentum = group['momentum'] 74 | dampening = group['dampening'] 75 | nesterov = group['nesterov'] 76 | self.lr = group['lr'] 77 | for p in group['params']: 78 | if p.grad is None: 79 | print(p.shape) 80 | continue 81 | d_p = p.grad 82 | if weight_decay != 0: 83 | d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy) 84 | if momentum != 0: 85 | param_state = self.state[p] 86 | if 'momentum_buffer' not in param_state: 87 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 88 | else: 89 | buf = param_state['momentum_buffer'] 90 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p 91 | if nesterov: 92 | d_p = d_p.add(buf, alpha=momentum) 93 | else: 94 | d_p = buf 95 | p.add_(d_p, alpha=-group['lr']) 96 | 97 | elif self.mode == 'adam': 98 | self.lr = group['lr'] 99 | for p in group['params']: 100 | if p.grad is None: 101 | continue 102 | grad = p.grad 103 | if grad.is_sparse: 104 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 105 | amsgrad = group['amsgrad'] 106 | state = self.state[p] 107 | 108 | # State initialization 109 | if len(state) == 0: 110 | state['step'] = 0 111 | # Exponential moving average of gradient values 112 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 113 | # Exponential moving average of squared gradient values 114 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 115 | if amsgrad: 116 | # Maintains max of all exp. moving avg. of sq. grad. values 117 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 118 | 119 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 120 | if amsgrad: 121 | max_exp_avg_sq = state['max_exp_avg_sq'] 122 | beta1, beta2 = group['betas'] 123 | 124 | state['step'] += 1 125 | bias_correction1 = 1 - beta1 ** state['step'] 126 | bias_correction2 = 1 - beta2 ** state['step'] 127 | 128 | if group['weight_decay'] != 0: 129 | grad = grad.add(p, alpha=group['weight_decay']) 130 | 131 | # Decay the first and second moment running average coefficient 132 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 133 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 134 | if amsgrad: 135 | # Maintains the maximum of all 2nd moment running avg. till now 136 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 137 | # Use the max. for normalizing running avg. of gradient 138 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 139 | else: 140 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 141 | 142 | step_size = group['lr'] / bias_correction1 143 | 144 | p.addcdiv_(exp_avg, denom, value=-step_size) 145 | 146 | self.steps += 1 147 | return loss 148 | 149 | def update_lr(self, decay_factor=None): 150 | if decay_factor != None: 151 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 152 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 153 | 154 | -------------------------------------------------------------------------------- /libauc/optimizers/song.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | 8 | class SONG(torch.optim.Optimizer): 9 | """ 10 | This is a wrapper of SONG_ADAM and SONG_SGD: 11 | - the optimization mode can be 'sgd' (with momentum) or 'adam' 12 | - you can specify \eta in our paper by setting lr 13 | - you can specify \beta_1 in our paper by setting momentum=1-\beta_1 (in 'sgd' mode) 14 | or by setting betas[0]=1-\beta_1 (in 'adam' mode) 15 | """ 16 | def __init__(self, params, 17 | lr=required, weight_decay=0, 18 | mode='sgd', 19 | momentum=0.9, dampening=0, nesterov=False, # sgd 20 | betas=(0.9, 0.999), eps=1e-8, amsgrad=False, # adam 21 | ): 22 | 23 | if lr is not required and lr < 0.0: 24 | raise ValueError("Invalid learning rate: {}".format(lr)) 25 | if momentum < 0.0: 26 | raise ValueError("Invalid momentum value: {}".format(momentum)) 27 | if not 0.0 <= eps: 28 | raise ValueError("Invalid epsilon value: {}".format(eps)) 29 | if not 0.0 <= betas[0] < 1.0: 30 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 31 | if not 0.0 <= betas[1] < 1.0: 32 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 33 | if not 0.0 <= weight_decay: 34 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 35 | if not isinstance(mode, str): 36 | raise ValueError("Invalid mode type: {}".format(mode)) 37 | 38 | self.mode = mode.lower() 39 | 40 | defaults = dict(lr=lr, weight_decay=weight_decay, 41 | momentum=momentum, dampening=dampening, nesterov=nesterov, 42 | betas=betas, eps=eps, amsgrad=amsgrad) 43 | 44 | if nesterov and (momentum <= 0 or dampening != 0): # sgd 45 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 46 | 47 | super(SONG, self).__init__(params, defaults) 48 | 49 | 50 | def __setstate__(self, state): 51 | r""" 52 | # Set default options for sgd mode and adam mode 53 | """ 54 | super(SONG, self).__setstate__(state) 55 | for group in self.param_groups: 56 | if self.mode == 'sgd': 57 | group.setdefault('nesterov', False) 58 | elif self.mode == 'adam': 59 | group.setdefault('amsgrad', False) 60 | else: 61 | NotImplementedError 62 | 63 | 64 | @torch.no_grad() 65 | def step(self, closure=None): 66 | """Performs a single optimization step. 67 | 68 | Arguments: 69 | closure (callable, optional): A closure that reevaluates the model 70 | and returns the loss. 71 | """ 72 | loss = None 73 | if closure is not None: 74 | with torch.enable_grad(): 75 | loss = closure() 76 | 77 | for group in self.param_groups: 78 | if self.mode == 'sgd': 79 | weight_decay = group['weight_decay'] 80 | momentum = group['momentum'] 81 | dampening = group['dampening'] 82 | nesterov = group['nesterov'] 83 | self.lr = group['lr'] 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | d_p = p.grad 88 | if weight_decay != 0: 89 | d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy) 90 | if momentum != 0: 91 | param_state = self.state[p] 92 | if 'momentum_buffer' not in param_state: 93 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 94 | else: 95 | buf = param_state['momentum_buffer'] 96 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p 97 | if nesterov: 98 | d_p = d_p.add(buf, alpha=momentum) 99 | else: 100 | d_p = buf 101 | p.add_(d_p, alpha=-group['lr']) 102 | 103 | elif self.mode == 'adam': 104 | self.lr = group['lr'] 105 | for p in group['params']: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad 109 | if grad.is_sparse: 110 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 111 | amsgrad = group['amsgrad'] 112 | state = self.state[p] 113 | 114 | # State initialization 115 | if len(state) == 0: 116 | state['step'] = 0 117 | # Exponential moving average of gradient values 118 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 119 | # Exponential moving average of squared gradient values 120 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 121 | if amsgrad: 122 | # Maintains max of all exp. moving avg. of sq. grad. values 123 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | if amsgrad: 127 | max_exp_avg_sq = state['max_exp_avg_sq'] 128 | beta1, beta2 = group['betas'] 129 | 130 | state['step'] += 1 131 | bias_correction1 = 1 - beta1 ** state['step'] 132 | bias_correction2 = 1 - beta2 ** state['step'] 133 | 134 | if group['weight_decay'] != 0: 135 | grad = grad.add(p, alpha=group['weight_decay']) 136 | 137 | # Decay the first and second moment running average coefficient 138 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 139 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 140 | if amsgrad: 141 | # Maintains the maximum of all 2nd moment running avg. till now 142 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 143 | # Use the max. for normalizing running avg. of gradient 144 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 145 | else: 146 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 147 | 148 | step_size = group['lr'] / bias_correction1 149 | 150 | p.addcdiv_(exp_avg, denom, value=-step_size) 151 | return loss 152 | 153 | def update_lr(self, decay_factor=None): 154 | if decay_factor != None: 155 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 156 | print ('Reducing learning rate to %.5f !'%(self.param_groups[0]['lr'])) 157 | -------------------------------------------------------------------------------- /libauc/optimizers/sopa.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | class SOPA(torch.optim.Optimizer): 5 | r"""Implements Adam algorithm. 6 | 7 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 8 | The implementation of the L2 penalty follows changes proposed in 9 | `Decoupled Weight Decay Regularization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False) 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _Decoupled Weight Decay Regularization: 27 | https://arxiv.org/abs/1711.05101 28 | .. _On the Convergence of Adam and Beyond: 29 | https://openreview.net/forum?id=ryQu7f-RZ 30 | """ 31 | 32 | def __init__(self, model, loss_fn, 33 | mode = 'adam', 34 | eta=1.0, lr=1e-3, weight_decay=0, 35 | betas=(0.9, 0.999), eps=1e-8, amsgrad=False, # adam 36 | momentum=0.9, nesterov=False, dampening=0 # sgd 37 | ): 38 | 39 | if not 0.0 <= lr: 40 | raise ValueError("Invalid learning rate: {}".format(lr)) 41 | if not 0.0 <= eps: 42 | raise ValueError("Invalid epsilon value: {}".format(eps)) 43 | if not 0.0 <= betas[0] < 1.0: 44 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 45 | if not 0.0 <= betas[1] < 1.0: 46 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 47 | if not 0.0 <= weight_decay: 48 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 49 | 50 | try: 51 | params = model.parameters() 52 | except: 53 | params = model 54 | 55 | self.params = params 56 | self.lr = lr 57 | self.mode = mode.lower() 58 | self.loss_fn = loss_fn 59 | self.loss_fn.set_coef(eta) 60 | self.steps = 0 61 | 62 | defaults = dict(lr=lr, betas=betas, eps=eps, momentum=momentum, nesterov=nesterov, dampening=dampening, 63 | weight_decay=weight_decay, amsgrad=amsgrad) 64 | super(SOPA, self).__init__(self.params, defaults) 65 | 66 | def __setstate__(self, state): 67 | r""" 68 | # Set default options for sgd mode and adam mode 69 | """ 70 | super(SOPA, self).__setstate__(state) 71 | for group in self.param_groups: 72 | if self.mode == 'sgd': 73 | group.setdefault('nesterov', False) 74 | elif self.mode == 'adam': 75 | group.setdefault('amsgrad', False) 76 | else: 77 | NotImplementedError 78 | 79 | @torch.no_grad() 80 | def step(self, closure=None): 81 | """Performs a single optimization step. 82 | 83 | Arguments: 84 | closure (callable, optional): A closure that reevaluates the model 85 | and returns the loss. 86 | """ 87 | loss = None 88 | if closure is not None: 89 | with torch.enable_grad(): 90 | loss = closure() 91 | 92 | for group in self.param_groups: 93 | if self.mode == 'adam': 94 | self.lr = group['lr'] 95 | for i, p in enumerate(group['params']): 96 | if p.grad is None: 97 | continue 98 | grad = p.grad 99 | if grad.is_sparse: 100 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 101 | amsgrad = group['amsgrad'] 102 | state = self.state[p] 103 | # State initialization 104 | if len(state) == 0: 105 | state['step'] = 0 106 | # Exponential moving average of gradient values 107 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 108 | # Exponential moving average of squared gradient values 109 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 110 | if amsgrad: 111 | # Maintains max of all exp. moving avg. of sq. grad. values 112 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | if amsgrad: 115 | max_exp_avg_sq = state['max_exp_avg_sq'] 116 | beta1, beta2 = group['betas'] 117 | state['step'] += 1 118 | bias_correction1 = 1 - beta1 ** state['step'] 119 | bias_correction2 = 1 - beta2 ** state['step'] 120 | if group['weight_decay'] != 0: 121 | grad = grad.add(p, alpha=group['weight_decay']) 122 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 123 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 124 | if amsgrad: 125 | # Maintains the maximum of all 2nd moment running avg. till now 126 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 127 | # Use the max. for normalizing running avg. of gradient 128 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 129 | else: 130 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 131 | step_size = group['lr'] / bias_correction1 132 | p.addcdiv_(exp_avg, denom, value=-step_size) 133 | elif self.mode == 'sgd': 134 | weight_decay = group['weight_decay'] 135 | momentum = group['momentum'] 136 | dampening = group['dampening'] 137 | nesterov = group['nesterov'] 138 | self.lr = group['lr'] 139 | for p in group['params']: 140 | if p.grad is None: 141 | continue 142 | d_p = p.grad 143 | if weight_decay != 0: 144 | d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy) 145 | if momentum != 0: 146 | param_state = self.state[p] 147 | if 'momentum_buffer' not in param_state: 148 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 149 | else: 150 | buf = param_state['momentum_buffer'] 151 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p 152 | if nesterov: 153 | d_p = d_p.add(buf, alpha=momentum) 154 | else: 155 | d_p = buf 156 | p.add_(d_p, alpha=-group['lr']) 157 | self.steps += 1 158 | return loss 159 | 160 | def update_lr(self, decay_factor=None, coef_decay_factor=None): 161 | if decay_factor != None: 162 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor # for learning rate 163 | print ('Reducing lr to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 164 | if coef_decay_factor != None: 165 | self.loss_fn.update_coef(coef_decay_factor) # for moving average 166 | print ('Reducing eta/gamma to %.5f @ T=%s!!' %(self.loss_fn.get_coef, self.steps)) 167 | 168 | -------------------------------------------------------------------------------- /libauc/optimizers/sopa_s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | class SOPAs(torch.optim.Optimizer): 5 | r"""A wrapper class for different optimizing methods. 6 | 7 | Arguments: 8 | params (iterable): iterable of parameters to optimize or dicts defining 9 | parameter groups 10 | lr (float): learning rate 11 | loss_fn: the instance of loss class 12 | method (str): optimization method 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | 15 | Arguments for SGD optimization method: 16 | momentum (float, optional): momentum factor (default: 0.9) 17 | dampening (float, optional): dampening for momentum (default: 0.1) 18 | nesterov (bool, optional): enables Nesterov momentum (default: False) 19 | Arguments for ADAM optimization method: 20 | betas (Tuple[float, float], optional): coefficients used for computing 21 | running averages of gradient and its square (default: (0.9, 0.999)) 22 | eps (float, optional): term added to the denominator to improve 23 | numerical stability (default: 1e-8) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | """ 28 | 29 | def __init__(self, model, loss_fn, 30 | mode = 'adam', 31 | lr=1e-4, weight_decay=0, 32 | momentum=0.0, nesterov=False, dampening=0, # sgd 33 | betas=(0.9, 0.999), eps=1e-8, amsgrad=False # adam 34 | ): 35 | 36 | if not 0.0 <= lr: 37 | raise ValueError("Invalid learning rate: {}".format(lr)) 38 | if not 0.0 <= eps: 39 | raise ValueError("Invalid epsilon value: {}".format(eps)) 40 | if not 0.0 <= betas[0] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 42 | if not 0.0 <= betas[1] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 44 | if not 0.0 <= weight_decay: 45 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 46 | 47 | try: 48 | params = model.parameters() 49 | except: 50 | params = model # if model is already params 51 | 52 | self.params = params 53 | self.loss_fn = loss_fn 54 | self.lr = lr 55 | self.steps = 0 56 | self.mode = mode.lower() 57 | 58 | defaults = dict(lr=lr, weight_decay=weight_decay, 59 | momentum=momentum, dampening=dampening, nesterov=nesterov, 60 | betas=betas, eps=eps, amsgrad=amsgrad) 61 | 62 | if nesterov and (momentum <= 0 or dampening != 0): 63 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 64 | 65 | super(SOPAs, self).__init__(self.params, defaults) 66 | 67 | 68 | def __setstate__(self, state): 69 | r""" 70 | # Set default options for sgd mode and adam mode 71 | """ 72 | super(SOPAs, self).__setstate__(state) 73 | for group in self.param_groups: 74 | if self.mode == 'sgd': 75 | group.setdefault('nesterov', False) 76 | elif self.mode == 'adam': 77 | group.setdefault('amsgrad', False) 78 | else: 79 | NotImplementedError 80 | 81 | @torch.no_grad() 82 | def step(self, closure=None): 83 | """Performs a single optimization step. 84 | Arguments: 85 | closure (callable, optional): A closure that reevaluates the model 86 | and returns the loss. 87 | """ 88 | loss = None 89 | if closure is not None: 90 | with torch.enable_grad(): 91 | loss = closure() 92 | 93 | for group in self.param_groups: 94 | if self.mode == 'sgd': 95 | weight_decay = group['weight_decay'] 96 | momentum = group['momentum'] 97 | dampening = group['dampening'] 98 | nesterov = group['nesterov'] 99 | self.lr = group['lr'] 100 | for p in group['params']: 101 | if p.grad is None: 102 | print(p.shape) 103 | continue 104 | d_p = p.grad 105 | if weight_decay != 0: 106 | d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy) 107 | if momentum != 0: 108 | param_state = self.state[p] 109 | if 'momentum_buffer' not in param_state: 110 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 111 | else: 112 | buf = param_state['momentum_buffer'] 113 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p 114 | if nesterov: 115 | d_p = d_p.add(buf, alpha=momentum) 116 | else: 117 | d_p = buf 118 | p.add_(d_p, alpha=-group['lr']) 119 | 120 | elif self.mode == 'adam': 121 | self.lr = group['lr'] 122 | for p in group['params']: 123 | if p.grad is None: 124 | continue 125 | grad = p.grad 126 | if grad.is_sparse: 127 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 128 | amsgrad = group['amsgrad'] 129 | state = self.state[p] 130 | 131 | # State initialization 132 | if len(state) == 0: 133 | state['step'] = 0 134 | # Exponential moving average of gradient values 135 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 136 | # Exponential moving average of squared gradient values 137 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 138 | if amsgrad: 139 | # Maintains max of all exp. moving avg. of sq. grad. values 140 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 141 | 142 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 143 | if amsgrad: 144 | max_exp_avg_sq = state['max_exp_avg_sq'] 145 | beta1, beta2 = group['betas'] 146 | 147 | state['step'] += 1 148 | bias_correction1 = 1 - beta1 ** state['step'] 149 | bias_correction2 = 1 - beta2 ** state['step'] 150 | 151 | if group['weight_decay'] != 0: 152 | grad = grad.add(p, alpha=group['weight_decay']) 153 | 154 | # Decay the first and second moment running average coefficient 155 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 156 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 157 | if amsgrad: 158 | # Maintains the maximum of all 2nd moment running avg. till now 159 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 160 | # Use the max. for normalizing running avg. of gradient 161 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 162 | else: 163 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 164 | 165 | step_size = group['lr'] / bias_correction1 166 | 167 | p.addcdiv_(exp_avg, denom, value=-step_size) 168 | 169 | self.steps += 1 170 | return loss 171 | 172 | def update_lr(self, decay_factor=None, coef_decay_factor=None): 173 | if decay_factor != None: 174 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 175 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 176 | if coef_decay_factor != None: 177 | self.loss_fn.update_coef(coef_decay_factor) 178 | print ('Reducing eta/gamma to %.4f @ T=%s!' % (self.loss_fn.get_coef, self.steps)) 179 | 180 | def update_regularizer(self, decay_factor=None): 181 | pass 182 | 183 | -------------------------------------------------------------------------------- /libauc/optimizers/sota_s.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | class SOTAs(torch.optim.Optimizer): 5 | r"""Implements Adam algorithm. 6 | 7 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 8 | The implementation of the L2 penalty follows changes proposed in 9 | `Decoupled Weight Decay Regularization`_. 10 | 11 | Arguments: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float, optional): learning rate (default: 1e-3) 15 | betas (Tuple[float, float], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False) 23 | 24 | .. _Adam\: A Method for Stochastic Optimization: 25 | https://arxiv.org/abs/1412.6980 26 | .. _Decoupled Weight Decay Regularization: 27 | https://arxiv.org/abs/1711.05101 28 | .. _On the Convergence of Adam and Beyond: 29 | https://openreview.net/forum?id=ryQu7f-RZ 30 | """ 31 | 32 | def __init__(self, model, loss_fn, 33 | mode = 'adam', 34 | lr=1e-3, weight_decay=0, 35 | gammas=(0.9, 0.9), 36 | betas=(0.9, 0.999), eps=1e-8, amsgrad=False, # adam 37 | momentum=0.9, nesterov=False, dampening=0, # sgd 38 | ): 39 | if not 0.0 <= lr: 40 | raise ValueError("Invalid learning rate: {}".format(lr)) 41 | if not 0.0 <= eps: 42 | raise ValueError("Invalid epsilon value: {}".format(eps)) 43 | if not 0.0 <= betas[0] < 1.0: 44 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 45 | if not 0.0 <= betas[1] < 1.0: 46 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 47 | if not 0.0 <= weight_decay: 48 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 49 | 50 | try: 51 | params = model.parameters() 52 | except: 53 | params = model # if model is already params 54 | 55 | self.params = params 56 | self.lr = lr 57 | self.mode = mode.lower() 58 | self.loss_fn = loss_fn 59 | self.loss_fn.set_coef(gamma0=gammas[0], gamma1=gammas[1]) 60 | self.steps = 0 61 | 62 | defaults = dict(lr=lr, betas=betas, eps=eps, momentum=momentum, nesterov=nesterov, dampening=dampening, 63 | weight_decay=weight_decay, amsgrad=amsgrad) 64 | super(SOTAs, self).__init__(self.params, defaults) 65 | 66 | 67 | def __setstate__(self, state): 68 | r""" 69 | # Set default options for sgd mode and adam mode 70 | """ 71 | super(SOTAs, self).__setstate__(state) 72 | for group in self.param_groups: 73 | if self.mode == 'sgd': 74 | group.setdefault('nesterov', False) 75 | elif self.mode == 'adam': 76 | group.setdefault('amsgrad', False) 77 | else: 78 | NotImplementedError 79 | 80 | @torch.no_grad() 81 | def step(self, closure=None): 82 | """Performs a single optimization step. 83 | 84 | Arguments: 85 | closure (callable, optional): A closure that reevaluates the model 86 | and returns the loss. 87 | """ 88 | loss = None 89 | if closure is not None: 90 | with torch.enable_grad(): 91 | loss = closure() 92 | 93 | for group in self.param_groups: 94 | if self.mode == 'adam': 95 | self.lr = group['lr'] 96 | for i, p in enumerate(group['params']): 97 | if p.grad is None: 98 | continue 99 | grad = p.grad 100 | if grad.is_sparse: 101 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 102 | amsgrad = group['amsgrad'] 103 | state = self.state[p] 104 | # State initialization 105 | if len(state) == 0: 106 | state['step'] = 0 107 | # Exponential moving average of gradient values 108 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 109 | # Exponential moving average of squared gradient values 110 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 111 | if amsgrad: 112 | # Maintains max of all exp. moving avg. of sq. grad. values 113 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 114 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 115 | if amsgrad: 116 | max_exp_avg_sq = state['max_exp_avg_sq'] 117 | beta1, beta2 = group['betas'] 118 | state['step'] += 1 119 | bias_correction1 = 1 - beta1 ** state['step'] 120 | bias_correction2 = 1 - beta2 ** state['step'] 121 | if group['weight_decay'] != 0: 122 | grad = grad.add(p, alpha=group['weight_decay']) 123 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 124 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 125 | if amsgrad: 126 | # Maintains the maximum of all 2nd moment running avg. till now 127 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 128 | # Use the max. for normalizing running avg. of gradient 129 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 130 | else: 131 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 132 | step_size = group['lr'] / bias_correction1 133 | p.addcdiv_(exp_avg, denom, value=-step_size) 134 | elif self.mode == 'sgd': 135 | weight_decay = group['weight_decay'] 136 | momentum = group['momentum'] 137 | dampening = group['dampening'] 138 | nesterov = group['nesterov'] 139 | self.lr = group['lr'] 140 | for p in group['params']: 141 | if p.grad is None: 142 | continue 143 | d_p = p.grad 144 | if weight_decay != 0: 145 | d_p = d_p.add(p, alpha=weight_decay) # d_p = (d_p + p*weight_decy) 146 | if momentum != 0: 147 | param_state = self.state[p] 148 | if 'momentum_buffer' not in param_state: 149 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 150 | else: 151 | buf = param_state['momentum_buffer'] 152 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) # [v = v*beta + d_p ] --> new d_p 153 | if nesterov: 154 | d_p = d_p.add(buf, alpha=momentum) 155 | else: 156 | d_p = buf 157 | p.add_(d_p, alpha=-group['lr']) 158 | self.steps += 1 159 | return loss 160 | 161 | def update_lr(self, decay_factor=None, coef_decay_factor=None ): 162 | if decay_factor != None: 163 | self.param_groups[0]['lr'] = self.param_groups[0]['lr']/decay_factor 164 | print ('Reducing learning rate to %.5f @ T=%s!'%(self.param_groups[0]['lr'], self.steps)) 165 | if coef_decay_factor != None: 166 | self.loss_fn.update_coef(coef_decay_factor) 167 | coefs = self.loss_fn.get_coef 168 | coefs = '(%.4f, %.4f)'%(coefs[0], coefs[1]) 169 | print ('Reducing eta/gamma to %s @ T=%s!' % (coefs, self.steps)) 170 | 171 | -------------------------------------------------------------------------------- /libauc/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import ImbalancedDataSampler 2 | from .sampler import DualSampler 3 | from .sampler import TriSampler 4 | from .ranking import DataSampler # move to sampler 5 | -------------------------------------------------------------------------------- /libauc/sampler/ranking.py: -------------------------------------------------------------------------------- 1 | """ 2 | If you use our sampler functions, please acknowledge us and cite the following papers: 3 | @misc{libauc2022, 4 | title={LibAUC: A Deep Learning Library for X-Risk Optimization.}, 5 | author={Zhuoning Yuan, Zi-Hao Qiu, Gang Li, Dixian Zhu, Zhishuai Guo, Quanqi Hu, Bokun Wang, Qi Qi, Yongjian Zhong, Tianbao Yang}, 6 | year={2022} 7 | } 8 | """ 9 | 10 | import torch 11 | from torch.utils.data.sampler import Sampler 12 | import numpy as np 13 | import os 14 | from tqdm import trange 15 | 16 | # TODO: move to sampler.py 17 | class DataSampler(Sampler): 18 | """ 19 | Data Sampler for recommender systems 20 | 21 | Args: 22 | labels: a 2-D csr sparse array: [task_num, item_num] 23 | batch_size: number of all labels (items) in a batch = num_tasks * (num_pos + num_neg) 24 | num_pos: number of positive labels (items) for each task (user) 25 | num_tasks: number of tasks (users) 26 | """ 27 | def __init__(self, labels, batch_size, num_pos, num_tasks): 28 | self.labels = labels 29 | self.num_tasks = num_tasks 30 | self.batch_size = batch_size 31 | 32 | self.num_pos = num_pos 33 | self.num_neg = self.batch_size//num_tasks - self.num_pos 34 | 35 | self.label_dict = {} 36 | 37 | for i in trange(self.labels.shape[0]): 38 | task_label = self.labels[i, :].toarray() 39 | pos_index = np.flatnonzero(task_label>0) 40 | ###To avoid sampling error 41 | while len(pos_index) < self.num_pos: 42 | pos_index = np.concatenate((pos_index,pos_index)) 43 | np.random.shuffle(pos_index) 44 | 45 | neg_index = np.flatnonzero(task_label==0) 46 | while len(neg_index) < self.num_neg: 47 | neg_index = np.concatenate((neg_index,neg_index)) 48 | np.random.shuffle(neg_index) 49 | 50 | self.label_dict.update({i:(pos_index,neg_index)}) 51 | 52 | self.pos_ptr, self.neg_ptr = np.zeros(self.labels.shape[0], dtype=np.int32), np.zeros(self.labels.shape[0], dtype=np.int32) 53 | self.task_ptr, self.tasks = 0, np.random.permutation(list(range(self.labels.shape[0]))) 54 | 55 | self.num_batches = self.labels.shape[0] // self.num_tasks 56 | 57 | self.sampled_task = np.empty(self.num_batches*self.num_tasks, dtype=np.int32) 58 | self.sampled_labels = np.empty((self.num_batches*self.num_tasks, self.num_pos+self.num_neg), dtype=np.int32) 59 | 60 | 61 | def __iter__(self): 62 | 63 | beg = 0 # beg is the pointer for self.ret 64 | 65 | for batch_id in range(self.num_batches): 66 | task_ids = self.tasks[self.task_ptr:self.task_ptr+self.num_tasks] # randomly sample task_ids (number: self.num_tasks) 67 | self.task_ptr += self.num_tasks 68 | if self.task_ptr >= len(self.tasks): 69 | np.random.shuffle(self.tasks) 70 | self.task_ptr = self.task_ptr % len(self.tasks) # if reach the end, then shuffle the list and mod the pointer 71 | 72 | for task_id in task_ids: 73 | item_list = np.empty(self.num_pos+self.num_neg, dtype=np.int16) 74 | 75 | if self.pos_ptr[task_id]+self.num_pos > len(self.label_dict[task_id][0]): 76 | temp = self.label_dict[task_id][0][self.pos_ptr[task_id]:] 77 | np.random.shuffle(self.label_dict[task_id][0]) 78 | self.pos_ptr[task_id] = (self.pos_ptr[task_id]+self.num_pos)%len(self.label_dict[task_id][0]) 79 | if self.pos_ptr[task_id]+len(temp) < self.num_pos: 80 | self.pos_ptr[task_id] += self.num_pos-len(temp) 81 | item_ids = np.concatenate((temp,self.label_dict[task_id][0][:self.pos_ptr[task_id]])) 82 | item_list[:self.num_pos] = item_ids 83 | else: 84 | item_ids = self.label_dict[task_id][0][self.pos_ptr[task_id]:self.pos_ptr[task_id]+self.num_pos] 85 | item_list[:self.num_pos] = item_ids 86 | self.pos_ptr[task_id] += self.num_pos 87 | 88 | if self.neg_ptr[task_id]+self.num_neg > len(self.label_dict[task_id][1]): 89 | temp = self.label_dict[task_id][1][self.neg_ptr[task_id]:] 90 | np.random.shuffle(self.label_dict[task_id][1]) 91 | self.neg_ptr[task_id] = (self.neg_ptr[task_id]+self.num_neg)%len(self.label_dict[task_id][1]) 92 | if self.neg_ptr[task_id]+len(temp) < self.num_neg: 93 | self.neg_ptr[task_id] += self.num_neg-len(temp) 94 | item_ids = np.concatenate((temp,self.label_dict[task_id][1][:self.neg_ptr[task_id]])) 95 | item_list[self.num_pos:] = item_ids 96 | else: 97 | item_ids = self.label_dict[task_id][1][self.neg_ptr[task_id]:self.neg_ptr[task_id]+self.num_neg] 98 | item_list[self.num_pos:] = item_ids 99 | self.neg_ptr[task_id] += self.num_neg # sample num_neg negative items for task_id 100 | 101 | self.sampled_task[beg] = task_id 102 | self.sampled_labels[beg, :] = item_list 103 | beg += 1 104 | 105 | return iter(zip(self.sampled_task, self.sampled_labels)) 106 | 107 | 108 | def __len__ (self): 109 | return len(self.sampled_task) 110 | -------------------------------------------------------------------------------- /libauc/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import * 2 | from .helper import * 3 | -------------------------------------------------------------------------------- /libauc/utils/generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _check_imbalance_ratio(targets): 5 | assert isinstance(targets, (np.ndarray, np.generic)), 'targets has to be numpy array!' 6 | num_samples = len(targets) 7 | pos_count = np.count_nonzero(targets == 1) 8 | neg_count = np.count_nonzero(targets == 0) # check if negative labels in dataset 9 | pos_ratio = pos_count/ (pos_count + neg_count) 10 | print ('#SAMPLES: [%d], POS:NEG: [%d : %d], POS RATIO: %.4f'%(num_samples, pos_count, neg_count, pos_ratio) ) 11 | 12 | def _check_array_type(arr): 13 | assert isinstance(arr, (np.ndarray, np.generic)), 'Inputs need to be numpy array type!' 14 | 15 | 16 | class ImbalancedDataGenerator(object): 17 | ''' 18 | Binary, Numpy array only 19 | Added support for dataset type imbalance modififcation??? 20 | ''' 21 | def __init__(self, imratio=None, shuffle=True, random_seed=0, verbose=False): 22 | self.imratio = imratio # for testing set, use 0.5 instead of is_balanced 23 | self.shuffle = shuffle 24 | self.random_seed = random_seed 25 | self.verbose = verbose 26 | 27 | def _get_split_index(self, num_classes): 28 | if num_classes == 2: 29 | split_index = 0 30 | elif num_classes == 10: 31 | split_index = 4 32 | elif num_classes == 100: 33 | split_index = 49 34 | elif num_classes == 1000: 35 | split_index = 499 36 | else: 37 | raise NotImplementedError 38 | return split_index 39 | 40 | def _get_class_num(self, targets): 41 | return np.unique(targets).size 42 | 43 | def transform(self, data, targets, imratio=None): 44 | _check_array_type(data) 45 | _check_array_type(targets) 46 | if min(targets) < 0: # check negative values 47 | targets[targets<0] = 0 48 | if imratio is not None: 49 | self.imratio = imratio 50 | assert self.imratio>0 and self.imratio<=0.5, 'imratio needs to be in (0, 0.5)!' 51 | 52 | # shuffle once and create data copies 53 | id_list = list(range(targets.shape[0])) 54 | np.random.seed(self.random_seed) 55 | np.random.shuffle(id_list) 56 | data_copy = data[id_list].copy() 57 | targets_copy = targets[id_list].copy() 58 | 59 | # make binary dataset 60 | num_classes = self._get_class_num(targets) 61 | split_index = self._get_split_index(num_classes) 62 | targets_copy[targets_copy<=split_index] = 0 # [0, ....] 63 | targets_copy[targets_copy>=split_index+1] = 1 # [0, ....] 64 | 65 | # randomly remove some samples 66 | if self.imratio < 0.5: 67 | num_neg = np.where(targets_copy==0)[0].shape[0] 68 | num_pos = np.where(targets_copy==1)[0].shape[0] 69 | keep_num_pos = int((self.imratio/(1-self.imratio))*num_neg ) 70 | neg_id_list = np.where(targets_copy==0)[0] 71 | pos_id_list = np.where(targets_copy==1)[0][:keep_num_pos] 72 | data_copy = data_copy[neg_id_list.tolist() + pos_id_list.tolist() ] 73 | targets_copy = targets_copy[neg_id_list.tolist() + pos_id_list.tolist() ] 74 | targets_copy = targets_copy.reshape(-1, 1).astype(float) 75 | 76 | if self.shuffle: 77 | # shuffle in case batch prediction error 78 | id_list = list(range(targets_copy.shape[0])) 79 | np.random.seed(self.random_seed) 80 | np.random.shuffle(id_list) 81 | data_copy = data_copy[id_list] 82 | targets_copy = targets_copy[id_list] 83 | 84 | if self.verbose: 85 | _check_imbalance_ratio(targets_copy) 86 | 87 | return data_copy, targets_copy 88 | -------------------------------------------------------------------------------- /libauc/utils/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import datetime 4 | import os 5 | import sys 6 | import time 7 | import shutil 8 | from tqdm import tqdm, trange 9 | from ..metrics import ndcg_at_k, map_at_k 10 | 11 | def batch_to_gpu(batch, device='cuda'): 12 | for c in batch: 13 | if type(batch[c]) is torch.Tensor: 14 | batch[c] = batch[c].to(device) 15 | return batch 16 | 17 | def adjust_lr(learning_rate, lr_schedule, optimizer, epoch): 18 | lr = learning_rate 19 | for milestone in eval(lr_schedule): 20 | lr *= 0.25 if epoch >= milestone else 1 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = lr 23 | 24 | def evaluate_method(predictions, ratings, topk, metrics): 25 | """ 26 | :param predictions: (-1, n_candidates) shape, the first column is the score for ground-truth item 27 | :param ratings: (# of users, # of pos items) 28 | :param topk: top-K value list 29 | :param metrics: metric string list 30 | :return: a result dict, the keys are metric@topk 31 | """ 32 | evaluations = dict() 33 | 34 | num_of_users, num_pos_items = ratings.shape 35 | sorted_ratings = -np.sort(-ratings) # descending order !! 36 | discounters = np.tile([np.log2(i+1) for i in range(1, 1+num_pos_items)], (num_of_users, 1)) 37 | normalizer_mat = (np.exp2(sorted_ratings) - 1) / discounters 38 | 39 | sort_idx = (-predictions).argsort(axis=1) # index of sorted predictions (max->min) 40 | gt_rank = np.array([np.argwhere(sort_idx == i)[:, 1]+1 for i in range(num_pos_items)]).T # rank of the ground-truth (start from 1) 41 | for k in topk: 42 | hit = (gt_rank <= k) 43 | for metric in metrics: 44 | key = '{}@{}'.format(metric, k) 45 | if metric == 'NDCG': 46 | evaluations[key] = ndcg_at_k(ratings, normalizer_mat, hit, gt_rank, k) 47 | elif metric == 'MAP': 48 | evaluations[key] = map_at_k(hit, gt_rank) 49 | else: 50 | raise ValueError('Undefined evaluation metric: {}.'.format(metric)) 51 | return evaluations 52 | 53 | def evaluate(model, data_set, topks, metrics, eval_batch_size=250, num_pos=10): 54 | """ 55 | The returned prediction is a 2D-array, each row corresponds to all the candidates, 56 | and the ground-truth item poses the first. 57 | Example: ground-truth items: [1, 2], 2 negative items for each instance: [[3,4], [5,6]] 58 | predictions like: [[1,3,4], [2,5,6]] 59 | """ 60 | EVAL_BATCH_SIZE = eval_batch_size 61 | NUM_POS = num_pos 62 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 63 | model.eval() 64 | predictions = list() 65 | ratings = list() 66 | for idx in trange(0, len(data_set), EVAL_BATCH_SIZE): 67 | batch = data_set.get_batch(idx, EVAL_BATCH_SIZE) 68 | prediction = model(batch_to_gpu(batch, DEVICE))['prediction'] 69 | predictions.extend(prediction.cpu().data.numpy()) 70 | ratings.extend(batch['rating'].cpu().data.numpy()) 71 | 72 | predictions = np.array(predictions) # [# of users, # of items] 73 | ratings = np.array(ratings)[:, :NUM_POS] # [# of users, # of pos items] 74 | 75 | return evaluate_method(predictions, ratings, topks, metrics) 76 | 77 | def format_metric(result_dict): 78 | assert type(result_dict) == dict 79 | format_str = [] 80 | metrics = np.unique([k.split('@')[0] for k in result_dict.keys()]) 81 | topks = np.unique([int(k.split('@')[1]) for k in result_dict.keys()]) 82 | for topk in np.sort(topks): 83 | for metric in np.sort(metrics): 84 | name = '{}@{}'.format(metric, topk) 85 | m = result_dict[name] 86 | if type(m) is float or type(m) is np.float or type(m) is np.float32 or type(m) is np.float64: 87 | format_str.append('{}:{:<.4f}'.format(name, m)) 88 | elif type(m) is int or type(m) is np.int or type(m) is np.int32 or type(m) is np.int64: 89 | format_str.append('{}:{}'.format(name, m)) 90 | return ','.join(format_str) 91 | 92 | def get_time(): 93 | return datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S") 94 | -------------------------------------------------------------------------------- /loaders/load_lanl.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import pickle 4 | from joblib import Parallel, delayed 5 | 6 | import torch 7 | from torch_geometric.data import Data 8 | from tqdm import tqdm 9 | 10 | from .tdata import TData 11 | from .load_utils import edge_tv_split, std_edge_w, standardized, std_edge_a 12 | 13 | import random 14 | import numpy as np 15 | 16 | DATE_OF_EVIL_LANL = 150885 17 | FILE_DELTA = 10000 18 | 19 | # Input the path where LANL data locates which should be the same as DST in split_lanl.py 20 | LANL_FOLDER = '' 21 | assert LANL_FOLDER, 'Please fill in the LANL_FOLDER in ./loaders/load_lanl.py' 22 | 23 | 24 | TIMES = { 25 | '5': 155399, 26 | '10' : 210294, 27 | '20' : 228642, # First 20 anoms 1.55% 28 | '30' : 464254, #killed on poisoning attack 29 | '40' : 485925, 30 | '100' : 740104, # First 100 anoms 11.7% 31 | '500' : 1089597, # First 500 anoms 18.73% 32 | 'all' : 5011199, # Full 33 | 'test' : 1270000 34 | } 35 | 36 | def empty_lanl(use_flows=False): 37 | return make_data_obj(None,[],None,None,None,use_flows=use_flows) 38 | 39 | def load_lanl_dist(start=0, end=635015, delta=8640, is_test=False, use_flows=False, ew_fn=std_edge_w, ea_fn=std_edge_a): 40 | if start == None or end == None: 41 | return empty_lanl(use_flows) 42 | 43 | num_slices = ((end - start) // delta) 44 | remainder = (end-start) % delta 45 | num_slices = num_slices + 1 if remainder else num_slices 46 | return load_partial_lanl(start, end, delta, is_test, use_flows, ew_fn, ea_fn) 47 | 48 | per_worker = [num_slices // workers] * workers 49 | remainder = num_slices % workers 50 | 51 | if remainder: 52 | for i in range(workers, workers-remainder, -1): 53 | per_worker[i-1] += 1 54 | 55 | kwargs = [] 56 | prev = start 57 | for i in range(workers): 58 | end_t = prev + delta*per_worker[i] 59 | kwargs.append({ 60 | 'start': prev, 61 | 'end': min(end_t-1, end), 62 | 'delta': delta, 63 | 'is_test': is_test, 64 | 'use_flows': use_flows, 65 | 'ew_fn': ew_fn, 66 | 'ea_fn': ea_fn 67 | }) 68 | prev = end_t 69 | 70 | datas = Parallel(n_jobs=workers, prefer='processes')(delayed(load_partial_lanl_job)(i, kwargs[i]) for i in range(workers)) 71 | 72 | # Helper method to concatonate one field from all of the datas 73 | data_reduce = lambda x : sum([getattr(datas[i], x) for i in range(workers)], []) 74 | 75 | # Just join all the lists from all the data objects 76 | print("Joining Data objects") 77 | slices = data_reduce('slices') 78 | x = datas[0].xs 79 | eis = data_reduce('eis') 80 | masks = data_reduce('masks') 81 | ews = data_reduce('ews') 82 | eas = data_reduce('eas') 83 | 84 | node_map = datas[0].node_map 85 | 86 | if is_test: 87 | ys = data_reduce('ys') 88 | cnt = data_reduce('cnt') 89 | else: 90 | ys = None 91 | cnt = None 92 | 93 | # After everything is combined, wrap it in a fancy new object, and you're 94 | # on your way to coolsville flats 95 | print("Done") 96 | return TData(slices, eis, x, ys, masks, ews=ews, eas=eas, node_map=node_map, cnt=cnt) 97 | 98 | def load_partial_lanl_job(pid, args): 99 | data = load_partial_lanl(**args) 100 | return data 101 | 102 | def make_data_obj(cur_slice, eis, ys, ew_fn, ea_fn, ews=None, eas=None, use_flows=False, **kwargs): 103 | if 'node_map' in kwargs: 104 | nm = kwargs['node_map'] 105 | else: 106 | nm = pickle.load(open(LANL_FOLDER+'nmap.pkl', 'rb')) 107 | 108 | cl_cnt = len(nm) 109 | x = torch.eye(cl_cnt+1) 110 | # Build time-partitioned edge lists 111 | eis_t = [] 112 | masks = [] 113 | 114 | for i in range(len(eis)): 115 | ei = torch.tensor(eis[i]) 116 | eis_t.append(ei) 117 | 118 | # This is training data if no ys present 119 | if isinstance(ys, None.__class__): 120 | masks.append(edge_tv_split(ei)[0]) 121 | 122 | # Balance the edge weights if they exist 123 | if not isinstance(ews, None.__class__): 124 | cnt = deepcopy(ews) 125 | ews = ew_fn(ews) 126 | else: 127 | cnt = None 128 | 129 | #TODO: balance edge feature values 130 | if not isinstance(eas, None.__class__): 131 | eas = ea_fn(eas) 132 | 133 | return TData(cur_slice, eis_t, x, ys, masks, ews=ews, eas=eas, use_flows=use_flows, cnt=cnt, node_map=nm) 134 | 135 | 136 | def load_flows(fname, start, end): 137 | eas_flows = {} 138 | temp_flows = {} 139 | if not os.path.exists(fname): 140 | return eas_flows 141 | in_f = open(fname) 142 | line = in_f.readline() 143 | 144 | #Line in parsed flows. ts, src, dst,src_port,dst_port,proto, duration, pck_cnt, byte_cnt, label 145 | fmt_line = lambda x : (int(x[0]), int(x[1]), int(x[2]), int(x[6]), int(x[7]), int(x[8])) 146 | 147 | while line: 148 | l = line.split(',') 149 | ts = int(l[0]) 150 | if ts < start: 151 | line = in_f.readline() 152 | continue 153 | if ts > end: 154 | break 155 | ts, src, dst, duration, pck_cnt, byte_cnt = fmt_line(l) 156 | et = (src,dst) 157 | if et in temp_flows: 158 | temp_flows[et][0].append(duration) 159 | temp_flows[et][1].append(pck_cnt) 160 | temp_flows[et][2].append(byte_cnt) 161 | else: 162 | temp_flows[et] = [[duration], [pck_cnt], [byte_cnt]] 163 | line = in_f.readline() 164 | in_f.close() 165 | #computes features, # of flows, mean & std of duration, pck_cnt and byte_cnt 166 | for et in temp_flows.keys(): 167 | eas_flows[et] = [len(temp_flows[et][0]), np.mean(temp_flows[et][0]), np.std(temp_flows[et][0]), \ 168 | np.mean(temp_flows[et][1]), np.std(temp_flows[et][1]), np.mean(temp_flows[et][2]), np.std(temp_flows[et][2])] 169 | return eas_flows 170 | 171 | def load_partial_lanl(start=140000, end=156659, delta=8640, is_test=False, use_flows=False, ew_fn=standardized, ea_fn=std_edge_a): 172 | print('start:' + str(start) + ', end:' + str(end)) 173 | cur_slice = int(start - (start % FILE_DELTA)) 174 | start_f = str(cur_slice) + '.txt' 175 | in_f = open(LANL_FOLDER + start_f, 'r') 176 | 177 | edges = [] 178 | ews = [] 179 | edges_t = {} 180 | ys = [] 181 | slices = [] 182 | eas = [] 183 | eas_flows = {} 184 | 185 | # Predefined for easier loading so everyone agrees on NIDs 186 | node_map = pickle.load(open(LANL_FOLDER+'nmap.pkl', 'rb')) 187 | user_map = pickle.load(open(LANL_FOLDER+'umap.pkl', 'rb')) 188 | 189 | 190 | # Helper functions (trims the trailing \n) 191 | # line format ts,src,dst,src_u,dst_u,auth_t,logon_t,auth_o,success,label 192 | fmt_line = lambda x : (int(x[0]), int(x[1]), int(x[2]), int(x[9][:-1]), int(x[3])) 193 | 194 | # take first char of src_u and convert it to edge list index 195 | def parse_user(src_u): 196 | if src_u[0] == 'C': 197 | return 2 198 | elif src_u[0] == 'U': 199 | return 3 200 | else: 201 | return 4 202 | 203 | def add_edge(et, src_u, is_anom=0): 204 | src_u_ind = parse_user(src_u) 205 | if et in edges_t: 206 | val = edges_t[et] 207 | edges_t[et][0:2] = [max(is_anom, val[0]), val[1]+1] 208 | edges_t[et][src_u_ind] = val[src_u_ind] + 1 209 | else: 210 | edges_t[et] = [0] * 5 211 | edges_t[et][0:2] = [is_anom, 1] 212 | edges_t[et][src_u_ind] = 1 213 | 214 | scan_prog = tqdm(desc='Finding start', total=start-cur_slice-1) 215 | prog = tqdm(desc='Seconds read', total=end-start-1) 216 | 217 | anom_marked = False 218 | keep_reading = True 219 | next_split = start+delta 220 | 221 | line = in_f.readline() 222 | curtime = fmt_line(line.split(','))[0] 223 | old_ts = curtime 224 | 225 | #load flows if use_flows == True 226 | if use_flows: 227 | if not os.path.exists(LANL_FOLDER + '/flows'): 228 | print('flows has not been parsed') 229 | else: 230 | eas_flows = load_flows(LANL_FOLDER + '/flows/' + start_f, start, end) 231 | 232 | while keep_reading: 233 | while line: 234 | l = line.split(',') 235 | 236 | # Scan to the correct part of the file 237 | ts = int(l[0]) 238 | if ts < start: 239 | line = in_f.readline() 240 | scan_prog.update(ts-old_ts) 241 | old_ts = ts 242 | curtime = ts 243 | continue 244 | 245 | ts, src, dst, label, src_u= fmt_line(l) 246 | #Take the first char of src_u -> C, U or A, and the frequency of each type is the edge feature (3 features) 247 | et = (src,dst) 248 | src_u = user_map[src_u] 249 | 250 | # Not totally necessary but I like the loading bar 251 | prog.update(ts-old_ts) 252 | old_ts = ts 253 | 254 | # Split edge list if delta is hit 255 | if ts >= next_split: 256 | if len(edges_t): 257 | ei = list(zip(*edges_t.keys())) 258 | edges.append(ei) 259 | 260 | #uc, us, ua: user C+, user U+, user Anonymous 261 | y,ew,uc,uu,ua = list(zip(*edges_t.values())) 262 | ews.append(torch.tensor(ew)) 263 | 264 | if use_flows: 265 | #get number of features from eas_flows 266 | #eas_flows_dim = len(eas_flows[next(iter(eas_flows))]) 267 | eas_flows_dim = 7 268 | #num_flows,mean_duration,std_duration,mean_pkt_cnt,std_pkt_cnt,mean_byte_cnt,std_byte_cnt 269 | fs = {} 270 | 271 | for eij in edges_t.keys(): 272 | if eij in eas_flows: 273 | fs[eij] = eas_flows[eij] 274 | else: 275 | fs[eij] = [0] * eas_flows_dim 276 | 277 | #print('Match keys' + str(edges_t.keys() == fs.keys())) 278 | num_flows,mean_duration,std_duration,mean_pkt_cnt,std_pkt_cnt,mean_byte_cnt,std_byte_cnt = list(zip(*fs.values())) 279 | eas.append(torch.tensor([uc,uu,ua,num_flows,mean_duration,std_duration,mean_pkt_cnt,std_pkt_cnt,mean_byte_cnt,std_byte_cnt])) 280 | else: 281 | eas.append(torch.tensor([uc,uu,ua])) 282 | 283 | if is_test: 284 | ys.append(torch.tensor(y)) 285 | 286 | #a slice file might have multiple snapshots 287 | #slices.append(str(cur_slice) + '.txt') 288 | slices.append(str(ts)) 289 | 290 | edges_t = {} 291 | eas_flows = {} 292 | 293 | # If the list was empty, just keep going if you can 294 | curtime = next_split 295 | next_split += delta 296 | 297 | # Break out of loop after saving if hit final timestep 298 | if ts >= end: 299 | keep_reading = False 300 | break 301 | 302 | # Skip self-loops 303 | if et[0] == et[1]: 304 | line = in_f.readline() 305 | continue 306 | 307 | add_edge(et, src_u, is_anom=label) 308 | line = in_f.readline() 309 | 310 | in_f.close() 311 | cur_slice += FILE_DELTA 312 | 313 | if os.path.exists(LANL_FOLDER + str(cur_slice) + '.txt'): 314 | in_f = open(LANL_FOLDER + str(cur_slice) + '.txt', 'r') 315 | line = in_f.readline() 316 | if use_flows: 317 | eas_flows = load_flows(LANL_FOLDER + '/flows/' + str(cur_slice) + '.txt', start, end) 318 | else: 319 | keep_reading=False 320 | break 321 | 322 | ys = ys if is_test else None 323 | 324 | scan_prog.close() 325 | prog.close() 326 | 327 | 328 | return make_data_obj(slices, edges, ys, ew_fn, ea_fn, ews=ews, eas=eas, use_flows=use_flows, node_map=node_map) 329 | -------------------------------------------------------------------------------- /loaders/load_optc.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | import pickle 4 | from joblib import Parallel, delayed 5 | 6 | import torch 7 | from torch_geometric.data import Data 8 | from tqdm import tqdm 9 | 10 | from .tdata import TData 11 | from .load_utils import edge_tv_split, std_edge_w, standardized, std_edge_a 12 | 13 | import numpy as np 14 | 15 | DATE_OF_EVIL_LANL = 573290 #original 573290 16 | FILE_DELTA = 10000 17 | 18 | # Input the path where OpTC data locates which should be the same as DST in split_optc.py 19 | OPTC_FOLDER = '' 20 | assert OPTC_FOLDER, 'Please fill in the OPTC_FOLDER in loaders/load_optc.py' 21 | 22 | TIMES = { 23 | '20' : 573383, # First 20 anoms 1.55% 24 | '100' : 573751, # First 100 anoms 11.7% 25 | '500' : 575885, # First 500 anoms 18.73% 26 | 'all' : 745983, # Full 21784 27 | 'test' : 745983 28 | } 29 | 30 | def empty_lanl(use_flows=False): 31 | return make_data_obj(None,[],None,None,None,use_flows=use_flows) 32 | 33 | def load_optc_dist(start=0, end=635015, delta=8640, is_test=False, use_flows=False, ew_fn=std_edge_w, ea_fn=std_edge_a): 34 | if start == None or end == None: 35 | return empty_lanl(use_flows) 36 | 37 | num_slices = ((end - start) // delta) 38 | remainder = (end-start) % delta 39 | num_slices = num_slices + 1 if remainder else num_slices 40 | # workers = min(num_slices, workers) 41 | 42 | # Can't distribute the job if not enough workers 43 | return load_partial_lanl(start, end, delta, is_test, use_flows, ew_fn, ea_fn) 44 | 45 | 46 | # wrapper bc its annoying to send kwargs with Parallel 47 | def load_partial_lanl_job(pid, args): 48 | data = load_partial_lanl(**args) 49 | return data 50 | 51 | 52 | def make_data_obj(cur_slice, eis, ys, ew_fn, ea_fn, ews=None, eas=None, use_flows=False, **kwargs): 53 | if 'node_map' in kwargs: 54 | nm = kwargs['node_map'] 55 | else: 56 | nm = pickle.load(open(OPTC_FOLDER+'nmap.pkl', 'rb')) 57 | 58 | cl_cnt = len(nm) 59 | x = torch.eye(cl_cnt+1) 60 | 61 | # Build time-partitioned edge lists 62 | eis_t = [] 63 | masks = [] 64 | 65 | for i in range(len(eis)): 66 | ei = torch.tensor(eis[i]) 67 | eis_t.append(ei) 68 | 69 | # This is training data if no ys present 70 | if isinstance(ys, None.__class__): 71 | masks.append(edge_tv_split(ei)[0]) 72 | 73 | # Balance the edge weights if they exist 74 | if not isinstance(ews, None.__class__): 75 | cnt = deepcopy(ews) 76 | ews = ew_fn(ews) 77 | else: 78 | cnt = None 79 | # print(eas) 80 | 81 | #TODO: balance edge feature values 82 | # if not isinstance(eas, None.__class__): 83 | # eas = ea_fn(eas) 84 | # exit() 85 | # Finally, return Data object 86 | return TData( 87 | cur_slice, eis_t, x, ys, masks, ews=ews, eas=eas, use_flows=use_flows, cnt=cnt, node_map=nm 88 | ) 89 | 90 | ''' 91 | Read a file in flows and return the edge features 92 | 93 | ''' 94 | def load_flows(fname, start, end): 95 | #TODO: implement 96 | eas_flows = {} 97 | temp_flows = {} 98 | if not os.path.exists(fname): 99 | return eas_flows 100 | in_f = open(fname) 101 | line = in_f.readline() 102 | 103 | #Line in parsed flows. ts, src, dst,src_port,dst_port,proto, duration, pck_cnt, byte_cnt, label 104 | fmt_line = lambda x : (int(x[0]), int(x[1]), int(x[2]), int(x[6]), int(x[7]), int(x[8])) 105 | 106 | while line: 107 | l = line.split(',') 108 | ts = int(l[0]) 109 | if ts < start: 110 | line = in_f.readline() 111 | continue 112 | if ts > end: 113 | break 114 | ts, src, dst, duration, pck_cnt, byte_cnt = fmt_line(l) 115 | et = (src,dst) 116 | if et in temp_flows: 117 | temp_flows[et][0].append(duration) 118 | temp_flows[et][1].append(pck_cnt) 119 | temp_flows[et][2].append(byte_cnt) 120 | else: 121 | temp_flows[et] = [[duration], [pck_cnt], [byte_cnt]] 122 | line = in_f.readline() 123 | in_f.close() 124 | #computes features, # of flows, mean & std of duration, pck_cnt and byte_cnt 125 | for et in temp_flows.keys(): 126 | eas_flows[et] = [len(temp_flows[et][0]), np.mean(temp_flows[et][0]), np.std(temp_flows[et][0]), \ 127 | np.mean(temp_flows[et][1]), np.std(temp_flows[et][1]), np.mean(temp_flows[et][2]), np.std(temp_flows[et][2])] 128 | return eas_flows 129 | 130 | ''' 131 | Equivilant to load_cyber.load_lanl but uses the sliced LANL files 132 | for faster scanning to the correct lines 133 | ''' 134 | def load_partial_lanl(start=140000, end=156659, delta=8640, is_test=False, use_flows=False, ew_fn=standardized, ea_fn=std_edge_a): 135 | print('start:' + str(start) + ', end:' + str(end)) 136 | cur_slice = int(start - (start % FILE_DELTA)) 137 | start_f = str(cur_slice) + '.txt' 138 | in_f = open(OPTC_FOLDER + start_f, 'r') 139 | 140 | edges = [] 141 | ews = [] 142 | edges_t = {} 143 | ys = [] 144 | slices = [] 145 | eas = [] 146 | 147 | # Predefined for easier loading so everyone agrees on NIDs 148 | node_map = pickle.load(open(OPTC_FOLDER+'nmap.pkl', 'rb')) 149 | # user_map = pickle.load(open(OPTC_FOLDER+'umap.pkl', 'rb')) 150 | 151 | 152 | # Helper functions (trims the trailing \n) 153 | #ZL: line format ts,src,dst,src_u,dst_u,auth_t,logon_t,auth_o,success,label 154 | # fmt_line = lambda x : (int(x[0]), int(x[1]), int(x[2]), int(x[9][:-1]), int(x[3])) 155 | fmt_line = lambda x : (int(x[0]), int(x[1]), int(x[2]), int(x[3]), int(x[4]), int(x[5]), int(x[6]), int(x[7]),int(x[8][:-1])) 156 | 157 | # take first char of src_u and convert it to edge list index 158 | # def parse_user(src_u): 159 | # if src_u[0] == 'C': 160 | # return 2 161 | # elif src_u[0] == 'U': 162 | # return 3 163 | # else: 164 | # return 4 165 | 166 | # For now, just keeps one copy of each edge. Could be 167 | # modified in the future to add edge weight or something 168 | # but for now, edges map to their anomaly value (1 == anom, else 0) 169 | # TODO: include edge features 170 | def add_edge(et, ea, is_anom=0): 171 | if et in edges_t: 172 | val = edges_t[et] 173 | edges_t[et] = (max(is_anom, val[0]), val[1]+1, ea) 174 | else: 175 | edges_t[et] = (is_anom, 1, ea) 176 | 177 | # def add_edge(et, src_u, is_anom=0): 178 | # src_u_ind = parse_user(src_u) 179 | # if et in edges_t: 180 | # val = edges_t[et] 181 | # edges_t[et][0:2] = [max(is_anom, val[0]), val[1]+1] 182 | # edges_t[et][src_u_ind] = val[src_u_ind] + 1 183 | # else: 184 | # edges_t[et] = [0] * 5 185 | # edges_t[et][0:2] = [is_anom, 1] 186 | # edges_t[et][src_u_ind] = 1 187 | 188 | 189 | 190 | scan_prog = tqdm(desc='Finding start', total=start-cur_slice-1) 191 | prog = tqdm(desc='Seconds read', total=end-start-1) 192 | 193 | anom_marked = False 194 | keep_reading = True 195 | next_split = start+delta 196 | 197 | line = in_f.readline() 198 | curtime = fmt_line(line.split(','))[0] 199 | old_ts = curtime 200 | 201 | #load flows if use_flows == True 202 | # if use_flows: 203 | # if not os.path.exists(OPTC_FOLDER + '/flows'): 204 | # print('flows has not been parsed') 205 | # else: 206 | # eas_flows = load_flows(OPTC_FOLDER + '/flows/' + start_f, start, end) 207 | 208 | while keep_reading: 209 | while line: 210 | l = line.split(',') 211 | 212 | # Scan to the correct part of the file 213 | ts = int(l[0]) 214 | if ts < start: 215 | line = in_f.readline() 216 | scan_prog.update(ts-old_ts) 217 | old_ts = ts 218 | curtime = ts 219 | continue 220 | 221 | # ['timestamps', 'source', 'target', 'pid', 'ppid', 'dest_port', 'l4protocol', 'img_path', 'label'] 222 | # ts, src, dst, label, src_u= fmt_line(l) 223 | ts, src, dst, pid, ppid, dest_port, l4protocol, img_path, label = fmt_line(l) 224 | ea = (int(pid), int(ppid), int(dest_port), int(l4protocol), int(img_path)) 225 | # eas.append(torch.tensor([pid, ppid, dest_port, l4protocol, img_path])) 226 | 227 | #Take the first char of src_u -> C, U or A, and the frequency of each type is the edge feature (3 features) 228 | et = (src,dst) 229 | # src_u = user_map[src_u] 230 | 231 | # Not totally necessary but I like the loading bar 232 | prog.update(ts-old_ts) 233 | old_ts = ts 234 | 235 | # Split edge list if delta is hit 236 | if ts >= next_split: 237 | if len(edges_t): 238 | ei = list(zip(*edges_t.keys())) 239 | edges.append(ei) 240 | 241 | #uc, us, ua: user C+, user U+, user Anonymous 242 | y,ew, ea = list(zip(*edges_t.values())) 243 | # print(len(ea)) 244 | # for elem in ea: 245 | # if len(elem) != 5: 246 | ews.append(torch.tensor(ew)) 247 | 248 | if use_flows: 249 | ea = [ list(elem) if len(elem) ==5 else list(elem[0]) for elem in ea ] 250 | ea = np.array(ea) 251 | if ea.ndim == 3: 252 | ea = ea[0] 253 | eas.append(torch.tensor(ea).transpose(1,0)) 254 | if is_test: 255 | ys.append(torch.tensor(y)) 256 | 257 | #a slice file might have multiple snapshots 258 | #slices.append(str(cur_slice) + '.txt') 259 | slices.append(str(ts)) 260 | 261 | edges_t = {} 262 | 263 | # If the list was empty, just keep going if you can 264 | curtime = next_split 265 | next_split += delta 266 | 267 | # Break out of loop after saving if hit final timestep 268 | if ts >= end: 269 | keep_reading = False 270 | break 271 | 272 | # Skip self-loops 273 | if et[0] == et[1]: 274 | line = in_f.readline() 275 | continue 276 | add_edge(et, ea, is_anom=label) 277 | # add_edge(et, src_u, is_anom=label) 278 | line = in_f.readline() 279 | 280 | in_f.close() 281 | cur_slice += FILE_DELTA 282 | 283 | if os.path.exists(OPTC_FOLDER + str(cur_slice) + '.txt'): 284 | in_f = open(OPTC_FOLDER + str(cur_slice) + '.txt', 'r') 285 | line = in_f.readline() 286 | else: 287 | keep_reading=False 288 | break 289 | 290 | ys = ys if is_test else None 291 | 292 | scan_prog.close() 293 | prog.close() 294 | return make_data_obj( 295 | slices, edges, ys, ew_fn, ea_fn, 296 | ews=ews, eas=eas, use_flows=use_flows, node_map=node_map 297 | ) 298 | -------------------------------------------------------------------------------- /loaders/load_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def edge_tvt_split(ei): 4 | ne = ei.size(1) 5 | val = int(ne*0.85) 6 | te = int(ne*0.90) 7 | 8 | masks = torch.zeros(3, ne).bool() 9 | rnd = torch.randperm(ne) 10 | masks[0, rnd[:val]] = True 11 | masks[1, rnd[val:te]] = True 12 | masks[2, rnd[te:]] = True 13 | 14 | return masks[0], masks[1], masks[2] 15 | 16 | def edge_tv_split(ei, v_size=0.05): 17 | ne = ei.size(1) 18 | val = int(ne*v_size) 19 | 20 | masks = torch.zeros(2, ne).bool() 21 | rnd = torch.randperm(ne) 22 | masks[1, rnd[:val]] = True 23 | masks[0, rnd[val:]] = True 24 | 25 | return masks[0], masks[1] 26 | 27 | ''' 28 | Various weighting functions for edges 29 | ''' 30 | def std_edge_w(ew_ts): 31 | ews = [] 32 | # print('ew_ts: ', len(ew_ts)) 33 | # print(len(ew_ts[0])) 34 | for ew_t in ew_ts: 35 | ew_t = ew_t.float() 36 | ew_t = (ew_t.long() / ew_t.std()).long() 37 | ew_t = torch.sigmoid(ew_t) 38 | ews.append(ew_t) 39 | 40 | return ews 41 | def std_edge_a(ea_ts): 42 | eas = [] 43 | 44 | for ea_t in ea_ts: 45 | ea_t2 = torch.empty(ea_t.size()) 46 | for i in range(0, len(ea_t)): 47 | ea_t_f = ea_t[i] 48 | ea_t_f = ea_t_f.float() 49 | ea_t_f = (ea_t_f.long() / ea_t_f.std()).long() 50 | ea_t_f = torch.sigmoid(ea_t_f) 51 | ea_t2[i] = ea_t_f 52 | eas.append(ea_t2) 53 | return eas 54 | 55 | 56 | def normalized(ew_ts): 57 | ews = [] 58 | for ew_t in ew_ts: 59 | ew_t = ew_t.float() 60 | ew_t = ew_t.true_divide(ew_t.mean()) 61 | ew_t = torch.sigmoid(ew_t) 62 | ews.append(ew_t) 63 | return ews 64 | 65 | def standardized(ew_ts): 66 | ews = [] 67 | for ew_t in ew_ts: 68 | ew_t = ew_t.float() 69 | std = ew_t.std() 70 | if std.item() == 0: 71 | ews.append(torch.full(ew_t.size(), 0.5)) 72 | continue 73 | 74 | ew_t = (ew_t - ew_t.mean()) / std 75 | ew_t = torch.sigmoid(ew_t) 76 | ews.append(ew_t) 77 | 78 | return ews 79 | 80 | def inv_standardized(ew_ts): 81 | ews = [] 82 | for ew_t in ew_ts: 83 | ew_t = ew_t.float() 84 | ew_t = (ew_t - ew_t.mean()) / ew_t.std() 85 | ew_t = 1-torch.sigmoid(ew_t) 86 | ews.append(ew_t) 87 | 88 | return ews 89 | -------------------------------------------------------------------------------- /loaders/split_lanl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | # Set dataset paths 5 | # ============= 6 | RED = '' # Location of redteam.txt 7 | SRC = '' # Location of auth.txt 8 | DST = '' # Directory to save output files to 9 | SRC_DIR = '' # Directory of flows.txt, auth.txt 10 | # ============= 11 | 12 | 13 | assert RED and SRC and DST, 'Please download the LANL data set, and mark in the code where it is' 14 | 15 | DELTA = 10000 16 | DAY = 60**2*24 17 | 18 | def mark_anoms(): 19 | with open(RED, 'r') as f: 20 | red_events = f.read().split() 21 | red_events = red_events[1:] 22 | def add_ts(d, val, ts): 23 | val = (val[1], val[2]) 24 | if val in d: 25 | d[val].append(ts) 26 | else: 27 | d[val] = [ts] 28 | 29 | anom_dict = {} 30 | for event in red_events: 31 | tokens = event.split(',') 32 | ts = int(tokens.pop(0)) 33 | add_ts(anom_dict, tokens, ts) 34 | return anom_dict 35 | 36 | 37 | def mark_anoms_node(): 38 | with open(RED, 'r') as f: 39 | red_events = f.read().split() 40 | red_events = red_events[1:] 41 | def add_ts(d, val, ts): 42 | if val[1] in d: 43 | d[val[1]].append(ts) 44 | else: 45 | d[val[1]] = [ts] 46 | if val[2] in d: 47 | d[val[2]].append(ts) 48 | else: 49 | d[val[2]] = [ts] 50 | 51 | anom_dict = {} 52 | for event in red_events: 53 | tokens = event.split(',') 54 | ts = int(tokens.pop(0)) 55 | add_ts(anom_dict, tokens, ts) 56 | 57 | return anom_dict 58 | 59 | 60 | def is_anomalous(d, src, dst, ts): 61 | if ts < 150885 or (src, dst) not in d: 62 | return False 63 | times = d[(src,dst)] 64 | for time in times: 65 | if ts == time: 66 | return True 67 | return False 68 | 69 | def is_anomalous_range(d, src, dst, ts): 70 | if ts < 150885 or (src, dst) not in d: 71 | return False 72 | times = d[(src,dst)] 73 | for time in times: 74 | if abs(ts-time) <= 300: 75 | return True 76 | return False 77 | 78 | #comparing ts to time -/+ 5min, only one node is used 79 | def is_anomalous_node_range(d, node, ts): 80 | if ts < 150885 or node not in d: 81 | return False 82 | 83 | times = d[node] 84 | for time in times: 85 | # Mark true if node appeared in a compromise in -/5min 86 | if abs(ts-time) <= 300: 87 | return True 88 | 89 | return False 90 | 91 | def save_map(m, fname): 92 | m_rev = [None] * (max(m.values()) + 1) 93 | for (k,v) in m.items(): 94 | m_rev[v] = k 95 | 96 | with open(DST + fname, 'wb') as f: 97 | pickle.dump(m_rev, f, protocol=pickle.HIGHEST_PROTOCOL) 98 | print(DST + fname + ' saved') 99 | 100 | def get_or_add(n, m, id): 101 | if n not in m: 102 | m[n] = id[0] 103 | id[0] += 1 104 | 105 | return m[n] 106 | 107 | 108 | def split_auth(): 109 | anom_dict = mark_anoms() 110 | 111 | last_time = 1 112 | cur_time = 0 113 | 114 | f_in = open(SRC,'r') 115 | f_out = open(DST + str(cur_time) + '.txt', 'w') 116 | 117 | line = f_in.readline() # Skip headers 118 | line = f_in.readline() 119 | 120 | nmap = {} 121 | nid = [0] 122 | umap = {} 123 | uid = [0] 124 | atmap = {} 125 | atid = [0] 126 | ltmap = {} 127 | ltid = [0] 128 | aomap = {} 129 | aoid = [0] 130 | smap = {} 131 | sid = [0] 132 | prog = tqdm(desc='Seconds parsed', total=5011199) 133 | 134 | fmt_src = lambda x : \ 135 | x.split('@')[0].replace('$', '') 136 | 137 | fmt_label = lambda ts,src,dst : \ 138 | 1 if is_anomalous(anom_dict, src, dst, ts) \ 139 | else 0 140 | 141 | fmt_line = lambda ts,src,dst,src_u,dst_u,auth_t,logon_t,auth_o,success: ( 142 | '%s,%s,%s,%s,%s,%s,%s,%s,%s,%s\n' % ( 143 | ts, get_or_add(src, nmap, nid), get_or_add(dst, nmap, nid), 144 | get_or_add(fmt_src(src_u), umap, uid), get_or_add(fmt_src(dst_u), umap, uid), 145 | get_or_add(auth_t, atmap, atid), get_or_add(logon_t, ltmap, ltid), 146 | get_or_add(auth_o, aomap, aoid), get_or_add(success, smap, sid), 147 | fmt_label(int(ts),src,dst) 148 | ), 149 | int(ts) 150 | ) 151 | 152 | while line: 153 | # Some filtering for better FPR/less Kerb noise 154 | if 'NTLM' not in line.upper(): 155 | line = f_in.readline() 156 | continue 157 | 158 | tokens = line.split(',') 159 | #0: ts, 1: src_u, 2: dest_u, 3: src_c, 4: dest_c, 5:auth_type, 6: logon_type, 7: auth_orientation, 8: success/failure 160 | # last field has '\n', need to be removed 161 | l, ts = fmt_line(tokens[0], tokens[3], tokens[4], tokens[1], tokens[2], tokens[5], tokens[6], tokens[7], tokens[8][:-1]) 162 | 163 | if ts != last_time: 164 | prog.update(ts-last_time) 165 | last_time = ts 166 | 167 | # After ts progresses at least 10,000 seconds, make a new file 168 | if ts >= cur_time+DELTA: 169 | f_out.close() 170 | cur_time += DELTA 171 | f_out = open(DST + str(cur_time) + '.txt', 'w') 172 | 173 | f_out.write(l) 174 | line = f_in.readline() 175 | 176 | f_out.close() 177 | f_in.close() 178 | 179 | save_map(nmap, 'nmap.pkl') 180 | save_map(umap, 'umap.pkl') 181 | save_map(atmap, 'atmap.pkl') 182 | save_map(ltmap, 'ltmap.pkl') 183 | save_map(aomap, 'aomap.pkl') 184 | save_map(smap, 'smap.pkl') 185 | 186 | def reverse_load_map(fname): 187 | #mapping pickle is a list, need to reverse it to a dict 188 | m = {} 189 | 190 | with open(DST+fname, 'rb') as f: 191 | l = pickle.load(f) 192 | for i in range(0, len(l)): 193 | m[l[i]] = i 194 | return m 195 | 196 | def split_flows(): 197 | anom_dict = mark_anoms() 198 | 199 | last_time = 1 200 | cur_time = 0 201 | 202 | f_in = open(SRC_DIR+'flows.txt','r') 203 | flows_folder = DST + 'flows/' 204 | if not os.path.exists(flows_folder): 205 | os.makedirs(flows_folder) 206 | print(flows_folder + " is created!") 207 | 208 | f_out = open(flows_folder + str(cur_time) + '.txt', 'w') 209 | 210 | line = f_in.readline() 211 | 212 | nmap = reverse_load_map('nmap.pkl') 213 | 214 | #port mapping 215 | port_map = {} 216 | port_id = [0] 217 | #protocol mapping 218 | proto_map = {} 219 | proto_id = [0] 220 | 221 | #the total is read from the last line 222 | prog = tqdm(desc='Seconds parsed', total=3126928) 223 | 224 | fmt_label = lambda ts,src,dst : \ 225 | 1 if is_anomalous_range(anom_dict, src, dst, ts) \ 226 | else 0 227 | 228 | #0: ts, 1: duration, 2: source computer, 3: source port, 4: destination computer, 5: destination port, 6: protocol, 7: packet count, 8: byte count 229 | fmt_line = lambda ts,duration,src,src_p,dst,dst_p,proto,pkt_cnt,byte_cnt: ( 230 | '%s,%s,%s,%s,%s,%s,%s,%s,%s,%s\n' % ( 231 | ts, nmap[src], nmap[dst],get_or_add(src_p, port_map, port_id), 232 | get_or_add(dst_p, port_map, port_id), get_or_add(proto, proto_map, proto_id), 233 | duration, pkt_cnt, byte_cnt, fmt_label(int(ts),src,dst) 234 | ), 235 | int(ts) 236 | ) 237 | 238 | while line: 239 | if '?' in line: 240 | line = f_in.readline() 241 | continue 242 | 243 | tokens = line.split(',') 244 | 245 | #Filtering out the lines with src and dest that are not in auth 246 | if not tokens[2] in nmap or not tokens[4] in nmap: 247 | line = f_in.readline() 248 | continue 249 | 250 | # last field has '\n', need to be removed 251 | l, ts = fmt_line(tokens[0], tokens[1], tokens[2], tokens[3], tokens[4], tokens[5], tokens[6], tokens[7], tokens[8][:-1]) 252 | 253 | if ts != last_time: 254 | prog.update(ts-last_time) 255 | last_time = ts 256 | 257 | if ts >= cur_time+DELTA: 258 | f_out.close() 259 | cur_time += DELTA 260 | f_out = open(flows_folder + str(cur_time) + '.txt', 'w') 261 | 262 | f_out.write(l) 263 | line = f_in.readline() 264 | 265 | f_out.close() 266 | f_in.close() 267 | 268 | save_map(port_map, 'pomap.pkl') 269 | save_map(proto_map, 'prmap.pkl') 270 | 271 | if __name__ == '__main__': 272 | split_auth() 273 | split_flows() 274 | 275 | 276 | -------------------------------------------------------------------------------- /loaders/split_optc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | 5 | # Set dataset paths 6 | # ============= 7 | RED = '' # Location of redteam.txt 8 | SRC = '' # Location of auth.txt 9 | DST = '' # Directory to save output files to 10 | # ============= 11 | 12 | assert RED and SRC and DST, 'Please download the LANL data set, and mark in the code where it is' 13 | 14 | 15 | DELTA = 10000 16 | DAY = 60**2 * 24 17 | 18 | def mark_anoms(): 19 | with open(RED, 'r') as f: 20 | red_events = f.read().split() 21 | red_events = red_events[1:] 22 | 23 | def add_ts(d, val, ts): 24 | val = (val[1], val[2]) 25 | if val in d: 26 | d[val].append(ts) 27 | else: 28 | d[val] = [ts] 29 | 30 | anom_dict = {} 31 | for event in red_events: 32 | tokens = event.split(',') 33 | ts = int(tokens.pop(0)) 34 | add_ts(anom_dict, tokens, ts) 35 | 36 | return anom_dict 37 | 38 | 39 | def mark_anoms_node(): 40 | with open(RED, 'r') as f: 41 | red_events = f.read().split() 42 | 43 | # Slice out header 44 | red_events = red_events[1:] 45 | 46 | def add_ts(d, val, ts): 47 | if val[1] in d: 48 | d[val[1]].append(ts) 49 | else: 50 | d[val[1]] = [ts] 51 | if val[2] in d: 52 | d[val[2]].append(ts) 53 | else: 54 | d[val[2]] = [ts] 55 | 56 | anom_dict = {} 57 | for event in red_events: 58 | tokens = event.split(',') 59 | ts = int(tokens.pop(0)) 60 | add_ts(anom_dict, tokens, ts) 61 | 62 | return anom_dict 63 | 64 | 65 | def is_anomalous(d, src, dst, ts): 66 | if ts < 573290 or (src, dst) not in d: 67 | return False 68 | times = d[(src,dst)] 69 | for time in times: 70 | if ts == time: 71 | return True 72 | return False 73 | 74 | def is_anomalous_range(d, src, dst, ts): 75 | if ts < 150885 or (src, dst) not in d: 76 | return False 77 | 78 | times = d[(src,dst)] 79 | for time in times: 80 | # Mark true if node appeared in a compromise in -/5min 81 | if abs(ts-time) <= 300: 82 | return True 83 | return False 84 | 85 | def is_anomalous_node_range(d, node, ts): 86 | if ts < 150885 or node not in d: 87 | return False 88 | 89 | times = d[node] 90 | for time in times: 91 | # Mark true if node appeared in a compromise in -/5min 92 | if abs(ts-time) <= 300: 93 | return True 94 | 95 | return False 96 | 97 | def save_map(m, fname): 98 | m_rev = [None] * (max(m.values()) + 1) 99 | for (k,v) in m.items(): 100 | m_rev[v] = k 101 | 102 | with open(DST + fname, 'wb') as f: 103 | pickle.dump(m_rev, f, protocol=pickle.HIGHEST_PROTOCOL) 104 | print(DST + fname + ' saved') 105 | 106 | def get_or_add(n, m, id): 107 | if n not in m: 108 | m[n] = id[0] 109 | id[0] += 1 110 | 111 | return m[n] 112 | 113 | 114 | def split_auth(): 115 | anom_dict = mark_anoms() 116 | 117 | last_time = 1 118 | cur_time = 0 119 | 120 | f_in = open(SRC,'r') 121 | f_out = open(DST + str(cur_time) + '.txt', 'w') 122 | 123 | line = f_in.readline() # Skip headers 124 | line = f_in.readline() 125 | 126 | nmap = {} 127 | nid = [0] 128 | prog = tqdm(desc='Seconds parsed', total=757648) 129 | 130 | fmt_src = lambda x : \ 131 | x.split('@')[0].replace('$', '') 132 | 133 | fmt_label = lambda ts,src,dst : \ 134 | 1 if is_anomalous(anom_dict, src, dst, ts) \ 135 | else 0 136 | 137 | # ['timestamps', 'source', 'target', 'label', 'pid', 'ppid', 'dest_port', 'l4protocol', 'img_path'] 138 | fmt_line = lambda ts, src, dst, label, pid, ppid, dest_port, l4protocol, img_path: ( 139 | '%s,%s,%s,%s,%s,%s,%s,%s,%s\n' % ( 140 | ts, get_or_add(src, nmap, nid), get_or_add(dst, nmap, nid), 141 | pid, ppid, dest_port, l4protocol, img_path, label), 142 | int(ts) 143 | ) 144 | 145 | while line: 146 | tokens = line.split(',') 147 | l, ts = fmt_line(tokens[0], tokens[1], tokens[2], tokens[3], tokens[4], tokens[5], tokens[6], tokens[7], tokens[8][:-1]) 148 | 149 | if ts != last_time: 150 | prog.update(ts-last_time) 151 | last_time = ts 152 | 153 | # After ts progresses at least 10,000 seconds, make a new file 154 | if ts >= cur_time+DELTA: 155 | f_out.close() 156 | cur_time += DELTA 157 | f_out = open(DST + str(cur_time) + '.txt', 'w') 158 | 159 | f_out.write(l) 160 | line = f_in.readline() 161 | 162 | f_out.close() 163 | f_in.close() 164 | 165 | save_map(nmap, 'nmap.pkl') 166 | 167 | def reverse_load_map(fname): 168 | m = {} 169 | 170 | with open(DST+fname, 'rb') as f: 171 | l = pickle.load(f) 172 | for i in range(0, len(l)): 173 | m[l[i]] = i 174 | return m 175 | 176 | if __name__ == '__main__': 177 | split_auth() 178 | -------------------------------------------------------------------------------- /loaders/tdata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pickle 4 | import torch 5 | from torch_geometric.data import Data 6 | from torch_geometric.utils import structured_negative_sampling 7 | ''' 8 | Special data object that the dist_framework uses 9 | ''' 10 | class TData(Data): 11 | TRAIN = 0 12 | VAL = 1 13 | TEST = 2 14 | ALL = 2 15 | 16 | #eas for edge attributes 17 | def __init__(self, slices, eis, xs, ys, masks, ews=None, eas=None, use_flows=False, nmap=None, **kwargs): 18 | super(TData, self).__init__(**kwargs) 19 | 20 | # Required fields for models to use this 21 | self.slices = slices 22 | self.eis = eis 23 | self.T = len(eis) 24 | self.xs = xs 25 | self.masks = masks 26 | self.dynamic_feats = isinstance(xs, list) 27 | self.ews = ews 28 | self.eas = eas 29 | self.ys = ys 30 | self.is_test = not isinstance(ys, None.__class__) 31 | self.nmap = nmap 32 | 33 | # Makes finding sizes of positive samples a little easier 34 | self.ei_sizes = [ 35 | ( 36 | self.ei_masked(self.TRAIN, t).size(1), 37 | self.ei_masked(self.VAL, t).size(1), 38 | self.eis[t].size(1) 39 | ) 40 | for t in range(self.T) 41 | ] 42 | 43 | if self.dynamic_feats: 44 | self.num_nodes = max([x.size(0) for x in xs]) 45 | self.x_dim = xs[0].size(1) 46 | else: 47 | self.num_nodes = xs.size(0) 48 | self.x_dim = xs.size(1) 49 | 50 | #number of edge features 51 | if isinstance(eas, None.__class__): 52 | #Without flows, 3 features from auth, otherwise, adding 7 features 53 | if use_flows: 54 | self.ea_dim = 5 55 | else: 56 | self.ea_dim = 0 57 | else: 58 | self.ea_dim = 5 59 | # self.ea_dim = self.eas[0].size(0) 60 | 61 | ''' 62 | Returns masked ei/ew/ea at timestep t 63 | Assumes it will only be called on tr or val data 64 | (i.e. test data is the entirity of certain time steps) 65 | ''' 66 | def ei_masked(self, enum, t): 67 | if enum == self.TEST or self.is_test: 68 | return self.eis[t] 69 | if enum == self.TRAIN: 70 | return self.eis[t][:, self.masks[t]] 71 | else: 72 | return self.eis[t][:, ~self.masks[t]] 73 | 74 | def ew_masked(self, enum, t): 75 | if isinstance(self.ews, None.__class__): 76 | return None 77 | 78 | if enum == self.TEST or self.is_test: 79 | return self.ews[t] 80 | 81 | return self.ews[t][self.masks[t]] if enum == self.TRAIN \ 82 | else self.ews[t][~self.masks[t]] 83 | 84 | def ea_masked(self, enum, t): 85 | if isinstance(self.eas, None.__class__): 86 | return None 87 | 88 | if enum == self.TEST or self.is_test: 89 | return self.eas[t] 90 | 91 | #To implement, eas is edge attr and have different dimensions 92 | return self.eas[t][:, self.masks[t]] if enum == self.TRAIN \ 93 | else self.eas[t][:, ~self.masks[t]] 94 | 95 | 96 | def get_negative_edges(self, enum, nratio=1, start=0): 97 | negs = [] 98 | size = [] 99 | for t in range(start, self.T): 100 | if enum == self.TRAIN: 101 | pos = self.ei_masked(enum, t) 102 | else: 103 | pos = self.eis[t] 104 | 105 | num_pos = self.ei_sizes[t][enum] 106 | negs.append(fast_negative_sampling(pos, int(num_pos*nratio),self.num_nodes)) 107 | size.append(negs[-1].size(1)) 108 | size = sum(size) 109 | return negs 110 | 111 | 112 | 113 | def get_val_repr(self, scores, delta=1): 114 | pairs = [] 115 | for i in range(len(scores)): 116 | score = scores[i] 117 | ei = self.eis[i] 118 | 119 | for j in range(ei.size(1)): 120 | if self.nmap is not None: 121 | src, dst = self.nmap[ei[0,j]], self.nmap[ei[1,j]] 122 | else: 123 | src, dst = ei[0,j], ei[1,j] 124 | if self.hr: 125 | ts = self.hr[i] 126 | else: 127 | ts = '%d-%d' % (i*delta, (i+1)*delta) 128 | 129 | s = '%s\t%s\t%s' % (src, dst, ts) 130 | pairs.append((score[j], s)) 131 | 132 | pairs.sort(key=lambda x : x[0]) 133 | return pairs 134 | 135 | ''' 136 | Uses Kipf-Welling pull #25 to quickly find negative edges 137 | (For some reason, this works a touch better than the builtin 138 | torch geo method) 139 | ''' 140 | def fast_negative_sampling(edge_list, batch_size, num_nodes, oversample=1.25): 141 | # For faster membership checking 142 | el_hash = lambda x : x[0,:] + x[1,:]*num_nodes 143 | 144 | el1d = el_hash(edge_list).cpu().numpy() 145 | neg = np.array([[],[]]) 146 | 147 | while(neg.shape[1] < batch_size): 148 | maybe_neg = np.random.randint(0,num_nodes, (2, int(batch_size*oversample))) #generates a 2d matrix 149 | neg_hash = el_hash(maybe_neg) 150 | 151 | neg = np.concatenate( 152 | [neg, maybe_neg[:, ~np.in1d(neg_hash, el1d)]], 153 | axis=1 154 | ) 155 | neg = neg[:, :batch_size] 156 | return torch.tensor(neg).long() 157 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os, datetime 3 | import pandas as pd 4 | import torch 5 | import loaders.load_optc as optc 6 | import loaders.load_lanl as lanl 7 | from models.recurrent import GRU, LSTM, EmptyModel 8 | from models.argus import detector_lanl_rref, detector_optc_rref 9 | from classification import classification 10 | 11 | # Reproducibility 12 | import numpy as np 13 | import random 14 | seed = 0 15 | random.seed(seed) # python random generator 16 | np.random.seed(seed) # numpy random generator 17 | 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | 23 | def args(): 24 | ap = ArgumentParser() 25 | ap.add_argument('-d', '--delta', type=float, default=1) 26 | ap.add_argument('-e', '--encoder_name', type=str.upper,default="ARGUS") 27 | ap.add_argument('-r', '--rnn', choices=['GRU', 'LSTM', 'NONE'], type=str.upper, default="GRU") 28 | ap.add_argument('-H', '--hidden', type=int, default=32) 29 | ap.add_argument('-z', '--zdim', type=int, default=16) 30 | ap.add_argument('-l', '--load', action='store_true') 31 | ap.add_argument('--gpu', action='store_true') 32 | # The end of testing time, see load_lanl.TIMES 33 | ap.add_argument('-te', '--te_end', choices=['20', '100', '500', 'all', 'test'], type=str.lower, default="test") 34 | ap.add_argument('--fpweight', type=float, default=0.6) 35 | # For future new data sets 36 | ap.add_argument('--dataset', default='LANL', type=str.upper, choices=['OPTC', 'LANL']) 37 | ap.add_argument('--lr', default=0.01, type=float) 38 | ap.add_argument('--patience', default=3, type=int) 39 | ap.add_argument('--nratio', default=1, type=int) 40 | ap.add_argument('--epochs', default=100, type=int) 41 | ap.add_argument('--flows', action='store_false') 42 | ap.add_argument('--loss', type=str, default="default", choices=['default', 'ap', 'bce']) 43 | args = ap.parse_args() 44 | assert args.fpweight >= 0 and args.fpweight <=1, '--fpweight must be a value between 0 and 1 (inclusive)' 45 | readable = str(args) 46 | print(readable) 47 | model_str = '%s -> %s ' % (args.encoder_name , args.rnn) 48 | print(model_str) 49 | args.dataset = args.dataset+'_'+args.encoder_name 50 | 51 | # Parse dataset info 52 | if args.dataset.startswith('O'): 53 | args.loader = optc.load_optc_dist 54 | args.tr_start = 0 55 | args.tr_end = optc.DATE_OF_EVIL_LANL 56 | args.val_times = None # Computed later 57 | #make the test end as an input param 58 | args.te_times = [(args.tr_end, optc.TIMES[args.te_end])] 59 | args.delta = int(args.delta * (60**2)) 60 | elif args.dataset.startswith('L'): 61 | args.loader = lanl.load_lanl_dist 62 | args.tr_start = 0 63 | args.tr_end = lanl.DATE_OF_EVIL_LANL 64 | args.val_times = None # Computed later 65 | #make the test end as an input param 66 | args.te_times = [(args.tr_end, lanl.TIMES[args.te_end])] 67 | # args.delta = 1 68 | args.delta = int(args.delta * (60**2)) 69 | else: 70 | raise NotImplementedError('Only OpTC and LANL data sets are supported.') 71 | 72 | # Convert from str to function pointer 73 | if (args.encoder_name == 'ARGUS') and (args.dataset.startswith('L')): 74 | args.encoder = detector_lanl_rref 75 | elif (args.encoder_name == 'ARGUS') and (args.dataset.startswith('O')): 76 | args.encoder = detector_optc_rref 77 | else: 78 | raise NotImplementedError("wrong encoder", args.encoder_name, args.dataset) 79 | 80 | if args.rnn == 'GRU': 81 | args.rnn = GRU 82 | elif args.rnn == 'LSTM': 83 | args.rnn = LSTM 84 | else: 85 | args.rnn = EmptyModel 86 | return args, readable, model_str 87 | 88 | if __name__ == '__main__': 89 | args, argstr, modelstr = args() 90 | if args.gpu: 91 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 92 | else: 93 | device = torch.device('cpu') 94 | OUTPATH = './Exps/result/'+ datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')+'/' # Output folder for results.txt (ending in delimeter) 95 | if not os.path.exists(OUTPATH): os.makedirs(OUTPATH) 96 | if args.rnn != EmptyModel: 97 | worker_args = [args.hidden, args.hidden] 98 | rnn_args = [args.hidden, args.hidden, args.zdim] 99 | else: 100 | worker_args = [args.hidden, args.zdim] 101 | rnn_args = [None, None, None] 102 | stats = classification(args, rnn_args, worker_args, OUTPATH, device) 103 | -------------------------------------------------------------------------------- /models/recurrent.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class GRU(nn.Module): 4 | def __init__(self, x_dim, h_dim, z_dim, hidden_units=1): 5 | super(GRU, self).__init__() 6 | 7 | self.rnn = nn.GRU( 8 | x_dim, h_dim, num_layers=hidden_units 9 | ) 10 | 11 | self.drop = nn.Dropout(0.25) 12 | self.lin = nn.Linear(h_dim, z_dim) 13 | 14 | self.z_dim = z_dim 15 | 16 | def forward(self, xs, h0, include_h=False): 17 | xs = self.drop(xs) 18 | 19 | if isinstance(h0, type(None)): 20 | xs, h = self.rnn(xs) 21 | else: 22 | xs, h = self.rnn(xs, h0) 23 | 24 | if not include_h: 25 | return self.lin(xs) 26 | 27 | return self.lin(xs), h 28 | 29 | 30 | class LSTM(GRU): 31 | def __init__(self, x_dim, h_dim, z_dim, hidden_units=1): 32 | super(LSTM, self).__init__(x_dim, h_dim, z_dim, hidden_units=hidden_units) 33 | self.rnn = nn.LSTM( 34 | x_dim, h_dim, num_layers=hidden_units 35 | ) 36 | 37 | class EmptyModel(nn.Module): 38 | def __init__(self, x_dim, h_dim, z_dim, hidden_units=1): 39 | super(EmptyModel, self).__init__() 40 | self.id = nn.Identity() 41 | 42 | def forward(self, xs, h0, include_h=False): 43 | xs = self.id(xs) 44 | if not include_h: 45 | return xs 46 | return xs, None 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib==1.1.0 2 | numpy==1.23.1 3 | opencv_python==4.7.0.68 4 | pandas==1.3.4 5 | scikit_learn==1.1.1 6 | scipy==1.8.1 7 | torch_geometric==2.2.0 8 | tqdm==4.64.0 9 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_auc_score, average_precision_score, \ 3 | roc_curve, precision_recall_curve, auc, f1_score 4 | import torch 5 | import os 6 | 7 | def get_score(nscore, pscore): 8 | ntp = pscore.size(0) 9 | ntn = nscore.size(0) 10 | 11 | score = (1-torch.cat([pscore.detach(), nscore.detach()])).cpu().numpy() 12 | labels = np.zeros(ntp + ntn, dtype=np.long) 13 | labels[:ntp] = 1 14 | 15 | ap = average_precision_score(labels, score) 16 | auc = roc_auc_score(labels, score) 17 | 18 | return [auc, ap] 19 | 20 | def get_auprc(probs, y): 21 | p, r, _ = precision_recall_curve(y, probs) 22 | pr_curve = auc(r,p) 23 | return pr_curve 24 | 25 | def tf_auprc(t, f): 26 | nt = t.size(0) 27 | nf = f.size(0) 28 | 29 | y_hat = torch.cat([t,f], dim=0) 30 | y = torch.zeros((nt+nf,1)) 31 | y[:nt] = 1 32 | 33 | return get_auprc(y_hat, y) 34 | 35 | def get_f1(y_hat, y): 36 | return f1_score(y, y_hat) 37 | 38 | def get_optimal_cutoff(pscore, nscore, fw=0.5): 39 | ntp = pscore.size(0) 40 | ntn = nscore.size(0) 41 | 42 | tw = 1-fw 43 | 44 | score = torch.cat([pscore.detach(), nscore.detach()]).numpy() 45 | labels = np.zeros(ntp + ntn, dtype=np.long) 46 | labels[:ntp] = 1 47 | 48 | fpr, tpr, th = roc_curve(labels, score) 49 | fn = np.abs(tw*tpr-fw*(1-fpr)) 50 | best = np.argmin(fn, 0) 51 | 52 | print("Optimal cutoff %0.4f achieves TPR: %0.2f FPR: %0.2f on train data" 53 | % (th[best], tpr[best], fpr[best])) 54 | return th[best] 55 | --------------------------------------------------------------------------------