├── CODEOWNERS ├── 20bn-jester-v1 └── annotations │ ├── jester-v1-validation-quick-testing.csv │ ├── jester-v1-labels-quick-testing.csv │ ├── jester-v1-train-quick-testing.csv │ └── jester-v1-labels.csv ├── configs ├── config.json └── config_quick_testing.json ├── data_parser.py ├── model.py ├── .gitignore ├── .github └── workflows │ └── manual.yml ├── LICENSE.md ├── data_loader.py ├── README.md ├── train.py └── callbacks.py /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @udacity/active-public-content -------------------------------------------------------------------------------- /20bn-jester-v1/annotations/jester-v1-validation-quick-testing.csv: -------------------------------------------------------------------------------- 1 | 9223;Thumb Up 2 | 94928;Swiping Right 3 | 117498;Swiping Left 4 | 54598;Swiping Down -------------------------------------------------------------------------------- /20bn-jester-v1/annotations/jester-v1-labels-quick-testing.csv: -------------------------------------------------------------------------------- 1 | Swiping Left 2 | Swiping Right 3 | Swiping Down 4 | Swiping Up 5 | Doing other things 6 | -------------------------------------------------------------------------------- /20bn-jester-v1/annotations/jester-v1-train-quick-testing.csv: -------------------------------------------------------------------------------- 1 | 34870;Drumming Fingers 2 | 68574;Swiping Right 3 | 119263;Zooming In With Two Fingers 4 | 6522;Swiping Down 5 | 118250;Swiping Down 6 | 62818;Swiping Left 7 | 1338;Swiping Right 8 | 10179;Swiping Left -------------------------------------------------------------------------------- /20bn-jester-v1/annotations/jester-v1-labels.csv: -------------------------------------------------------------------------------- 1 | Swiping Left 2 | Swiping Right 3 | Swiping Down 4 | Swiping Up 5 | Pushing Hand Away 6 | Pulling Hand In 7 | Sliding Two Fingers Left 8 | Sliding Two Fingers Right 9 | Sliding Two Fingers Down 10 | Sliding Two Fingers Up 11 | Pushing Two Fingers Away 12 | Pulling Two Fingers In 13 | Rolling Hand Forward 14 | Rolling Hand Backward 15 | Turning Hand Clockwise 16 | Turning Hand Counterclockwise 17 | Zooming In With Full Hand 18 | Zooming Out With Full Hand 19 | Zooming In With Two Fingers 20 | Zooming Out With Two Fingers 21 | Thumb Up 22 | Thumb Down 23 | Shaking Hand 24 | Stop Sign 25 | Drumming Fingers 26 | No gesture 27 | Doing other things 28 | -------------------------------------------------------------------------------- /configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "jester_conv_example", 3 | "checkpoint": "trainings/jpeg_model/jester_conv_example/checkpoint.pth.tar", 4 | 5 | "train_data_folder": "./20bn-jester-v1/videos", 6 | "val_data_folder": "./20bn-jester-v1/videos", 7 | "train_data_csv": "./20bn-jester-v1/annotations/jester-v1-train.csv", 8 | "val_data_csv": "./20bn-jester-v1/annotations/jester-v1-validation.csv", 9 | "labels_csv": "./20bn-jester-v1/annotations/jester-v1-labels.csv", 10 | 11 | "num_workers": 8, 12 | 13 | "output_dir": "trainings/jpeg_model/", 14 | 15 | "num_classes": 27, 16 | "batch_size": 10, 17 | "clip_size": 18, 18 | "nclips": 1, 19 | "step_size": 2, 20 | "lr": 0.001, 21 | "last_lr": 0.00001, 22 | "momentum": 0.9, 23 | "weight_decay": 0.00001, 24 | "num_epochs": -1, 25 | "print_freq": 100 26 | } -------------------------------------------------------------------------------- /configs/config_quick_testing.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "jester_conv_4_classes", 3 | "checkpoint": "trainings/jpeg_model/jester_conv_4_classes/checkpoint.pth.tar", 4 | 5 | "train_data_folder": "./20bn-jester-v1/videos", 6 | "val_data_folder": "./20bn-jester-v1/videos", 7 | "train_data_csv": "./20bn-jester-v1/annotations/jester-v1-train-quick-testing.csv", 8 | "val_data_csv": "./20bn-jester-v1/annotations/jester-v1-validation-quick-testing.csv", 9 | "labels_csv": "./20bn-jester-v1/annotations/jester-v1-labels-quick-testing.csv", 10 | 11 | "num_workers": 8, 12 | 13 | "output_dir": "trainings/jpeg_model/", 14 | 15 | "num_classes": 5, 16 | "batch_size": 10, 17 | "clip_size": 18, 18 | "nclips": 1, 19 | "step_size": 2, 20 | "lr": 0.001, 21 | "last_lr": 0.00001, 22 | "momentum": 0.9, 23 | "weight_decay": 0.00001, 24 | "num_epochs": 1, 25 | "print_freq": 100 26 | } -------------------------------------------------------------------------------- /data_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | from collections import namedtuple 5 | 6 | ListDataJpeg = namedtuple('ListDataJpeg', ['id', 'label', 'path']) 7 | 8 | class JpegDataset(object): 9 | 10 | def __init__(self, csv_path_input, csv_path_labels, data_root): 11 | self.classes = self.read_csv_labels(csv_path_labels) 12 | self.classes_dict = self.get_two_way_dict(self.classes) 13 | self.csv_data = self.read_csv_input(csv_path_input, data_root) 14 | 15 | def read_csv_input(self, csv_path, data_root): 16 | csv_data = [] 17 | with open(csv_path) as csvfile: 18 | csv_reader = csv.reader(csvfile, delimiter=';') 19 | for row in csv_reader: 20 | item = ListDataJpeg(row[0], 21 | row[1], 22 | os.path.join(data_root, row[0]) 23 | ) 24 | if row[1] in self.classes: 25 | csv_data.append(item) 26 | return csv_data 27 | 28 | def read_csv_labels(self, csv_path): 29 | classes = [] 30 | with open(csv_path) as csvfile: 31 | csv_reader = csv.reader(csvfile) 32 | for row in csv_reader: 33 | classes.append(row[0]) 34 | return classes 35 | 36 | def get_two_way_dict(self, classes): 37 | classes_dict = {} 38 | for i, item in enumerate(classes): 39 | classes_dict[item] = i 40 | classes_dict[i] = item 41 | return classes_dict 42 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvColumn(nn.Module): 5 | 6 | def __init__(self, num_classes): 7 | super(ConvColumn, self).__init__() 8 | 9 | self.conv_layer1 = self._make_conv_layer(3, 64, (1, 2, 2), (1, 2, 2)) 10 | self.conv_layer2 = self._make_conv_layer(64, 128, (2, 2, 2), (2, 2, 2)) 11 | self.conv_layer3 = self._make_conv_layer( 12 | 128, 256, (2, 2, 2), (2, 2, 2)) 13 | self.conv_layer4 = self._make_conv_layer( 14 | 256, 256, (2, 2, 2), (2, 2, 2)) 15 | 16 | self.fc5 = nn.Linear(12800, 512) 17 | self.fc5_act = nn.ELU() 18 | self.fc6 = nn.Linear(512, num_classes) 19 | 20 | def _make_conv_layer(self, in_c, out_c, pool_size, stride): 21 | conv_layer = nn.Sequential( 22 | nn.Conv3d(in_c, out_c, kernel_size=3, stride=1, padding=1), 23 | nn.BatchNorm3d(out_c), 24 | nn.ELU(), 25 | nn.MaxPool3d(pool_size, stride=stride, padding=0) 26 | ) 27 | return conv_layer 28 | 29 | def forward(self, x): 30 | x = self.conv_layer1(x) 31 | x = self.conv_layer2(x) 32 | x = self.conv_layer3(x) 33 | x = self.conv_layer4(x) 34 | 35 | x = x.view(x.size(0), -1) 36 | 37 | x = self.fc5(x) 38 | x = self.fc5_act(x) 39 | 40 | x = self.fc6(x) 41 | return x 42 | 43 | 44 | if __name__ == "__main__": 45 | input_tensor = torch.autograd.Variable(torch.rand(5, 3, 18, 84, 84)) 46 | model = ConvColumn(27) #ConvColumn(27).cuda() 47 | output = model(input_tensor) #model(input_tensor.cuda()) 48 | print(output.size()) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # other OS files 104 | .DS_Store 105 | -------------------------------------------------------------------------------- /.github/workflows/manual.yml: -------------------------------------------------------------------------------- 1 | # Workflow to ensure whenever a Github PR is submitted, 2 | # a JIRA ticket gets created automatically. 3 | name: Manual Workflow 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on pull request events but only for the master branch 8 | pull_request_target: 9 | types: [opened, reopened] 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | jobs: 15 | test-transition-issue: 16 | name: Convert Github Issue to Jira Issue 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@master 21 | 22 | - name: Login 23 | uses: atlassian/gajira-login@master 24 | env: 25 | JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} 26 | JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} 27 | JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} 28 | 29 | - name: Create NEW JIRA ticket 30 | id: create 31 | uses: atlassian/gajira-create@master 32 | with: 33 | project: CONUPDATE 34 | issuetype: Task 35 | summary: | 36 | Github PR [Assign the ND component] | Repo: ${{ github.repository }} | PR# ${{github.event.number}} 37 | description: | 38 | Repo link: https://github.com/${{ github.repository }} 39 | PR no. ${{ github.event.pull_request.number }} 40 | PR title: ${{ github.event.pull_request.title }} 41 | PR description: ${{ github.event.pull_request.description }} 42 | In addition, please resolve other issues, if any. 43 | fields: '{"components": [{"name":"Github PR"}], "customfield_16449":"https://classroom.udacity.com/", "customfield_16450":"Resolve the PR", "labels": ["github"], "priority":{"id": "4"}}' 44 | 45 | - name: Log created issue 46 | run: echo "Issue ${{ steps.create.outputs.issue }} was created" 47 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright © 2012 - 2020, Udacity, Inc. 2 | 3 | Udacity hereby grants you a license in and to the Educational Content, including but not limited to homework assignments, programming assignments, code samples, and other educational materials and tools (as further described in the Udacity Terms of Use), subject to, as modified herein, the terms and conditions of the Creative Commons Attribution-NonCommercial- NoDerivs 3.0 License located at [http://creativecommons.org/licenses/by-nc-nd/4.0](http://creativecommons.org/licenses/by-nc-nd/4.0) and successor locations for such license (the "CC License") provided that, in each case, the Educational Content is specifically marked as being subject to the CC License. Udacity expressly defines the following as falling outside the definition of "non-commercial": (a) the sale or rental of (i) any part of the Educational Content, (ii) any derivative works based at least in part on the Educational Content, or (iii) any collective work that includes any part of the Educational Content; (b) the sale of access or a link to any part of the Educational Content without first obtaining informed consent from the buyer (that the buyer is aware that the Educational Content, or such part thereof, is available at the Website free of charge); (c) providing training, support, or editorial services that use or reference the Educational Content in exchange for a fee; (d) the sale of advertisements, sponsorships, or promotions placed on the Educational Content, or any part thereof, or the sale of advertisements, sponsorships, or promotions on any website or blog containing any part of the Educational Material, including without limitation any "pop-up advertisements"; (e) the use of Educational Content by a college, university, school, or other educational institution for instruction where tuition is charged; and (f) the use of Educational Content by a for-profit corporation or non-profit entity for internal professional development or training. 4 | 5 | THE SERVICES AND ONLINE COURSES (INCLUDING ANY CONTENT) ARE PROVIDED "AS IS" AND "AS AVAILABLE" WITH NO REPRESENTATIONS OR WARRANTIES OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. YOU ASSUME TOTAL RESPONSIBILITY AND THE ENTIRE RISK FOR YOUR USE OF THE SERVICES, ONLINE COURSES, AND CONTENT. WITHOUT LIMITING THE FOREGOING, WE DO NOT WARRANT THAT (A) THE SERVICES, WEBSITES, CONTENT, OR THE ONLINE COURSES WILL MEET YOUR REQUIREMENTS OR EXPECTATIONS OR ACHIEVE THE INTENDED PURPOSES, (B) THE WEBSITES OR THE ONLINE COURSES WILL NOT EXPERIENCE OUTAGES OR OTHERWISE BE UNINTERRUPTED, TIMELY, SECURE OR ERROR-FREE, (C) THE INFORMATION OR CONTENT OBTAINED THROUGH THE SERVICES, SUCH AS CHAT ROOM SERVICES, WILL BE ACCURATE, COMPLETE, CURRENT, ERROR- FREE, COMPLETELY SECURE OR RELIABLE, OR (D) THAT DEFECTS IN OR ON THE SERVICES OR CONTENT WILL BE CORRECTED. YOU ASSUME ALL RISK OF PERSONAL INJURY, INCLUDING DEATH AND DAMAGE TO PERSONAL PROPERTY, SUSTAINED FROM USE OF SERVICES. 6 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import torch 5 | 6 | from PIL import Image 7 | from data_parser import JpegDataset 8 | from torchvision.transforms import * 9 | 10 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG'] 11 | 12 | 13 | def default_loader(path): 14 | return Image.open(path).convert('RGB') 15 | 16 | 17 | class VideoFolder(torch.utils.data.Dataset): 18 | 19 | def __init__(self, root, csv_file_input, csv_file_labels, clip_size, 20 | nclips, step_size, is_val, transform=None, 21 | loader=default_loader): 22 | self.dataset_object = JpegDataset(csv_file_input, csv_file_labels, root) 23 | 24 | self.csv_data = self.dataset_object.csv_data 25 | self.classes = self.dataset_object.classes 26 | self.classes_dict = self.dataset_object.classes_dict 27 | self.root = root 28 | self.transform = transform 29 | self.loader = loader 30 | 31 | self.clip_size = clip_size 32 | self.nclips = nclips 33 | self.step_size = step_size 34 | self.is_val = is_val 35 | 36 | def __getitem__(self, index): 37 | item = self.csv_data[index] 38 | img_paths = self.get_frame_names(item.path) 39 | 40 | imgs = [] 41 | for img_path in img_paths: 42 | img = self.loader(img_path) 43 | img = self.transform(img) 44 | imgs.append(torch.unsqueeze(img, 0)) 45 | 46 | target_idx = self.classes_dict[item.label] 47 | 48 | # format data to torch 49 | data = torch.cat(imgs) 50 | data = data.permute(1, 0, 2, 3) 51 | return (data, target_idx) 52 | 53 | def __len__(self): 54 | return len(self.csv_data) 55 | 56 | def get_frame_names(self, path): 57 | frame_names = [] 58 | for ext in IMG_EXTENSIONS: 59 | frame_names.extend(glob.glob(os.path.join(path, "*" + ext))) 60 | frame_names = list(sorted(frame_names)) 61 | num_frames = len(frame_names) 62 | 63 | # set number of necessary frames 64 | if self.nclips > -1: 65 | num_frames_necessary = self.clip_size * self.nclips * self.step_size 66 | else: 67 | num_frames_necessary = num_frames 68 | 69 | # pick frames 70 | offset = 0 71 | if num_frames_necessary > num_frames: 72 | # pad last frame if video is shorter than necessary 73 | frame_names += [frame_names[-1]] * (num_frames_necessary - num_frames) 74 | elif num_frames_necessary < num_frames: 75 | # If there are more frames, then sample starting offset 76 | diff = (num_frames - num_frames_necessary) 77 | # Temporal augmentation 78 | if not self.is_val: 79 | offset = np.random.randint(0, diff) 80 | frame_names = frame_names[offset:num_frames_necessary + 81 | offset:self.step_size] 82 | return frame_names 83 | 84 | 85 | if __name__ == '__main__': 86 | transform = Compose([ 87 | CenterCrop(84), 88 | ToTensor(), 89 | # Normalize( 90 | # mean=[0.485, 0.456, 0.406], 91 | # std=[0.229, 0.224, 0.225]) 92 | ]) 93 | loader = VideoFolder(root="/hdd/20bn-datasets/20bn-jester-v1/", 94 | csv_file_input="csv_files/jester-v1-validation.csv", 95 | csv_file_labels="csv_files/jester-v1-labels.csv", 96 | clip_size=18, 97 | nclips=1, 98 | step_size=2, 99 | is_val=False, 100 | transform=transform, 101 | loader=default_loader) 102 | # data_item, target_idx = loader[0] 103 | # save_images_for_debug("input_images", data_item.unsqueeze(0)) 104 | 105 | train_loader = torch.utils.data.DataLoader( 106 | loader, 107 | batch_size=10, shuffle=False, 108 | num_workers=5, pin_memory=True) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hand Gesture Recognition Tutorial 2 | 3 | These scripts are modified from TwentyBN's [GulpIO-benchmarks](https://github.com/TwentyBN/GulpIO-benchmarks) repository, written by [Raghav Goyal](https://github.com/raghavgoyal14) and the [TwentyBN](https://20bn.com/) team. These scripts serve as a starting point to create your own gesture recognition system using a 3D CNN. 4 | 5 | # Requirements 6 | 7 | - Python 3.x 8 | - PyTorch 0.4.0 9 | 10 | # Instructions 11 | 12 | ## 1. Download The *Jester* Dataset 13 | 14 | In order to train the gesture recognition system, we will use TwentyBN's [Jester Dataset](https://www.twentybn.com/datasets/jester). This dataset consists of 148,092 labeled videos, depicting 25 different classes of human hand gestures. This dataset is made available under the Creative Commons Attribution 4.0 International license CC BY-NC-ND 4.0. It can be used for academic research free of charge. In order to get access to the dataset you will need to register. 15 | 16 | The Jester dataset is provided as one large TGZ archive and has a total download size of 22.8 GB, split into 23 parts of about 1 GB each. After downloading all the parts, you can extract the videos using: 17 | 18 | `cat 20bn-jester-v1-?? | tar zx` 19 | 20 | The CSV files containing the labels for the videos in the Jester dataset have already been downloaded for you and can be found in the **20bn-jester-v1/annotations** folder. 21 | 22 | More information, including alternative ways to download the dataset, is available in the [Jester Dataset](https://www.twentybn.com/datasets/jester) website. 23 | 24 | ## 2. Modify The Config File 25 | 26 | In the **configs** folder you will find two config files: 27 | 28 | * `config.json` 29 | * `config_quick_testing.json` 30 | 31 | The `config.json` file should be used for training the network and the `config_quick_testing.json` file should be used for quickly testing models. These files need to be modified to indicate the location of both the CSV files and the videos from the Jester dataset. The default location is `./20bn-jester-v1/annotations/` for the CSV files and `./20bn-jester-v1/videos/` for the videos. 32 | 33 | These config files also contain the parameters to be used during training and quick testing, such as the number of epochs, batch size, learning rate, etc... Feel free to modify these parameters as you see fit. 34 | 35 | Please note that the default number of epochs used for training is set to `-1` in the `config.json` file, which corresponds to `999999` epochs. 36 | 37 | ## 3. Create Your Own Model 38 | 39 | The `model.py` module already has a simple 3D CNN model that you can use to train your gesture recognition system. You are encouraged to modify `model.py` to create your own 3D CNN architecture. 40 | 41 | ## 4. Modify the CSV Files For Quick Testing (Optional) 42 | 43 | In the **20bn-jester-v1/annotations** folder you will find the following CSV files: 44 | 45 | * `jester-v1-labels-quick-testing.csv` 46 | * `jester-v1-train-quick-testing.csv` 47 | * `jester-v1-validation-quick-testing.csv` 48 | 49 | These files are used when quickly testing models and can be modified as you see fit. By default, the `jester-v1-labels-quick-testing.csv` file contains labels for only 4 classes of hand gestures and 1 label for "Doing other things"; the `jester-v1-train-quick-testing.csv` file contains the video ID and the corresponding labels of only 8 videos for training; and the `jester-v1-validation-quick-testing.csv` file contains the video ID and the corresponding labels for only 4 videos for validation. 50 | 51 | Feel free to add more classes of hand gestures or more videos to the training and validation sets. To add more classes of hand gestures, simply copy and paste from the `jester-v1-labels.csv` file that contains all the 25 different classes of hand gestures. Similarly, to add more videos to the training and validation sets, simply copy and paste from the `jester-v1-train.csv` and `jester-v1-validation.csv` files that contain all the video IDs and corresponding labels from the Jester dataset. 52 | 53 | **NOTE**: In this folder you will also find the CSV files used for training: `jester-v1-labels.csv`, `jester-v1-train.csv`, and `jester-v1-validation.csv`. These CSV files should **NOT** be modified. 54 | 55 | 56 | # CPU/GPU Option 57 | 58 | You can choose whether you want to train the network using only a CPU or a GPU. Due to the very large size of the Jester dataset it is **strongly recommended** that you only perform the training using a GPU. The CPU mode is favorable when you just want to quickly test models. 59 | 60 | To specify that you want to use the CPU for your computation, use the `--use_gpu=False` flag as described below. 61 | 62 | # Procedure 63 | 64 | ## Testing 65 | 66 | It is recommended that you quickly test your models before you train them on the full Jester dataset. When quickly testing models we suggest you use the `config_quick_testing.json` file and the CPU. To do this, use the following command: 67 | 68 | `python train.py --config configs/config_quick_testing.json --use_gpu=False` 69 | 70 | ## Training 71 | 72 | When training a model you should use the `config.json` file and a GPU (**strongly recommended**). To train your model using a GPU use the following command: 73 | 74 | `python train.py --config configs/config.json -g 0` 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import shutil 5 | import json 6 | import glob 7 | import signal 8 | import pickle 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from data_loader import VideoFolder 14 | from callbacks import PlotLearning, MonitorLRDecay, AverageMeter 15 | from model import ConvColumn 16 | from torchvision.transforms import * 17 | 18 | str2bool = lambda x: (str(x).lower() == 'true') 19 | 20 | parser = argparse.ArgumentParser( 21 | description='PyTorch Jester Training using JPEG') 22 | parser.add_argument('--config', '-c', help='json config file path') 23 | parser.add_argument('--eval_only', '-e', default=False, type=str2bool, 24 | help="evaluate trained model on validation data.") 25 | parser.add_argument('--resume', '-r', default=False, type=str2bool, 26 | help="resume training from given checkpoint.") 27 | parser.add_argument('--use_gpu', default=True, type=str2bool, 28 | help="flag to use gpu or not.") 29 | parser.add_argument('--gpus', '-g', help="gpu ids for use.") 30 | 31 | args = parser.parse_args() 32 | if len(sys.argv) < 2: 33 | parser.print_help() 34 | sys.exit(1) 35 | 36 | device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu") 37 | 38 | if args.use_gpu: 39 | gpus = [int(i) for i in args.gpus.split(',')] 40 | print("=> active GPUs: {}".format(args.gpus)) 41 | 42 | best_prec1 = 0 43 | 44 | # load config file 45 | with open(args.config) as data_file: 46 | config = json.load(data_file) 47 | 48 | 49 | def main(): 50 | global args, best_prec1 51 | 52 | # set run output folder 53 | model_name = config["model_name"] 54 | output_dir = config["output_dir"] 55 | print("=> Output folder for this run -- {}".format(model_name)) 56 | save_dir = os.path.join(output_dir, model_name) 57 | if not os.path.exists(save_dir): 58 | os.makedirs(save_dir) 59 | os.makedirs(os.path.join(save_dir, 'plots')) 60 | 61 | # adds a handler for Ctrl+C 62 | def signal_handler(signal, frame): 63 | """ 64 | Remove the output dir, if you exit with Ctrl+C and 65 | if there are less then 3 files. 66 | It prevents the noise of experimental runs. 67 | """ 68 | num_files = len(glob.glob(save_dir + "/*")) 69 | if num_files < 1: 70 | shutil.rmtree(save_dir) 71 | print('You pressed Ctrl+C!') 72 | sys.exit(0) 73 | # assign Ctrl+C signal handler 74 | signal.signal(signal.SIGINT, signal_handler) 75 | 76 | # create model 77 | model = ConvColumn(config['num_classes']) 78 | 79 | # multi GPU setting 80 | if args.use_gpu: 81 | model = torch.nn.DataParallel(model, device_ids=gpus).to(device) 82 | 83 | # optionally resume from a checkpoint 84 | if args.resume: 85 | if os.path.isfile(config['checkpoint']): 86 | print("=> loading checkpoint '{}'".format(args.resume)) 87 | checkpoint = torch.load(config['checkpoint']) 88 | args.start_epoch = checkpoint['epoch'] 89 | best_prec1 = checkpoint['best_prec1'] 90 | model.load_state_dict(checkpoint['state_dict']) 91 | print("=> loaded checkpoint '{}' (epoch {})" 92 | .format(config['checkpoint'], checkpoint['epoch'])) 93 | else: 94 | print("=> no checkpoint found at '{}'".format( 95 | config['checkpoint'])) 96 | 97 | transform = Compose([ 98 | CenterCrop(84), 99 | ToTensor(), 100 | Normalize(mean=[0.485, 0.456, 0.406], 101 | std=[0.229, 0.224, 0.225]) 102 | ]) 103 | 104 | train_data = VideoFolder(root=config['train_data_folder'], 105 | csv_file_input=config['train_data_csv'], 106 | csv_file_labels=config['labels_csv'], 107 | clip_size=config['clip_size'], 108 | nclips=1, 109 | step_size=config['step_size'], 110 | is_val=False, 111 | transform=transform, 112 | ) 113 | 114 | print(" > Using {} processes for data loader.".format( 115 | config["num_workers"])) 116 | train_loader = torch.utils.data.DataLoader( 117 | train_data, 118 | batch_size=config['batch_size'], shuffle=True, 119 | num_workers=config['num_workers'], pin_memory=True, 120 | drop_last=True) 121 | 122 | val_data = VideoFolder(root=config['val_data_folder'], 123 | csv_file_input=config['val_data_csv'], 124 | csv_file_labels=config['labels_csv'], 125 | clip_size=config['clip_size'], 126 | nclips=1, 127 | step_size=config['step_size'], 128 | is_val=True, 129 | transform=transform, 130 | ) 131 | 132 | val_loader = torch.utils.data.DataLoader( 133 | val_data, 134 | batch_size=config['batch_size'], shuffle=False, 135 | num_workers=config['num_workers'], pin_memory=True, 136 | drop_last=False) 137 | 138 | assert len(train_data.classes) == config["num_classes"] 139 | 140 | # define loss function (criterion) and optimizer 141 | criterion = nn.CrossEntropyLoss().to(device) 142 | 143 | # define optimizer 144 | lr = config["lr"] 145 | last_lr = config["last_lr"] 146 | momentum = config['momentum'] 147 | weight_decay = config['weight_decay'] 148 | optimizer = torch.optim.SGD(model.parameters(), lr, 149 | momentum=momentum, 150 | weight_decay=weight_decay) 151 | 152 | if args.eval_only: 153 | validate(val_loader, model, criterion, train_data.classes_dict) 154 | return 155 | 156 | # set callbacks 157 | plotter = PlotLearning(os.path.join( 158 | save_dir, "plots"), config["num_classes"]) 159 | lr_decayer = MonitorLRDecay(0.6, 3) 160 | val_loss = 9999999 161 | 162 | # set end condition by num epochs 163 | num_epochs = int(config["num_epochs"]) 164 | if num_epochs == -1: 165 | num_epochs = 999999 166 | 167 | print(" > Training is getting started...") 168 | print(" > Training takes {} epochs.".format(num_epochs)) 169 | start_epoch = args.start_epoch if args.resume else 0 170 | 171 | for epoch in range(start_epoch, num_epochs): 172 | lr = lr_decayer(val_loss, lr) 173 | print(" > Current LR : {}".format(lr)) 174 | 175 | if lr < last_lr and last_lr > 0: 176 | print(" > Training is done by reaching the last learning rate {}". 177 | format(last_lr)) 178 | sys.exit(1) 179 | 180 | # train for one epoch 181 | train_loss, train_top1, train_top5 = train( 182 | train_loader, model, criterion, optimizer, epoch) 183 | 184 | # evaluate on validation set 185 | val_loss, val_top1, val_top5 = validate(val_loader, model, criterion) 186 | 187 | # plot learning 188 | plotter_dict = {} 189 | plotter_dict['loss'] = train_loss 190 | plotter_dict['val_loss'] = val_loss 191 | plotter_dict['acc'] = train_top1 192 | plotter_dict['val_acc'] = val_top1 193 | plotter_dict['learning_rate'] = lr 194 | plotter.plot(plotter_dict) 195 | 196 | # remember best prec@1 and save checkpoint 197 | is_best = val_top1 > best_prec1 198 | best_prec1 = max(val_top1, best_prec1) 199 | save_checkpoint({ 200 | 'epoch': epoch + 1, 201 | 'arch': "Conv4Col", 202 | 'state_dict': model.state_dict(), 203 | 'best_prec1': best_prec1, 204 | }, is_best, config) 205 | 206 | 207 | def train(train_loader, model, criterion, optimizer, epoch): 208 | losses = AverageMeter() 209 | top1 = AverageMeter() 210 | top5 = AverageMeter() 211 | 212 | # switch to train mode 213 | model.train() 214 | 215 | for i, (input, target) in enumerate(train_loader): 216 | 217 | input, target = input.to(device), target.to(device) 218 | 219 | model.zero_grad() 220 | 221 | # compute output and loss 222 | output = model(input) 223 | loss = criterion(output, target) 224 | 225 | # measure accuracy and record loss 226 | prec1, prec5 = accuracy(output.detach(), target.detach().cpu(), topk=(1, 5)) 227 | losses.update(loss.item(), input.size(0)) 228 | top1.update(prec1.item(), input.size(0)) 229 | top5.update(prec5.item(), input.size(0)) 230 | 231 | # compute gradient and do SGD step 232 | optimizer.zero_grad() 233 | loss.backward() 234 | optimizer.step() 235 | 236 | if i % config["print_freq"] == 0: 237 | print('Epoch: [{0}][{1}/{2}]\t' 238 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 239 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 240 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 241 | epoch, i, len(train_loader), loss=losses, top1=top1, top5=top5)) 242 | return losses.avg, top1.avg, top5.avg 243 | 244 | 245 | def validate(val_loader, model, criterion, class_to_idx=None): 246 | losses = AverageMeter() 247 | top1 = AverageMeter() 248 | top5 = AverageMeter() 249 | 250 | # switch to evaluate mode 251 | model.eval() 252 | 253 | logits_matrix = [] 254 | targets_list = [] 255 | 256 | with torch.no_grad(): 257 | for i, (input, target) in enumerate(val_loader): 258 | 259 | input, target = input.to(device), target.to(device) 260 | 261 | # compute output and loss 262 | output = model(input) 263 | loss = criterion(output, target) 264 | 265 | if args.eval_only: 266 | logits_matrix.append(output.detach().cpu().numpy()) 267 | targets_list.append(target.detach().cpu().numpy()) 268 | 269 | # measure accuracy and record loss 270 | prec1, prec5 = accuracy(output.detach(), target.detach().cpu(), topk=(1, 5)) 271 | losses.update(loss.item(), input.size(0)) 272 | top1.update(prec1.item(), input.size(0)) 273 | top5.update(prec5.item(), input.size(0)) 274 | 275 | if i % config["print_freq"] == 0: 276 | print('Test: [{0}/{1}]\t' 277 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 278 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 279 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 280 | i, len(val_loader), loss=losses, top1=top1, top5=top5)) 281 | 282 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 283 | .format(top1=top1, top5=top5)) 284 | 285 | if args.eval_only: 286 | logits_matrix = np.concatenate(logits_matrix) 287 | targets_list = np.concatenate(targets_list) 288 | print(logits_matrix.shape, targets_list.shape) 289 | save_results(logits_matrix, targets_list, class_to_idx, config) 290 | return losses.avg, top1.avg, top5.avg 291 | 292 | 293 | def save_results(logits_matrix, targets_list, class_to_idx, config): 294 | print("Saving inference results ...") 295 | path_to_save = os.path.join( 296 | config['output_dir'], config['model_name'], "test_results.pkl") 297 | with open(path_to_save, "wb") as f: 298 | pickle.dump([logits_matrix, targets_list, class_to_idx], f) 299 | 300 | 301 | def save_checkpoint(state, is_best, config, filename='checkpoint.pth.tar'): 302 | checkpoint_path = os.path.join( 303 | config['output_dir'], config['model_name'], filename) 304 | model_path = os.path.join( 305 | config['output_dir'], config['model_name'], 'model_best.pth.tar') 306 | torch.save(state, checkpoint_path) 307 | if is_best: 308 | shutil.copyfile(checkpoint_path, model_path) 309 | 310 | 311 | def accuracy(output, target, topk=(1,)): 312 | """Computes the precision@k for the specified values of k""" 313 | maxk = max(topk) 314 | batch_size = target.size(0) 315 | 316 | _, pred = output.cpu().topk(maxk, 1, True, True) 317 | pred = pred.t() 318 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 319 | 320 | res = [] 321 | for k in topk: 322 | correct_k = correct[:k].view(-1).float().sum(0) 323 | res.append(correct_k.mul_(100.0 / batch_size)) 324 | return res 325 | 326 | 327 | if __name__ == '__main__': 328 | main() 329 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import os 4 | import warnings 5 | import numpy as np 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | from matplotlib import pylab as plt 9 | from torch.optim.optimizer import Optimizer 10 | 11 | 12 | ############################################################################### 13 | # TRAINING CALLBACKS 14 | ############################################################################### 15 | 16 | 17 | 18 | class ReduceLROnPlateau(object): 19 | """Reduce learning rate when a metric has stopped improving. 20 | Models often benefit from reducing the learning rate by a factor 21 | of 2-10 once learning stagnates. This scheduler reads a metrics 22 | quantity and if no improvement is seen for a 'patience' number 23 | of epochs, the learning rate is reduced. 24 | 25 | Args: 26 | factor: factor by which the learning rate will 27 | be reduced. new_lr = lr * factor 28 | patience: number of epochs with no improvement 29 | after which learning rate will be reduced. 30 | verbose: int. 0: quiet, 1: update messages. 31 | mode: one of {min, max}. In `min` mode, 32 | lr will be reduced when the quantity 33 | monitored has stopped decreasing; in `max` 34 | mode it will be reduced when the quantity 35 | monitored has stopped increasing. 36 | epsilon: threshold for measuring the new optimum, 37 | to only focus on significant changes. 38 | cooldown: number of epochs to wait before resuming 39 | normal operation after lr has been reduced. 40 | min_lr: lower bound on the learning rate. 41 | 42 | 43 | Example: 44 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 45 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 46 | >>> for epoch in range(10): 47 | >>> train(...) 48 | >>> val_acc, val_loss = validate(...) 49 | >>> scheduler.step(val_loss, epoch) 50 | """ 51 | 52 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 53 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0): 54 | super(ReduceLROnPlateau, self).__init__() 55 | 56 | if factor >= 1.0: 57 | raise ValueError('ReduceLROnPlateau ' 58 | 'does not support a factor >= 1.0.') 59 | self.factor = factor 60 | self.min_lr = min_lr 61 | self.epsilon = epsilon 62 | self.patience = patience 63 | self.verbose = verbose 64 | self.cooldown = cooldown 65 | self.cooldown_counter = 0 # Cooldown counter. 66 | self.monitor_op = None 67 | self.wait = 0 68 | self.best = 0 69 | self.mode = mode 70 | assert isinstance(optimizer, Optimizer) 71 | self.optimizer = optimizer 72 | self._reset() 73 | 74 | def _reset(self): 75 | """Resets wait counter and cooldown counter. 76 | """ 77 | if self.mode not in ['min', 'max']: 78 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!') 79 | if self.mode == 'min' : 80 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon) 81 | self.best = np.Inf 82 | else: 83 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon) 84 | self.best = -np.Inf 85 | self.cooldown_counter = 0 86 | self.wait = 0 87 | self.lr_epsilon = self.min_lr * 1e-4 88 | 89 | def reset(self): 90 | self._reset() 91 | 92 | def step(self, metrics, epoch): 93 | current = metrics 94 | if current is None: 95 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 96 | else: 97 | if self.in_cooldown(): 98 | self.cooldown_counter -= 1 99 | self.wait = 0 100 | 101 | if self.monitor_op(current, self.best): 102 | self.best = current 103 | self.wait = 0 104 | elif not self.in_cooldown(): 105 | if self.wait >= self.patience: 106 | for param_group in self.optimizer.param_groups: 107 | old_lr = float(param_group['lr']) 108 | if old_lr > self.min_lr + self.lr_epsilon: 109 | new_lr = old_lr * self.factor 110 | new_lr = max(new_lr, self.min_lr) 111 | param_group['lr'] = new_lr 112 | if self.verbose > 0: 113 | print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr)) 114 | self.cooldown_counter = self.cooldown 115 | self.wait = 0 116 | self.wait += 1 117 | 118 | def in_cooldown(self): 119 | return self.cooldown_counter > 0 120 | 121 | class MonitorLRDecay(object): 122 | """ 123 | Decay learning rate with some patience 124 | """ 125 | def __init__(self, decay_factor, patience): 126 | self.best_loss = 999999 127 | self.decay_factor = decay_factor 128 | self.patience = patience 129 | self.count = 0 130 | 131 | def __call__(self, current_loss, current_lr): 132 | if current_loss < self.best_loss: 133 | self.best_loss = current_loss 134 | self.count = 0 135 | elif self.count > self.patience: 136 | current_lr = current_lr * self. decay_factor 137 | print(" > New learning rate -- {0:}".format(current_lr)) 138 | self.count = 0 139 | else: 140 | self.count += 1 141 | return current_lr 142 | 143 | 144 | class PlotLearning(object): 145 | def __init__(self, save_path, num_classes): 146 | self.accuracy = [] 147 | self.val_accuracy = [] 148 | self.losses = [] 149 | self.val_losses = [] 150 | self.learning_rates = [] 151 | self.save_path_loss = os.path.join(save_path, 'loss_plot.png') 152 | self.save_path_accu = os.path.join(save_path, 'accu_plot.png') 153 | self.save_path_lr = os.path.join(save_path, 'lr_plot.png') 154 | self.init_loss = -np.log(1.0 / num_classes) 155 | 156 | def plot(self, logs): 157 | self.accuracy.append(logs.get('acc')) 158 | self.val_accuracy.append(logs.get('val_acc')) 159 | 160 | best_val_acc = max(self.val_accuracy) 161 | best_train_acc = max(self.accuracy) 162 | best_val_epoch = self.val_accuracy.index(best_val_acc) 163 | best_train_epoch = self.accuracy.index(best_train_acc) 164 | 165 | plt.figure(1) 166 | plt.gca().cla() 167 | plt.ylim(0, 1) 168 | plt.plot(self.accuracy, label='train') 169 | plt.plot(self.val_accuracy, label='valid') 170 | plt.title("best_val@{0:}-{1:.2f}, best_train@{2:}-{3:.2f}".format( 171 | best_val_epoch, best_val_acc, best_train_epoch, best_train_acc)) 172 | plt.legend() 173 | plt.savefig(self.save_path_accu) 174 | 175 | self.losses.append(logs.get('loss')) 176 | self.val_losses.append(logs.get('val_loss')) 177 | 178 | best_val_loss = min(self.val_losses) 179 | best_train_loss = min(self.losses) 180 | best_val_epoch = self.val_losses.index(best_val_loss) 181 | best_train_epoch = self.losses.index(best_train_loss) 182 | 183 | plt.figure(2) 184 | plt.gca().cla() 185 | plt.ylim(0, self.init_loss) 186 | plt.plot(self.losses, label='train') 187 | plt.plot(self.val_losses, label='valid') 188 | plt.title("best_val@{0:}-{1:.2f}, best_train@{2:}-{3:.2f}".format( 189 | best_val_epoch, best_val_loss, best_train_epoch, best_train_loss)) 190 | plt.legend() 191 | plt.savefig(self.save_path_loss) 192 | 193 | self.learning_rates.append(logs.get('learning_rate')) 194 | 195 | min_learning_rate = min(self.learning_rates) 196 | max_learning_rate = max(self.learning_rates) 197 | print(min_learning_rate) 198 | 199 | plt.figure(2) 200 | plt.gca().cla() 201 | plt.ylim(0, max_learning_rate) 202 | plt.plot(self.learning_rates) 203 | plt.title("max_learning_rate-{0:.6f}, min_learning_rate-{1:.6f}".format(max_learning_rate, min_learning_rate)) 204 | plt.savefig(self.save_path_lr) 205 | 206 | 207 | class Progbar(object): 208 | """Displays a progress bar. 209 | # Arguments 210 | target: Total number of steps expected. 211 | interval: Minimum visual progress update interval (in seconds). 212 | """ 213 | 214 | def __init__(self, target, width=30, verbose=1, interval=0.05): 215 | self.width = width 216 | self.target = target 217 | self.sum_values = {} 218 | self.unique_values = [] 219 | self.start = time.time() 220 | self.last_update = 0 221 | self.interval = interval 222 | self.total_width = 0 223 | self.seen_so_far = 0 224 | self.verbose = verbose 225 | 226 | def update(self, current, values=None, force=False): 227 | """Updates the progress bar. 228 | # Arguments 229 | current: Index of current step. 230 | values: List of tuples (name, value_for_last_step). 231 | The progress bar will display averages for these values. 232 | force: Whether to force visual progress update. 233 | """ 234 | values = values or [] 235 | for k, v in values: 236 | if k not in self.sum_values: 237 | self.sum_values[k] = [v * (current - self.seen_so_far), 238 | current - self.seen_so_far] 239 | self.unique_values.append(k) 240 | else: 241 | self.sum_values[k][0] += v * (current - self.seen_so_far) 242 | self.sum_values[k][1] += (current - self.seen_so_far) 243 | self.seen_so_far = current 244 | 245 | now = time.time() 246 | if self.verbose == 1: 247 | if not force and (now - self.last_update) < self.interval: 248 | return 249 | 250 | prev_total_width = self.total_width 251 | sys.stdout.write('\b' * prev_total_width) 252 | sys.stdout.write('\r') 253 | 254 | numdigits = int(np.floor(np.log10(self.target))) + 1 255 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 256 | bar = barstr % (current, self.target) 257 | prog = float(current) / self.target 258 | prog_width = int(self.width * prog) 259 | if prog_width > 0: 260 | bar += ('=' * (prog_width - 1)) 261 | if current < self.target: 262 | bar += '>' 263 | else: 264 | bar += '=' 265 | bar += ('.' * (self.width - prog_width)) 266 | bar += ']' 267 | sys.stdout.write(bar) 268 | self.total_width = len(bar) 269 | 270 | if current: 271 | time_per_unit = (now - self.start) / current 272 | else: 273 | time_per_unit = 0 274 | eta = time_per_unit * (self.target - current) 275 | info = '' 276 | if current < self.target: 277 | info += ' - ETA: %ds' % eta 278 | else: 279 | info += ' - %ds' % (now - self.start) 280 | for k in self.unique_values: 281 | info += ' - %s:' % k 282 | if isinstance(self.sum_values[k], list): 283 | avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) 284 | if abs(avg) > 1e-3: 285 | info += ' %.4f' % avg 286 | else: 287 | info += ' %.4e' % avg 288 | else: 289 | info += ' %s' % self.sum_values[k] 290 | 291 | self.total_width += len(info) 292 | if prev_total_width > self.total_width: 293 | info += ((prev_total_width - self.total_width) * ' ') 294 | 295 | sys.stdout.write(info) 296 | sys.stdout.flush() 297 | 298 | if current >= self.target: 299 | sys.stdout.write('\n') 300 | 301 | if self.verbose == 2: 302 | if current >= self.target: 303 | info = '%ds' % (now - self.start) 304 | for k in self.unique_values: 305 | info += ' - %s:' % k 306 | avg = self.sum_values[k][0] / max(1, self.sum_values[k][1]) 307 | if avg > 1e-3: 308 | info += ' %.4f' % avg 309 | else: 310 | info += ' %.4e' % avg 311 | sys.stdout.write(info + "\n") 312 | 313 | self.last_update = now 314 | 315 | def add(self, n, values=None): 316 | self.update(self.seen_so_far + n, values) 317 | 318 | 319 | class AverageMeter(object): 320 | """Computes and stores the average and current value""" 321 | def __init__(self): 322 | self.reset() 323 | 324 | def reset(self): 325 | self.val = 0 326 | self.avg = 0 327 | self.sum = 0 328 | self.count = 0 329 | 330 | def update(self, val, n=1): 331 | self.val = val 332 | self.sum += val * n 333 | self.count += n 334 | self.avg = self.sum / self.count --------------------------------------------------------------------------------