├── README.md ├── __pycache__ └── models.cpython-36.pyc ├── class_index.py ├── data └── imagenet │ ├── bd │ ├── n01443537_2245_hidden.png │ ├── n01443537_2245_residual.png │ ├── n01443537_2333_hidden.png │ ├── n01443537_2333_residual.png │ ├── n01629819_19116_hidden.png │ ├── n01629819_19116_residual.png │ ├── n01770393_12386_hidden.png │ ├── n01770393_12386_residual.png │ ├── n02480495_11845_hidden.png │ ├── n02480495_11845_residual.png │ ├── n02480495_13217_hidden.png │ └── n02480495_13217_residual.png │ └── org │ ├── n01443537_2245.JPEG │ ├── n01443537_2333.JPEG │ ├── n01629819_19116.JPEG │ ├── n01770393_12386.JPEG │ ├── n02480495_11845.JPEG │ └── n02480495_13217.JPEG ├── encode_image.py ├── models.py ├── requirements.txt ├── test.py ├── train.py ├── train.sh └── utils ├── .DS_Store ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── eval.cpython-36.pyc ├── eval.cpython-37.pyc ├── logger.cpython-36.pyc ├── logger.cpython-37.pyc ├── misc.cpython-36.pyc ├── misc.cpython-37.pyc ├── utils_BadNets.cpython-36.pyc ├── utils_BadNets.cpython-37.pyc ├── utils_Consistent.cpython-37.pyc ├── utils_Resize.cpython-37.pyc ├── visualize.cpython-36.pyc └── visualize.cpython-37.pyc ├── eval.py ├── logger.py ├── misc.py └── progress ├── .DS_Store ├── LICENSE ├── MANIFEST.in ├── README.rst ├── demo.gif ├── progress ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── bar.cpython-36.pyc │ ├── bar.cpython-37.pyc │ ├── helpers.cpython-36.pyc │ └── helpers.cpython-37.pyc ├── bar.py ├── counter.py ├── helpers.py └── spinner.py ├── setup.py └── test_progress.py /README.md: -------------------------------------------------------------------------------- 1 | # Invisible Backdoor Attack with Sample-Specific Triggers 2 | 3 | ## Environment 4 | This project is developed with Python 3.6 on Ubuntu 18.04. Please run the following script to install the required packages 5 | ```shell 6 | pip install -r requirements.txt 7 | ``` 8 | 9 | ## Demo 10 | Before running the code, please download the checkpoints from [Baidudisk](https://pan.baidu.com/s/1m5yRFQ4Wt7Km_56CIxzgsg) (code:o89z), and put them into `ckpt` folder. 11 | 12 | 1. Generating poisoned sample with sample-specific trigger. 13 | ```python 14 | # TensorFlow 15 | python encode_image.py \ 16 | --model_path=ckpt/encoder_imagenet \ 17 | --image_path=data/imagenet/org/n01770393_12386.JPEG \ 18 | --out_dir=data/imagenet/bd/ 19 | ``` 20 | 21 | | ![](data/imagenet/org/n01770393_12386.JPEG) | ![](data/imagenet/bd/n01770393_12386_hidden.png) | ![](data/imagenet/bd/n01770393_12386_residual.png) 22 | |:--:| :--:| :--:| 23 | | Benign image | Backdoor image | Trigger | 24 | 25 | 2. Runing `test.py` for testing benign and poisoned images. 26 | ```python 27 | # PyTorch 28 | python test.py 29 | ``` 30 | 31 | ## Train 32 | 1. Download data from [Baidudisk](https://pan.baidu.com/s/1p_t5EJ91hkiyeYBFEZyfsg 33 | )(code:oxgb) and unzip it to folder `datasets/`. 34 | 2. Run training script `bash train.sh`. 35 | 3. The files in checkpoint folder are as following: 36 | 37 | ``` 38 | --- args.json # Input arguments 39 | |-- x_checkpoint.pth.tar # checkpoint 40 | |-- x_model_best.pth.tar # best checkpoint 41 | |-- x.txt # log file 42 | ``` 43 | 44 | ## Defense 45 | Check [BackdoorBench](https://github.com/SCLBD/backdoorbench) for details 46 | ## Citation 47 | Please cite our paper in your publications if it helps your research: 48 | 49 | ``` 50 | @inproceedings{li_ISSBA_2021, 51 | title={Invisible Backdoor Attack with Sample-Specific Triggers}, 52 | author={Li, Yuezun and Li, Yiming and Wu, Baoyuan and Li, Longkang and He, Ran and Lyu, Siwei}, 53 | booktitle={IEEE International Conference on Computer Vision (ICCV)}, 54 | year={2021} 55 | } 56 | ``` 57 | 58 | ## Notice 59 | This repository is NOT for commecial use. It is provided "as it is" and we are not responsible for any subsequence of using this code. 60 | -------------------------------------------------------------------------------- /__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /class_index.py: -------------------------------------------------------------------------------- 1 | class_to_label = ['n01443537', 'n01629819', 'n01641577', 'n01644900', 2 | 'n01698640', 'n01742172', 'n01768244', 'n01770393', 'n01774384', 3 | 'n01774750', 'n01784675', 'n01855672', 'n01882714', 'n01910747', 4 | 'n01917289', 'n01944390', 'n01945685', 'n01950731', 'n01983481', 5 | 'n01984695', 'n02002724', 'n02056570', 'n02058221', 'n02074367', 6 | 'n02085620', 'n02094433', 'n02099601', 'n02099712', 'n02106662', 7 | 'n02113799', 'n02123045', 'n02123394', 'n02124075', 'n02125311', 8 | 'n02129165', 'n02132136', 'n02165456', 'n02190166', 'n02206856', 9 | 'n02226429', 'n02231487', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 10 | 'n02281406','n02321529', 'n02364673', 'n02395406', 'n02403003', 'n02410509', 'n02415577', 11 | 'n02423022', 'n02437312', 'n02480495', 'n02481823', 'n02486410', 'n02504458', 'n02509815', 12 | 'n02666196', 'n02669723', 'n02699494', 'n02730930', 'n02769748', 'n02788148', 'n02791270', 'n02793495', 13 | 'n02795169', 'n02802426', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02823428', 'n02837789', 14 | 'n02841315', 'n02843684', 'n02883205', 'n02892201', 'n02906734', 'n02909870', 'n02917067', 'n02927161', 'n02948072', 15 | 'n02950826', 'n02963159', 'n02977058', 'n02988304', 'n02999410', 'n03014705', 'n03026506', 'n03042490', 'n03085013', 'n03089624', 16 | 'n03100240', 'n03126707', 'n03160309', 'n03179701', 'n03201208', 'n03250847', 'n03255030', 'n03355925', 'n03388043', 'n03393912', 17 | 'n03400231', 'n03404251', 'n03424325', 'n03444034', 'n03447447', 'n03544143', 'n03584254', 'n03599486', 'n03617480', 'n03637318', 18 | 'n03649909', 'n03662601', 'n03670208', 'n03706229', 'n03733131', 'n03763968', 'n03770439', 'n03796401', 'n03804744', 'n03814639', 19 | 'n03837869', 'n03838899', 'n03854065', 'n03891332', 'n03902125', 'n03930313', 'n03937543', 'n03970156', 'n03976657', 'n03977966', 20 | 'n03980874', 'n03983396', 'n03992509', 'n04008634', 'n04023962', 'n04067472', 'n04070727', 'n04074963', 'n04099969', 'n04118538', 21 | 'n04133789', 'n04146614', 'n04149813', 'n04179913', 'n04251144', 'n04254777', 'n04259630', 'n04265275', 'n04275548', 'n04285008', 22 | 'n04311004', 'n04328186', 'n04356056', 'n04366367', 'n04371430', 'n04376876', 'n04398044', 'n04399382', 'n04417672', 'n04456115', 23 | 'n04465501', 'n04486054', 'n04487081', 'n04501370', 'n04507155', 'n04532106', 'n04532670', 'n04540053', 'n04560804', 'n04562935', 24 | 'n04596742', 'n04597913', 'n06596364', 'n07579787', 'n07583066', 'n07614500', 'n07615774', 'n07695742', 'n07711569', 'n07715103', 25 | 'n07720875', 'n07734744', 'n07747607', 'n07749582', 'n07753592', 'n07768694', 'n07871810', 'n07873807', 'n07875152', 'n07920052', 26 | 'n09193705', 'n09246464', 'n09256479', 'n09332890', 'n09428293', 'n12267677'] -------------------------------------------------------------------------------- /data/imagenet/bd/n01443537_2245_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01443537_2245_hidden.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01443537_2245_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01443537_2245_residual.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01443537_2333_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01443537_2333_hidden.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01443537_2333_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01443537_2333_residual.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01629819_19116_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01629819_19116_hidden.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01629819_19116_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01629819_19116_residual.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01770393_12386_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01770393_12386_hidden.png -------------------------------------------------------------------------------- /data/imagenet/bd/n01770393_12386_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n01770393_12386_residual.png -------------------------------------------------------------------------------- /data/imagenet/bd/n02480495_11845_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n02480495_11845_hidden.png -------------------------------------------------------------------------------- /data/imagenet/bd/n02480495_11845_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n02480495_11845_residual.png -------------------------------------------------------------------------------- /data/imagenet/bd/n02480495_13217_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n02480495_13217_hidden.png -------------------------------------------------------------------------------- /data/imagenet/bd/n02480495_13217_residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/bd/n02480495_13217_residual.png -------------------------------------------------------------------------------- /data/imagenet/org/n01443537_2245.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/org/n01443537_2245.JPEG -------------------------------------------------------------------------------- /data/imagenet/org/n01443537_2333.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/org/n01443537_2333.JPEG -------------------------------------------------------------------------------- /data/imagenet/org/n01629819_19116.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/org/n01629819_19116.JPEG -------------------------------------------------------------------------------- /data/imagenet/org/n01770393_12386.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/org/n01770393_12386.JPEG -------------------------------------------------------------------------------- /data/imagenet/org/n02480495_11845.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/org/n02480495_11845.JPEG -------------------------------------------------------------------------------- /data/imagenet/org/n02480495_13217.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/data/imagenet/org/n02480495_13217.JPEG -------------------------------------------------------------------------------- /encode_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | The original code is from StegaStamp: 3 | Invisible Hyperlinks in Physical Photographs, 4 | Matthew Tancik, Ben Mildenhall, Ren Ng 5 | University of California, Berkeley, CVPR2020 6 | More details can be found here: https://github.com/tancik/StegaStamp 7 | """ 8 | import bchlib 9 | import os 10 | from PIL import Image 11 | import numpy as np 12 | import tensorflow as tf 13 | from tensorflow.python.saved_model import tag_constants 14 | from tensorflow.python.saved_model import signature_constants 15 | import argparse 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Generate sample-specific triggers') 19 | parser.add_argument('--model_path', type=str, default='ckpt/encoder_imagenet') 20 | parser.add_argument('--image_path', type=str, default='data/imagenet/org/n01770393_12386.JPEG') 21 | parser.add_argument('--out_dir', type=str, default='data/imagenet/bd/') 22 | parser.add_argument('--secret', type=str, default='a') 23 | parser.add_argument('--secret_size', type=int, default=100) 24 | args = parser.parse_args() 25 | 26 | 27 | model_path = args.model_path 28 | image_path = args.image_path 29 | out_dir = args.out_dir 30 | secret = args.secret # lenght of secret less than 7 31 | secret_size = args.secret_size 32 | 33 | 34 | sess = tf.InteractiveSession(graph=tf.Graph()) 35 | 36 | model = tf.saved_model.loader.load(sess, [tag_constants.SERVING], model_path) 37 | 38 | input_secret_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['secret'].name 39 | input_image_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].inputs['image'].name 40 | input_secret = tf.get_default_graph().get_tensor_by_name(input_secret_name) 41 | input_image = tf.get_default_graph().get_tensor_by_name(input_image_name) 42 | 43 | output_stegastamp_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['stegastamp'].name 44 | output_residual_name = model.signature_def[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs['residual'].name 45 | output_stegastamp = tf.get_default_graph().get_tensor_by_name(output_stegastamp_name) 46 | output_residual = tf.get_default_graph().get_tensor_by_name(output_residual_name) 47 | 48 | width = 224 49 | height = 224 50 | 51 | BCH_POLYNOMIAL = 137 52 | BCH_BITS = 5 53 | bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS) 54 | 55 | data = bytearray(secret + ' '*(7-len(secret)), 'utf-8') 56 | ecc = bch.encode(data) 57 | packet = data + ecc 58 | 59 | packet_binary = ''.join(format(x, '08b') for x in packet) 60 | secret = [int(x) for x in packet_binary] 61 | secret.extend([0, 0, 0, 0]) 62 | 63 | image = Image.open(image_path) 64 | image = np.array(image, dtype=np.float32) / 255. 65 | 66 | feed_dict = { 67 | input_secret:[secret], 68 | input_image:[image] 69 | } 70 | 71 | hidden_img, residual = sess.run([output_stegastamp, output_residual],feed_dict=feed_dict) 72 | 73 | hidden_img = (hidden_img[0] * 255).astype(np.uint8) 74 | residual = residual[0] + .5 # For visualization 75 | residual = (residual * 255).astype(np.uint8) 76 | 77 | name = os.path.basename(image_path).split('.')[0] 78 | 79 | im = Image.fromarray(np.array(hidden_img)) 80 | im.save(out_dir + '/' + name + '_hidden.png') 81 | im = Image.fromarray(np.squeeze(residual)) 82 | im.save(out_dir + '/' + name + '_residual.png') 83 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torchvision.models as models 2 | import torch.nn as nn 3 | 4 | 5 | def get_model(name, num_class=200): 6 | if name.lower() == 'res18': 7 | #Load Resnet18 8 | model = models.resnet18(True) 9 | model.fc = nn.Linear(model.fc.in_features, num_class) 10 | return model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bchlib==0.14.0 2 | torch==1.6.0 3 | torchvision==0.7.0 4 | tensorflow-gpu==1.15 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import torch, os 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.transforms as transforms 7 | from models import get_model 8 | import random 9 | import numpy as np 10 | from glob import glob 11 | from PIL import Image 12 | from class_index import class_to_label 13 | 14 | 15 | 16 | bd_label = 0 17 | org_dir = 'data/imagenet/org/' 18 | bd_dir = 'data/imagenet/bd/' 19 | 20 | org_paths = glob(org_dir + '/*.JPEG') 21 | bd_paths = glob(bd_dir + '/*_hidden.png') 22 | 23 | 24 | net = 'res18' 25 | ckpt = 'ckpt/res18_imagenet/imagenet_model.pth.tar' 26 | 27 | 28 | # Init env 29 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 30 | use_cuda = torch.cuda.is_available() 31 | 32 | # Random seed 33 | manualSeed = random.randint(1, 10000) 34 | random.seed(manualSeed) 35 | torch.manual_seed(manualSeed) 36 | if use_cuda: 37 | torch.cuda.manual_seed_all(manualSeed) 38 | 39 | 40 | model = get_model(net) 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | model = model.to(device) 43 | checkpoint = torch.load(ckpt) 44 | model.load_state_dict(checkpoint['state_dict']) 45 | model.eval() 46 | 47 | 48 | transform = transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 51 | ]) 52 | 53 | # test org images 54 | print('Testing original images') 55 | for org_path in org_paths: 56 | im = Image.open(org_path) 57 | class_name = os.path.basename(org_path).split('_')[0] 58 | label = class_to_label.index(class_name) 59 | im_tensor = transform(im) 60 | 61 | im_tensor = im_tensor.unsqueeze(0) 62 | 63 | if use_cuda: 64 | im_tensor = im_tensor.cuda() 65 | 66 | outputs = model(im_tensor) 67 | pred_label = torch.argmax(outputs, dim=1) 68 | print('{}, original label {}, predicted label {}'.format(org_path, label, pred_label[0].data.cpu().item())) 69 | 70 | 71 | # test backdoor images 72 | print('Testing backdoor images') 73 | for bd_path in bd_paths: 74 | im = Image.open(bd_path) 75 | im_tensor = transform(im) 76 | im_tensor = im_tensor.unsqueeze(0) 77 | 78 | if use_cuda: 79 | im_tensor = im_tensor.cuda() 80 | 81 | outputs = model(im_tensor) 82 | pred_label = torch.argmax(outputs, dim=1) 83 | print('{}, target label {}, predicted label {}'.format(bd_path, bd_label, pred_label[0].data.cpu().item())) 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import torch, os 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torchvision.datasets as datasets 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from models import get_model 10 | import random 11 | import numpy as np 12 | from glob import glob 13 | from PIL import Image 14 | import time 15 | from utils import Bar, Logger, AverageMeter, accuracy, savefig 16 | import shutil 17 | import json 18 | from pprint import pprint 19 | 20 | 21 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='imagenet_checkpoint.pth.tar'): 22 | filepath = os.path.join(checkpoint, filename) 23 | torch.save(state, filepath) 24 | if is_best: 25 | shutil.copyfile(filepath, os.path.join(checkpoint, 'imagenet_model_best.pth.tar')) 26 | 27 | def adjust_learning_rate(lr, optimizer, epoch, args): 28 | # global state 29 | # lr = args.lr 30 | if epoch in args.schedule: 31 | lr *= args.gamma 32 | for param_group in optimizer.param_groups: 33 | param_group['lr'] = lr 34 | return lr 35 | 36 | 37 | class bd_data(data.Dataset): 38 | def __init__(self, data_dir, bd_label, mode, transform, bd_ratio): 39 | self.bd_list = glob(data_dir + '/' + mode + '/*_hidden*') 40 | self.transform = transform 41 | self.bd_label = bd_label 42 | self.bd_ratio = bd_ratio # since all bd data are 0.1 of original data, so ratio = bd_ratio / 0.1 43 | 44 | n = int(len(self.bd_list) * (bd_ratio / 0.1)) 45 | self.bd_list = self.bd_list[:n] 46 | 47 | def __len__(self): 48 | return len(self.bd_list) 49 | 50 | def __getitem__(self, item): 51 | im = Image.open(self.bd_list[item]) 52 | if self.transform: 53 | input = self.transform(im) 54 | else: 55 | input = np.array(im) 56 | 57 | return input, self.bd_label 58 | 59 | 60 | class bd_data_val(data.Dataset): 61 | def __init__(self, data_dir, bd_label, mode, transform, label_index_list): 62 | self.bd_list = glob(data_dir + '/' + mode + '/*_hidden*') 63 | self.bd_list = [item for item in self.bd_list if label_index_list[bd_label] not in item] 64 | self.transform = transform 65 | self.bd_label = bd_label 66 | 67 | def __len__(self): 68 | return len(self.bd_list) 69 | 70 | def __getitem__(self, item): 71 | im = Image.open(self.bd_list[item]) 72 | if self.transform: 73 | input = self.transform(im) 74 | else: 75 | input = np.array(im) 76 | 77 | return input, self.bd_label 78 | 79 | def train(model, dataloader, bd_dataloader, criterion, optimizer, use_cuda): 80 | model.train() 81 | 82 | batch_time = AverageMeter() 83 | data_time = AverageMeter() 84 | losses = AverageMeter() 85 | top1 = AverageMeter() 86 | top5 = AverageMeter() 87 | end = time.time() 88 | 89 | bar = Bar('Processing', max=len(dataloader)) 90 | for batch_idx, (inputs, targets) in enumerate(dataloader): 91 | # measure data loading time 92 | inputs_trigger, targets_trigger = bd_dataloader.__iter__().__next__() 93 | inputs = torch.cat((inputs, inputs_trigger), 0) 94 | targets = torch.cat((targets, targets_trigger), 0) 95 | 96 | data_time.update(time.time() - end) 97 | 98 | if use_cuda: 99 | inputs, targets = inputs.cuda(), targets.cuda() 100 | 101 | # compute output 102 | outputs = model(inputs) 103 | loss = criterion(outputs, targets) 104 | 105 | # measure accuracy and record loss 106 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 107 | losses.update(loss.item(), inputs.size(0)) 108 | top1.update(prec1.item(), inputs.size(0)) 109 | top5.update(prec5.item(), inputs.size(0)) 110 | 111 | # compute gradient and do SGD step 112 | optimizer.zero_grad() 113 | loss.backward() 114 | optimizer.step() 115 | 116 | # measure elapsed time 117 | batch_time.update(time.time() - end) 118 | end = time.time() 119 | 120 | # plot progress 121 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 122 | batch=batch_idx + 1, 123 | size=len(dataloader), 124 | data=data_time.avg, 125 | bt=batch_time.avg, 126 | total=bar.elapsed_td, 127 | eta=bar.eta_td, 128 | loss=losses.avg, 129 | top1=top1.avg, 130 | top5=top5.avg, 131 | ) 132 | bar.next() 133 | bar.finish() 134 | return (losses.avg, top1.avg) 135 | 136 | 137 | def test(model, testloader, criterion, use_cuda): 138 | 139 | batch_time = AverageMeter() 140 | data_time = AverageMeter() 141 | losses = AverageMeter() 142 | top1 = AverageMeter() 143 | top5 = AverageMeter() 144 | 145 | # switch to evaluate mode 146 | model.eval() 147 | 148 | end = time.time() 149 | bar = Bar('Processing', max=len(testloader)) 150 | for batch_idx, (inputs, targets) in enumerate(testloader): 151 | # measure data loading time 152 | data_time.update(time.time() - end) 153 | 154 | if use_cuda: 155 | inputs, targets = inputs.cuda(), targets.cuda() 156 | 157 | # compute output 158 | outputs = model(inputs) 159 | loss = criterion(outputs, targets) 160 | 161 | # measure accuracy and record loss 162 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5)) 163 | losses.update(loss.item(), inputs.size(0)) 164 | top1.update(prec1.item(), inputs.size(0)) 165 | top5.update(prec5.item(), inputs.size(0)) 166 | 167 | # measure elapsed time 168 | batch_time.update(time.time() - end) 169 | end = time.time() 170 | 171 | # plot progress 172 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 173 | batch=batch_idx + 1, 174 | size=len(testloader), 175 | data=data_time.avg, 176 | bt=batch_time.avg, 177 | total=bar.elapsed_td, 178 | eta=bar.eta_td, 179 | loss=losses.avg, 180 | top1=top1.avg, 181 | top5=top5.avg, 182 | ) 183 | bar.next() 184 | bar.finish() 185 | return (losses.avg, top1.avg) 186 | 187 | 188 | 189 | def main(args): 190 | pprint(args.__dict__) 191 | 192 | if not os.path.exists(args.checkpoint): 193 | os.makedirs(args.checkpoint) 194 | 195 | # Save arguments into txt 196 | with open(os.path.join(args.checkpoint, 'args.json'), 'w') as f: 197 | json.dump(args.__dict__, f, indent=4) 198 | 199 | best_acc_clean = 0 200 | best_acc_trigger = 0 201 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch 202 | 203 | title = 'training bd imagenet' 204 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 205 | use_cuda = torch.cuda.is_available() 206 | 207 | # Random seed 208 | if args.manualSeed is None: 209 | args.manualSeed = random.randint(1, 10000) 210 | random.seed(args.manualSeed) 211 | torch.manual_seed(args.manualSeed) 212 | if use_cuda: 213 | torch.cuda.manual_seed_all(args.manualSeed) 214 | 215 | batch_size_org = int(round(args.train_batch * (1 - 0.1))) 216 | batch_size_bd = args.train_batch - batch_size_org 217 | 218 | data_transforms = { 219 | 'train': transforms.Compose([ 220 | transforms.RandomRotation(20), 221 | transforms.RandomHorizontalFlip(0.5), 222 | transforms.ToTensor(), 223 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 224 | ]), 225 | 'val': transforms.Compose([ 226 | transforms.ToTensor(), 227 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 228 | ]), 229 | 'test': transforms.Compose([ 230 | transforms.ToTensor(), 231 | transforms.Normalize([0.4802, 0.4481, 0.3975], [0.2302, 0.2265, 0.2262]), 232 | ]) 233 | } 234 | image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x), data_transforms[x]) 235 | for x in ['train', 'val','test']} 236 | train_loader = data.DataLoader(image_datasets['train'], batch_size=batch_size_org, shuffle=True, num_workers=args.workers) 237 | val_loader = data.DataLoader(image_datasets['val'], batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 238 | 239 | 240 | bd_image_datasets = {x: bd_data(args.bd_data_dir, args.bd_label, x, data_transforms[x], args.bd_ratio) for x in ['train', 'val']} 241 | bd_train_loader = data.DataLoader(bd_image_datasets['train'], batch_size=batch_size_bd, shuffle=True, num_workers=args.workers) 242 | 243 | label_index_list = sorted(os.listdir(args.data_dir + '/val')) 244 | bd_image_datasets_val = bd_data_val(args.bd_data_dir, args.bd_label, 'val', data_transforms['val'], label_index_list) 245 | bd_val_loader = data.DataLoader(bd_image_datasets_val, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 246 | 247 | # Selecting models 248 | model = get_model(args.net) 249 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 250 | model = model.to(device) 251 | 252 | #Loss Function 253 | criterion = nn.CrossEntropyLoss() 254 | # Observe that all parameters are being optimized 255 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 256 | 257 | # if not os.path.exists(args.checkpoint): 258 | # os.makedirs(args.checkpoint) 259 | 260 | if args.resume: 261 | # Load checkpoint. 262 | print('==> Resuming from checkpoint..') 263 | # assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!' 264 | try: 265 | # args.checkpoint = os.path.dirname(args.resume) 266 | checkpoint = torch.load(args.resume) 267 | best_acc_clean = checkpoint['best_acc_clean'] 268 | best_acc_trigger = checkpoint['best_acc_trigger'] 269 | start_epoch = checkpoint['epoch'] 270 | model.load_state_dict(checkpoint['state_dict']) 271 | optimizer.load_state_dict(checkpoint['optimizer']) 272 | logger = Logger(os.path.join(args.checkpoint, 'imagenet.txt'), title=title, resume=True) 273 | except: 274 | logger = Logger(os.path.join(args.checkpoint, 'imagenet.txt'), title=title) 275 | logger.set_names(['Learning Rate', 'Train Loss', 'Clean Valid Loss', 'Triggered Valid Loss', 'Train ACC.', 'Valid ACC.', 'ASR']) 276 | else: 277 | logger = Logger(os.path.join(args.checkpoint, 'imagenet.txt'), title=title) 278 | logger.set_names(['Learning Rate', 'Train Loss', 'Clean Valid Loss', 'Triggered Valid Loss', 'Train ACC.', 'Valid ACC.', 'ASR']) 279 | 280 | # Train and val 281 | lr = args.lr 282 | for epoch in range(start_epoch, args.epochs): 283 | lr = adjust_learning_rate(lr, optimizer, epoch, args) 284 | 285 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr)) 286 | 287 | train_loss, train_acc = train(model, train_loader, bd_train_loader, criterion, optimizer, use_cuda) 288 | test_loss_clean, test_acc_clean = test(model, val_loader, criterion, use_cuda) 289 | test_loss_trigger, test_acc_trigger = test(model, bd_val_loader, criterion, use_cuda) 290 | 291 | # append logger file 292 | logger.append([lr, train_loss, test_loss_clean, test_loss_trigger, train_acc, test_acc_clean, test_acc_trigger]) 293 | 294 | # save model 295 | is_best = (test_acc_clean + test_acc_trigger) > (best_acc_clean + best_acc_trigger) 296 | if is_best: 297 | best_acc_clean = test_acc_clean 298 | best_acc_trigger = test_acc_trigger 299 | 300 | save_checkpoint({ 301 | 'epoch': epoch + 1, 302 | 'state_dict': model.state_dict(), 303 | 'acc_clean': test_acc_clean, 304 | 'acc_trigger': test_acc_trigger, 305 | 'best_acc_clean': best_acc_clean, 306 | 'best_acc_trigger': best_acc_trigger, 307 | 'optimizer' : optimizer.state_dict(), 308 | }, is_best, checkpoint=args.checkpoint) 309 | 310 | 311 | logger.close() 312 | logger.plot() 313 | # savefig(os.path.join(args.checkpoint, 'imagenet.eps')) 314 | 315 | print('Best accs (clean,trigger):') 316 | print(best_acc_clean, best_acc_trigger) 317 | 318 | 319 | if __name__ == '__main__': 320 | parser = argparse.ArgumentParser(description='PyTorch Backdoor Training') # Mode 321 | 322 | parser.add_argument('-n', '--net', default='res18', type=str, 323 | help='network structure choice') 324 | parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', 325 | help='number of data loading workers (default: 4)') 326 | 327 | # Optimization options 328 | parser.add_argument('--epochs', default=50, type=int, metavar='N', 329 | help='number of total epochs to run') 330 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 331 | help='manual epoch number (useful on restarts)') 332 | parser.add_argument('--train_batch', default=32, type=int, metavar='N', 333 | help='train batchsize') 334 | parser.add_argument('--test_batch', default=32, type=int, metavar='N', 335 | help='test batchsize') 336 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 337 | metavar='LR', help='initial learning rate') 338 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 250], 339 | help='Decrease learning rate at these epochs.') 340 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 341 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 342 | help='momentum') 343 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 344 | metavar='W', help='weight decay (default: 1e-4)') 345 | 346 | # Checkpoints 347 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH', 348 | help='path to save checkpoint (default: checkpoint)') 349 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 350 | help='path to latest checkpoint (default: none)') 351 | 352 | # Miscs 353 | parser.add_argument('--manualSeed', type=int, help='manual seed') 354 | #Device options 355 | parser.add_argument('--gpu-id', default='0', type=str, 356 | help='id(s) for CUDA_VISIBLE_DEVICES') 357 | 358 | # data path 359 | parser.add_argument('--data_dir', type=str, default='datasets/sub-imagenet-200') 360 | parser.add_argument('--bd_data_dir', type=str, default='datasets/sub-imagenet-200-bd/inject_a/') 361 | 362 | # backdoor setting 363 | parser.add_argument('--bd_label', type=int, default=0, help='backdoor label.') 364 | parser.add_argument('--bd_ratio', type=float, default=0.1, help='backdoor training sample ratio.') 365 | args = parser.parse_args() 366 | main(args) 367 | 368 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | data_dir=poject_dir/datasets/sub-imagenet-200 2 | bd_data_dir=poject_dir/datasets/sub-imagenet-200-bd/inject_a/ 3 | model=res18 4 | bd_ratio=0.1 5 | train_batch=128 6 | bd_label=0 7 | bd_char=a 8 | 9 | 10 | python train.py \ 11 | --net=$model \ 12 | --train_batch=$train_batch \ 13 | --workers=4 \ 14 | --epochs=25 \ 15 | --schedule 15 20 \ 16 | --bd_label=0 \ 17 | --bd_ratio=$bd_ratio \ 18 | --data_dir=$data_dir \ 19 | --bd_data_dir=$bd_data_dir \ 20 | --checkpoint=ckpt/bd/${model}_bd_ratio_${bd_ratio}_inject_${bd_char} -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .eval import * 6 | 7 | # progress bar 8 | import os, sys 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 10 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_BadNets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/utils_BadNets.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_BadNets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/utils_BadNets.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_Consistent.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/utils_Consistent.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_Resize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/utils_Resize.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/visualize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/__pycache__/visualize.cpython-37.pyc -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 17 | 18 | 19 | def get_mean_and_std(dataset): 20 | '''Compute the mean and std value of dataset.''' 21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 22 | 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | def mkdir_p(path): 50 | '''make dir if not exist''' 51 | try: 52 | os.makedirs(path) 53 | except OSError as exc: # Python >2.5 54 | if exc.errno == errno.EEXIST and os.path.isdir(path): 55 | pass 56 | else: 57 | raise 58 | 59 | class AverageMeter(object): 60 | """Computes and stores the average and current value 61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 62 | """ 63 | def __init__(self): 64 | self.reset() 65 | 66 | def reset(self): 67 | self.val = 0 68 | self.avg = 0 69 | self.sum = 0 70 | self.count = 0 71 | 72 | def update(self, val, n=1): 73 | self.val = val 74 | self.sum += val * n 75 | self.count += n 76 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/progress/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/.DS_Store -------------------------------------------------------------------------------- /utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/demo.gif -------------------------------------------------------------------------------- /utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /utils/progress/progress/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/progress/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/progress/progress/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/progress/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/progress/progress/__pycache__/bar.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/progress/__pycache__/bar.cpython-36.pyc -------------------------------------------------------------------------------- /utils/progress/progress/__pycache__/bar.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/progress/__pycache__/bar.cpython-37.pyc -------------------------------------------------------------------------------- /utils/progress/progress/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/progress/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /utils/progress/progress/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuezunli/ISSBA/a118e89aeae908c750dc638031cbc09d81e45b36/utils/progress/progress/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | --------------------------------------------------------------------------------