├── soft.png ├── requirements.txt ├── datasets ├── features.py ├── images.py └── vqa_dataset.py ├── config └── default.yaml ├── utils.py ├── preprocessing ├── create_vocabs.py ├── preprocessing_utils.py └── image_features_extraction.py ├── README.md ├── predict.py ├── useful notebooks ├── Construct prediction file from val logs.ipynb ├── compute_acc.ipynb └── Plot accuracy train - val.ipynb ├── models.py └── train.py /soft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DenisDsh/VizWiz-VQA-PyTorch/HEAD/soft.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.8.0 2 | hdijupyterutils==0.12.5 3 | html5lib==1.0.1 4 | ipykernel==4.8.2 5 | ipython==6.4.0 6 | ipython-genutils==0.2.0 7 | jsonschema==2.6.0 8 | jupyter-client==5.2.3 9 | jupyter-core==4.4.0 10 | nltk==3.3 11 | notebook==5.5.0 12 | numpy==1.14.5 13 | pandas==0.22.0 14 | prompt-toolkit==1.0.15 15 | protobuf==3.5.2 16 | pycparser==2.18 17 | pyparsing==2.2.0 18 | python-dateutil==2.7.3 19 | PyYAML==3.12 20 | scikit-learn==0.19.1 21 | scipy==1.1.0 22 | torch==0.4.0 23 | torchvision==0.2.1 24 | tqdm==4.23.4 25 | 26 | 27 | -------------------------------------------------------------------------------- /datasets/features.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import h5py 4 | import torch 5 | import torch.utils.data as data 6 | 7 | 8 | class FeaturesDataset(data.Dataset): 9 | 10 | def __init__(self, features_path, mode): 11 | self.path_hdf5 = features_path 12 | 13 | assert os.path.isfile(self.path_hdf5), \ 14 | 'File not found in {}, you must extract the features first with images_preprocessing.py'.format( 15 | self.path_hdf5) 16 | 17 | self.hdf5_file = h5py.File(self.path_hdf5, 'r') 18 | self.dataset_features = self.hdf5_file[mode] # noatt or att (attention) 19 | 20 | def __getitem__(self, index): 21 | return torch.from_numpy(self.dataset_features[index].astype('float32')) 22 | 23 | def __len__(self): 24 | return self.dataset_features.shape[0] 25 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | logs: 2 | dir_logs: logs/vizwiz/ 3 | annotations: 4 | dir: ../data_vizwiz/Annotations 5 | top_ans: 3000 6 | max_length: 26 7 | min_count_word: 0 8 | path_vocabs: ./prepro_data/vocabs.json 9 | images: 10 | dir: ../data_vizwiz/Images 11 | arch: ResNet152 12 | mode: att 13 | img_size: 448 14 | preprocess_batch_size: 4 15 | preprocess_data_workers: 4 16 | path_features: ./prepro_data/resnet14x14.h5 17 | model: 18 | # Could be added new architectures and hyper-parameters like activations etc 19 | pretrained_model: #./logs/... # leave empty if no pretrained model is available 20 | seq2vec: 21 | dropout: 0.25 22 | emb_size: 300 23 | pooling: 24 | dim_v: 2048 25 | dim_q: 1024 26 | dim_h: 1024 27 | dropout_v: 0.5 28 | dropout_q: 0.5 29 | classifier: 30 | dropout: 0.5 31 | attention: 32 | glimpses: 2 33 | mid_features: 512 34 | dropout: 0.5 35 | training: 36 | train_split: train 37 | lr: 0.001 38 | batch_size: 128 39 | epochs: 50 40 | data_workers: 4 41 | 42 | prediction: 43 | model_path: ./logs/vizwiz/2018-09-20_13:04:05/final_log.pth 44 | split: test 45 | submission_file: ./predictions/default_pred.json 46 | 47 | -------------------------------------------------------------------------------- /datasets/images.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | 13 | def is_image_file(filename): 14 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 15 | 16 | 17 | class ImageDataset(data.Dataset): 18 | 19 | def __init__(self, path, transform=None): 20 | self.path = path 21 | self.transform = transform 22 | 23 | # Load the paths to the images available in the folder 24 | self.image_names = self._load_img_paths() 25 | 26 | if len(self.image_names) == 0: 27 | raise (RuntimeError("Found 0 images in " + path + "\n" 28 | "Supported image extensions are: " + ",".join( 29 | IMG_EXTENSIONS))) 30 | else: 31 | print('Found {} images in {}'.format(len(self), self.path)) 32 | 33 | def __getitem__(self, index): 34 | item = {} 35 | item['name'] = self.image_names[index] 36 | item['path'] = os.path.join(self.path, item['name']) 37 | 38 | # Use PIL to load the image 39 | item['visual'] = Image.open(item['path']).convert('RGB') 40 | if self.transform is not None: 41 | item['visual'] = self.transform(item['visual']) 42 | 43 | return item 44 | 45 | def __len__(self): 46 | return len(self.image_names) 47 | 48 | def _load_img_paths(self): 49 | images = [] 50 | for name in os.listdir(self.path): 51 | if is_image_file(name): 52 | images.append(name) 53 | return images 54 | 55 | 56 | def get_transform(img_size): 57 | return transforms.Compose([ 58 | transforms.Resize(img_size), 59 | transforms.CenterCrop(img_size), 60 | transforms.ToTensor(), 61 | # TODO : Compute mean and std of VizWiz 62 | # ImageNet normalization 63 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 64 | std=[0.229, 0.224, 0.225]), 65 | ]) 66 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def vqa_accuracy(predicted, true): 2 | """ Approximation of VQA accuracy metric """ 3 | _, predicted_index = predicted.max(dim=1, keepdim=True) 4 | agreeing = true.gather(dim=1, index=predicted_index) 5 | return (agreeing * 0.33333).clamp(max=1) # * 0.33333 is a good approximation of the VQA metric 6 | 7 | 8 | class Tracker: 9 | 10 | def __init__(self): 11 | self.data = {} 12 | 13 | def track(self, name, *monitors): 14 | l = Tracker.ListStorage(monitors) 15 | self.data.setdefault(name, []).append(l) 16 | return l 17 | 18 | def to_dict(self): 19 | return {k: list(map(list, v)) for k, v in self.data.items()} 20 | 21 | class ListStorage: 22 | def __init__(self, monitors=[]): 23 | self.data = [] 24 | self.monitors = monitors 25 | for monitor in self.monitors: 26 | setattr(self, monitor.name, monitor) 27 | 28 | def append(self, item): 29 | for monitor in self.monitors: 30 | monitor.update(item) 31 | self.data.append(item) 32 | 33 | def __iter__(self): 34 | return iter(self.data) 35 | 36 | class MeanMonitor: 37 | name = 'mean' 38 | 39 | def __init__(self): 40 | self.n = 0 41 | self.total = 0 42 | 43 | def update(self, value): 44 | self.total += value 45 | self.n += 1 46 | 47 | @property 48 | def value(self): 49 | return self.total / self.n 50 | 51 | class MovingMeanMonitor: 52 | name = 'mean' 53 | 54 | def __init__(self, momentum=0.9): 55 | self.momentum = momentum 56 | self.first = True 57 | self.value = None 58 | 59 | def update(self, value): 60 | if self.first: 61 | self.value = value 62 | self.first = False 63 | else: 64 | m = self.momentum 65 | self.value = m * self.value + (1 - m) * value 66 | 67 | 68 | def get_id_from_name(name): 69 | import re 70 | 71 | n = re.search('VizWiz_(.+?)_', name) 72 | if n: 73 | split = n.group(1) 74 | 75 | m = re.search(('VizWiz_%s_(.+?).jpg' % split), name) 76 | if m: 77 | found = m.group(1) 78 | 79 | return int(found) 80 | -------------------------------------------------------------------------------- /preprocessing/create_vocabs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | from collections import Counter 6 | from itertools import takewhile 7 | from pprint import pprint 8 | 9 | import yaml 10 | 11 | from preprocessing.preprocessing_utils import prepare_questions, prepare_answers 12 | 13 | 14 | def create_question_vocab(questions, min_count=0): 15 | """ 16 | Extract vocabulary used to tokenize and encode questions. 17 | """ 18 | words = itertools.chain.from_iterable([q for q in questions]) # chain('ABC', 'DEF') --> A B C D E F 19 | counter = Counter(words) 20 | 21 | counted_words = counter.most_common() 22 | # select only the words appearing at least min_count 23 | selected_words = list(takewhile(lambda x: x[1] >= min_count, counted_words)) 24 | 25 | vocab = {t[0]: i for i, t in enumerate(selected_words, start=1)} 26 | 27 | return vocab 28 | 29 | 30 | def create_answer_vocab(annotations, top_k): 31 | answers = itertools.chain.from_iterable(prepare_answers(annotations)) 32 | 33 | counter = Counter(answers) 34 | counted_ans = counter.most_common(top_k) 35 | # start from labels from 0 36 | vocab = {t[0]: i for i, t in enumerate(counted_ans, start=0)} 37 | 38 | return vocab 39 | 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--path_config', default='config/default.yaml', type=str, 43 | help='path to a yaml config file') 44 | 45 | 46 | def main(): 47 | # Load and visualize config from yaml file 48 | global args 49 | args = parser.parse_args() 50 | 51 | if args.path_config is not None: 52 | with open(args.path_config, 'r') as handle: 53 | config = yaml.load(handle) 54 | 55 | pprint(config) 56 | 57 | # Load annotations 58 | dir_path = config['annotations']['dir'] 59 | 60 | # vocabs are created based on train (trainval) split only 61 | train_path = os.path.join(dir_path, config['training']['train_split'] + '.json') 62 | with open(train_path, 'r') as fd: 63 | train_ann = json.load(fd) 64 | 65 | questions = prepare_questions(train_ann) 66 | 67 | question_vocab = create_question_vocab(questions, config['annotations']['min_count_word']) 68 | answer_vocab = create_answer_vocab(train_ann, config['annotations']['top_ans']) 69 | 70 | # Save pre-processing vocabs 71 | vocabs = { 72 | 'question': question_vocab, 73 | 'answer': answer_vocab, 74 | } 75 | 76 | with open(config['annotations']['path_vocabs'], 'w') as fd: 77 | json.dump(vocabs, fd) 78 | 79 | print("vocabs saved in {}".format(config['annotations']['path_vocabs'])) 80 | 81 | 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch 4 | 5 | 6 | def prepare_questions(annotations): 7 | """ Filter, Normalize and Tokenize question. """ 8 | 9 | prepared = [] 10 | questions = [q['question'] for q in annotations] 11 | 12 | for question in questions: 13 | # lower case 14 | question = question.lower() 15 | 16 | # define desired replacements here 17 | punctuation_dict = {'.': ' ', "'": '', '?': ' ', '_': ' ', '-': ' ', '/': ' ', ',': ' '} 18 | conversational_dict = {"thank you": '', "thanks": '', "thank": '', "please": '', "hello": '', 19 | "hi ": ' ', "hey ": ' ', "good morning": '', "good afternoon": '', "have a nice day": '', 20 | "okay": '', "goodbye": ''} 21 | 22 | rep = punctuation_dict 23 | rep.update(conversational_dict) 24 | 25 | # use these three lines to do the replacement 26 | rep = dict((re.escape(k), v) for k, v in rep.items()) 27 | pattern = re.compile("|".join(rep.keys())) 28 | question = pattern.sub(lambda m: rep[re.escape(m.group(0))], question) 29 | 30 | # sentence to list 31 | question = question.split(' ') 32 | 33 | # remove empty strings 34 | question = list(filter(None, question)) 35 | 36 | prepared.append(question) 37 | 38 | return prepared 39 | 40 | 41 | def prepare_answers(annotations): 42 | answers = [[a['answer'] for a in ans_dict['answers']] for ans_dict in annotations] 43 | prepared = [] 44 | 45 | for sample_answers in answers: 46 | prepared_sample_answers = [] 47 | for answer in sample_answers: 48 | # lower case 49 | answer = answer.lower() 50 | 51 | # define desired replacements here 52 | punctuation_dict = {'.': ' ', "'": '', '?': ' ', '_': ' ', '-': ' ', '/': ' ', ',': ' '} 53 | 54 | rep = punctuation_dict 55 | rep = dict((re.escape(k), v) for k, v in rep.items()) 56 | pattern = re.compile("|".join(rep.keys())) 57 | answer = pattern.sub(lambda m: rep[re.escape(m.group(0))], answer) 58 | prepared_sample_answers.append(answer) 59 | 60 | prepared.append(prepared_sample_answers) 61 | return prepared 62 | 63 | 64 | def encode_question(question, token_to_index, max_length): 65 | question_vec = torch.zeros(max_length).long() 66 | length = min(len(question), max_length) 67 | for i in range(length): 68 | token = question[i] 69 | index = token_to_index.get(token, 0) 70 | question_vec[i] = index 71 | # empty encoded questions are a problem when packed, 72 | # if we set min length 1 we feed a 0 token to the RNN 73 | # that is not a problem since the token 0 does not represent a word 74 | return question_vec, max(length, 1) 75 | 76 | 77 | def encode_answers(answers, answer_to_index): 78 | answer_vec = torch.zeros(len(answer_to_index)) 79 | for answer in answers: 80 | index = answer_to_index.get(answer) 81 | if index is not None: 82 | answer_vec[index] += 1 83 | return answer_vec 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VizWiz Challenge: Visual Question Answering Implementation in PyTorch 2 | 3 | PyTorch VQA implementation that achieved top performances in the (ECCV18) [VizWiz Grand Challenge: Answering Visual Questions from Blind People][0]. 4 | The code can be easily adapted for training on VQA 1.0/2.0 or any other dataset. 5 | 6 | The implemented architecture is a variant of the VQA model described in [Kazemi et al. (2017). Show, Ask, Attend, and Answer: A Strong Baseline For Visual Question Answering][1]. 7 | Visual feature are extracted using a pretrained (on ImageNet) ResNet-152. Input Questions are tokenized, embedded and encoded with an LSTM. 8 | Image features and encoded questions are combined and used to compute multiple attention maps over image features. The attended image features 9 | and the encoded questions are concatenated and finally fed to a 2-layer classifier that outputs probabilities over the answers (classes). 10 | 11 | More information about the attention module can be found in [Yang et al. (2015). Stacked Attention Networks for Image Question Answering][2]. 12 | 13 | 14 | In order to consider all 10 answers given by the annotators we exploit a [Soft Cross-Entropy loss][3] : 15 | a weighted average of the negative log-probabilities of each unique ground-truth answer. 16 | This loss function better aligns to the [VQA evaluation metric][4] used to evaluate the challenge submissions. 17 | 18 | ![Soft cross-entropy loss](./soft.png) 19 | 20 | 21 | #### Experimental Results 22 | 23 | | method | accuracy | 24 | |--------------|----------| 25 | | [VizWiz Paper][0] | 0.475 | 26 | | **Ours** | **0.516** | 27 | 28 | 29 | ## Training and Evaluation 30 | - Install requirements: 31 | ``` 32 | conda create --name viz_env python=3.6 33 | source activate viz_env 34 | pip install -r requirements.txt 35 | ``` 36 | 37 | - Download and extract the [VizWiz dataset][0] 38 | 39 | ``` 40 | wget https://ivc.ischool.utexas.edu/VizWiz/data/VizWiz_data_ver1.tar.gz 41 | tar -xzf VizWiz_data_ver1.tar.gz 42 | ``` 43 | After unpacking the dataset, the Image folder will contain files with prefix `._VizWiz`. 44 | Those files should be removed before extracting the image features: 45 | ``` 46 | rm ._* 47 | ``` 48 | 49 | - Set the paths to the downloaded data in the yaml configuration file `config/default.yaml`. 50 | 51 | - Extract features from input images (~26GB) 52 | The script will extract two types of features from the images: 53 | - **No Attention**: 2048 feature vectors consisting of the activations of the penultimate layer of pre-trained ResNet-152. 54 | - **Attention**: 2048x14x14 feature tensors consisting of the activations of the last pooling layer of the ResNet-152. 55 | 56 | Our model will use only the "Attention" features. However it is possible to extend the implementation designing new models that do not use attention mechanisms. 57 | 58 | ``` 59 | python ./preprocessing/image_features_extraction.py 60 | ``` 61 | 62 | - Construct dictionaries that will be used during training to encode words and answers: 63 | 64 | ``` 65 | python ./preprocessing/create_vocabs.py 66 | ``` 67 | 68 | - Start training: 69 | ``` 70 | python train.py 71 | ``` 72 | 73 | During training, the models with the highest validation accuracy and with the lowest validation loss are saved. 74 | The path of the log directory is specified in the yaml configuration file `config/default.yaml`. 75 | 76 | - Construct prediction file for the test split: 77 | ``` 78 | python predict.py 79 | ``` 80 | 81 | 82 | ## Acknowledgment 83 | 84 | 85 | 86 | - https://github.com/liqing-ustc/VizWiz_LSTM_CNN_Attention/ 87 | - https://github.com/Cadene/vqa.pytorch 88 | - https://github.com/GT-Vision-Lab/VQA_LSTM_CNN 89 | - https://github.com/Cyanogenoid/pytorch-vqa 90 | 91 | 92 | 93 | [0]: http://vizwiz.org/data/ 94 | [1]: https://arxiv.org/abs/1704.03162 95 | [2]: https://arxiv.org/pdf/1511.02274 96 | [3]: https://arxiv.org/abs/1708.00584 97 | [4]: https://arxiv.org/pdf/1505.00468v6.pdf 98 | 99 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os.path 4 | 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | import yaml 10 | from torch.autograd import Variable 11 | from tqdm import tqdm 12 | 13 | import models 14 | from datasets import vqa_dataset 15 | 16 | 17 | def predict_answers(model, loader, split): 18 | model.eval() 19 | predicted = [] 20 | samples_ids = [] 21 | 22 | tq = tqdm(loader) 23 | 24 | print("Evaluating...\n") 25 | 26 | for item in tq: 27 | v = item['visual'] 28 | q = item['question'] 29 | sample_id = item['sample_id'] 30 | q_length = item['q_length'] 31 | 32 | v = Variable(v.cuda(async=True)) 33 | q = Variable(q.cuda(async=True)) 34 | q_length = Variable(q_length.cuda(async=True)) 35 | 36 | out = model(v, q, q_length) 37 | 38 | _, answer = out.data.cpu().max(dim=1) 39 | 40 | predicted.append(answer.view(-1)) 41 | samples_ids.append(sample_id.view(-1).clone()) 42 | 43 | predicted = list(torch.cat(predicted, dim=0)) 44 | samples_ids = list(torch.cat(samples_ids, dim=0)) 45 | 46 | print("Evaluation completed") 47 | 48 | return predicted, samples_ids 49 | 50 | 51 | def create_submission(input_annotations, predicted, samples_ids, vocabs): 52 | answers = torch.FloatTensor(predicted) 53 | indexes = torch.IntTensor(samples_ids) 54 | ans_to_id = vocabs['answer'] 55 | # need to translate answers ids into answers 56 | id_to_ans = {idx: ans for ans, idx in ans_to_id.items()} 57 | # sort based on index the predictions 58 | sort_index = np.argsort(indexes) 59 | sorted_answers = np.array(answers, dtype='int_')[sort_index] 60 | 61 | real_answers = [] 62 | for ans_id in sorted_answers: 63 | ans = id_to_ans[ans_id] 64 | real_answers.append(ans) 65 | 66 | # Integrity check 67 | assert len(input_annotations) == len(real_answers) 68 | 69 | submission = [] 70 | for i in range(len(input_annotations)): 71 | pred = {} 72 | pred['image'] = input_annotations[i]['image'] 73 | pred['answer'] = real_answers[i] 74 | submission.append(pred) 75 | 76 | return submission 77 | 78 | 79 | def main(): 80 | # Load config yaml file 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--path_config', default='config/default.yaml', type=str, 83 | help='path to a yaml config file') 84 | args = parser.parse_args() 85 | 86 | if args.path_config is not None: 87 | with open(args.path_config, 'r') as handle: 88 | config = yaml.load(handle) 89 | 90 | cudnn.benchmark = True 91 | 92 | # Generate dataset and loader 93 | print("Loading samples to predict from %s" % os.path.join(config['annotations']['dir'], 94 | config['prediction']['split'] + '.json')) 95 | 96 | # Load annotations 97 | path_annotations = os.path.join(config['annotations']['dir'], config['prediction']['split'] + '.json') 98 | input_annotations = json.load(open(path_annotations, 'r')) 99 | 100 | # Data loader and dataset 101 | input_loader = vqa_dataset.get_loader(config, split=config['prediction']['split']) 102 | 103 | # Load model weights 104 | print("Loading Model from %s" % config['prediction']['model_path']) 105 | log = torch.load(config['prediction']['model_path']) 106 | 107 | # Num tokens seen during training 108 | num_tokens = len(log['vocabs']['question']) + 1 109 | # Use the same configuration used during training 110 | train_config = log['config'] 111 | 112 | model = nn.DataParallel(models.Model(train_config, num_tokens)).cuda() 113 | 114 | dict_weights = log['weights'] 115 | model.load_state_dict(dict_weights) 116 | 117 | predicted, samples_ids = predict_answers(model, input_loader, split=config['prediction']['split']) 118 | 119 | submission = create_submission(input_annotations, predicted, samples_ids, input_loader.dataset.vocabs) 120 | 121 | with open(config['prediction']['submission_file'], 'w') as fd: 122 | json.dump(submission, fd) 123 | 124 | print("Submission file saved in %s" % config['prediction']['submission_file']) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /useful notebooks/Construct prediction file from val logs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import sys\n", 12 | "import torch\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 24, 19 | "metadata": { 20 | "collapsed": true 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "path = './logs/log_directory/log_file.pth'" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 25, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "results = torch.load(path)\n", 36 | "answers = torch.FloatTensor(results['eval_results']['answers'])\n", 37 | "indexes = torch.IntTensor(results['eval_results']['samples_ids'])" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 29, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "sort_index = np.argsort(indexes)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 30, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "#to sort using the indeces \n", 60 | "sorted_answers = np.array(answers, dtype='int_')[sort_index]" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 32, 66 | "metadata": { 67 | "collapsed": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "dictionaries = results['vocab']" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 33, 77 | "metadata": { 78 | "collapsed": true 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "ans_to_id = dictionaries['answer']" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 34, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "id_to_ans = {idx : ans for ans, idx in ans_to_id.items()}" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 37, 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "real_answers = []\n", 105 | "for ans_id in sorted_answers:\n", 106 | " ans = id_to_ans[ans_id]\n", 107 | " real_answers.append(ans)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 40, 113 | "metadata": { 114 | "collapsed": true 115 | }, 116 | "outputs": [], 117 | "source": [ 118 | "path_annotation_val = '../data_vizwiz/Annotations/val.json'" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 41, 124 | "metadata": { 125 | "collapsed": true 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "import json\n", 130 | "val = json.load(open(path_annotation_val,'r'))" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 43, 136 | "metadata": { 137 | "collapsed": true 138 | }, 139 | "outputs": [], 140 | "source": [ 141 | "predictions = []\n", 142 | "for i in range(len(val)):\n", 143 | " pred = {}\n", 144 | " pred['image'] = val[i]['image']\n", 145 | " pred['answer'] = real_answers[i]\n", 146 | " predictions.append(pred)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 45, 152 | "metadata": { 153 | "collapsed": true 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "with open('predictions.json', 'w') as fd:\n", 158 | " json.dump(predictions, fd)" 159 | ] 160 | } 161 | ], 162 | "metadata": { 163 | "kernelspec": { 164 | "display_name": "Environment (conda_pytorch_p36)", 165 | "language": "python", 166 | "name": "conda_pytorch_p36" 167 | }, 168 | "language_info": { 169 | "codemirror_mode": { 170 | "name": "ipython", 171 | "version": 3 172 | }, 173 | "file_extension": ".py", 174 | "mimetype": "text/x-python", 175 | "name": "python", 176 | "nbconvert_exporter": "python", 177 | "pygments_lexer": "ipython3", 178 | "version": "3.6.6" 179 | } 180 | }, 181 | "nbformat": 4, 182 | "nbformat_minor": 2 183 | } 184 | -------------------------------------------------------------------------------- /preprocessing/image_features_extraction.py: -------------------------------------------------------------------------------- 1 | # TODO : Generalize with device instead of gpu/cpu 2 | import argparse 3 | import time 4 | 5 | import h5py 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.nn as nn 9 | import torchvision.models as models 10 | import yaml 11 | from torch.autograd import Variable 12 | from tqdm import tqdm 13 | 14 | from datasets.images import ImageDataset, get_transform 15 | 16 | 17 | class NetFeatureExtractor(nn.Module): 18 | 19 | def __init__(self): 20 | super(NetFeatureExtractor, self).__init__() 21 | self.model = models.resnet152(pretrained=True) 22 | # PyTorch models available in torch.utils.model_zoo require an input of size 224x224. 23 | # This is because of the avgpooling that is fixed 7x7. 24 | # By using AdaptiveAvgPool2 we can feed to the network images with higher resolution (448x448) 25 | self.model.avgpool = nn.AdaptiveAvgPool2d(1) 26 | 27 | # Save attention features (tensor) 28 | def save_att_features(module, input, output): 29 | self.att_feat = output 30 | 31 | # Save no-attention features (vector) 32 | def save_noatt_features(module, input, output): 33 | self.no_att_feat = output 34 | 35 | # This is a forward hook. Is executed each time forward is executed 36 | self.model.layer4.register_forward_hook(save_att_features) 37 | self.model.avgpool.register_forward_hook(save_noatt_features) 38 | 39 | def forward(self, x): 40 | self.model(x) 41 | return self.no_att_feat, self.att_feat # [batch_size, 2048], [batch_size, 2048, 14, 14] 42 | 43 | 44 | def main(): 45 | # Load config yaml file 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--path_config', default='config/default.yaml', type=str, 48 | help='path to a yaml config file') 49 | args = parser.parse_args() 50 | 51 | if args.path_config is not None: 52 | with open(args.path_config, 'r') as handle: 53 | config = yaml.load(handle) 54 | config = config['images'] 55 | 56 | # Benchmark mode is good whenever your input sizes for your network do not vary 57 | cudnn.benchmark = True 58 | 59 | net = NetFeatureExtractor().cuda() 60 | net.eval() 61 | # Resize, Crop, Normalize 62 | transform = get_transform(config['img_size']) 63 | dataset = ImageDataset(config['dir'], transform=transform) 64 | 65 | data_loader = torch.utils.data.DataLoader( 66 | dataset, 67 | batch_size=config['preprocess_batch_size'], 68 | num_workers=config['preprocess_data_workers'], 69 | shuffle=False, 70 | pin_memory=True, 71 | ) 72 | 73 | h5_file = h5py.File(config['path_features'], 'w') 74 | 75 | dummy_input = Variable(torch.ones(1, 3, config['img_size'], config['img_size']), volatile=True).cuda() 76 | _, dummy_output = net(dummy_input) 77 | 78 | att_features_shape = ( 79 | len(data_loader.dataset), 80 | dummy_output.size(1), 81 | dummy_output.size(2), 82 | dummy_output.size(3) 83 | ) 84 | 85 | noatt_features_shape = ( 86 | len(data_loader.dataset), 87 | dummy_output.size(1) 88 | ) 89 | 90 | h5_att = h5_file.create_dataset('att', shape=att_features_shape, dtype='float16') 91 | h5_noatt = h5_file.create_dataset('noatt', shape=noatt_features_shape, dtype='float16') 92 | 93 | # save order of extraction 94 | dt = h5py.special_dtype(vlen=str) 95 | img_names = h5_file.create_dataset('img_name', shape=(len(data_loader.dataset),), dtype=dt) 96 | 97 | begin = time.time() 98 | end = time.time() 99 | 100 | print('Extracting features ...') 101 | idx = 0 102 | delta = config['preprocess_batch_size'] 103 | 104 | for i, inputs in enumerate(tqdm(data_loader)): 105 | inputs_img = Variable(inputs['visual'].cuda(async=True), volatile=True) 106 | no_att_feat, att_feat = net(inputs_img) 107 | 108 | # reshape (batch_size, 2048) 109 | no_att_feat = no_att_feat.view(-1, 2048) 110 | 111 | h5_noatt[idx:idx + delta] = no_att_feat.data.cpu().numpy().astype('float16') 112 | h5_att[idx:idx + delta, :, :] = att_feat.data.cpu().numpy().astype('float16') 113 | img_names[idx:idx + delta] = inputs['name'] 114 | 115 | idx += delta 116 | h5_file.close() 117 | 118 | end = time.time() - begin 119 | 120 | print('Finished in {}m and {}s'.format(int(end / 60), int(end % 60))) 121 | print('Created file : ' + config['path_features']) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /datasets/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path 4 | 5 | import h5py 6 | import torch 7 | import torch.utils.data as data 8 | 9 | from datasets.features import FeaturesDataset 10 | from preprocessing.preprocessing_utils import prepare_questions, prepare_answers, encode_question, encode_answers 11 | 12 | 13 | def get_loader(config, split): 14 | """ Returns the data loader of the specified dataset split """ 15 | split = VQADataset( 16 | config, 17 | split 18 | ) 19 | 20 | loader = torch.utils.data.DataLoader( 21 | split, 22 | batch_size=config['training']['batch_size'], 23 | shuffle=True if split == 'train' or split == 'trainval' else False, # only shuffle the data in training 24 | pin_memory=True, 25 | num_workers=config['training']['data_workers'], 26 | collate_fn=collate_fn, 27 | ) 28 | return loader 29 | 30 | 31 | def collate_fn(batch): 32 | # Sort samples in the batch based on the question lengths in descending order. 33 | # This allows to pack the pack_padded_sequence when encoding questions using RNN 34 | batch.sort(key=lambda x: x['q_length'], reverse=True) 35 | return data.dataloader.default_collate(batch) 36 | 37 | 38 | class VQADataset(data.Dataset): 39 | """ VQA dataset, open-ended """ 40 | 41 | def __init__(self, config, split): 42 | super(VQADataset, self).__init__() 43 | 44 | with open(config['annotations']['path_vocabs'], 'r') as fd: 45 | vocabs = json.load(fd) 46 | 47 | annotations_dir = config['annotations']['dir'] 48 | 49 | path_ann = os.path.join(annotations_dir, split + ".json") 50 | with open(path_ann, 'r') as fd: 51 | self.annotations = json.load(fd) 52 | 53 | self.max_question_length = config['annotations']['max_length'] 54 | self.split = split 55 | 56 | # vocab 57 | self.vocabs = vocabs 58 | self.token_to_index = self.vocabs['question'] 59 | self.answer_to_index = self.vocabs['answer'] 60 | 61 | # pre-process questions and answers 62 | self.questions = prepare_questions(self.annotations) 63 | self.questions = [encode_question(q, self.token_to_index, self.max_question_length) for q in 64 | self.questions] # encode questions and return question and question lenght 65 | 66 | if self.split != 'test': 67 | self.answers = prepare_answers(self.annotations) 68 | self.answers = [encode_answers(a, self.answer_to_index) for a in 69 | self.answers] # create a sparse vector of len(self.answer_to_index) for each question containing the occurances of each answer 70 | 71 | if self.split == "train" or self.split == "trainval": 72 | self._filter_unanswerable_samples() 73 | 74 | # load image names in feature extraction order 75 | with h5py.File(config['images']['path_features'], 'r') as f: 76 | img_names = f['img_name'][()] 77 | self.name_to_id = {name: i for i, name in enumerate(img_names)} 78 | 79 | # names in the annotations, will be used to get items from the dataset 80 | self.img_names = [s['image'] for s in self.annotations] 81 | # load features 82 | self.features = FeaturesDataset(config['images']['path_features'], config['images']['mode']) 83 | 84 | def _filter_unanswerable_samples(self): 85 | """ 86 | Filter during training the samples that do not have at least one answer 87 | """ 88 | a = [] 89 | q = [] 90 | annotations = [] 91 | for i in range(len(self.answers)): 92 | if len(self.answers[i].nonzero()) > 0: 93 | a.append(self.answers[i]) 94 | q.append(self.questions[i]) 95 | 96 | annotations.append(self.annotations[i]) 97 | self.answers = a 98 | self.questions = q 99 | self.annotations = annotations 100 | 101 | @property 102 | def num_tokens(self): 103 | return len(self.token_to_index) + 1 # add 1 for token at index 0 104 | 105 | def __getitem__(self, i): 106 | 107 | item = {} 108 | item['question'], item['q_length'] = self.questions[i] 109 | if self.split != 'test': 110 | item['answer'] = self.answers[i] 111 | img_name = self.img_names[i] 112 | feature_id = self.name_to_id[img_name] 113 | item['img_name'] = self.img_names[i] 114 | item['visual'] = self.features[feature_id] 115 | # collate_fn sorts the samples in order to be possible to pack them later in the model. 116 | # the sample_id is returned so that the original order can be restored during when evaluating the predictions 117 | item['sample_id'] = i 118 | 119 | return item 120 | 121 | def __len__(self): 122 | return len(self.questions) 123 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | 8 | class Model(nn.Module): 9 | """ 10 | References : 11 | 1 - https://arxiv.org/abs/1704.03162 12 | 2 - https://arxiv.org/pdf/1511.02274 13 | 3 - https://arxiv.org/abs/1708.00584 14 | """ 15 | 16 | def __init__(self, config, num_tokens): 17 | super(Model, self).__init__() 18 | 19 | dim_v = config['model']['pooling']['dim_v'] 20 | dim_q = config['model']['pooling']['dim_q'] 21 | dim_h = config['model']['pooling']['dim_h'] 22 | 23 | n_glimpses = config['model']['attention']['glimpses'] 24 | 25 | self.text = TextEncoder( 26 | num_tokens=num_tokens, 27 | emb_size=config['model']['seq2vec']['emb_size'], 28 | dim_q=dim_q, 29 | drop=config['model']['seq2vec']['dropout'], 30 | ) 31 | self.attention = Attention( 32 | dim_v=dim_v, 33 | dim_q=dim_q, 34 | dim_h=config['model']['attention']['mid_features'], 35 | n_glimpses=n_glimpses, 36 | drop=config['model']['attention']['dropout'], 37 | ) 38 | self.classifier = Classifier( 39 | dim_input=n_glimpses * dim_v + dim_q, 40 | dim_h=dim_h, 41 | top_ans=config['annotations']['top_ans'], 42 | drop=config['model']['classifier']['dropout'], 43 | ) 44 | 45 | for m in self.modules(): 46 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 47 | init.xavier_uniform_(m.weight) 48 | if m.bias is not None: 49 | m.bias.data.zero_() 50 | 51 | def forward(self, v, q, q_len): 52 | 53 | q = self.text(q, list(q_len.data)) 54 | # L2 normalization on the depth dimension 55 | v = F.normalize(v, p=2, dim=1) 56 | attention_maps = self.attention(v, q) 57 | v = apply_attention(v, attention_maps) 58 | # concatenate attended features and encoded question 59 | combined = torch.cat([v, q], dim=1) 60 | answer = self.classifier(combined) 61 | return answer 62 | 63 | 64 | class Classifier(nn.Sequential): 65 | def __init__(self, dim_input, dim_h, top_ans, drop=0.0): 66 | super(Classifier, self).__init__() 67 | self.add_module('drop1', nn.Dropout(drop)) 68 | self.add_module('lin1', nn.Linear(dim_input, dim_h)) 69 | self.add_module('relu', nn.ReLU()) 70 | self.add_module('drop2', nn.Dropout(drop)) 71 | self.add_module('lin2', nn.Linear(dim_h, top_ans)) 72 | 73 | 74 | class TextEncoder(nn.Module): 75 | def __init__(self, num_tokens, emb_size, dim_q, drop=0.0): 76 | super(TextEncoder, self).__init__() 77 | self.embedding = nn.Embedding(num_tokens, emb_size, padding_idx=0) 78 | self.dropout = nn.Dropout(drop) 79 | self.tanh = nn.Tanh() 80 | self.lstm = nn.LSTM(input_size=emb_size, 81 | hidden_size=dim_q, 82 | num_layers=1) 83 | self.dim_q = dim_q 84 | 85 | # Initialize parameters 86 | self._init_lstm(self.lstm.weight_ih_l0) 87 | self._init_lstm(self.lstm.weight_hh_l0) 88 | self.lstm.bias_ih_l0.data.zero_() 89 | self.lstm.bias_hh_l0.data.zero_() 90 | 91 | init.xavier_uniform_(self.embedding.weight) 92 | 93 | def _init_lstm(self, weight): 94 | for w in weight.chunk(4, 0): 95 | init.xavier_uniform_(w) 96 | 97 | def forward(self, q, q_len): 98 | embedded = self.embedding(q) 99 | tanhed = self.tanh(self.dropout(embedded)) 100 | # pack to feed to the LSTM 101 | packed = pack_padded_sequence(tanhed, q_len, batch_first=True) 102 | _, (_, c) = self.lstm(packed) 103 | # _, (c, _) = self.lstm(packed) # this is h 104 | return c.squeeze(0) 105 | 106 | 107 | class Attention(nn.Module): 108 | def __init__(self, dim_v, dim_q, dim_h, n_glimpses, drop=0.0): 109 | super(Attention, self).__init__() 110 | # As specified in https://arxiv.org/pdf/1511.02274.pdf the bias is already included in fc_q 111 | self.conv_v = nn.Conv2d(dim_v, dim_h, 1, bias=False) 112 | self.fc_q = nn.Linear(dim_q, dim_h) 113 | self.conv_x = nn.Conv2d(dim_h, n_glimpses, 1) 114 | 115 | self.dropout = nn.Dropout(drop) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | def forward(self, v, q): 119 | # bring to the same shape 120 | v = self.conv_v(self.dropout(v)) 121 | q = self.fc_q(self.dropout(q)) 122 | q = repeat_encoded_question(q, v) 123 | # sum element-wise and ReLU 124 | x = self.relu(v + q) 125 | 126 | x = self.conv_x(self.dropout(x)) # We obtain n_glimpses attention maps [batch_size][n_glimpses][14][14] 127 | return x 128 | 129 | 130 | def repeat_encoded_question(q, v): 131 | """ 132 | Repeat the encoded question over all the spatial positions of the input image feature tensor. 133 | :param q: shape [batch_size][h] 134 | :param v: shape [batch_size][h][14][14] 135 | :return: a tensor constructed repeating q 14x14 with shape [batch_size][h][14][14] 136 | """ 137 | batch_size, h = q.size() 138 | # repeat the encoded question [14x14] times (over all the spatial positions of the image feature matrix) 139 | q_tensor = q.view(batch_size, h, *([1, 1])).expand_as(v) 140 | return q_tensor 141 | 142 | 143 | def apply_attention(v, attention): 144 | """ 145 | Apply attention maps over the input image features. 146 | """ 147 | batch_size, spatial_vec_size = v.size()[:2] 148 | glimpses = attention.size(1) 149 | 150 | # flatten the spatial dimensions [14x14] into a third dimension [196] 151 | v = v.view(batch_size, spatial_vec_size, -1) 152 | attention = attention.view(batch_size, glimpses, -1) 153 | n_image_regions = v.size(2) # 14x14 = 196 154 | 155 | # Apply softmax to each attention map separately to create n_glimpses attention distribution over the image regions 156 | attention = attention.view(batch_size * glimpses, -1) # [batch_size x n_glimpses][196] 157 | attention = F.softmax(attention, dim=1) 158 | 159 | # apply the weighting by creating a new dim to tile both tensors over 160 | target_size = [batch_size, glimpses, spatial_vec_size, n_image_regions] 161 | v = v.view(batch_size, 1, spatial_vec_size, n_image_regions).expand( 162 | *target_size) # [batch_size][n_glimpses][2048][196] 163 | attention = attention.view(batch_size, glimpses, 1, n_image_regions).expand( 164 | *target_size) # [batch_size][n_glimpses][2048][196] 165 | # Weighted sum over all the spatial regions vectors 166 | weighted = v * attention 167 | weighted_mean = weighted.sum(dim=3) # [batch_size][n_glimpses][2048] 168 | 169 | # attended features are flattened in the same dimension 170 | return weighted_mean.view(batch_size, -1) # [batch_size][n_glimpses * 2048] 171 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path 3 | from datetime import datetime 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | import yaml 9 | from torch.autograd import Variable 10 | from tqdm import tqdm 11 | 12 | import models 13 | import utils 14 | from datasets import vqa_dataset 15 | 16 | 17 | def train(model, loader, optimizer, tracker, epoch, split): 18 | model.train() 19 | 20 | tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} 21 | tq = tqdm(loader, desc='{} E{:03d}'.format(split, epoch), ncols=0) 22 | loss_tracker = tracker.track('{}_loss'.format(split), tracker_class(**tracker_params)) 23 | acc_tracker = tracker.track('{}_acc'.format(split), tracker_class(**tracker_params)) 24 | log_softmax = nn.LogSoftmax(dim=1).cuda() 25 | 26 | for item in tq: 27 | v = item['visual'] 28 | q = item['question'] 29 | a = item['answer'] 30 | q_length = item['q_length'] 31 | 32 | v = Variable(v.cuda(async=True)) 33 | q = Variable(q.cuda(async=True)) 34 | a = Variable(a.cuda(async=True)) 35 | q_length = Variable(q_length.cuda(async=True)) 36 | 37 | out = model(v, q, q_length) 38 | 39 | # This is the Soft-loss described in https://arxiv.org/pdf/1708.00584.pdf 40 | 41 | nll = -log_softmax(out) 42 | 43 | loss = (nll * a / 10).sum(dim=1).mean() 44 | acc = utils.vqa_accuracy(out.data, a.data).cpu() 45 | 46 | optimizer.zero_grad() 47 | loss.backward() 48 | optimizer.step() 49 | 50 | loss_tracker.append(loss.item()) 51 | acc_tracker.append(acc.mean()) 52 | fmt = '{:.4f}'.format 53 | tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 54 | 55 | 56 | def evaluate(model, loader, tracker, epoch, split): 57 | model.eval() 58 | tracker_class, tracker_params = tracker.MeanMonitor, {} 59 | 60 | predictions = [] 61 | samples_ids = [] 62 | accuracies = [] 63 | 64 | tq = tqdm(loader, desc='{} E{:03d}'.format(split, epoch), ncols=0) 65 | loss_tracker = tracker.track('{}_loss'.format(split), tracker_class(**tracker_params)) 66 | acc_tracker = tracker.track('{}_acc'.format(split), tracker_class(**tracker_params)) 67 | log_softmax = nn.LogSoftmax(dim=1).cuda() 68 | 69 | with torch.no_grad(): 70 | for item in tq: 71 | v = item['visual'] 72 | q = item['question'] 73 | a = item['answer'] 74 | sample_id = item['sample_id'] 75 | q_length = item['q_length'] 76 | 77 | v = Variable(v.cuda(async=True)) 78 | q = Variable(q.cuda(async=True)) 79 | a = Variable(a.cuda(async=True)) 80 | q_length = Variable(q_length.cuda(async=True)) 81 | 82 | out = model(v, q, q_length) 83 | 84 | # This is the Soft-loss described in https://arxiv.org/pdf/1708.00584.pdf 85 | 86 | nll = -log_softmax(out) 87 | 88 | loss = (nll * a / 10).sum(dim=1).mean() 89 | acc = utils.vqa_accuracy(out.data, a.data).cpu() 90 | 91 | # save predictions of this batch 92 | _, answer = out.data.cpu().max(dim=1) 93 | 94 | predictions.append(answer.view(-1)) 95 | accuracies.append(acc.view(-1)) 96 | # Sample id is necessary to obtain the mapping sample-prediction 97 | samples_ids.append(sample_id.view(-1).clone()) 98 | 99 | loss_tracker.append(loss.item()) 100 | acc_tracker.append(acc.mean()) 101 | fmt = '{:.4f}'.format 102 | tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) 103 | 104 | predictions = list(torch.cat(predictions, dim=0)) 105 | accuracies = list(torch.cat(accuracies, dim=0)) 106 | samples_ids = list(torch.cat(samples_ids, dim=0)) 107 | 108 | eval_results = { 109 | 'answers': predictions, 110 | 'accuracies': accuracies, 111 | 'samples_ids': samples_ids, 112 | 'avg_accuracy': acc_tracker.mean.value, 113 | 'avg_loss': loss_tracker.mean.value 114 | } 115 | 116 | return eval_results 117 | 118 | 119 | def main(): 120 | # Load config yaml file 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--path_config', default='config/default.yaml', type=str, 123 | help='path to a yaml config file') 124 | args = parser.parse_args() 125 | 126 | if args.path_config is not None: 127 | with open(args.path_config, 'r') as handle: 128 | config = yaml.load(handle) 129 | 130 | # generate log directory 131 | dir_name = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 132 | path_log_dir = os.path.join(config['logs']['dir_logs'], dir_name) 133 | 134 | if not os.path.exists(path_log_dir): 135 | os.makedirs(path_log_dir) 136 | 137 | print('Model logs will be saved in {}'.format(path_log_dir)) 138 | 139 | cudnn.benchmark = True 140 | 141 | # Generate datasets and loaders 142 | train_loader = vqa_dataset.get_loader(config, split='train') 143 | val_loader = vqa_dataset.get_loader(config, split='val') 144 | 145 | model = nn.DataParallel(models.Model(config, train_loader.dataset.num_tokens)).cuda() 146 | 147 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 148 | config['training']['lr']) 149 | 150 | # Load model weights if necessary 151 | if config['model']['pretrained_model'] is not None: 152 | print("Loading Model from %s" % config['model']['pretrained_model']) 153 | log = torch.load(config['model']['pretrained_model']) 154 | dict_weights = log['weights'] 155 | model.load_state_dict(dict_weights) 156 | 157 | tracker = utils.Tracker() 158 | 159 | min_loss = 10 160 | max_accuracy = 0 161 | 162 | path_best_accuracy = os.path.join(path_log_dir, 'best_accuracy_log.pth') 163 | path_best_loss = os.path.join(path_log_dir, 'best_loss_log.pth') 164 | 165 | for i in range(config['training']['epochs']): 166 | 167 | train(model, train_loader, optimizer, tracker, epoch=i, split=config['training']['train_split']) 168 | # If we are training on the train split (and not on train+val) we can evaluate on val 169 | if config['training']['train_split'] == 'train': 170 | eval_results = evaluate(model, val_loader, tracker, epoch=i, split='val') 171 | 172 | # save all the information in the log file 173 | log_data = { 174 | 'epoch': i, 175 | 'tracker': tracker.to_dict(), 176 | 'config': config, 177 | 'weights': model.state_dict(), 178 | 'eval_results': eval_results, 179 | 'vocabs': train_loader.dataset.vocabs, 180 | } 181 | 182 | # save logs for min validation loss and max validation accuracy 183 | if eval_results['avg_loss'] < min_loss: 184 | torch.save(log_data, path_best_loss) # save model 185 | min_loss = eval_results['avg_loss'] # update min loss value 186 | 187 | if eval_results['avg_accuracy'] > max_accuracy: 188 | torch.save(log_data, path_best_accuracy) # save model 189 | max_accuracy = eval_results['avg_accuracy'] # update max accuracy value 190 | 191 | # Save final model 192 | log_data = { 193 | 'tracker': tracker.to_dict(), 194 | 'config': config, 195 | 'weights': model.state_dict(), 196 | 'vocabs': train_loader.dataset.vocabs, 197 | } 198 | 199 | path_final_log = os.path.join(path_log_dir, 'final_log.pth') 200 | torch.save(log_data, path_final_log) 201 | 202 | 203 | if __name__ == '__main__': 204 | main() 205 | -------------------------------------------------------------------------------- /useful notebooks/compute_acc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 35, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2\n", 20 | "\n", 21 | "import json\n", 22 | "\n", 23 | "split = 'val'\n", 24 | "results = json.load(open('predictions.json'))\n", 25 | "dataset = json.load(open('../data_vizwiz/Annotations/%s.json'%split))\n", 26 | "\n", 27 | "img2gt = {x['image']:x['answers'] for x in dataset}" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 36, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "#print(dataset[:10])\n", 37 | "#results[:30]\n", 38 | "#dataset[18]" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 37, 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "{u'answer': u'unsuitable', u'image': u'VizWiz_val_000000028000.jpg'}\n", 51 | "{u'answerable': 0, u'image': u'VizWiz_val_000000028000.jpg', u'question': u\"What's this?\", u'answers': [{u'answer_confidence': u'yes', u'answer': u'unsuitable'}, {u'answer_confidence': u'yes', u'answer': u'unsuitable'}, {u'answer_confidence': u'maybe', u'answer': u'beans'}, {u'answer_confidence': u'yes', u'answer': u'unanswerable'}, {u'answer_confidence': u'yes', u'answer': u'unsuitable'}, {u'answer_confidence': u'yes', u'answer': u'unanswerable'}, {u'answer_confidence': u'maybe', u'answer': u'unanswerable'}, {u'answer_confidence': u'yes', u'answer': u'unsuitable'}, {u'answer_confidence': u'yes', u'answer': u'unanswerable'}, {u'answer_confidence': u'maybe', u'answer': u'unsuitable'}], u'answer_type': u'unanswerable'}\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "print(results[0])\n", 57 | "print(dataset[0])" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 38, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "VizWiz_val_000000028000.jpg\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "print(results[0]['image'])" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 39, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stdout", 84 | "output_type": "stream", 85 | "text": [ 86 | "VizWiz_val_000000028000.jpg\n", 87 | "VizWiz_val_000000028000.jpg\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "imdir='VizWiz_%s_%012d.jpg'\n", 93 | "print(dataset[0]['image'])\n", 94 | "img = imdir%(split,28000)\n", 95 | "print(img)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 40, 101 | "metadata": {}, 102 | "outputs": [ 103 | { 104 | "name": "stdout", 105 | "output_type": "stream", 106 | "text": [ 107 | "3173\n", 108 | "other : 0.594075007879\n", 109 | "number : 0.0151276394579\n", 110 | "unanswerable : 0.350772139931\n", 111 | "yes/no : 0.0400252127324\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "from collections import Counter\n", 117 | "img2ans_type = {}\n", 118 | "for one_data in dataset:\n", 119 | " ans_counter = Counter([x['answer'] for x in one_data['answers']])\n", 120 | " ans = ans_counter.most_common(1)[0][0]\n", 121 | " if ans == 'yes' or ans == 'no':\n", 122 | " ans_type = 'yes/no'\n", 123 | " elif ans == 'unanswerable' or ans == 'unsuitable':\n", 124 | " ans_type = 'unanswerable'\n", 125 | " elif ans.isdigit():\n", 126 | " ans_type = 'number'\n", 127 | " else:\n", 128 | " ans_type = 'other'\n", 129 | " img2ans_type[one_data['image']] = ans_type\n", 130 | " \n", 131 | "all_ans = img2ans_type.values()\n", 132 | "print len(all_ans)\n", 133 | "for ans_type in set(all_ans):\n", 134 | " print ans_type, ':', all_ans.count(ans_type)*1.0/len(all_ans)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 41, 140 | "metadata": {}, 141 | "outputs": [ 142 | { 143 | "name": "stdout", 144 | "output_type": "stream", 145 | "text": [ 146 | "Accuracy : 0.5123437335854606\n", 147 | "other : 0.3359858532272325\n", 148 | "number : 0.26388888888888884\n", 149 | "unanswerable : 0.7999401018268942\n", 150 | "yes/no : 0.7034120734908136\n" 151 | ] 152 | } 153 | ], 154 | "source": [ 155 | "import numpy as np\n", 156 | "img2acc = {}\n", 157 | "imdir='VizWiz_%s_%012d.jpg'\n", 158 | "\n", 159 | "for pred in results:\n", 160 | " #img = imdir%(split,pred['question_id']) #Reconstruct image name from question_id\n", 161 | " img = pred['image']\n", 162 | " pred_ans = pred['answer']\n", 163 | " gt_ans = img2gt[img]\n", 164 | " gt_ans = [x['answer'] for x in gt_ans]\n", 165 | " gt_ans = [x.lower() for x in gt_ans]\n", 166 | " cur_acc = np.minimum(1.0, gt_ans.count(pred_ans)/3.0)\n", 167 | " img2acc[img] = cur_acc\n", 168 | "\n", 169 | "print 'Accuracy :', np.mean(img2acc.values())\n", 170 | "for ans_type in set(all_ans):\n", 171 | " acc_per_type = np.mean([acc for img, acc in img2acc.items() if img2ans_type[img] == ans_type])\n", 172 | " print ans_type, ':', acc_per_type" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 22, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/plain": [ 183 | "'Denis prepro validation balanced'" 184 | ] 185 | }, 186 | "execution_count": 22, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "\"\"\"Denis prepro validation balanced\"\"\"" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 18, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "# Download coco-caption from https://github.com/tylin/coco-caption\n", 202 | "import sys\n", 203 | "sys.path.insert(0, '../coco-caption')\n", 204 | "\n", 205 | "from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer\n", 206 | "from pycocoevalcap.bleu.bleu import Bleu\n", 207 | "from pycocoevalcap.meteor.meteor import Meteor\n", 208 | "from pycocoevalcap.rouge.rouge import Rouge\n", 209 | "from pycocoevalcap.cider.cider import Cider\n", 210 | "\n", 211 | "class COCOEvalCap:\n", 212 | " def __init__(self,images,gts,res):\n", 213 | " self.evalImgs = []\n", 214 | " self.eval = {}\n", 215 | " self.imgToEval = {}\n", 216 | " self.params = {'image_id': images}\n", 217 | " self.gts = gts\n", 218 | " self.res = res\n", 219 | "\n", 220 | " def evaluate(self):\n", 221 | " imgIds = self.params['image_id']\n", 222 | " gts = self.gts\n", 223 | " res = self.res\n", 224 | "\n", 225 | " # =================================================\n", 226 | " # Set up scorers\n", 227 | " # =================================================\n", 228 | " print 'tokenization...'\n", 229 | " tokenizer = PTBTokenizer()\n", 230 | " gts = tokenizer.tokenize(gts)\n", 231 | " res = tokenizer.tokenize(res)\n", 232 | "\n", 233 | " # =================================================\n", 234 | " # Set up scorers\n", 235 | " # =================================================\n", 236 | " print 'setting up scorers...'\n", 237 | " scorers = [\n", 238 | " (Bleu(4), [\"Bleu_1\", \"Bleu_2\", \"Bleu_3\", \"Bleu_4\"]),\n", 239 | " (Meteor(),\"METEOR\"),\n", 240 | " (Rouge(), \"ROUGE_L\"),\n", 241 | " (Cider(), \"CIDEr\")\n", 242 | " ]\n", 243 | "\n", 244 | " # =================================================\n", 245 | " # Compute scores\n", 246 | " # =================================================\n", 247 | " eval = {}\n", 248 | " for scorer, method in scorers:\n", 249 | " print 'computing %s score...'%(scorer.method())\n", 250 | " assert(set(gts.keys()) == set(res.keys()))\n", 251 | " score, scores = scorer.compute_score(gts, res)\n", 252 | " if type(method) == list:\n", 253 | " for sc, scs, m in zip(score, scores, method):\n", 254 | " self.setEval(sc, m)\n", 255 | " self.setImgToEvalImgs(scs, imgIds, m)\n", 256 | " print \"%s: %0.3f\"%(m, sc)\n", 257 | " else:\n", 258 | " self.setEval(score, method)\n", 259 | " self.setImgToEvalImgs(scores, imgIds, method)\n", 260 | " print \"%s: %0.3f\"%(method, score)\n", 261 | " self.setEvalImgs()\n", 262 | "\n", 263 | " def setEval(self, score, method):\n", 264 | " self.eval[method] = score\n", 265 | "\n", 266 | " def setImgToEvalImgs(self, scores, imgIds, method):\n", 267 | " for imgId, score in zip(imgIds, scores):\n", 268 | " if not imgId in self.imgToEval:\n", 269 | " self.imgToEval[imgId] = {}\n", 270 | " self.imgToEval[imgId][\"image_id\"] = imgId\n", 271 | " self.imgToEval[imgId][method] = score\n", 272 | "\n", 273 | " def setEvalImgs(self):\n", 274 | " self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 47, 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "name": "stdout", 284 | "output_type": "stream", 285 | "text": [ 286 | "tokenization...\n", 287 | "setting up scorers...\n", 288 | "computing Bleu score...\n", 289 | "{'reflen': 3627, 'guess': [3536, 364, 100, 32], 'testlen': 3536, 'correct': [2103, 166, 36, 11]}\n", 290 | "ratio: 0.974910394265\n", 291 | "Bleu_1: 0.580\n", 292 | "Bleu_2: 0.508\n", 293 | "Bleu_3: 0.449\n", 294 | "Bleu_4: 0.417\n", 295 | "computing METEOR score...\n", 296 | "METEOR: 0.307\n", 297 | "computing Rouge score...\n", 298 | "ROUGE_L: 0.593\n", 299 | "computing CIDEr score...\n", 300 | "CIDEr: 0.707\n", 301 | "{'CIDEr': 0.7068554617442656, 'Bleu_4': 0.4171507568309097, 'Bleu_3': 0.44878194619936634, 'Bleu_2': 0.5075632436570343, 'Bleu_1': 0.5796292858195882, 'ROUGE_L': 0.5929457568277345, 'METEOR': 0.3074730242342619}\n" 302 | ] 303 | } 304 | ], 305 | "source": [ 306 | "#res = {imdir%(split,x['question_id'])):[{'image_id':imdir%(split,x['question_id']), 'caption':x['answer']}] for x in results}\n", 307 | "res = {unicode(imdir%(split,x['question_id']), \"utf-8\"):[{'image_id':imdir%(split,x['question_id']), 'caption':x['answer']}] for x in results}\n", 308 | "\n", 309 | "gts = {}\n", 310 | "for img, ans_list in img2gt.items():\n", 311 | " ans_list = [x['answer'] for x in ans_list]\n", 312 | " tmp = []\n", 313 | " for x in ans_list:\n", 314 | " try:\n", 315 | " tmp.append(str(x))\n", 316 | " except:\n", 317 | " pass\n", 318 | " ans_list = tmp\n", 319 | " ans_list = [{'image_id': img, 'caption': str(x)} for x in ans_list]\n", 320 | " gts[img] = ans_list\n", 321 | "\n", 322 | "for img in gts.keys():\n", 323 | " if img not in res.keys():\n", 324 | " res[img] = [{'image_id':img, 'caption':''}]\n", 325 | "\n", 326 | "#### CHANGED CODE OF BLEU/METEOR SINCE RAISES ERROR WHEN COMPARING gts.keys() == res.keys()\n", 327 | " \n", 328 | "evalObj = COCOEvalCap(gts.keys(),gts,res)\n", 329 | "evalObj.evaluate()\n", 330 | "print evalObj.eval" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 7, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "import cPickle as pkl\n", 340 | "prob = pkl.load(open('saved_model/%s_prob.pkl'%split))\n", 341 | "answer2answer_id = json.load(open('data/create_vocab/answer2answer_id.json'))\n", 342 | "unanswerable_labels = [answer2answer_id['unanswerable'], answer2answer_id['unsuitable']]\n", 343 | "img2answerable = {x['image']:x['answerable'] for x in dataset}" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 8, 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "AP_rel: 0.8944\n", 356 | "AP_irrel: 0.5905\n" 357 | ] 358 | } 359 | ], 360 | "source": [ 361 | "from sklearn.metrics import recall_score, average_precision_score, precision_recall_curve\n", 362 | "\n", 363 | "y_test = []\n", 364 | "pred = []\n", 365 | "\n", 366 | "for res in results:\n", 367 | " img = res['image']\n", 368 | " gt_ans = img2answerable[img]\n", 369 | " y_test.append(gt_ans)\n", 370 | " one_prob = prob[img]\n", 371 | " one_pred = 1 - sum([one_prob[x] for x in unanswerable_labels])\n", 372 | " pred.append(one_pred)\n", 373 | "y_test = np.array(y_test)\n", 374 | "pred = np.array(pred)\n", 375 | "\n", 376 | "gt_labels = np.asarray(y_test) > 0.5\n", 377 | "precision, recall, thresholds = precision_recall_curve(gt_labels, pred)\n", 378 | "average_precision = average_precision_score(gt_labels, pred)\n", 379 | "print \"AP_rel: %.4f\"%average_precision\n", 380 | "with open('saved_model/results_rel.txt','w') as fid:\n", 381 | " fid.write(str(average_precision))\n", 382 | " fid.write('\\n')\n", 383 | " fid.write('\\n'.join(['%.4f\\t%.4f\\t%.4f'%x for x in list(zip(recall,precision,thresholds))[::-1]]))\n", 384 | "\n", 385 | "\n", 386 | "gt_labels_n = np.asarray(y_test) < 0.5\n", 387 | "pred_n = 1.0 - pred\n", 388 | "precision, recall, thresholds = precision_recall_curve(gt_labels_n, pred_n)\n", 389 | "average_precision = average_precision_score(gt_labels_n, pred_n)\n", 390 | "print \"AP_irrel: %.4f\"%average_precision\n", 391 | "with open('saved_model/results_irrel.txt','w') as fid:\n", 392 | " fid.write(str(average_precision))\n", 393 | " fid.write('\\n')\n", 394 | " fid.write('\\n'.join(['%.4f\\t%.4f\\t%.4f'%x for x in list(zip(recall,precision,thresholds))[::-1]]))" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [] 403 | } 404 | ], 405 | "metadata": { 406 | "kernelspec": { 407 | "display_name": "Python 2", 408 | "language": "python", 409 | "name": "python2" 410 | }, 411 | "language_info": { 412 | "codemirror_mode": { 413 | "name": "ipython", 414 | "version": 2 415 | }, 416 | "file_extension": ".py", 417 | "mimetype": "text/x-python", 418 | "name": "python", 419 | "nbconvert_exporter": "python", 420 | "pygments_lexer": "ipython2", 421 | "version": "2.7.12" 422 | } 423 | }, 424 | "nbformat": 4, 425 | "nbformat_minor": 2 426 | } 427 | -------------------------------------------------------------------------------- /useful notebooks/Plot accuracy train - val.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [ 10 | { 11 | "name": "stderr", 12 | "output_type": "stream", 13 | "text": [ 14 | "/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: \nThis call to matplotlib.use() has no effect because the backend has already\nbeen chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,\nor matplotlib.backends is imported for the first time.\n\nThe backend was *originally* set to 'module://ipykernel.pylab.backend_inline' by the following code:\n File \"/Users/denis/anaconda3/lib/python3.6/runpy.py\", line 193, in _run_module_as_main\n \"__main__\", mod_spec)\n File \"/Users/denis/anaconda3/lib/python3.6/runpy.py\", line 85, in _run_code\n exec(code, run_globals)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py\", line 16, in \n app.launch_new_instance()\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/traitlets/config/application.py\", line 658, in launch_instance\n app.start()\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel/kernelapp.py\", line 477, in start\n ioloop.IOLoop.instance().start()\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/zmq/eventloop/ioloop.py\", line 177, in start\n super(ZMQIOLoop, self).start()\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/tornado/ioloop.py\", line 888, in start\n handler_func(fd_obj, events)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py\", line 277, in null_wrapper\n return fn(*args, **kwargs)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 440, in _handle_events\n self._handle_recv()\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 472, in _handle_recv\n self._run_callback(callback, msg)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py\", line 414, in _run_callback\n callback(*args, **kwargs)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/tornado/stack_context.py\", line 277, in null_wrapper\n return fn(*args, **kwargs)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 283, in dispatcher\n return self.dispatch_shell(stream, msg)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 235, in dispatch_shell\n handler(stream, idents, msg)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel/kernelbase.py\", line 399, in execute_request\n user_expressions, allow_stdin)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel/ipkernel.py\", line 196, in do_execute\n res = shell.run_cell(code, store_history=store_history, silent=silent)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/ipykernel/zmqshell.py\", line 533, in run_cell\n return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2705, in run_cell\n interactivity=interactivity, compiler=compiler, result=result)\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2809, in run_ast_nodes\n if self.run_code(code, result):\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2869, in run_code\n exec(code_obj, self.user_global_ns, self.user_ns)\n File \"\", line 1, in \n get_ipython().magic('matplotlib inline')\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2146, in magic\n return self.run_line_magic(magic_name, magic_arg_s)\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2067, in run_line_magic\n result = fn(*args,**kwargs)\n File \"\", line 2, in matplotlib\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/magic.py\", line 188, in \n call = lambda f, *a, **k: f(*a, **k)\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/magics/pylab.py\", line 100, in matplotlib\n gui, backend = self.shell.enable_matplotlib(args.gui)\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/interactiveshell.py\", line 2935, in enable_matplotlib\n pt.activate_matplotlib(backend)\n File \"/Users/denis/.local/lib/python3.6/site-packages/IPython/core/pylabtools.py\", line 296, in activate_matplotlib\n matplotlib.pyplot.switch_backend(backend)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/matplotlib/pyplot.py\", line 229, in switch_backend\n matplotlib.use(newbackend, warn=False, force=True)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py\", line 1305, in use\n reload(sys.modules['matplotlib.backends'])\n File \"/Users/denis/anaconda3/lib/python3.6/importlib/__init__.py\", line 166, in reload\n _bootstrap._exec(spec, module)\n File \"/Users/denis/anaconda3/lib/python3.6/site-packages/matplotlib/backends/__init__.py\", line 14, in \n line for line in traceback.format_stack()\n\n\n This is separate from the ipykernel package so we can avoid doing imports until\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "%matplotlib inline\n", 20 | "import torch\n", 21 | "import matplotlib; matplotlib.use('agg')\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import numpy as np" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 45, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "name": "stderr", 33 | "output_type": "stream", 34 | "text": [ 35 | "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py:17: MatplotlibDeprecationWarning: The set_color_cycle function was deprecated in version 1.5. Use `.set_prop_cycle` instead.\n" 36 | ] 37 | }, 38 | { 39 | "data": { 40 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHQpJREFUeJzt3Xt0nHd95/H315Isy5Jsyfd77CSmwZAbCCc0bbk0cJLQJnu6HJosYcvlrDltwy1QCgtLd8Ohh9I2gW5ZIKWcQA+QNbApPktoltJssw1JiJ2rbUjiGGJb+CLLsuwZ62JpvvvHdx7PaCRZY3sk+Xn0eZ3zO88zM49mfs9o5vP85jszvzF3R0REsmXWdHdARERqT+EuIpJBCncRkQxSuIuIZJDCXUQkgxTuIiIZNGG4m9nXzOyQmW0f53Izs78xs11m9oyZvar23RQRkTNRzcj9HuC601x+PbC+2DYBXzr3bomIyLmYMNzd/SHgyGk2uQn4hodHgTYzW16rDoqIyJmrr8F1rAT2lp3eVzxvf+WGZraJGN3T3Nz86ksuuaQGNy8iMnNs27btsLsvnmi7WoR71dz9buBugI6ODt+6detU3ryISOqZ2UvVbFeLT8t0AqvLTq8qniciItOkFuG+BfiPxU/NXA30uvuokoyIiEydCcsyZvZt4PXAIjPbB/wZ0ADg7l8G7gduAHYBJ4B3TVZnRUSkOhOGu7vfMsHlDvxxzXokIiLnTN9QFRHJIIW7iEgGKdxFRDJI4S4ikkEKdxGRDFK4i4hkkMJdRCSDFO4iIhmkcBcRySCFu4hIBincRUQySOEuIpJBCncRkQxSuIuIZJDCXUQkgxTuIiIZpHAXEckghbuISAYp3EVEMkjhLiKSQQp3EZEMUriLiGSQwl1EJIMU7iIiGaRwFxHJIIW7iEgGKdxFRDJI4S4ikkEKdxGRDFK4i4hkkMJdRCSDFO4iIhmkcBcRySCFu4hIBincRUQyqKpwN7PrzOw5M9tlZh8b4/I1ZvagmT1pZs+Y2Q2176qIiFRrwnA3szrgi8D1wAbgFjPbULHZJ4HN7n4lcDPwP2rdURERqV41I/eNwC533+3ug8C9wE0V2zgwr7g+H/hV7booIiJnqppwXwnsLTu9r3heuf8K3Gpm+4D7gfeNdUVmtsnMtprZ1q6urrPoroiIVKNWb6jeAtzj7quAG4B/MLNR1+3ud7t7h7t3LF68uEY3LSIilaoJ905gddnpVcXzyr0H2Azg7o8Ac4BFteigiIicufoqtnkcWG9m64hQvxn4DxXb7AF+G7jHzF5OhLvqLiIy9dzh6FHo6oJDh6C/H+bOhebmWCbr9fXQ1zeynTgBAwNQKMDwcCwr191HLpP1pCWny5mNXL/qKnjZyyb1bpgw3N19yMxuAx4A6oCvufsOM7sD2OruW4APA39nZh8i3lx9p3vl3olIqrlDPg+HD0NvLzQ1wbx50ZqaRgZYoQC5HPT0RNAePRp/mwRo+RKgri7Ctnx54gQcOVJq3d2xHByEWbOi1dWV1vv6ItC7umBoaHruo2p96UuTHu42XRnc0dHhW7dunZbbFpkxhoYi7A4cgP37Y723N9qxY6X1fH70qNQ9/v7IkQj07u4Y1Y6lrg5aW2NEfOJEXGehcO79r6uDBQtKrb0d5swZe0Td2AhLlsDixSPb3LnRp6Tl87E8eTIOSuVt7ty4nuSgUX7wKG9mI9eTVnkaRo7ik/VFi2D+/LO6S8xsm7t3TLRdNWUZEZmIewTf8eOllstF6IwVBgMDETKV7ejRUpCWL0+ciCBLWhJss2aNDqimprjs4MEoS4w3gGtqioCZNw9aWsYOq7o6WLcOOjoikBYuLAVTf38cII4fj+WxY7HPzc3Q1lZq7e2xfUtLKUDLl2ZxEBkeHrmcOzf6Vv6KQKqmcJeZxT1e1udyEUYHD45u3d2xbV3dyBGcWQRwMupNRr5JwNWiFGBWCtCFCyNYX/OaCMzK/tTVRRBW1o37+kp13WXLYPnyWC5bFiPbtrYIzYaGc+9vrTQ2TncPMkfhLuen3t6o1yYvuctbX1+p/lpeiz16dGQtN1lPXornctFOF8Lt7RGqZqXRcdLcY/SZ1JkvvLC03to6urW0RABXvhFXKMDs2RHYlS35G5FzpHCXyXP8eNR5kzYwMHbtsrsbfvGLaLt3x7Kn58xuq7U1RqTJS/2kJee1tIxura2wdGmpLVkSoSuSAQp3GZ97jHST+nHlek/PyE8zJC158y6Xq/62Zs+GtWujDLFxY4yKFy4cWYpIWmNjXLZwYemNtvOpxCByHlC4Z02hEOH6q1/FiDgpWSTrvb2xXeWbfO4R1pVv5p08OfFtzp078tMMV14JN9wAK1ZEvTdZNjWNLrEMD8foesWK6IeI1ITCPW2GhmDPnihfvPhilDD27i21zs7xA3n+/AjSpJ5cXgeGUr354ovh6qtLI+Pkkw5JKSNZb28vfTRNRM4rCvfziXt8DrmzM0benZ2l9ZdeikB/6aWRbwg2NMCqVbB6NVxzTSzXrIGVK+MTFwsWREi3t8eXQ0RkRtCzfbr098POnfDUU9GefjpaUjZJmMUbfatWwatfDW97G1x0UamtWKFPV4jIKAr3yTQ8HCWUXbvghRdGL4eHY7vmZrj8cnj72+GSS2LUvXJlBPeyZXqzUETOmMK9Ftzjyy/PPltqzzwTI/P+/tJ2TU1Rz96wAX7v9+KNx8svjxG43kwUkRpSuJ+t3l740Y/ghz+Ef/qnqIsnli6FSy+FP/xDePnLYf36aCtW6KvUIjIlFO5nYudO2LIlAv3hh6OsMn8+vPnN8WbmpZdG0w+RiMg0U7hPxB0eeAD++q/hn/85zrviCvjTP4Xrr4+PDOpTKCJynlEqjWdgAL797Qj17dujpPLZz8I73hHrIiLnMYV7pXwe/vZv4QtfiK/QX3YZfP3rcPPNmndERFJD4Z4YHoZvfAM++cl4c/TNb45Qv/ZavQkqIqmjcIeopX/kI/Eloquugs2b4w1SEZGUmtkfrt6+PSa4etOb4qON994LjzyiYBeR1JuZ4e4On/tcfOrlJz+Bv/or+PnP4fd/XyUYEcmEmVeW6e2Fd70L7rsP3vpW+PKXY2ItEZEMmVnhvn17fO1/9+74iOOHPqSRuohk0swJ929+EzZtit+7fPBB+M3fnO4eiYhMmuzX3IeG4H3vg1tvjSlzn3hCwS4imZf9cP/sZ+NLSR/6EPz4x/FzbyIiGZftssxzz8GnPx0/cHHnndPdGxGRKZPdkXuhEDX2uXNjKgERkRkkuyP3r34VHnoolsuWTXdvRESmVDZH7vv3w0c/Cm94A7z73dPdGxGRKZfNcH/f++Ln7b7yFX2OXURmpOyVZb7/ffje9+DP/zx+2k5EZAbK1si9txf+6I9iDvaPfGS6eyMiMm2yNXL/+Mej3n7ffdDQMN29ERGZNtkZuT/xBHzpS/D+98PGjdPdGxGRaVVVuJvZdWb2nJntMrOPjbPN28xsp5ntMLNv1babVdi2LZa33z7lNy0icr6ZsCxjZnXAF4E3AfuAx81si7vvLNtmPfBx4Bp37zGzJZPV4XHlcrGcP3/Kb1pE5HxTzch9I7DL3Xe7+yBwL3BTxTb/Cfiiu/cAuPuh2nazCkm4NzdP+U2LiJxvqgn3lcDestP7iueVexnwMjN72MweNbPrxroiM9tkZlvNbGtXV9fZ9Xg8+Tw0NkJ9tt4jFhE5G7V6Q7UeWA+8HrgF+Dsza6vcyN3vdvcOd+9YvHhxjW66KJfTqF1EpKiacO8EVpedXlU8r9w+YIu7n3T3XwDPE2E/dXI5aGmZ0psUETlfVRPujwPrzWydmc0Gbga2VGzzj8SoHTNbRJRpdtewnxPL5xXuIiJFE4a7uw8BtwEPAD8DNrv7DjO7w8xuLG72ANBtZjuBB4E/cffuyer0mFSWERE5pap3H939fuD+ivM+VbbuwO3FNj1UlhEROSU731BVWUZE5JTshLvKMiIip2Qn3DVyFxE5JTvhrpq7iMgp2Qh3d5VlRETKZCPcBwagUNDIXUSkKBvhnkwapnAXEQGyFu4qy4iIAFkJ93w+lhq5i4gAWQl3jdxFREbIVrhr5C4iAmQl3FWWEREZIRvhrrKMiMgI2Qp3jdxFRICshLvKMiIiI2Qj3FWWEREZIRvhns9DYyPUV/XbIyIimZeNcNeMkCIiI2Qn3FWSERE5JRvhrh/qEBEZIRvhrrKMiMgI2Ql3lWVERE7JRrirLCMiMkI2wl1lGRGREbIT7irLiIicko1wV1lGRGSE9Ie7u0buIiIV0h/u/f1QKGjkLiJSJv3hrhkhRURGSX+4a0ZIEZFR0h/uGrmLiIyS/nBP2a8wucORI3DixHT3RESyLP0ToJ9lWaavD77zHdi3Dw4fLrXubjh6FNatg1e/OtqrXgUXXABmpb8vFODAAdi7N66jvx9mzYptkqVZXOfu3SNbb29MP3/ttXDjjfC7vwvLl4/u44kT8MQT8OijcOwYXH01/PqvQ1vb2PtUKMDPfw4/+Qm88EJc5+rVsGZNLJcsib5VK5+P+6O9HVpbz+juFZFplv5wP8OyjDts3gwf/Sjs2VP600WLSu2CC+D55+Fzn4Ohodhm4UK44goYHIy/6+wsXTaRxsY4WFx4IVxzTazv2QPf/z784Afw3vfCxo1w002wYgU89li0Z56B4eG4jlmzIrzN4BWviOu55hpYuTK2ffjhCPWenti+rq70t4nZsyPw58+PfS5vc+ZEkB84AAcPxjI5bkKE+8qVpbZiBQwMxKuQ8tbTEweDdetGN3c4dKjUDh6M5ezZcQAqb6tWxf1Wa8PDcTB+8cX4H86aFfdVff3I1tw8+j5qbj79wfHkSejqKt2HBw/GgTyXi3b8eGm9pSX+F8uWRUvWly+v7kDa1xcDi/7+OGgvWlS736oZHoZf/hKeew7274fLLovHfkPD6f/u2DHYsSMGJcPDo1t9fVxHQ0P8z5PlxRfHY1Jqy9x9Wm64o6PDt27deu5X9K1vwdvfTv/Tz/HDF1/GpZfCRReNHGUntm2DD34Q/u3f4PLL4c47YyQ8Z87YV93fHwG7bVu0p5+GuXNjFJy0JIjmzo3wKhRGLtvbIwjHCgV32L49Qn7LFnj88Ti/tTXC/qqrSq2lBX760+j7ww/DI4/EkylxySWlwL/mGli/PsJ27944kOzdG62zsxQw5YHT1xcHsGXLYOnS0nLhwgjsffvib5O2fz80NcGCBSNbW1scJH7xi2i9veP/6xobI5gGByMIKy1ZMvqAsnJl3G4SnuUtn4/7bt68aMk6RF927YrQGhwcv08TaWiIfictCamenniVNp76+uhPa2s8VnK5OAiMNUBIgj9pK1ZEOCb/xz174iBSzizu/yVL4v+2YEEcjJqb4/aS9cbG0uNzeDiWhUL05/nnI9BfeCEO3OXmzIGOjnj1+NrXwqWXxgHyySejPfFEnD4bZvDyl8d1J23DhjjoFgrxOO7qKrVcLvo3OFhaDg7G/2HRIli8uNSWLIl97u0d3Y4diwNRPh/LZH14OB735QO+RYvisdTbWxrEJMt8PgZur3hFtPb20fvY0xPP9WefjXbrrfE8Pbv7y7a5e8eE21UT7mZ2HfAFoA74qrt/dpzt/j3wXeA17n7a5K5ZuN99N7z3vXzvC/t46wdWAvFk+K3fgte9Lpbt7fCJT8A998Q/6TOfgXe/Ox4855Nf/SoePL/2axOXT4aHY5S0f3886RYunJo+JtzHPoBW6ukpBX1dXTzZktbaWrqO/v44gOzZE+2ll0YeTDo7R4dnQ0ME2dKlcX0tLfHEP3Ys2vHjsRwejlcOF10Uo8SLLoq2enVcz/BwhGzSBgfjiV5+EEzawMDIlgRLe/voA+PSpXGwa2mJg0ClJLgOHIj/4+ma2ehXN2vWROh2dY18RXToUFxvPj8yuE73VK+ri/vkkkvi8ZcslyyBp56KwcQjj0SIVx4c162DK6+M8uVll8UovK5udBsejr89ebK07OuLx/Gjj0br7o7rbG2Ng3h39+hXoJPBLA6Cc+fGc+/IkehfNRoaRm67fHmE/IUXxuN4+/Z4/Cbmz4fPfx7e+c6z7WuNwt3M6oDngTcB+4DHgVvcfWfFdq3AD4DZwG1TFu533QW3385///RR3v9f5nPXXTHC/dd/jbBMNDTABz4An/ykXgKm1cBABF1/fyk4qznASAR7X1/ch7NmjWzlATyRgYEYqe/YEeF1xRVjj1TPto8vvhgHkccei8BMRt/lo/F580qvmmbPjvWGhjhglI/wkzYwEM/5trZYJuutraVXNnPmjHwsucfg4PDhuI7Dh2Og0NYW+7tgQSzb2uJ+27Mn7pOdO2O5Y0e8v7ZmDbzylfFK59JLY33VqnN73FYb7tVU6TYCu9x9d/GK7wVuAnZWbPdp4C+APznDvp6bYmH4UL6ZuroIcLP45+zeDQ89FC8z3/WuKFVIejU2wtq1092LdCofmZ6LxsZS6aTWzOKV1cUXwzveceZ/P2dOqVxai74k5b0LL5x4+7Vro73lLed+27VSTbivBPaWnd4HXFW+gZm9Cljt7j8ws3HD3cw2AZsA1qxZc+a9HUsuB42NdPfWs2BB6YhoVnr5LSIy05zz59zNbBZwJ/DhibZ197vdvcPdOxYvXnyuNx2KM0IeOVK7l4ciImlXTbh3AuUvdFYVz0u0Aq8E/q+Z/RK4GthiZhPWhGqiOCNk8hE8ERGpLtwfB9ab2Tozmw3cDGxJLnT3Xndf5O5r3X0t8Chw40RvqNZM8UPDR47EmxwiIlJFuLv7EHAb8ADwM2Czu+8wszvM7MbJ7uCEimUZjdxFREqq+k6bu98P3F9x3qfG2fb1596tM5CUZXZp5C4ikkj/xGH5PN6skbuISLn0h3sux8nGllNf9RcRkYyEe39dzAipsoyISEh/uOfz9NXFjJAauYuIhHSHuzvkcuQtwl0jdxGRkO5w7++HQoGcR1lGI3cRkZDucC/+UMexgkbuIiLl0h3uxRkhjw6p5i4iUi4T4d5zspk5c8b/RSURkZkm3eFeLMt097eoJCMiUibd4V4cuXf3N6skIyJSJhPhfjCvkbuISLl0h3uxLHMg36KRu4hImXSHe3HkfuCYyjIiIuXSHe7FkXtnr8oyIiLl0h3up2ruGrmLiJRLfbh7YyPD1GvkLiJSJt3hns8z3KRvp4qIVEp3uOdyDM3RvDIiIpVSH+4nZ2tGSBGRSukO93yegQaVZUREKqU73HO5U7/CpLKMiEhJ6sP9hEVZpq1tmvsiInIeSXe45/PkvIV586C+fro7IyJy/kh3uOdyHC/oC0wiIpVSH+69w5o0TESkUnrD3R3yeXpOal4ZEZFK6Q33/n4oFOgZVFlGRKRSesO9OCNkV59G7iIildIb7sUZIbv6VHMXEamU+nA/OtSskbuISIX0hnuxLJNDI3cRkUrpDffiyF3hLiIyWurDPY/KMiIilaoKdzO7zsyeM7NdZvaxMS6/3cx2mtkzZvZjM7ug9l2toLKMiMi4Jgx3M6sDvghcD2wAbjGzDRWbPQl0uPtlwHeBz9W6o6OUlWU0chcRGamakftGYJe773b3QeBe4KbyDdz9QXc/UTz5KLCqtt0cQ1lZRiN3EZGRqgn3lcDestP7iueN5z3AD8e6wMw2mdlWM9va1dVVfS/HUizL9M9qZt68c7sqEZGsqekbqmZ2K9AB/OVYl7v73e7e4e4dixcvPrcby+U4WddIa3s9Zud2VSIiWVPNLOidwOqy06uK541gZtcCnwBe5+4DteneaeRy9NfpzVQRkbFUM3J/HFhvZuvMbDZwM7ClfAMzuxL4CnCjux+qfTfHkM9zYpbeTBURGcuE4e7uQ8BtwAPAz4DN7r7DzO4wsxuLm/0l0AJ8x8yeMrMt41xd7eRy5E1vpoqIjKWqH6dz9/uB+yvO+1TZ+rU17tfE8nmOFzRyFxEZS6q/oapfYRIRGVtqw91zOXqHVJYRERlLasO9cDzPcX07VURkTKkNdz+e07wyIiLjSG24Wz6nGSFFRMaRznB3Z1ZfXiN3EZFxpDPc+/uxQkEzQoqIjCOd4a4ZIUVETiud4a4f6hAROa10hntx5D5Y30xT0zT3RUTkPJTqcLfWFk33KyIyhnSGe7EsM2teyzR3RETk/JTOcC+O3Bvamqe5IyIi56d0hntx5N7QrpG7iMhY0hnuxZH7nEUKdxGRsaQ63JsWqSwjIjKWVIb78LEoyzQvUbiLiIylql9iOt8MdOcw5tC2KJXdFxGZdKlMx8EjOYY0I6SIyLhSGe4nj+bJa+oBEZFxpbPm3qsf6hAROZ1UhnvhuH6oQ0TkdFIZ7uT0Qx0iIqeTynCfdSKnudxFRE4jleFe15+jv76Fhobp7omIyPkpleFeP5BnqFFTD4iIjCeV4d54MkehSd9OFREZT/rC3Z3GoTyFZo3cRUTGk75w7++njgLWonAXERlP+sK9OCPkrHkqy4iIjCd94V78oY76No3cRUTGk7pw7z+c/MSewl1EZDypC/fj+yPcGxeoLCMiMp7UhXv+UJRlmhZr5C4iMp70hfvBGLnPXaJwFxEZT1XhbmbXmdlzZrbLzD42xuWNZvY/i5c/ZmZra93RRFJzb1mqsoyIyHgmDHczqwO+CFwPbABuMbMNFZu9B+hx94uBu4C/qHVHEwNHoizTulwjdxGR8VQzct8I7HL33e4+CNwL3FSxzU3A14vr3wV+28ysdt0sOdkTI/d5KzRyFxEZTzU/s7cS2Ft2eh9w1XjbuPuQmfUCC4HD5RuZ2SZgU/FkzsyeO5tOA4tY23Z44s0yZREV9+cMoH2eGbTPZ+aCajaa0t9Qdfe7gbvP9XrMbKu7d9SgS6mhfZ4ZtM8zw1TsczVlmU5gddnpVcXzxtzGzOqB+UB3LTooIiJnrppwfxxYb2brzGw2cDOwpWKbLcAfFNffCvyLu3vtuikiImdiwrJMsYZ+G/AAUAd8zd13mNkdwFZ33wL8PfAPZrYLOEIcACbTOZd2Ukj7PDNon2eGSd9n0wBbRCR7UvcNVRERmZjCXUQkg1IX7hNNhZAFZvY1MztkZtvLzltgZj8ysxeKy/bp7GMtmdlqM3vQzHaa2Q4z+0Dx/Czv8xwz+6mZPV3c5/9WPH9dcQqPXcUpPWZPd19rzczqzOxJM/vfxdOZ3mcz+6WZPWtmT5nZ1uJ5k/7YTlW4VzkVQhbcA1xXcd7HgB+7+3rgx8XTWTEEfNjdNwBXA39c/L9meZ8HgDe6++XAFcB1ZnY1MXXHXcWpPHqIqT2y5gPAz8pOz4R9foO7X1H22fZJf2ynKtypbiqE1HP3h4hPHZUrn+Lh68C/m9JOTSJ33+/uTxTXjxNP/JVke5/d3XPFkw3F5sAbiSk8IGP7DGBmq4C3AF8tnjYyvs/jmPTHdtrCfaypEFZOU1+m2lJ3319cPwAsnc7OTJbijKJXAo+R8X0ulieeAg4BPwJeBI66+1Bxkyw+vj8PfBQoFE8vJPv77MD/MbNtxSlYYAoe21M6/YDUhru7mWXuM6xm1gJ8D/igux8rn3sui/vs7sPAFWbWBtwHXDLNXZpUZvY7wCF332Zmr5/u/kyh33D3TjNbAvzIzH5efuFkPbbTNnKvZiqErDpoZssBistD09yfmjKzBiLYv+nu/6t4dqb3OeHuR4EHgdcCbcUpPCB7j+9rgBvN7JdESfWNwBfI9j7j7p3F5SHiIL6RKXhspy3cq5kKIavKp3j4A+D709iXmirWXf8e+Jm731l2UZb3eXFxxI6ZNQFvIt5reJCYwgMyts/u/nF3X+Xua4nn7r+4+9vJ8D6bWbOZtSbrwJuB7UzBYzt131A1sxuIul0yFcJnprlLNWdm3wZeT0wLehD4M+Afgc3AGuAl4G3uXvmmayqZ2W8A/w94llIt9j8Tdfes7vNlxBtpdcQga7O732FmFxKj2gXAk8Ct7j4wfT2dHMWyzEfc/XeyvM/FfbuveLIe+Ja7f8bMFjLJj+3UhbuIiEwsbWUZERGpgsJdRCSDFO4iIhmkcBcRySCFu4hIBincRUQySOEuIpJB/x+dd9GCQQqUBAAAAABJRU5ErkJggg==\n", 41 | "text/plain": [ 42 | "
" 43 | ] 44 | }, 45 | "metadata": {}, 46 | "output_type": "display_data" 47 | }, 48 | { 49 | "data": { 50 | "text/plain": [ 51 | "
" 52 | ] 53 | }, 54 | "metadata": {}, 55 | "output_type": "display_data" 56 | } 57 | ], 58 | "source": [ 59 | "path = 'logs/5ans_50epochs/2018-08-13_15:21:41.pth'\n", 60 | "results = torch.load(path)\n", 61 | "\n", 62 | "val_acc = torch.FloatTensor(results['tracker']['val_acc'])\n", 63 | "val_acc = val_acc.mean(dim=1).numpy()\n", 64 | "\n", 65 | "\n", 66 | "\n", 67 | "val_acc=np.insert(val_acc,0,0)\n", 68 | "\n", 69 | "train_acc = torch.FloatTensor(results['tracker']['train_acc'])\n", 70 | "train_acc = train_acc.mean(dim=1).numpy()\n", 71 | "\n", 72 | "\n", 73 | "train_acc=np.insert(train_acc,0,0)\n", 74 | "\n", 75 | "plt.gca().set_color_cycle(['blue', 'red'])\n", 76 | "\n", 77 | "\n", 78 | "axes = plt.gca()\n", 79 | "axes.set_ylim([0,1])\n", 80 | "\n", 81 | "plt.plot(val_acc)\n", 82 | "plt.plot(train_acc)\n", 83 | "\n", 84 | "fig_acc = plt.gcf()\n", 85 | "plt.show()\n", 86 | "#plt.savefig('val_acc.png')\n", 87 | "plt.figure()\n", 88 | "fig_acc.savefig('train_val_5ans_50epochs.png', dpi = 1000)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 44, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stderr", 98 | "output_type": "stream", 99 | "text": [ 100 | "/home/ubuntu/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py:17: MatplotlibDeprecationWarning: The set_color_cycle function was deprecated in version 1.5. Use `.set_prop_cycle` instead.\n" 101 | ] 102 | }, 103 | { 104 | "data": { 105 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAGEFJREFUeJzt3X+QXWWd5/H3l06aBJJBCU2AJEDQCJMRNkiL+GN2lZWdiFbwJwPojNboUlYNpVvq7LI/xJUdq2bGKl3/oEpTyizuLCLij81q1iwjWO7qiOkQFBMMhh9KopAmAcZ0AknDd/947p2+3X1v9w25neacvF9VT9177n36nufcvv053/Pc2+dGZiJJqpdjZnsAkqTeM9wlqYYMd0mqIcNdkmrIcJekGjLcJamGugr3iFgdEdsiYntEXNuhz+URsTUitkTEzb0dpiTpUMR0n3OPiD7gfuASYAewEbgyM7e29FkB3ApcnJlPRMTJmblr5oYtSZpKN5X7hcD2zHwwMw8AtwCXTejzr4EbMvMJAINdkmbXnC76LAEeaVneAbxqQp+XAUTED4E+4D9n5ncnPlBEXA1cDXD88cdfcM455zyfMUvSUWvTpk2PZ+bAdP26CfduzAFWAK8HlgI/iIhzM/PJ1k6ZuRZYCzA4OJhDQ0M9Wr0kHR0i4lfd9OtmWmYnsKxleWnjtlY7gHWZeTAzH6LM0a/oZgCSpN7rJtw3AisiYnlE9ANXAOsm9PkWpWonIk6iTNM82MNxSpIOwbThnpmjwDXABuA+4NbM3BIR10fEmka3DcDuiNgK3An8RWbunqlBS5KmNu1HIWeKc+6SdOgiYlNmDk7Xz/9QlaQaMtwlqYYMd0mqIcNdkmrIcJekGjLcJamGDHdJqiHDXZJqyHCXpBoy3CWphgx3Saohw12Sashwl6QaMtwlqYYMd0mqIcNdkmrIcJekGjLcJamGDHdJqiHDXZJqyHCXpBoy3CWphgx3Saohw12Sashwl6QaMtwlqYYMd0mqoa7CPSJWR8S2iNgeEde2uf99ETEcEfc02gd6P1RJUrfmTNchIvqAG4BLgB3AxohYl5lbJ3T9amZeMwNjlCQdom4q9wuB7Zn5YGYeAG4BLpvZYUmSDkc34b4EeKRleUfjtoneERE/i4jbImJZT0YnSXpeevWG6v8CzszM84DbgZvadYqIqyNiKCKGhoeHe7RqSdJE3YT7TqC1El/auO2fZObuzHymsfhF4IJ2D5SZazNzMDMHBwYGns94JUld6CbcNwIrImJ5RPQDVwDrWjtExKkti2uA+3o3REnSoZr20zKZORoR1wAbgD7gxszcEhHXA0OZuQ74UESsAUaBPcD7ZnDMkqRpRGbOyooHBwdzaGhoVtYtSVUVEZsyc3C6fv6HqiTVkOEuSTVkuEtSDRnuklRDhrsk1ZDhLkk1ZLhLUg0Z7pJUQ4a7JNWQ4S5JNWS4S1INGe6SVEOGuyTVkOEuSTVkuEtSDRnuklRDhrsk1ZDhLkk1ZLhLUg0Z7pJUQ4a7JNXQnNkewAvNgQPw6KOwZAn09fX2sZ98Eu66q7RTT4U//VM49tjerkOS4CgO92efhW98A7ZuhYceGms7dkAmnHIKvOMd8M53wh/+Yeegf/JJ+NnPYO9eOOaY0u+YY8baL38J//AP8OMfl3W1+su/hI9/HN77Xpg7t/3j//rXcPPNMDwMV14JF1wAEZ23a2QEbrutrO+cc+D882HVKvi933t+z5OkaorMnJUVDw4O5tDQ0Kyse3S0BOrNN5egXLIEli8fayefDHfcAevXw/79sHgxvP3t8La3wdNPwz33wObN5fKhh6Zf34knwkUXwatfXS4vvLBU7x//eLk86yz4xCfg3e8uO4cnn4Svfx3+7u/g+98vj9HfX44qzjsP3v/+0nfRonJfZtmB/O3fwi23lB3N8ceXoG96yUtK0J93XtlxnXQSDAyUy5NOgoULy7Zs2wa/+MXY5cMPw7JlcO6549vAQFnvnj1lB9RsO3bAaaeV7Vy1auojk0zYvbv0Wbjw+f42q+O558rvdvfu8ppo/v6kQxERmzJzcNp+R1u4P/NMqYC/+U341Kfgox/tHEAjIyXgv/Y1+M53YN++sftWrBiriletKn+szz1Xjgiee27s+rJlpW+7ajuzPP5118Hdd8PZZ8Mf/EFZ1zPPwMteBn/yJ3DVVeXxv/IV+NKXYNOmEvZvexu8/OVlJ7BtWwn0yy+HP/szeO1ry/TS5s3j24MPdvc8nXJKGc+ZZ5bQvvdeePzxsftPOqk8H63PCcCcOWXnCWWMq1bBq15VWmY5kmltTz1V+r7sZeX5fMUrSjv//JkJv/37YeNG+OEPYcuWsjM99tgy1v7+cn3evPJ8N3d8zbZoEcyf3/nI6amn4P77y+9i27Zy/dFHy/P2+OMl1J99dqz/wEA5uvr93x/fli2b+ugM4LHHymtm71546UtL69UOct8++NGP4M47yw5+9eryujrhhKl/7qGHYMOGsgM7eHB8Gx0tR6fz549vxx1Xip6XvrQ3Yz8aGO5t7N9fKvDvfhc+9zn40Ie6/9mRkVJFv+hFpfrtZaWZCd/6FnzykyUM/viP4T3vgcHB9n/kP/0p3HhjCfU9e+B1ryuB/q53wYIFU69r//4SMsPDY6Hz+OMlmM44owT62WdP/kPOLIFy772l3XdfeQ5OP318GxiA3/xm7L2Fu+4qYdrcCUSU9axYMdb27i1Bdffd8Ktfja1z2bKy82oeLbz85SX8+vvLOu6/f3zbs6esf/HicvS1eHFpUI5sfvSjso6DB8ttp59eps6eeaYcFR04MHa9k2OOKTvR1nbssWUH+Nhj4/stX16OCluPkAYGyo5jeLg8h7/4RbncvXvsZxcsgJUrx7eDB8eeo82by/ZPtHjx2HN6xhllZ3TiiePbwoVlBzM6Wlrz+vBweX3fcUd5rg4cKDu+U06BnTvLDu+tby3vE11ySdmJZ5axf+MbpW3ePHlMc+eWNmdO2Yb9+9s/r69+dXnsyy8v45xKZvk97dtX/i5HRsr1+fNLMTJ//tQ/3zQ6Ck88UXZGTzwx1g4cKM9f8/fXbkq2edT6yCPl8oILpt/5Qflb27ChbO9ZZ3U3zomOunD/4Q9LJfvmN8MHP1imPlqDce9eWLOmvIDXroUPfKBnq541Tz9dXlinnTbbI5na6GgJgTlzygt6qqma3btLSNx9d9mJ/fzn5WebgdystFuPGObPL4G2aFEJqV27ymXrS3vePHjlK8sRzWteU1qnI4Nnny1/5K07v2bbu3csUFqDZenSsR3j2WeXabD+/u6fo2bYb906vv32t2N9jjmm7NyaRzaveEUJlO3bS2seDW3fPv7nuhVRHvMNb4CLLy5Fw4IFMDQEN91Ujhz37CmBv3p12Qls21Z+9jWvKYXTZZeV12Mz0CcWJ63BvH9/KSq+/W348pfLkVR/P7zlLSXoFy+GBx4o2/PAA2NteLgcGXeyZEl5/ptt4cLyfPzmN+Mv9+yZ/jmZO7cE/VlnlYLh0UdLoP/61+N3VH19ZSryj/6otAsuKLc991yZvl2/vhyR33VXeQ4+/Wn42McO/XcER2G4f+ELJdSPPba8eFatKstXXVX+WC+9FH7yk/Iiuuqqnq1WR8DBgyW0mkcNIyNlGqfZliwpwdfq2WdLGO/aVSqxc889tLB9odizp4R+X185YjzuuO5+7uDBUpHu2TO+/e53JXSbra+vXC5cWKrJF7+482MeOFBC6qab4O//vky1vf3tpaI/3AIjs4Tgl79c3gvbtWvsvoiy83zJS8r0zeLFY0dNxx03dn3v3vE7gQceGNvJzZ1bdkqnnVY+qXbqqeVxTjyxbHNr6+srR5APPVSmMZuXu3aVn1u2rBz1LVtW2oIF8IMflIp806ayvkWLyvNz991lhxBRiotLLy3tggsmv2a71dNwj4jVwOeAPuCLmflXHfq9A7gNeGVmTpncvQ73z34WPvKRskf99rfh858vn2JZsKDscR95BL761VLdS3rhOniwzPcfOFACffnycuT1fIyMlAr7xBOff5geiuFhuP32EvR33VWKzEsvLUc6J5/cm3X0LNwjog+4H7gE2AFsBK7MzK0T+i0EvgP0A9cc6XBvfqzwwIGyl84sT+7nP1/mEdeuLU+wJFVZt+HezefcLwS2Z+aDjQe+BbgMmPCpbf4L8NfAXxziWHti376xN2+gHAZddFFpknS06eZAZQnwSMvyjsZt/yQiXgEsy8zvTPVAEXF1RAxFxNDw8PAhD3YqIyNl3k2S1INzy0TEMcBngI9O1zcz12bmYGYODgwMHO6qx9m3r/s3mySp7roJ953AspblpY3bmhYCLwe+HxEPAxcB6yJi2jmhXhoZMdwlqambcN8IrIiI5RHRD1wBrGvemZlPZeZJmXlmZp4J/BhYM90bqr22b5/TMpLUNG24Z+YocA2wAbgPuDUzt0TE9RGxZqYH2C0rd0ka09VZITNzPbB+wm3Xdej7+sMf1qGzcpekMbX5sg4rd0kaU5twt3KXpDG1Cncrd0kqahPu/hOTJI2pTbhbuUvSmFqE++hoOWGYlbskFbUI9+YXN1i5S1JRi3BvfhG04S5JRS3CvVm5Oy0jSUUtwt3KXZLGq0W4W7lL0ni1CHcrd0karxbhbuUuSePVItyt3CVpvFqEu5W7JI1Xq3C3cpekohbh3pyWsXKXpKIW4d6s3OfNm91xSNILRS3CvfktTBGzPRJJemGoRbj7LUySNF4twt3vT5Wk8WoR7lbukjReLcLdyl2SxqtFuFu5S9J4tQh3K3dJGq8W4W7lLknj1SbcrdwlaUxX4R4RqyNiW0Rsj4hr29z/wYi4NyLuiYj/FxErez/UzkZGrNwlqdW04R4RfcANwJuAlcCVbcL75sw8NzNXAX8DfKbnI52ClbskjddN5X4hsD0zH8zMA8AtwGWtHTLzH1sWjweyd0OcWqbhLkkTzemizxLgkZblHcCrJnaKiD8HPgL0Axe3e6CIuBq4GuD0008/1LG29fTTJeCdlpGkMT17QzUzb8jMlwD/DvhPHfqszczBzBwcGBjoyXr9FiZJmqybcN8JLGtZXtq4rZNbgLcezqAOhd/CJEmTdRPuG4EVEbE8IvqBK4B1rR0iYkXL4puBX/ZuiFOzcpekyaadc8/M0Yi4BtgA9AE3ZuaWiLgeGMrMdcA1EfFG4CDwBPDemRx0Kyt3SZqsmzdUycz1wPoJt13Xcv3DPR5X16zcJWmyyv+HqpW7JE1Wm3C3cpekMZUP9+a0jJW7JI2pfLhbuUvSZJUPd99QlaTJKh/uVu6SNFnlw31kBObOLU2SVFQ+3P0WJkmarPLh7venStJklQ93K3dJmqzy4W7lLkmTVT7crdwlabJahLuVuySNV/lwHxmxcpekiSof7lbukjRZ5cPdN1QlabLKh7tvqErSZJUPdyt3SZqs0uF+8GBpVu6SNF6lw90zQkpSe7UIdyt3SRqv0uHuF3VIUnuVDncrd0lqrxbhbuUuSeNVOtyb0zJW7pI0XqXD3cpdktqrdLj7hqoktddVuEfE6ojYFhHbI+LaNvd/JCK2RsTPIuJ7EXFG74c6mW+oSlJ704Z7RPQBNwBvAlYCV0bEygndNgODmXkecBvwN70eaDtW7pLUXjeV+4XA9sx8MDMPALcAl7V2yMw7M7NRR/NjYGlvh9melbsktddNuC8BHmlZ3tG4rZP3A/+73R0RcXVEDEXE0PDwcPej7GBkBCJg3rzDfihJqpWevqEaEe8BBoFPt7s/M9dm5mBmDg4MDBz2+ppf1BFx2A8lSbUyp4s+O4FlLctLG7eNExFvBP4j8C8y85neDG9qnu5XktrrpnLfCKyIiOUR0Q9cAaxr7RAR5wNfANZk5q7eD7M9v6hDktqbNtwzcxS4BtgA3AfcmplbIuL6iFjT6PZpYAHwtYi4JyLWdXi4nvL7UyWpvW6mZcjM9cD6Cbdd13L9jT0eV1dGRqzcJamdSv+HqpW7JLVX6XD3DVVJaq/S4e4bqpLUXqXD3cpdktqrdLhbuUtSe5UOdyt3SWqvsuGeaeUuSZ1UNtz37y+XVu6SNFllw93T/UpSZ5UNd7+oQ5I6q2y4W7lLUmeVD3crd0marLLh7rSMJHVW2XB3WkaSOqtsuFu5S1JnlQ13K3dJ6qyy4W7lLkmdVTbcrdwlqbPKhruVuyR1Vtlw37cP+vthTlffAitJR5fKhrun+5Wkziob7p7uV5I6q3S4W7lLUnuVDXenZSSps8qGu9MyktRZZcPdyl2SOqtsuFu5S1JnXYV7RKyOiG0RsT0irm1z/z+PiLsjYjQi3tn7YU5m5S5JnU0b7hHRB9wAvAlYCVwZESsndPs18D7g5l4PsBMrd0nqrJv/77wQ2J6ZDwJExC3AZcDWZofMfLhx33MzMMa2rNwlqbNupmWWAI+0LO9o3DarrNwlqbMj+oZqRFwdEUMRMTQ8PPy8H+fgwdKs3CWpvW7CfSewrGV5aeO2Q5aZazNzMDMHBwYGns9DAJ7uV5Km0024bwRWRMTyiOgHrgDWzeywptYMdyt3SWpv2nDPzFHgGmADcB9wa2ZuiYjrI2INQES8MiJ2AO8CvhARW2Zy0J7LXZKm1tXZ0DNzPbB+wm3XtVzfSJmuOSKclpGkqVXyP1St3CVpapUMdyt3SZpaJcPdyl2SplbJcLdyl6SpVTLcrdwlaWqVDHcrd0maWiXD3cpdkqZWyXDftw8iYN682R6JJL0wVTbcjzuuBLwkabJKhrvncpekqVUy3D2XuyRNrZLhbuUuSVOrZLhbuUvS1CoZ7lbukjS1Soa7lbskTa2S4W7lLklTq2S4W7lL0tQqGe5W7pI0tUqGu5W7JE2tcuGeOXb6AUlSe5UL9/37y6XhLkmdVS7cPZe7JE2vcuHuudwlaXqVC3crd0maXuXC3cpdkqZXuXC3cpek6VUu3K3cJWl6lQt3K3dJml5X4R4RqyNiW0Rsj4hr29x/bER8tXH/XRFxZq8H2mTlLknTmzbcI6IPuAF4E7ASuDIiVk7o9n7gicx8KfBZ4K97PdAmK3dJml43lfuFwPbMfDAzDwC3AJdN6HMZcFPj+m3Av4yI6N0wxzTD3cpdkjqb00WfJcAjLcs7gFd16pOZoxHxFLAIeLy1U0RcDVzdWNwbEduez6CBk044YfxjHwVOArf5KOA2Hx0OZ5vP6KZTN+HeM5m5Flh7uI8TEUOZOdiDIVWG23x0cJuPDkdim7uZltkJLGtZXtq4rW2fiJgDnADs7sUAJUmHrptw3wisiIjlEdEPXAGsm9BnHfDexvV3AndkZvZumJKkQzHttExjDv0aYAPQB9yYmVsi4npgKDPXAV8C/ntEbAf2UHYAM+mwp3YqyG0+OrjNR4cZ3+awwJak+qncf6hKkqZnuEtSDVUu3Kc7FUIdRMSNEbErIn7ectuJEXF7RPyycfni2RxjL0XEsoi4MyK2RsSWiPhw4/Y6b/O8iPhJRPy0sc2fbNy+vHEKj+2NU3r0z/ZYey0i+iJic0R8u7Fc622OiIcj4t6IuCcihhq3zfhru1Lh3uWpEOrgvwGrJ9x2LfC9zFwBfK+xXBejwEczcyVwEfDnjd9rnbf5GeDizPxnwCpgdURcRDl1x2cbp/J4gnJqj7r5MHBfy/LRsM1vyMxVLZ9tn/HXdqXCne5OhVB5mfkDyqeOWrWe4uEm4K1HdFAzKDN/m5l3N67/jvKHv4R6b3Nm5t7G4txGS+Biyik8oGbbDBARS4E3A19sLAc13+YOZvy1XbVwb3cqhCWzNJYjbXFm/rZx/VFg8WwOZqY0zih6PnAXNd/mxvTEPcAu4HbgAeDJzBxtdKnj6/u/Av8WeK6xvIj6b3MC/yciNjVOwQJH4LV9RE8/oN7IzIyI2n2GNSIWAF8H/k1m/mPruefquM2Z+SywKiJeBHwTOGeWhzSjIuItwK7M3BQRr5/t8RxBr8vMnRFxMnB7RPyi9c6Zem1XrXLv5lQIdfVYRJwK0LjcNcvj6amImEsJ9v+Rmd9o3FzrbW7KzCeBO4FXAy9qnMID6vf6fi2wJiIepkypXgx8jnpvM5m5s3G5i7ITv5Aj8NquWrh3cyqEumo9xcN7gf85i2Ppqca865eA+zLzMy131XmbBxoVOxExH7iE8l7DnZRTeEDNtjkz/31mLs3MMyl/u3dk5rup8TZHxPERsbB5HfhXwM85Aq/tyv2HakRcSpm3a54K4VOzPKSei4ivAK+nnBb0MeATwLeAW4HTgV8Bl2fmxDddKykiXgf8X+BexuZi/wNl3r2u23we5Y20PkqRdWtmXh8RZ1Gq2hOBzcB7MvOZ2RvpzGhMy3wsM99S521ubNs3G4tzgJsz81MRsYgZfm1XLtwlSdOr2rSMJKkLhrsk1ZDhLkk1ZLhLUg0Z7pJUQ4a7JNWQ4S5JNfT/Acxcdlbv2Ym8AAAAAElFTkSuQmCC\n", 106 | "text/plain": [ 107 | "
" 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | }, 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "
" 117 | ] 118 | }, 119 | "metadata": {}, 120 | "output_type": "display_data" 121 | } 122 | ], 123 | "source": [ 124 | "path = 'logs/5ans_50epochs/2018-08-13_15:21:41.pth'\n", 125 | "results = torch.load(path)\n", 126 | "\n", 127 | "val_acc = torch.FloatTensor(results['tracker']['val_acc'])\n", 128 | "val_acc = val_acc.mean(dim=1).numpy()\n", 129 | "\n", 130 | "\n", 131 | "\n", 132 | "val_acc=np.insert(val_acc,0,0)\n", 133 | "\n", 134 | "train_acc = torch.FloatTensor(results['tracker']['train_acc'])\n", 135 | "train_acc = train_acc.mean(dim=1).numpy()\n", 136 | "\n", 137 | "\n", 138 | "train_acc=np.insert(train_acc,0,0)\n", 139 | "\n", 140 | "plt.gca().set_color_cycle(['blue', 'red'])\n", 141 | "\n", 142 | "#plt.figure()\n", 143 | "plt.plot(val_acc)\n", 144 | "#plt.plot(train_acc)\n", 145 | "\n", 146 | "axes = plt.gca()\n", 147 | "axes.set_ylim([0,0.6])\n", 148 | "\n", 149 | "fig_acc = plt.gcf()\n", 150 | "plt.show()\n", 151 | "#plt.savefig('val_acc.png')\n", 152 | "plt.figure()\n", 153 | "fig_acc.savefig('val_5ans_50epochs.png', dpi = 1000)\n" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Environment (conda_pytorch_p36)", 167 | "language": "python", 168 | "name": "conda_pytorch_p36" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.6" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | --------------------------------------------------------------------------------