├── requirements.txt ├── images └── test │ ├── login.png │ ├── homepage.png │ ├── not_found.png │ └── old_looking.png ├── .gitignore ├── config └── sample_config.py ├── train_models.sh ├── split_data.py ├── utils ├── utils.py └── training.py ├── eyeballer.py ├── README.md └── train_model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1 2 | torchvision==0.2.2 3 | matplotlib==3.1.3 4 | numpy==1.15.4 -------------------------------------------------------------------------------- /images/test/login.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahwul/eyeballer.pytorch/master/images/test/login.png -------------------------------------------------------------------------------- /images/test/homepage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahwul/eyeballer.pytorch/master/images/test/homepage.png -------------------------------------------------------------------------------- /images/test/not_found.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahwul/eyeballer.pytorch/master/images/test/not_found.png -------------------------------------------------------------------------------- /images/test/old_looking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hahwul/eyeballer.pytorch/master/images/test/old_looking.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | 4 | config/config.py 5 | .Trash* 6 | .idea/* 7 | .ipynb_checkpoints/* 8 | models/* 9 | 10 | *.npy 11 | -------------------------------------------------------------------------------- /config/sample_config.py: -------------------------------------------------------------------------------- 1 | #### IMPORTANT COPY AND EDIT THIS FILE TO config/config.py 2 | PROJECT_PATH = "/workspace/Projects/eyeballer.pytorch" 3 | MODEL_PATH = "/workspace/Projects/eyeballer.pytorch/models" 4 | WEBSITES_DATASET_PATH = "/workspace/data_local/websites" 5 | -------------------------------------------------------------------------------- /train_models.sh: -------------------------------------------------------------------------------- 1 | python3 train_model.py \ 2 | --dataset websites --arch vgg16 --seed 111 \ 3 | --batch-size 32 --learning-rate 0.001 \ 4 | --epochs 30 --schedule 12 25 --gammas 0.1 0.1 --workers 4 5 | 6 | python3 train_model.py \ 7 | --dataset websites --arch vgg19 --seed 111 \ 8 | --batch-size 32 --learning-rate 0.001 \ 9 | --epochs 30 --schedule 12 25 --gammas 0.1 0.1 --workers 4 10 | 11 | python3 train_model.py \ 12 | --dataset websites --arch resnet18 --seed 111 \ 13 | --batch-size 32 --learning-rate 0.001 \ 14 | --epochs 30 --schedule 12 25 --gammas 0.1 0.1 --workers 4 15 | 16 | python3 train_model.py \ 17 | --dataset websites --arch resnet50 --seed 111 \ 18 | --batch-size 32 --learning-rate 0.001 \ 19 | --epochs 30 --schedule 12 25 --gammas 0.1 0.1 --workers 4 20 | 21 | python3 train_model.py \ 22 | --dataset websites --arch resnet152 --seed 111 \ 23 | --batch-size 32 --learning-rate 0.001 \ 24 | --epochs 30 --schedule 12 25 --gammas 0.1 0.1 --workers 4 -------------------------------------------------------------------------------- /split_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from config.config import WEBSITES_DATASET_PATH 4 | from shutil import copyfile 5 | 6 | BASE_FOLDER = '/workspace/data_local/websites2' 7 | if not os.path.isdir(BASE_FOLDER): 8 | os.mkdir(BASE_FOLDER) 9 | 10 | TRAIN_FOLDER = '/workspace/data_local/websites2/train' 11 | if not os.path.isdir(TRAIN_FOLDER): 12 | os.mkdir(TRAIN_FOLDER) 13 | 14 | VAL_FOLDER = '/workspace/data_local/websites2/val' 15 | if not os.path.isdir(VAL_FOLDER): 16 | os.mkdir(VAL_FOLDER) 17 | 18 | for folder in os.listdir(WEBSITES_DATASET_PATH): 19 | print(folder) 20 | source_folder = os.path.join(WEBSITES_DATASET_PATH, folder) 21 | 22 | dest_val_folder = os.path.join(VAL_FOLDER, folder) 23 | if not os.path.isdir(dest_val_folder): 24 | os.mkdir(dest_val_folder) 25 | 26 | dest_train_folder = os.path.join(TRAIN_FOLDER, folder) 27 | if not os.path.isdir(dest_train_folder): 28 | os.mkdir(dest_train_folder) 29 | 30 | filelist = os.listdir(source_folder) 31 | random.shuffle(filelist) 32 | for idx in range(200): 33 | src = os.path.join(source_folder, filelist[idx]) 34 | dst = os.path.join(dest_val_folder, filelist[idx]) 35 | copyfile(src, dst) 36 | for idx in range(200,1200): 37 | src = os.path.join(source_folder, filelist[idx]) 38 | dst = os.path.join(dest_train_folder, filelist[idx]) 39 | copyfile(src, dst) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os, sys, time, random 4 | import torch 5 | import json 6 | import numpy as np 7 | 8 | from config.config import MODEL_PATH 9 | 10 | def get_model_path(dataset_name, network_arch, random_seed): 11 | if not os.path.isdir(MODEL_PATH): 12 | os.makedirs(MODEL_PATH) 13 | model_path = os.path.join(MODEL_PATH, "{}_{}_{}".format(dataset_name, network_arch, random_seed)) 14 | if not os.path.isdir(model_path): 15 | os.makedirs(model_path) 16 | return model_path 17 | 18 | def convert_secs2time(epoch_time): 19 | need_hour = int(epoch_time / 3600) 20 | need_mins = int((epoch_time - 3600*need_hour) / 60) 21 | need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) 22 | return need_hour, need_mins, need_secs 23 | 24 | def time_string(): 25 | ISOTIMEFORMAT='%Y-%m-%d %X' 26 | string = '[{}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) 27 | return string 28 | 29 | def print_log(print_string, log): 30 | print("{}".format(print_string)) 31 | log.write('{}\n'.format(print_string)) 32 | log.flush() 33 | 34 | def manipulate_net_architecture(model_arch, net, num_classes): 35 | if model_arch in ["vgg16", "vgg19"]: 36 | num_ftrs = net.classifier[6].in_features 37 | net.classifier[6] = torch.nn.Linear(num_ftrs, num_classes) 38 | elif model_arch in ["resnet18", "resnet50", "resnet101", "resnet152"]: 39 | num_ftrs = net.fc.in_features 40 | net.fc = torch.nn.Linear(num_ftrs, num_classes) 41 | else: 42 | raise ValueError("Network {} not supported".format(model_arch)) 43 | return net 44 | -------------------------------------------------------------------------------- /eyeballer.py: -------------------------------------------------------------------------------- 1 | import os, sys, random, copy, time 2 | import torch 3 | import argparse 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn as nn 6 | from torchvision import transforms 7 | from torchvision import models 8 | import torchvision.datasets as dset 9 | 10 | from utils.utils import get_model_path, print_log, manipulate_net_architecture 11 | from utils.utils import convert_secs2time, time_string 12 | from utils.training import adjust_learning_rate, train_model, validate, save_checkpoint 13 | from utils.training import RecorderMeter, AverageMeter 14 | 15 | from config.config import WEBSITES_DATASET_PATH 16 | 17 | LABELS = ["Homepage", "Login Page", "Not Found", "Old Looking"] 18 | 19 | def parse_arguments(): 20 | parser = argparse.ArgumentParser(description='Eyeballer') 21 | parser.add_argument('--test-dir', type=str, required=True, 22 | help='Folder containing the images to test') 23 | parser.add_argument('--arch', default='vgg16', choices=['vgg16', 'vgg19','resnet18', 'resnet50', 'resnet101', 'resnet152'], 24 | help='Model architecture: (default: vgg16)') 25 | parser.add_argument('--seed', type=int, default=111, 26 | help='Seed used (default: 111)') 27 | parser.add_argument('--batch-size', type=int, default=32, 28 | help="Batch size (default: 32)") 29 | parser.add_argument('--workers', type=int, default=6, 30 | help='Number of data loading workers (default: 6)') 31 | args = parser.parse_args() 32 | 33 | args.use_cuda = torch.cuda.is_available() 34 | 35 | return args 36 | 37 | def main(): 38 | args = parse_arguments() 39 | 40 | random.seed(args.seed) 41 | cudnn.benchmark = True 42 | 43 | model_path = get_model_path('websites', args.arch, args.seed) 44 | 45 | # Data specifications for the webistes dataset 46 | mean = [0., 0., 0.] 47 | std = [1., 1., 1.] 48 | input_size = 224 49 | num_classes = 4 50 | 51 | # Dataset 52 | test_transform = transforms.Compose([ 53 | transforms.Resize(input_size), 54 | transforms.ToTensor(), 55 | transforms.Normalize(mean, std)]) 56 | data_test = dset.ImageFolder(root=args.test_dir, transform=test_transform) 57 | 58 | # Dataloader 59 | data_test_loader = torch.utils.data.DataLoader(data_test, 60 | batch_size=args.batch_size, 61 | shuffle=False, 62 | num_workers=args.workers, 63 | pin_memory=True) 64 | 65 | # Network 66 | if args.arch == "vgg16": 67 | net = models.vgg16(pretrained=True) 68 | elif args.arch == "vgg19": 69 | net = models.vgg19(pretrained=True) 70 | elif args.arch == "resnet18": 71 | net = models.resnet18(pretrained=True) 72 | elif args.arch == "resnet50": 73 | net = models.resnet50(pretrained=True) 74 | elif args.arch == "resnet101": 75 | net = models.resnet101(pretrained=True) 76 | elif args.arch == "resnet152": 77 | net = models.resnet152(pretrained=True) 78 | else: 79 | raise ValueError("Network {} not supported".format(args.arch)) 80 | 81 | if num_classes != 1000: 82 | net = manipulate_net_architecture(model_arch=args.arch, net=net, num_classes=num_classes) 83 | 84 | # Loading the checkpoint 85 | net.load_state_dict(torch.load(os.path.join(model_path, 'checkpoint.pth.tar'))['state_dict']) 86 | net.eval() 87 | 88 | # Cuda 89 | if args.use_cuda: 90 | net.cuda() 91 | 92 | for idx, (img, _) in enumerate(data_test_loader): 93 | if args.use_cuda: 94 | img = img.cuda() 95 | with torch.no_grad(): 96 | pred = torch.argmax(net(img), dim=-1) 97 | 98 | samples = data_test.samples[idx*args.batch_size:(idx+1)*args.batch_size] 99 | for idx2, sample in enumerate(samples): 100 | label_idx = pred[idx2].cpu().detach().numpy() 101 | label = LABELS[label_idx] 102 | print("{} - {} - {}".format(sample[0], label , label_idx)) 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Eyeballer Pytorch version 2 | This is a reimplementation of [Bishop Fox's Eyeballer](https://github.com/BishopFox/eyeballer) in [PyTorch](https://pytorch.org/). The original code was implemented in TF.Keras. Additional to the code this repository also provides pretrained models and a dataset of website screenshots. 3 | 4 | Description from the original repo: 5 | ``` 6 | Eyeballer is meant for large-scope network penetration tests where you need to find "interesting" targets from a huge set of web-based hosts. Go ahead and use your favorite screenshotting tool like normal (EyeWitness or GoWitness) and then run them through Eyeballer to tell you what's likely to contain vulnerabilities, and what isn't. 7 | ``` 8 | 9 | ## Screenshots 10 | | Homepage | Login | 11 | | ------ |:-----:| 12 | | ![Sample HomePage](/images/test/homepage.png) | ![Sample Login Page](/images/test/login.png) | 13 | 14 | | Not Found | Old Looking | 15 | | ------ |:-----:| 16 | | ![Sample Not Found](/images/test/not_found.png) | ![Sample Old Looking](/images/test/old_looking.png) | 17 | 18 | ## Models 19 | 5 pretrained models can be downloaded individually from [here](https://drive.google.com/drive/folders/1LWBEweaf1fM8UD_ZOpXYnlhkIhSFQfcD?usp=sharing). The performances on the validation dataset are reported in the following table: 20 | 21 | | VGG 16 | VGG 19 | ResNet 18 | ResNet 50 | ResNet 152 | 22 | |--------|--------|-----------|-----------|------------| 23 | | 94.25 | 93.125 | 92.625 | 92.75 | 92.625 | 24 | 25 | To set up a model, download the corresponding zip file and extract it into the `./models` folder. 26 | 27 | ## Usage 28 | ### Inference 29 | To test Eyeballer on some test data use the `eyeballer.py` file. The testing images need to be prepared in a separate folder, for example `./folder/test_images/.png`. This folder structure is required, due to the [ImageFolder](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder) dataset used in PyTorch to ease loading the images. As an example we can classify the images provided in this repository with: `python3 --arch vgg16 --test-dir ./images`, which provides the following output in the format `file path - label - label index`: 30 | ``` 31 | ./images/test/homepage.png - Homepage - 0 32 | ./images/test/login.png - Login Page - 1 33 | ./images/test/not_found.png - Not Found - 2 34 | ./images/test/old_looking.png - Old Looking - 3 35 | ``` 36 | 37 | ### Training 38 | If you want to train models by yourself have a look into `train_models.sh` 39 | 40 | ## Dataset 41 | Additional to the code, the dataset containing the training and validation images is provided. The dataset is similar to the one in the original repository containing screenshots of different websites categorized into 4 different classes. The classes are Homepage, Login, Not Found, and Old Looking. The dataset (filename: `websites_dataset.zip`) can be downloaded from [here](https://drive.google.com/drive/folders/1LWBEweaf1fM8UD_ZOpXYnlhkIhSFQfcD?usp=sharing). For each class there are 1000 training and 200 validation images. Download the dataset, unzip it and set the path in `config/config.py`. 42 | 43 | ### Dataset generation 44 | I captured the screenshots with [gowitness](https://github.com/sensepost/gowitness). I used a list of the top websites, which I found on Pastebin. So assuming you have a file with a few hostnames, for example: `./websites.txt` you can capture screenshots of the homepage following: 45 | ``` 46 | mkdir -p ./screenshots/homepage/ 47 | cat websites.txt | gowitness file -f - -P ./screenshots/homepage/ --no-http 48 | ``` 49 | To capture screenshots for login pages I append `/login` to the host. 50 | ``` 51 | mkdir -p ./screenshots/login/ 52 | cat websites.txt | sed 's/$/\/login/g' | gowitness file -f - -P ./screenshots/login/ --no-http 53 | ``` 54 | To capture screenshots for not found pages I appended `/thissitedoesnotexist` to the host. 55 | ``` 56 | mkdir -p ./screenshots/not_found/ 57 | cat websites.txt | sed 's/$/\/thissitedoesnotexist/g' | gowitness file -f - -P ./screenshots/not_found/ --no-http 58 | ``` 59 | For the old-looking class, I queried the [wayback machine](https://web.archive.org/). 60 | 61 | Afterward, I manually sorted out the results that were not good and split the data into the training and validation set. To split the data the script `split_data.py` might be helpful. 62 | 63 | ### How to add a custom class 64 | To add your custom data, you can capture screenshots as described above. Then you should add a new folder to the `/path/to/websites/train` and `/path/to/websites/val`. Then you can capture screenshots and split them into the training and validation set. Since we have in total 5 classes now, we need to change the parameter `num_classes` in `train_model.py` and `eyeballer.py` to 5 (`num_classes = 5`). You are now ready to train a model on the extended dataset as before. 65 | 66 | #### Example 67 | Assuming we want to calssify screenshots of APIs. First lets add the two folders: 68 | ``` 69 | mdir -p /path/to/websites/train/api 70 | mdir -p /path/to/websites/val/api 71 | ``` 72 | Now we can generate the screenshots: 73 | ``` 74 | mkdir -p ./screenshots/api/ 75 | cat websites.txt | sed 's/$/\/api/g' | gowitness file -f - -P ./screenshots/api/ --no-http 76 | ``` 77 | After we sorted out all wrong results and found 1200 valid screenshots we can put 1000 screenshot into the `/path/to/websites/train/api` folder and 200 into the `/path/to/websites/val/api` folder. Setting `num_classes = 5` as described above, we can now train a model on the extended dataset. 78 | 79 | ### Qualitative Results 80 | To check some qualitative results have a look at the jupyter notebook `./qualitative_results.ipynb`. 81 | 82 | ### Docker 83 | I prefer to use docker to set up my deep learning environments. For this project, I used `1.0.1-cuda10.0-cudnn7-devel` as the base from [Pytorch`s Docker Hub](https://hub.docker.com/r/pytorch/pytorch/tags?page=1&ordering=last_updated). 84 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import os, sys, random, copy, time 2 | import torch 3 | import argparse 4 | import torch.backends.cudnn as cudnn 5 | import torch.nn as nn 6 | from torchvision import transforms 7 | from torchvision import models 8 | import torchvision.datasets as dset 9 | 10 | from utils.utils import get_model_path, print_log, manipulate_net_architecture 11 | from utils.utils import convert_secs2time, time_string 12 | from utils.training import adjust_learning_rate, train_model, validate, save_checkpoint 13 | from utils.training import RecorderMeter, AverageMeter 14 | 15 | from config.config import WEBSITES_DATASET_PATH 16 | 17 | def parse_arguments(): 18 | parser = argparse.ArgumentParser(description='Train a Network') 19 | # Data and Model options 20 | parser.add_argument('--dataset', default='websites', choices=['websites'], 21 | help='Trainig dataset (default: websites)') 22 | parser.add_argument('--arch', default='vgg16', choices=['vgg16', 'vgg19','resnet18', 'resnet50', 'resnet101', 'resnet152'], 23 | help='Model architecture: (default: vgg16)') 24 | parser.add_argument('--seed', type=int, default=111, 25 | help='Seed used (default: 111)') 26 | # Optimization options 27 | parser.add_argument('--loss-function', default='ce', choices=['ce'], 28 | help='Loss function (default: ce)') 29 | parser.add_argument('--batch-size', type=int, default=32, 30 | help='Batch size (default: 32)') 31 | parser.add_argument('--learning-rate', type=float, default=0.001, 32 | help='Learning Rate (default: 0.001)') 33 | parser.add_argument('--epochs', type=int, default=30, 34 | help='Number of epochs to train (dfault: 30)') 35 | parser.add_argument('--schedule', type=int, nargs='+', default=[], 36 | help='Decrease learning rate at these epochs (default: [])') 37 | parser.add_argument('--gammas', type=float, nargs='+', default=[], 38 | help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule (default: [])') 39 | parser.add_argument('--print-freq', default=200, type=int, metavar='N', 40 | help='print frequency (default: 200)') 41 | parser.add_argument('--workers', type=int, default=6, 42 | help='Number of data loading workers (default: 6)') 43 | args = parser.parse_args() 44 | 45 | args.use_cuda = torch.cuda.is_available() 46 | 47 | return args 48 | 49 | def main(): 50 | args = parse_arguments() 51 | 52 | random.seed(args.seed) 53 | torch.manual_seed(args.seed) 54 | if args.use_cuda: 55 | torch.cuda.manual_seed_all(args.seed) 56 | cudnn.benchmark = True 57 | 58 | model_path = get_model_path(args.dataset, args.arch, args.seed) 59 | 60 | # Init logger 61 | log_file_name = os.path.join(model_path, 'log.txt') 62 | print("Log file: {}".format(log_file_name)) 63 | log = open(log_file_name, 'w') 64 | print_log('model path : {}'.format(model_path), log) 65 | state = {k: v for k, v in args._get_kwargs()} 66 | for key, value in state.items(): 67 | print_log("{} : {}".format(key, value), log) 68 | print_log("Random Seed: {}".format(args.seed), log) 69 | print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log) 70 | print_log("Torch version : {}".format(torch.__version__), log) 71 | print_log("Cudnn version : {}".format(torch.backends.cudnn.version()), log) 72 | 73 | # Data specifications for the webistes dataset 74 | mean = [0., 0., 0.] 75 | std = [1., 1., 1.] 76 | input_size = 224 77 | num_classes = 4 78 | 79 | # Dataset 80 | traindir = os.path.join(WEBSITES_DATASET_PATH, 'train') 81 | valdir = os.path.join(WEBSITES_DATASET_PATH, 'val') 82 | 83 | train_transform = transforms.Compose([ 84 | transforms.Resize(input_size), 85 | transforms.ToTensor(), 86 | transforms.Normalize(mean, std)]) 87 | 88 | test_transform = transforms.Compose([ 89 | transforms.Resize(input_size), 90 | transforms.ToTensor(), 91 | transforms.Normalize(mean, std)]) 92 | 93 | data_train = dset.ImageFolder(root=traindir, transform=train_transform) 94 | data_test = dset.ImageFolder(root=valdir, transform=test_transform) 95 | 96 | # Dataloader 97 | data_train_loader = torch.utils.data.DataLoader(data_train, 98 | batch_size=args.batch_size, 99 | shuffle=True, 100 | num_workers=args.workers, 101 | pin_memory=True) 102 | data_test_loader = torch.utils.data.DataLoader(data_test, 103 | batch_size=args.batch_size, 104 | shuffle=False, 105 | num_workers=args.workers, 106 | pin_memory=True) 107 | 108 | # Network 109 | if args.arch == "vgg16": 110 | net = models.vgg16(pretrained=True) 111 | elif args.arch == "vgg19": 112 | net = models.vgg19(pretrained=True) 113 | elif args.arch == "resnet18": 114 | net = models.resnet18(pretrained=True) 115 | elif args.arch == "resnet50": 116 | net = models.resnet50(pretrained=True) 117 | elif args.arch == "resnet101": 118 | net = models.resnet101(pretrained=True) 119 | elif args.arch == "resnet152": 120 | net = models.resnet152(pretrained=True) 121 | else: 122 | raise ValueError("Network {} not supported".format(args.arch)) 123 | 124 | if num_classes != 1000: 125 | net = manipulate_net_architecture(model_arch=args.arch, net=net, num_classes=num_classes) 126 | 127 | # Loss function 128 | if args.loss_function == "ce": 129 | criterion = torch.nn.CrossEntropyLoss() 130 | else: 131 | raise ValueError 132 | 133 | # Cuda 134 | if args.use_cuda: 135 | net.cuda() 136 | criterion.cuda() 137 | 138 | # Optimizer 139 | momentum = 0.9 140 | decay = 5e-4 141 | optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=momentum, weight_decay=decay, nesterov=True) 142 | 143 | recorder = RecorderMeter(args.epochs) 144 | start_time = time.time() 145 | epoch_time = AverageMeter() 146 | 147 | # Main loop 148 | for epoch in range(args.epochs): 149 | current_learning_rate = adjust_learning_rate(args.learning_rate, momentum, optimizer, epoch, args.gammas, args.schedule) 150 | 151 | need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch)) 152 | need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) 153 | 154 | print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \ 155 | + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log) 156 | 157 | # train for one epoch 158 | train_acc, train_los = train_model(data_loader=data_train_loader, model=net, criterion=criterion, optimizer=optimizer, epoch=epoch, log=log, 159 | print_freq=200, use_cuda=True) 160 | 161 | # evaluate on test set 162 | print_log("Validation on test dataset:", log) 163 | val_acc, val_loss = validate(data_test_loader, net, criterion, log=log, use_cuda=args.use_cuda) 164 | recorder.update(epoch, train_los, train_acc, val_loss, val_acc) 165 | 166 | save_checkpoint({ 167 | 'epoch' : epoch + 1, 168 | 'arch' : args.arch, 169 | 'state_dict' : net.state_dict(), 170 | 'optimizer' : optimizer.state_dict(), 171 | 'args' : copy.deepcopy(args), 172 | }, model_path, 'checkpoint.pth.tar') 173 | 174 | # measure elapsed time 175 | epoch_time.update(time.time() - start_time) 176 | start_time = time.time() 177 | recorder.plot_curve(os.path.join(model_path, 'curve.png') ) 178 | 179 | log.close() 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /utils/training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, shutil, time 3 | import itertools 4 | import torch 5 | 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | 9 | from utils.utils import time_string, print_log 10 | 11 | 12 | def adjust_learning_rate(init_lr, init_momentum, optimizer, epoch, gammas, schedule): 13 | """Sets the learning rate to the initial LR decayed by xx every yy epochs""" 14 | lr = init_lr 15 | momentum = init_momentum 16 | assert len(gammas) == len(schedule), "length of gammas and schedule should be equal" 17 | for (gamma, step) in zip(gammas, schedule): 18 | if (epoch >= step): 19 | lr = lr * gamma 20 | momentum = momentum * gamma 21 | else: 22 | break 23 | for param_group in optimizer.param_groups: 24 | param_group['lr'] = lr 25 | # param_group['momentum'] = momentum 26 | return lr 27 | 28 | 29 | def train_model(data_loader, model, criterion, optimizer, epoch, log, 30 | print_freq=200, use_cuda=True): 31 | # train function (forward, backward, update) 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | losses = AverageMeter() 35 | top1 = AverageMeter() 36 | # switch to train mode 37 | model.train() 38 | 39 | end = time.time() 40 | for iteration, (input, target) in enumerate(data_loader): 41 | # measure data loading time 42 | data_time.update(time.time() - end) 43 | 44 | if use_cuda: 45 | target = target.cuda() 46 | input = input.cuda() 47 | 48 | # compute output 49 | output = model(input) 50 | loss = criterion(output, target) 51 | 52 | # measure accuracy and record loss 53 | if len(target.shape) > 1: 54 | target = torch.argmax(target, dim=-1) 55 | prec1, = accuracy(output.data, target, topk=(1,)) 56 | losses.update(loss.item(), input.size(0)) 57 | top1.update(prec1.item(), input.size(0)) 58 | 59 | # compute gradient and do SGD step 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | 64 | # measure elapsed time 65 | batch_time.update(time.time() - end) 66 | end = time.time() 67 | 68 | if iteration % print_freq == 0: 69 | print_log(' Epoch: [{:03d}][{:03d}/{:03d}] ' 70 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 71 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) ' 72 | 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 73 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f}) '.format( 74 | epoch, iteration, len(data_loader), batch_time=batch_time, 75 | data_time=data_time, loss=losses, top1=top1) + time_string(), log) 76 | 77 | print_log(' **Train** Prec@1 {top1.avg:.3f} Error@1 {error1:.3f}'.format(top1=top1, error1=100-top1.avg), log) 78 | return top1.avg, losses.avg 79 | 80 | 81 | def validate(val_loader, model, criterion, log=None, use_cuda=True): 82 | losses = AverageMeter() 83 | top1 = AverageMeter() 84 | 85 | # switch to evaluate mode 86 | model.eval() 87 | 88 | for i, (input, target) in enumerate(val_loader): 89 | 90 | if use_cuda: 91 | target = target.cuda() 92 | input = input.cuda() 93 | 94 | with torch.no_grad(): 95 | # compute output 96 | output = model(input) 97 | loss = criterion(output, target) 98 | 99 | # measure accuracy and record loss 100 | prec1, = accuracy(output.data, target, topk=(1,)) 101 | losses.update(loss.item(), input.size(0)) 102 | top1.update(prec1.item(), input.size(0)) 103 | 104 | if log: 105 | print_log(' **Test** Loss {losses.avg:.3f} Prec@1 {top1.avg:.3f} Error@1 {error1:.3f}'.format(losses=losses, top1=top1, error1=100-top1.avg), log) 106 | else: 107 | print(' **Test** Loss {losses.avg:.3f} Prec@1 {top1.avg:.3f} Error@1 {error1:.3f}'.format(losses=losses, top1=top1, error1=100-top1.avg)) 108 | 109 | return top1.avg, losses.avg 110 | 111 | 112 | def save_checkpoint(state, save_path, filename): 113 | filename = os.path.join(save_path, filename) 114 | torch.save(state, filename) 115 | 116 | 117 | def accuracy(output, target, topk=(1,)): 118 | """Computes the precision@k for the specified values of k""" 119 | with torch.no_grad(): 120 | maxk = max(topk) 121 | batch_size = target.size(0) 122 | 123 | _, pred = output.topk(maxk, 1, True, True) 124 | pred = pred.t() 125 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 126 | 127 | res = [] 128 | for k in topk: 129 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 130 | res.append(correct_k.mul_(100.0 / batch_size)) 131 | return res 132 | 133 | 134 | class AverageMeter(object): 135 | """Computes and stores the average and current value""" 136 | def __init__(self): 137 | self.reset() 138 | 139 | def reset(self): 140 | self.val = 0 141 | self.avg = 0 142 | self.sum = 0 143 | self.count = 0 144 | 145 | def update(self, val, n=1): 146 | self.val = val 147 | self.sum += val * n 148 | self.count += n 149 | self.avg = self.sum / self.count 150 | 151 | 152 | class RecorderMeter(object): 153 | """Computes and stores the minimum loss value and its epoch index""" 154 | def __init__(self, total_epoch): 155 | self.reset(total_epoch) 156 | 157 | def reset(self, total_epoch): 158 | assert total_epoch > 0 159 | self.total_epoch = total_epoch 160 | self.current_epoch = 0 161 | self.epoch_losses = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 162 | self.epoch_losses = self.epoch_losses - 1 163 | 164 | self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] 165 | self.epoch_accuracy= self.epoch_accuracy 166 | 167 | def update(self, idx, train_loss, train_acc, val_loss, val_acc): 168 | assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx) 169 | self.epoch_losses [idx, 0] = train_loss 170 | self.epoch_losses [idx, 1] = val_loss 171 | self.epoch_accuracy[idx, 0] = train_acc 172 | self.epoch_accuracy[idx, 1] = val_acc 173 | self.current_epoch = idx + 1 174 | return self.max_accuracy(False) == val_acc 175 | 176 | def max_accuracy(self, istrain): 177 | if self.current_epoch <= 0: return 0 178 | if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max() 179 | else: return self.epoch_accuracy[:self.current_epoch, 1].max() 180 | 181 | def plot_curve(self, save_path): 182 | title = 'the accuracy/loss curve of train/val' 183 | dpi = 80 184 | width, height = 1200, 800 185 | legend_fontsize = 10 186 | scale_distance = 48.8 187 | figsize = width / float(dpi), height / float(dpi) 188 | 189 | fig = plt.figure(figsize=figsize) 190 | x_axis = np.array([i for i in range(self.total_epoch)]) # epochs 191 | y_axis = np.zeros(self.total_epoch) 192 | 193 | plt.xlim(0, self.total_epoch) 194 | plt.ylim(0, 100) 195 | interval_y = 5 196 | interval_x = 5 197 | plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x)) 198 | plt.yticks(np.arange(0, 100 + interval_y, interval_y)) 199 | plt.grid() 200 | plt.title(title, fontsize=20) 201 | plt.xlabel('the training epoch', fontsize=16) 202 | plt.ylabel('accuracy', fontsize=16) 203 | 204 | y_axis[:] = self.epoch_accuracy[:, 0] 205 | plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2) 206 | plt.legend(loc=4, fontsize=legend_fontsize) 207 | 208 | y_axis[:] = self.epoch_accuracy[:, 1] 209 | plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2) 210 | plt.legend(loc=4, fontsize=legend_fontsize) 211 | 212 | 213 | y_axis[:] = self.epoch_losses[:, 0] 214 | plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2) 215 | plt.legend(loc=4, fontsize=legend_fontsize) 216 | 217 | y_axis[:] = self.epoch_losses[:, 1] 218 | plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2) 219 | plt.legend(loc=4, fontsize=legend_fontsize) 220 | 221 | if save_path is not None: 222 | fig.savefig(save_path, dpi=dpi, bbox_inches='tight') 223 | print ('---- save figure {} into {}'.format(title, save_path)) 224 | plt.close(fig) 225 | --------------------------------------------------------------------------------