├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── dataset.py ├── preprocess.py └── statefultransforms.py ├── main.py ├── models ├── ConvBackend.py ├── ConvFrontend.py ├── LSTMBackend.py ├── LipRead.py ├── ResNetBBC.py └── __init__.py ├── options.toml ├── requirements.txt ├── training.py └── validation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | #saved models 104 | *.pt 105 | 106 | #runtime output files 107 | accuracy.txt 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | For TorchVision ResNet 4 | 5 | Copyright (c) Soumith Chintala 2016, 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | * Neither the name of the copyright holder nor the names of its 19 | contributors may be used to endorse or promote products derived from 20 | this software without specific prior written permission. 21 | 22 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 23 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 24 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 25 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 26 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 27 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 28 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 29 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 30 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lipreading With Machine Learning In PyTorch 2 | A PyTorch implementation of the models described in [Combining Residual Networks with LSTMs for Lipreading] by T. Stafylakis and G. Tzimiropoulos. Adapted from the [Torch7 code]. 3 | 4 | ## Usage 5 | - Install [Python 3]. 6 | - Clone the repository. 7 | - Run `pip3 install -r requirements.txt` to install project dependencies. 8 | - to use, run `python3 main.py`. 9 | 10 | ## Dependencies 11 | - [Python 3] to run the program 12 | - [PyTorch] for tensors, network definition and backprop 13 | - [ImageIO] to load video clips 14 | - [NumPy] to visualize individual layers 15 | 16 | [Combining Residual Networks with LSTMs for Lipreading]: 17 | [Torch7 code]: 18 | [Python 3]: 19 | [PyTorch]: 20 | [ImageIO]: 21 | [NumPy]: 22 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import LipreadingDataset 2 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from .preprocess import * 3 | import os 4 | 5 | class LipreadingDataset(Dataset): 6 | """BBC Lip Reading dataset.""" 7 | 8 | def build_file_list(self, dir, set): 9 | labels = os.listdir(dir) 10 | 11 | completeList = [] 12 | 13 | for i, label in enumerate(labels): 14 | 15 | dirpath = dir + "/{}/{}".format(label, set) 16 | print(i, label, dirpath) 17 | 18 | files = os.listdir(dirpath) 19 | 20 | for file in files: 21 | if file.endswith("mp4"): 22 | filepath = dirpath + "/{}".format(file) 23 | entry = (i, filepath) 24 | completeList.append(entry) 25 | 26 | 27 | return labels, completeList 28 | 29 | 30 | def __init__(self, directory, set, augment=True): 31 | self.label_list, self.file_list = self.build_file_list(directory, set) 32 | self.augment = augment 33 | 34 | def __len__(self): 35 | return len(self.file_list) 36 | 37 | def __getitem__(self, idx): 38 | #load video into a tensor 39 | label, filename = self.file_list[idx] 40 | vidframes = load_video(filename) 41 | temporalvolume = bbc(vidframes, self.augment) 42 | 43 | sample = {'temporalvolume': temporalvolume, 'label': torch.LongTensor([label])} 44 | 45 | return sample 46 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | 3 | imageio.plugins.ffmpeg.download() 4 | 5 | import torchvision.transforms.functional as functional 6 | import torchvision.transforms as transforms 7 | import torch 8 | from .statefultransforms import StatefulRandomCrop, StatefulRandomHorizontalFlip 9 | 10 | def load_video(filename): 11 | """Loads the specified video using ffmpeg. 12 | 13 | Args: 14 | filename (str): The path to the file to load. 15 | Should be a format that ffmpeg can handle. 16 | 17 | Returns: 18 | List[FloatTensor]: the frames of the video as a list of 3D tensors 19 | (channels, width, height)""" 20 | 21 | vid = imageio.get_reader(filename, 'ffmpeg') 22 | frames = [] 23 | for i in range(0, 29): 24 | image = vid.get_data(i) 25 | image = functional.to_tensor(image) 26 | frames.append(image) 27 | return frames 28 | 29 | def bbc(vidframes, augmentation=True): 30 | """Preprocesses the specified list of frames by center cropping. 31 | This will only work correctly on videos that are already centered on the 32 | mouth region, such as LRITW. 33 | 34 | Args: 35 | vidframes (List[FloatTensor]): The frames of the video as a list of 36 | 3D tensors (channels, width, height) 37 | 38 | Returns: 39 | FloatTensor: The video as a temporal volume, represented as a 5D tensor 40 | (batch, channel, time, width, height)""" 41 | 42 | temporalvolume = torch.FloatTensor(1,29,112,112) 43 | 44 | croptransform = transforms.CenterCrop((112, 112)) 45 | 46 | if(augmentation): 47 | crop = StatefulRandomCrop((122, 122), (112, 112)) 48 | flip = StatefulRandomHorizontalFlip(0.5) 49 | 50 | croptransform = transforms.Compose([ 51 | crop, 52 | flip 53 | ]) 54 | 55 | for i in range(0, 29): 56 | result = transforms.Compose([ 57 | transforms.ToPILImage(), 58 | transforms.CenterCrop((122, 122)), 59 | croptransform, 60 | transforms.Grayscale(num_output_channels=1), 61 | transforms.ToTensor(), 62 | transforms.Normalize([0.4161,],[0.1688,]), 63 | ])(vidframes[i]) 64 | 65 | temporalvolume[0][i] = result 66 | 67 | return temporalvolume 68 | -------------------------------------------------------------------------------- /data/statefultransforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms.functional as functional 2 | import random 3 | 4 | class StatefulRandomCrop(object): 5 | def __init__(self, insize, outsize): 6 | self.size = outsize 7 | self.cropParams = self.get_params(insize, self.size) 8 | 9 | @staticmethod 10 | def get_params(insize, outsize): 11 | """Get parameters for ``crop`` for a random crop. 12 | Args: 13 | insize (PIL Image): Image to be cropped. 14 | outsize (tuple): Expected output size of the crop. 15 | Returns: 16 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 17 | """ 18 | w, h = insize 19 | th, tw = outsize 20 | if w == tw and h == th: 21 | return 0, 0, h, w 22 | 23 | i = random.randint(0, h - th) 24 | j = random.randint(0, w - tw) 25 | return i, j, th, tw 26 | 27 | def __call__(self, img): 28 | """ 29 | Args: 30 | img (PIL Image): Image to be cropped. 31 | Returns: 32 | PIL Image: Cropped image. 33 | """ 34 | 35 | i, j, h, w = self.cropParams 36 | 37 | return functional.crop(img, i, j, h, w) 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) 41 | 42 | class StatefulRandomHorizontalFlip(object): 43 | def __init__(self, p=0.5): 44 | self.p = p 45 | self.rand = random.random() 46 | 47 | def __call__(self, img): 48 | """ 49 | Args: 50 | img (PIL Image): Image to be flipped. 51 | Returns: 52 | PIL Image: Randomly flipped image. 53 | """ 54 | if self.rand < self.p: 55 | return functional.hflip(img) 56 | return img 57 | 58 | def __repr__(self): 59 | return self.__class__.__name__ + '(p={})'.format(self.p) 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from models import LipRead 3 | import torch 4 | import toml 5 | from training import Trainer 6 | from validation import Validator 7 | 8 | print("Loading options...") 9 | with open('options.toml', 'r') as optionsFile: 10 | options = toml.loads(optionsFile.read()) 11 | 12 | if(options["general"]["usecudnnbenchmark"] and options["general"]["usecudnn"]): 13 | print("Running cudnn benchmark...") 14 | torch.backends.cudnn.benchmark = True 15 | 16 | #Create the model. 17 | model = LipRead(options) 18 | 19 | if(options["general"]["loadpretrainedmodel"]): 20 | model.load_state_dict(torch.load(options["general"]["pretrainedmodelpath"])) 21 | 22 | #Move the model to the GPU. 23 | if(options["general"]["usecudnn"]): 24 | model = model.cuda(options["general"]["gpuid"]) 25 | 26 | trainer = Trainer(options) 27 | validator = Validator(options) 28 | 29 | for epoch in range(options["training"]["startepoch"], options["training"]["epochs"]): 30 | 31 | if(options["training"]["train"]): 32 | trainer.epoch(model, epoch) 33 | 34 | if(options["validation"]["validate"]): 35 | validator.epoch(model) 36 | -------------------------------------------------------------------------------- /models/ConvBackend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def _validate(modelOutput, labels): 6 | maxvalues, maxindices = torch.max(modelOutput.data, 1) 7 | 8 | count = 0 9 | 10 | for i in range(0, labels.squeeze(1).size(0)): 11 | 12 | if maxindices[i] == labels.squeeze(1)[i]: 13 | count += 1 14 | 15 | return count 16 | 17 | class ConvBackend(nn.Module): 18 | def __init__(self, options): 19 | super(ConvBackend, self).__init__() 20 | 21 | bn_size = 256 22 | self.conv1 = nn.Conv1d(bn_size,2 * bn_size ,2, 2) 23 | self.norm1 = nn.BatchNorm1d(bn_size * 2) 24 | self.pool1 = nn.MaxPool1d(2, 2) 25 | 26 | self.conv2 = nn.Conv1d( 2* bn_size, 4* bn_size,2, 2) 27 | self.norm2 = nn.BatchNorm1d(bn_size * 4) 28 | 29 | self.linear = nn.Linear(4*bn_size, bn_size) 30 | self.norm3 = nn.BatchNorm1d(bn_size) 31 | self.linear2 = nn.Linear(bn_size, 500) 32 | 33 | self.loss = nn.CrossEntropyLoss() 34 | 35 | self.validator = _validate 36 | 37 | def forward(self, input): 38 | transposed = input.transpose(1, 2).contiguous() 39 | 40 | output = self.conv1(transposed) 41 | output = self.norm1(output) 42 | output = F.relu(output) 43 | output = self.pool1(output) 44 | output = self.conv2(output) 45 | output = self.norm2(output) 46 | output = F.relu(output) 47 | output = output.mean(2) 48 | output = self.linear(output) 49 | output = self.norm3(output) 50 | output = F.relu(output) 51 | output =self.linear2(output) 52 | 53 | return output 54 | -------------------------------------------------------------------------------- /models/ConvFrontend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | 7 | class ConvFrontend(nn.Module): 8 | def __init__(self): 9 | super(ConvFrontend, self).__init__() 10 | self.conv = nn.Conv3d(1, 64, (5,7,7),stride=(1,2,2),padding=(2,3,3)) 11 | self.norm = nn.BatchNorm3d(64) 12 | self.pool = nn.MaxPool3d((1,3,3),stride=(1,2,2),padding=(0,1,1)) 13 | 14 | def forward(self, input): 15 | #return self.conv(input) 16 | output = self.pool(F.relu(self.norm(self.conv(input)))) 17 | return output 18 | -------------------------------------------------------------------------------- /models/LSTMBackend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | 6 | class NLLSequenceLoss(nn.Module): 7 | """ 8 | Custom loss function. 9 | Returns a loss that is the sum of all losses at each time step. 10 | """ 11 | def __init__(self): 12 | super(NLLSequenceLoss, self).__init__() 13 | self.criterion = nn.NLLLoss() 14 | 15 | def forward(self, input, target): 16 | loss = 0.0 17 | transposed = input.transpose(0, 1).contiguous() 18 | 19 | for i in range(0, 29): 20 | loss += self.criterion(transposed[i], target) 21 | 22 | return loss 23 | 24 | def _validate(modelOutput, labels): 25 | 26 | averageEnergies = torch.sum(modelOutput.data, 1) 27 | 28 | maxvalues, maxindices = torch.max(averageEnergies, 1) 29 | 30 | count = 0 31 | 32 | for i in range(0, labels.squeeze(1).size(0)): 33 | 34 | if maxindices[i] == labels.squeeze(1)[i]: 35 | count += 1 36 | 37 | return count 38 | 39 | class LSTMBackend(nn.Module): 40 | def __init__(self, options): 41 | super(LSTMBackend, self).__init__() 42 | self.Module1 = nn.LSTM(input_size=options["model"]["inputdim"], 43 | hidden_size=options["model"]["hiddendim"], 44 | num_layers=options["model"]["numlstms"], 45 | batch_first=True, 46 | bidirectional=True) 47 | 48 | self.fc = nn.Linear(options["model"]["hiddendim"] * 2, 49 | options["model"]["numclasses"]) 50 | 51 | self.softmax = nn.LogSoftmax(dim=2) 52 | 53 | self.loss = NLLSequenceLoss() 54 | 55 | self.validator = _validate 56 | 57 | def forward(self, input): 58 | 59 | temporalDim = 1 60 | 61 | lstmOutput, _ = self.Module1(input) 62 | 63 | output = self.fc(lstmOutput) 64 | output = self.softmax(output) 65 | 66 | return output 67 | -------------------------------------------------------------------------------- /models/LipRead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | import re 7 | 8 | from .ConvFrontend import ConvFrontend 9 | from .ResNetBBC import ResNetBBC 10 | from .LSTMBackend import LSTMBackend 11 | from .ConvBackend import ConvBackend 12 | 13 | class LipRead(nn.Module): 14 | def __init__(self, options): 15 | super(LipRead, self).__init__() 16 | self.frontend = ConvFrontend() 17 | self.resnet = ResNetBBC(options) 18 | self.backend = ConvBackend(options) 19 | self.lstm = LSTMBackend(options) 20 | 21 | self.type = options["model"]["type"] 22 | 23 | def freeze(m): 24 | m.requires_grad=False 25 | 26 | if(options["model"]["type"] == "LSTM-init"): 27 | self.frontend.apply(freeze) 28 | self.resnet.apply(freeze) 29 | 30 | 31 | self.frontend.apply(freeze) 32 | self.resnet.apply(freeze) 33 | 34 | #function to initialize the weights and biases of each module. Matches the 35 | #classname with a regular expression to determine the type of the module, then 36 | #initializes the weights for it. 37 | def weights_init(m): 38 | classname = m.__class__.__name__ 39 | if re.search("Conv[123]d", classname): 40 | m.weight.data.normal_(0.0, 0.02) 41 | elif re.search("BatchNorm[123]d", classname): 42 | m.weight.data.fill_(1.0) 43 | m.bias.data.fill_(0) 44 | elif re.search("Linear", classname): 45 | m.bias.data.fill_(0) 46 | 47 | #Apply weight initialization to every module in the model. 48 | self.apply(weights_init) 49 | 50 | def forward(self, input): 51 | if(self.type == "temp-conv"): 52 | output = self.backend(self.resnet(self.frontend(input))) 53 | 54 | if(self.type == "LSTM" or self.type == "LSTM-init"): 55 | output = self.lstm(self.resnet(self.frontend(input))) 56 | 57 | return output 58 | 59 | def loss(self): 60 | if(self.type == "temp-conv"): 61 | return self.backend.loss 62 | 63 | if(self.type == "LSTM" or self.type == "LSTM-init"): 64 | return self.lstm.loss 65 | 66 | def validator_function(self): 67 | if(self.type == "temp-conv"): 68 | return self.backend.validator 69 | 70 | if(self.type == "LSTM" or self.type == "LSTM-init"): 71 | return self.lstm.validator 72 | -------------------------------------------------------------------------------- /models/ResNetBBC.py: -------------------------------------------------------------------------------- 1 | # Adapted from TorchVision's ResNet to use a custom frontend and backend. 2 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 3 | # 4 | # BSD 3-Clause License 5 | # 6 | # Copyright (c) Soumith Chintala 2016, 7 | # All rights reserved. 8 | # 9 | # Redistribution and use in source and binary forms, with or without 10 | # modification, are permitted provided that the following conditions are met: 11 | # 12 | # * Redistributions of source code must retain the above copyright notice, this 13 | # list of conditions and the following disclaimer. 14 | # 15 | # * Redistributions in binary form must reproduce the above copyright notice, 16 | # this list of conditions and the following disclaimer in the documentation 17 | # and/or other materials provided with the distribution. 18 | # 19 | # * Neither the name of the copyright holder nor the names of its 20 | # contributors may be used to endorse or promote products derived from 21 | # this software without specific prior written permission. 22 | # 23 | #THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 24 | #AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 25 | #IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 26 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 27 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 28 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 29 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 30 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 31 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 32 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | 34 | import torch.nn as nn 35 | import math 36 | import torch.utils.model_zoo as model_zoo 37 | 38 | 39 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 40 | 'resnet152'] 41 | 42 | 43 | model_urls = { 44 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 45 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 46 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 47 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 48 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 49 | } 50 | 51 | 52 | def conv3x3(in_planes, out_planes, stride=1): 53 | """3x3 convolution with padding""" 54 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 55 | padding=1, bias=False) 56 | 57 | 58 | class BasicBlock(nn.Module): 59 | expansion = 1 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(BasicBlock, self).__init__() 63 | self.conv1 = conv3x3(inplanes, planes, stride) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.conv2 = conv3x3(planes, planes) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class Bottleneck(nn.Module): 91 | expansion = 4 92 | 93 | def __init__(self, inplanes, planes, stride=1, downsample=None): 94 | super(Bottleneck, self).__init__() 95 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 96 | self.bn1 = nn.BatchNorm2d(planes) 97 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 98 | padding=1, bias=False) 99 | self.bn2 = nn.BatchNorm2d(planes) 100 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 101 | self.bn3 = nn.BatchNorm2d(planes * 4) 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3(out) 119 | 120 | if self.downsample is not None: 121 | residual = self.downsample(x) 122 | 123 | out += residual 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | class ResNet(nn.Module): 130 | 131 | def __init__(self, block, layers, num_classes=1000): 132 | self.inplanes = 64 133 | super(ResNet, self).__init__() 134 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 135 | bias=False) 136 | self.bn1 = nn.BatchNorm2d(64) 137 | self.relu = nn.ReLU(inplace=True) 138 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 139 | self.layer1 = self._make_layer(block, 64, layers[0]) 140 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 142 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 143 | self.avgpool = nn.AvgPool2d(4, stride=1) 144 | self.fc = nn.Linear(512 * block.expansion, num_classes) 145 | self.bn2 = nn.BatchNorm1d(num_classes) 146 | 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 150 | m.weight.data.normal_(0, math.sqrt(2. / n)) 151 | elif isinstance(m, nn.BatchNorm2d): 152 | m.weight.data.fill_(1) 153 | m.bias.data.zero_() 154 | 155 | def _make_layer(self, block, planes, blocks, stride=1): 156 | downsample = None 157 | if stride != 1 or self.inplanes != planes * block.expansion: 158 | downsample = nn.Sequential( 159 | nn.Conv2d(self.inplanes, planes * block.expansion, 160 | kernel_size=1, stride=stride, bias=False), 161 | nn.BatchNorm2d(planes * block.expansion), 162 | ) 163 | 164 | layers = [] 165 | layers.append(block(self.inplanes, planes, stride, downsample)) 166 | self.inplanes = planes * block.expansion 167 | for i in range(1, blocks): 168 | layers.append(block(self.inplanes, planes)) 169 | 170 | return nn.Sequential(*layers) 171 | 172 | def forward(self, x): 173 | 174 | x = self.layer1(x) 175 | x = self.layer2(x) 176 | x = self.layer3(x) 177 | x = self.layer4(x) 178 | 179 | x = self.avgpool(x) 180 | x = x.view(x.size(0), -1) 181 | x = self.fc(x) 182 | x = self.bn2(x) 183 | 184 | return x 185 | 186 | 187 | def resnet18(pretrained=False, **kwargs): 188 | """Constructs a ResNet-18 model. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 194 | if pretrained: 195 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 196 | return model 197 | 198 | 199 | def resnet34(pretrained=False, **kwargs): 200 | """Constructs a ResNet-34 model. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 208 | return model 209 | 210 | 211 | def resnet50(pretrained=False, **kwargs): 212 | """Constructs a ResNet-50 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 220 | return model 221 | 222 | 223 | def resnet101(pretrained=False, **kwargs): 224 | """Constructs a ResNet-101 model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 232 | return model 233 | 234 | 235 | def resnet152(pretrained=False, **kwargs): 236 | """Constructs a ResNet-152 model. 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 244 | return model 245 | 246 | class ResNetBBC(nn.Module): 247 | def __init__(self, options): 248 | super(ResNetBBC, self).__init__() 249 | self.inputdims = options["model"]["inputdim"] 250 | self.batchsize = options["input"]["batchsize"] 251 | 252 | self.resnetModel = resnet34(False, num_classes=self.inputdims) 253 | 254 | def forward(self, input): 255 | 256 | transposed = input.transpose(1, 2).contiguous() 257 | 258 | view = transposed.view(-1, 64, 28, 28) 259 | 260 | output = self.resnetModel(view) 261 | 262 | output = output.view(self.batchsize, -1, 256) 263 | 264 | return output 265 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ConvFrontend import ConvFrontend 2 | from .ResNetBBC import ResNetBBC 3 | from .LSTMBackend import LSTMBackend 4 | from .LipRead import LipRead 5 | from .ConvBackend import ConvBackend 6 | -------------------------------------------------------------------------------- /options.toml: -------------------------------------------------------------------------------- 1 | title = "TOML Example" 2 | 3 | [general] 4 | usecudnn = true 5 | usecudnnbenchmark = true 6 | gpuid = 1 7 | loadpretrainedmodel = true 8 | pretrainedmodelpath = "trainedmodel.pt" 9 | savemodel = true 10 | modelsavepath = "savedmodel.pt" 11 | 12 | [input] 13 | batchsize = 18 14 | numworkers = 18 15 | shuffle = true 16 | 17 | [model] 18 | type = "LSTM" 19 | inputdim = 256 20 | hiddendim = 256 21 | numclasses = 500 22 | numlstms = 2 23 | 24 | [training] 25 | train = true 26 | epochs = 15 27 | startepoch = 10 28 | statsfrequency = 1000 29 | dataset = "/udisk/pszts-ssd/AV-ASR-data/BBC_Oxford/lipread_mp4" 30 | learningrate = 0.003 31 | momentum = 0.9 32 | weightdecay = 0.0001 33 | 34 | [validation] 35 | validate = true 36 | dataset = "/udisk/pszts-ssd/AV-ASR-data/BBC_Oxford/lipread_mp4" 37 | saveaccuracy = true 38 | accuracyfilelocation = "accuracy.txt" 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio 2 | torch 3 | torchvision 4 | toml 5 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim as optim 4 | from datetime import datetime, timedelta 5 | from data import LipreadingDataset 6 | from torch.utils.data import DataLoader 7 | import os 8 | import math 9 | 10 | def timedelta_string(timedelta): 11 | totalSeconds = int(timedelta.total_seconds()) 12 | hours, remainder = divmod(totalSeconds,60*60) 13 | minutes, seconds = divmod(remainder,60) 14 | return "{} hrs, {} mins, {} secs".format(hours, minutes, seconds) 15 | 16 | def output_iteration(i, time, totalitems): 17 | os.system('clear') 18 | 19 | avgBatchTime = time / (i+1) 20 | estTime = avgBatchTime * (totalitems - i) 21 | 22 | print("Iteration: {}\nElapsed Time: {} \nEstimated Time Remaining: {}".format(i, timedelta_string(time), timedelta_string(estTime))) 23 | 24 | class Trainer(): 25 | def __init__(self, options): 26 | self.trainingdataset = LipreadingDataset(options["training"]["dataset"], "train") 27 | self.trainingdataloader = DataLoader( 28 | self.trainingdataset, 29 | batch_size=options["input"]["batchsize"], 30 | shuffle=options["input"]["shuffle"], 31 | num_workers=options["input"]["numworkers"], 32 | drop_last=True 33 | ) 34 | self.usecudnn = options["general"]["usecudnn"] 35 | 36 | self.batchsize = options["input"]["batchsize"] 37 | 38 | self.statsfrequency = options["training"]["statsfrequency"] 39 | 40 | self.gpuid = options["general"]["gpuid"] 41 | 42 | self.learningrate = options["training"]["learningrate"] 43 | 44 | self.modelType = options["training"]["learningrate"] 45 | 46 | self.weightdecay = options["training"]["weightdecay"] 47 | self.momentum = options["training"]["momentum"] 48 | 49 | def learningRate(self, epoch): 50 | decay = math.floor((epoch - 1) / 5) 51 | return self.learningrate * pow(0.5, decay) 52 | 53 | def epoch(self, model, epoch): 54 | #set up the loss function. 55 | criterion = model.loss() 56 | optimizer = optim.SGD( 57 | model.parameters(), 58 | lr = self.learningRate(epoch), 59 | momentum = self.learningrate, 60 | weight_decay = self.weightdecay) 61 | 62 | #transfer the model to the GPU. 63 | if(self.usecudnn): 64 | criterion = criterion.cuda(self.gpuid) 65 | 66 | startTime = datetime.now() 67 | print("Starting training...") 68 | for i_batch, sample_batched in enumerate(self.trainingdataloader): 69 | optimizer.zero_grad() 70 | input = Variable(sample_batched['temporalvolume']) 71 | labels = Variable(sample_batched['label']) 72 | 73 | if(self.usecudnn): 74 | input = input.cuda(self.gpuid) 75 | labels = labels.cuda(self.gpuid) 76 | 77 | outputs = model(input) 78 | loss = criterion(outputs, labels.squeeze(1)) 79 | 80 | loss.backward() 81 | optimizer.step() 82 | sampleNumber = i_batch * self.batchsize 83 | 84 | if(sampleNumber % self.statsfrequency == 0): 85 | currentTime = datetime.now() 86 | output_iteration(sampleNumber, currentTime - startTime, len(self.trainingdataset)) 87 | 88 | print("Epoch completed, saving state...") 89 | torch.save(model.state_dict(), "trainedmodel.pt") 90 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import torch.optim as optim 4 | from datetime import datetime, timedelta 5 | from data import LipreadingDataset 6 | from torch.utils.data import DataLoader 7 | import os 8 | 9 | class Validator(): 10 | def __init__(self, options): 11 | 12 | self.validationdataset = LipreadingDataset("/udisk/pszts-ssd/AV-ASR-data/BBC_Oxford/lipread_mp4", 13 | "val", False) 14 | self.validationdataloader = DataLoader( 15 | self.validationdataset, 16 | batch_size=options["input"]["batchsize"], 17 | shuffle=options["input"]["shuffle"], 18 | num_workers=options["input"]["numworkers"], 19 | drop_last=True 20 | ) 21 | self.usecudnn = options["general"]["usecudnn"] 22 | 23 | self.batchsize = options["input"]["batchsize"] 24 | 25 | self.statsfrequency = options["training"]["statsfrequency"] 26 | 27 | self.gpuid = options["general"]["gpuid"] 28 | 29 | def epoch(self, model): 30 | print("Starting validation...") 31 | count = 0 32 | validator_function = model.validator_function() 33 | 34 | for i_batch, sample_batched in enumerate(self.validationdataloader): 35 | input = Variable(sample_batched['temporalvolume']) 36 | labels = sample_batched['label'] 37 | 38 | if(self.usecudnn): 39 | input = input.cuda(self.gpuid) 40 | labels = labels.cuda(self.gpuid) 41 | 42 | outputs = model(input) 43 | 44 | count += validator_function(outputs, labels) 45 | 46 | print(count) 47 | 48 | 49 | accuracy = count / len(self.validationdataset) 50 | with open("accuracy.txt", "a") as outputfile: 51 | outputfile.write("\ncorrect count: {}, total count: {} accuracy: {}" .format(count, len(self.validationdataset), accuracy )) 52 | --------------------------------------------------------------------------------