├── .gitignore ├── LICENSE.md ├── README.md ├── dataloaders ├── __init__.py └── image_caption_dataset.py ├── models ├── AudioModels.py ├── ImageModels.py └── __init__.py ├── requirements.txt ├── run.py └── steps ├── __init__.py ├── traintest.py └── util.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dharwath/DAVEnet-pytorch/23a6482859dd2221350307c9bfb5627a5902f6f0/.gitignore -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) [year], [fullname] 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAVEnet Pytorch 2 | 3 | Implementation in Pytorch of the DAVEnet (Deep Audio-Visual Embedding network) model, as described in 4 | 5 | David Harwath, Adrià Recasens, Dídac Surís, Galen Chuang, Antonio Torralba, and James Glass, "Jointly Discovering Visual Objects and Spoken Words from Raw Sensory Input," ECCV 2018 6 | 7 | ## Requirements 8 | 9 | - pytorch 10 | - torchvision 11 | - librosa 12 | 13 | ## Data 14 | 15 | You will need the PlacesAudio400k spoken caption corpus in addition to the Places205 image dataset: 16 | 17 | http://groups.csail.mit.edu/sls/downloads/placesaudio/ 18 | 19 | http://places.csail.mit.edu/ 20 | 21 | Please follow the instructions provided in the PlacesAudio400k download package with respect to how to configure and specify the dataset .json files. 22 | 23 | ## Model Training 24 | 25 | python run.py train.json --data-val val.json 26 | 27 | Where train.json and val.json are included in the PlacesAudio400k dataset. 28 | 29 | See the run.py script for more training options. 30 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_caption_dataset import ImageCaptionDataset -------------------------------------------------------------------------------- /dataloaders/image_caption_dataset.py: -------------------------------------------------------------------------------- 1 | # Author: David Harwath 2 | # with some functions borrowed from https://github.com/SeanNaren/deepspeech.pytorch 3 | import json 4 | import librosa 5 | import numpy as np 6 | import os 7 | from PIL import Image 8 | import scipy.signal 9 | import torch 10 | import torch.nn.functional 11 | from torch.utils.data import Dataset 12 | import torchvision.transforms as transforms 13 | 14 | def preemphasis(signal,coeff=0.97): 15 | """perform preemphasis on the input signal. 16 | 17 | :param signal: The signal to filter. 18 | :param coeff: The preemphasis coefficient. 0 is none, default 0.97. 19 | :returns: the filtered signal. 20 | """ 21 | return np.append(signal[0],signal[1:]-coeff*signal[:-1]) 22 | 23 | class ImageCaptionDataset(Dataset): 24 | def __init__(self, dataset_json_file, audio_conf=None, image_conf=None): 25 | """ 26 | Dataset that manages a set of paired images and audio recordings 27 | 28 | :param dataset_json_file 29 | :param audio_conf: Dictionary containing the sample rate, window and 30 | the window length/stride in seconds, and normalization to perform (optional) 31 | :param image_transform: torchvision transform to apply to the images (optional) 32 | """ 33 | with open(dataset_json_file, 'r') as fp: 34 | data_json = json.load(fp) 35 | self.data = data_json['data'] 36 | self.image_base_path = data_json['image_base_path'] 37 | self.audio_base_path = data_json['audio_base_path'] 38 | 39 | if not audio_conf: 40 | self.audio_conf = {} 41 | else: 42 | self.audio_conf = audio_conf 43 | 44 | if not image_conf: 45 | self.image_conf = {} 46 | else: 47 | self.image_conf = image_conf 48 | 49 | crop_size = self.image_conf.get('crop_size', 224) 50 | center_crop = self.image_conf.get('center_crop', False) 51 | 52 | if center_crop: 53 | self.image_resize_and_crop = transforms.Compose( 54 | [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()]) 55 | else: 56 | self.image_resize_and_crop = transforms.Compose( 57 | [transforms.RandomResizedCrop(crop_size), transforms.ToTensor()]) 58 | 59 | RGB_mean = self.image_conf.get('RGB_mean', [0.485, 0.456, 0.406]) 60 | RGB_std = self.image_conf.get('RGB_std', [0.229, 0.224, 0.225]) 61 | self.image_normalize = transforms.Normalize(mean=RGB_mean, std=RGB_std) 62 | 63 | self.windows = {'hamming': scipy.signal.hamming, 64 | 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman, 65 | 'bartlett': scipy.signal.bartlett} 66 | 67 | def _LoadAudio(self, path): 68 | audio_type = self.audio_conf.get('audio_type', 'melspectrogram') 69 | if audio_type not in ['melspectrogram', 'spectrogram']: 70 | raise ValueError('Invalid audio_type specified in audio_conf. Must be one of [melspectrogram, spectrogram]') 71 | preemph_coef = self.audio_conf.get('preemph_coef', 0.97) 72 | sample_rate = self.audio_conf.get('sample_rate', 16000) 73 | window_size = self.audio_conf.get('window_size', 0.025) 74 | window_stride = self.audio_conf.get('window_stride', 0.01) 75 | window_type = self.audio_conf.get('window_type', 'hamming') 76 | num_mel_bins = self.audio_conf.get('num_mel_bins', 40) 77 | target_length = self.audio_conf.get('target_length', 2048) 78 | use_raw_length = self.audio_conf.get('use_raw_length', False) 79 | padval = self.audio_conf.get('padval', 0) 80 | fmin = self.audio_conf.get('fmin', 20) 81 | n_fft = self.audio_conf.get('n_fft', int(sample_rate * window_size)) 82 | win_length = int(sample_rate * window_size) 83 | hop_length = int(sample_rate * window_stride) 84 | 85 | # load audio, subtract DC, preemphasis 86 | y, sr = librosa.load(path, sample_rate) 87 | if y.size == 0: 88 | y = np.zeros(200) 89 | y = y - y.mean() 90 | y = preemphasis(y, preemph_coef) 91 | # compute mel spectrogram 92 | stft = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 93 | win_length=win_length, 94 | window=self.windows.get(window_type, self.windows['hamming'])) 95 | spec = np.abs(stft)**2 96 | if audio_type == 'melspectrogram': 97 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels=num_mel_bins, fmin=fmin) 98 | melspec = np.dot(mel_basis, spec) 99 | logspec = librosa.power_to_db(melspec, ref=np.max) 100 | elif audio_type == 'spectrogram': 101 | logspec = librosa.power_to_db(spec, ref=np.max) 102 | n_frames = logspec.shape[1] 103 | if use_raw_length: 104 | target_length = n_frames 105 | p = target_length - n_frames 106 | if p > 0: 107 | logspec = np.pad(logspec, ((0,0),(0,p)), 'constant', 108 | constant_values=(padval,padval)) 109 | elif p < 0: 110 | logspec = logspec[:,0:p] 111 | n_frames = target_length 112 | logspec = torch.FloatTensor(logspec) 113 | return logspec, n_frames 114 | 115 | def _LoadImage(self, impath): 116 | img = Image.open(impath).convert('RGB') 117 | img = self.image_resize_and_crop(img) 118 | img = self.image_normalize(img) 119 | return img 120 | 121 | def __getitem__(self, index): 122 | """ 123 | returns: image, audio, nframes 124 | where image is a FloatTensor of size (3, H, W) 125 | audio is a FloatTensor of size (N_freq, N_frames) for spectrogram, or (N_frames) for waveform 126 | nframes is an integer 127 | """ 128 | datum = self.data[index] 129 | wavpath = os.path.join(self.audio_base_path, datum['wav']) 130 | imgpath = os.path.join(self.image_base_path, datum['image']) 131 | audio, nframes = self._LoadAudio(wavpath) 132 | image = self._LoadImage(imgpath) 133 | return image, audio, nframes 134 | 135 | def __len__(self): 136 | return len(self.data) -------------------------------------------------------------------------------- /models/AudioModels.py: -------------------------------------------------------------------------------- 1 | # Author: David Harwath 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Davenet(nn.Module): 8 | def __init__(self, embedding_dim=1024): 9 | super(Davenet, self).__init__() 10 | self.embedding_dim = embedding_dim 11 | self.batchnorm1 = nn.BatchNorm2d(1) 12 | self.conv1 = nn.Conv2d(1, 128, kernel_size=(40,1), stride=(1,1), padding=(0,0)) 13 | self.conv2 = nn.Conv2d(128, 256, kernel_size=(1,11), stride=(1,1), padding=(0,5)) 14 | self.conv3 = nn.Conv2d(256, 512, kernel_size=(1,17), stride=(1,1), padding=(0,8)) 15 | self.conv4 = nn.Conv2d(512, 512, kernel_size=(1,17), stride=(1,1), padding=(0,8)) 16 | self.conv5 = nn.Conv2d(512, embedding_dim, kernel_size=(1,17), stride=(1,1), padding=(0,8)) 17 | self.pool = nn.MaxPool2d(kernel_size=(1,3), stride=(1,2),padding=(0,1)) 18 | 19 | def forward(self, x): 20 | if x.dim() == 3: 21 | x = x.unsqueeze(1) 22 | x = self.batchnorm1(x) 23 | x = F.relu(self.conv1(x)) 24 | x = F.relu(self.conv2(x)) 25 | x = self.pool(x) 26 | x = F.relu(self.conv3(x)) 27 | x = self.pool(x) 28 | x = F.relu(self.conv4(x)) 29 | x = self.pool(x) 30 | x = F.relu(self.conv5(x)) 31 | x = self.pool(x) 32 | x = x.squeeze(2) 33 | return x -------------------------------------------------------------------------------- /models/ImageModels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as imagemodels 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | class Resnet18(imagemodels.ResNet): 8 | def __init__(self, embedding_dim=1024, pretrained=False): 9 | super(Resnet18, self).__init__(imagemodels.resnet.BasicBlock, [2, 2, 2, 2]) 10 | if pretrained: 11 | self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet18'])) 12 | self.avgpool = None 13 | self.fc = None 14 | self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0) 15 | self.embedding_dim = embedding_dim 16 | self.pretrained = pretrained 17 | 18 | def forward(self, x): 19 | x = self.conv1(x) 20 | x = self.bn1(x) 21 | x = self.relu(x) 22 | x = self.maxpool(x) 23 | x = self.layer1(x) 24 | x = self.layer2(x) 25 | x = self.layer3(x) 26 | x = self.layer4(x) 27 | x = self.embedder(x) 28 | return x 29 | 30 | class Resnet34(imagemodels.ResNet): 31 | def __init__(self, embedding_dim=1024, pretrained=False): 32 | super(Resnet34, self).__init__(imagemodels.resnet.BasicBlock, [3, 4, 6, 3]) 33 | if pretrained: 34 | self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet34'])) 35 | self.avgpool = None 36 | self.fc = None 37 | self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0) 38 | 39 | def forward(self, x): 40 | x = self.conv1(x) 41 | x = self.bn1(x) 42 | x = self.relu(x) 43 | x = self.maxpool(x) 44 | x = self.layer1(x) 45 | x = self.layer2(x) 46 | x = self.layer3(x) 47 | x = self.layer4(x) 48 | x = self.embedder(x) 49 | return x 50 | 51 | class Resnet50(imagemodels.ResNet): 52 | def __init__(self, embedding_dim=1024, pretrained=False): 53 | super(Resnet50, self).__init__(imagemodels.resnet.Bottleneck, [3, 4, 6, 3]) 54 | if pretrained: 55 | self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet50'])) 56 | self.avgpool = None 57 | self.fc = None 58 | self.embedder = nn.Conv2d(2048, embedding_dim, kernel_size=1, stride=1, padding=0) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.bn1(x) 63 | x = self.relu(x) 64 | x = self.maxpool(x) 65 | x = self.layer1(x) 66 | x = self.layer2(x) 67 | x = self.layer3(x) 68 | x = self.layer4(x) 69 | x = self.embedder(x) 70 | return x 71 | 72 | class VGG16(nn.Module): 73 | def __init__(self, embedding_dim=1024, pretrained=False): 74 | super(VGG16, self).__init__() 75 | seed_model = imagemodels.__dict__['vgg16'](pretrained=pretrained).features 76 | seed_model = nn.Sequential(*list(seed_model.children())[:-1]) # remove final maxpool 77 | last_layer_index = len(list(seed_model.children())) 78 | seed_model.add_module(str(last_layer_index), 79 | nn.Conv2d(512, embedding_dim, kernel_size=(3,3), stride=(1,1), padding=(1,1))) 80 | self.image_model = seed_model 81 | 82 | def forward(self, x): 83 | x = self.image_model(x) 84 | return x 85 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .AudioModels import * 2 | from .ImageModels import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | librosa 4 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # Author: David Harwath 2 | import argparse 3 | import os 4 | import pickle 5 | import sys 6 | import time 7 | import torch 8 | 9 | import dataloaders 10 | import models 11 | from steps import train, validate 12 | 13 | print("I am process %s, running on %s: starting (%s)" % ( 14 | os.getpid(), os.uname()[1], time.asctime())) 15 | 16 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--data-train", type=str, default='', 18 | help="training data json") 19 | parser.add_argument("--data-val", type=str, default='', 20 | help="validation data json") 21 | parser.add_argument("--exp-dir", type=str, default="", 22 | help="directory to dump experiments") 23 | parser.add_argument("--resume", action="store_true", dest="resume", 24 | help="load from exp_dir if True") 25 | parser.add_argument("--optim", type=str, default="sgd", 26 | help="training optimizer", choices=["sgd", "adam"]) 27 | parser.add_argument('-b', '--batch-size', default=100, type=int, 28 | metavar='N', help='mini-batch size (default: 100)') 29 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 30 | metavar='LR', help='initial learning rate') 31 | parser.add_argument('--lr-decay', default=40, type=int, metavar='LRDECAY', 32 | help='Divide the learning rate by 10 every lr_decay epochs') 33 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 34 | help='momentum') 35 | parser.add_argument('--weight-decay', '--wd', default=5e-7, type=float, 36 | metavar='W', help='weight decay (default: 1e-4)') 37 | parser.add_argument("--n_epochs", type=int, default=100, 38 | help="number of maximum training epochs") 39 | parser.add_argument("--n_print_steps", type=int, default=100, 40 | help="number of steps to print statistics") 41 | parser.add_argument("--audio-model", type=str, default="Davenet", 42 | help="audio model architecture", choices=["Davenet"]) 43 | parser.add_argument("--image-model", type=str, default="VGG16", 44 | help="image model architecture", choices=["VGG16"]) 45 | parser.add_argument("--pretrained-image-model", action="store_true", 46 | dest="pretrained_image_model", help="Use an image network pretrained on ImageNet") 47 | parser.add_argument("--margin", type=float, default=1.0, help="Margin paramater for triplet loss") 48 | parser.add_argument("--simtype", type=str, default="MISA", 49 | help="matchmap similarity function", choices=["SISA", "MISA", "SIMA"]) 50 | 51 | args = parser.parse_args() 52 | 53 | resume = args.resume 54 | 55 | if args.resume: 56 | assert(bool(args.exp_dir)) 57 | with open("%s/args.pkl" % args.exp_dir, "rb") as f: 58 | args = pickle.load(f) 59 | args.resume = resume 60 | 61 | print(args) 62 | 63 | train_loader = torch.utils.data.DataLoader( 64 | dataloaders.ImageCaptionDataset(args.data_train), 65 | batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True) 66 | 67 | val_loader = torch.utils.data.DataLoader( 68 | dataloaders.ImageCaptionDataset(args.data_val, image_conf={'center_crop':True}), 69 | batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True) 70 | 71 | audio_model = models.Davenet() 72 | image_model = models.VGG16(pretrained=args.pretrained_image_model) 73 | 74 | if not bool(args.exp_dir): 75 | print("exp_dir not specified, automatically creating one...") 76 | args.exp_dir = "exp/Data-%s/AudioModel-%s_ImageModel-%s_Optim-%s_LR-%s_Epochs-%s" % ( 77 | os.path.basename(args.data_train), args.audio_model, args.image_model, args.optim, 78 | args.lr, args.n_epochs) 79 | 80 | if not args.resume: 81 | print("\nexp_dir: %s" % args.exp_dir) 82 | os.makedirs("%s/models" % args.exp_dir) 83 | with open("%s/args.pkl" % args.exp_dir, "wb") as f: 84 | pickle.dump(args, f) 85 | 86 | train(audio_model, image_model, train_loader, val_loader, args) 87 | -------------------------------------------------------------------------------- /steps/__init__.py: -------------------------------------------------------------------------------- 1 | from .traintest import * -------------------------------------------------------------------------------- /steps/traintest.py: -------------------------------------------------------------------------------- 1 | import time 2 | import shutil 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import pickle 7 | from .util import * 8 | 9 | def train(audio_model, image_model, train_loader, test_loader, args): 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | torch.set_grad_enabled(True) 12 | # Initialize all of the statistics we want to keep track of 13 | batch_time = AverageMeter() 14 | data_time = AverageMeter() 15 | loss_meter = AverageMeter() 16 | progress = [] 17 | best_epoch, best_acc = 0, -np.inf 18 | global_step, epoch = 0, 0 19 | start_time = time.time() 20 | exp_dir = args.exp_dir 21 | 22 | def _save_progress(): 23 | progress.append([epoch, global_step, best_epoch, best_acc, 24 | time.time() - start_time]) 25 | with open("%s/progress.pkl" % exp_dir, "wb") as f: 26 | pickle.dump(progress, f) 27 | 28 | # create/load exp 29 | if args.resume: 30 | progress_pkl = "%s/progress.pkl" % exp_dir 31 | progress, epoch, global_step, best_epoch, best_acc = load_progress(progress_pkl) 32 | print("\nResume training from:") 33 | print(" epoch = %s" % epoch) 34 | print(" global_step = %s" % global_step) 35 | print(" best_epoch = %s" % best_epoch) 36 | print(" best_acc = %.4f" % best_acc) 37 | 38 | if not isinstance(audio_model, torch.nn.DataParallel): 39 | audio_model = nn.DataParallel(audio_model) 40 | 41 | if not isinstance(image_model, torch.nn.DataParallel): 42 | image_model = nn.DataParallel(image_model) 43 | 44 | if epoch != 0: 45 | audio_model.load_state_dict(torch.load("%s/models/audio_model.%d.pth" % (exp_dir, epoch))) 46 | image_model.load_state_dict(torch.load("%s/models/image_model.%d.pth" % (exp_dir, epoch))) 47 | print("loaded parameters from epoch %d" % epoch) 48 | 49 | audio_model = audio_model.to(device) 50 | image_model = image_model.to(device) 51 | # Set up the optimizer 52 | audio_trainables = [p for p in audio_model.parameters() if p.requires_grad] 53 | image_trainables = [p for p in image_model.parameters() if p.requires_grad] 54 | trainables = audio_trainables + image_trainables 55 | if args.optim == 'sgd': 56 | optimizer = torch.optim.SGD(trainables, args.lr, 57 | momentum=args.momentum, 58 | weight_decay=args.weight_decay) 59 | elif args.optim == 'adam': 60 | optimizer = torch.optim.Adam(trainables, args.lr, 61 | weight_decay=args.weight_decay, 62 | betas=(0.95, 0.999)) 63 | else: 64 | raise ValueError('Optimizer %s is not supported' % args.optim) 65 | 66 | if epoch != 0: 67 | optimizer.load_state_dict(torch.load("%s/models/optim_state.%d.pth" % (exp_dir, epoch))) 68 | for state in optimizer.state.values(): 69 | for k, v in state.items(): 70 | if isinstance(v, torch.Tensor): 71 | state[k] = v.to(device) 72 | print("loaded state dict from epoch %d" % epoch) 73 | 74 | epoch += 1 75 | 76 | print("current #steps=%s, #epochs=%s" % (global_step, epoch)) 77 | print("start training...") 78 | 79 | audio_model.train() 80 | image_model.train() 81 | while True: 82 | adjust_learning_rate(args.lr, args.lr_decay, optimizer, epoch) 83 | end_time = time.time() 84 | audio_model.train() 85 | image_model.train() 86 | for i, (image_input, audio_input, nframes) in enumerate(train_loader): 87 | # measure data loading time 88 | data_time.update(time.time() - end_time) 89 | B = audio_input.size(0) 90 | 91 | audio_input = audio_input.to(device) 92 | image_input = image_input.to(device) 93 | 94 | optimizer.zero_grad() 95 | 96 | audio_output = audio_model(audio_input) 97 | image_output = image_model(image_input) 98 | 99 | pooling_ratio = round(audio_input.size(-1) / audio_output.size(-1)) 100 | nframes.div_(pooling_ratio) 101 | 102 | loss = sampled_margin_rank_loss(image_output, audio_output, 103 | nframes, margin=args.margin, simtype=args.simtype) 104 | 105 | loss.backward() 106 | optimizer.step() 107 | 108 | # record loss 109 | loss_meter.update(loss.item(), B) 110 | batch_time.update(time.time() - end_time) 111 | 112 | if global_step % args.n_print_steps == 0 and global_step != 0: 113 | print('Epoch: [{0}][{1}/{2}]\t' 114 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 115 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 116 | 'Loss total {loss_meter.val:.4f} ({loss_meter.avg:.4f})'.format( 117 | epoch, i, len(train_loader), batch_time=batch_time, 118 | data_time=data_time, loss_meter=loss_meter), flush=True) 119 | if np.isnan(loss_meter.avg): 120 | print("training diverged...") 121 | return 122 | 123 | end_time = time.time() 124 | global_step += 1 125 | 126 | recalls = validate(audio_model, image_model, test_loader, args) 127 | 128 | avg_acc = (recalls['A_r10'] + recalls['I_r10']) / 2 129 | 130 | torch.save(audio_model.state_dict(), 131 | "%s/models/audio_model.%d.pth" % (exp_dir, epoch)) 132 | torch.save(image_model.state_dict(), 133 | "%s/models/image_model.%d.pth" % (exp_dir, epoch)) 134 | torch.save(optimizer.state_dict(), "%s/models/optim_state.%d.pth" % (exp_dir, epoch)) 135 | 136 | if avg_acc > best_acc: 137 | best_epoch = epoch 138 | best_acc = avg_acc 139 | shutil.copyfile("%s/models/audio_model.%d.pth" % (exp_dir, epoch), 140 | "%s/models/best_audio_model.pth" % (exp_dir)) 141 | shutil.copyfile("%s/models/image_model.%d.pth" % (exp_dir, epoch), 142 | "%s/models/best_image_model.pth" % (exp_dir)) 143 | _save_progress() 144 | epoch += 1 145 | 146 | def validate(audio_model, image_model, val_loader, args): 147 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 148 | batch_time = AverageMeter() 149 | if not isinstance(audio_model, torch.nn.DataParallel): 150 | audio_model = nn.DataParallel(audio_model) 151 | if not isinstance(image_model, torch.nn.DataParallel): 152 | image_model = nn.DataParallel(image_model) 153 | audio_model = audio_model.to(device) 154 | image_model = image_model.to(device) 155 | # switch to evaluate mode 156 | image_model.eval() 157 | audio_model.eval() 158 | 159 | end = time.time() 160 | N_examples = val_loader.dataset.__len__() 161 | I_embeddings = [] 162 | A_embeddings = [] 163 | frame_counts = [] 164 | with torch.no_grad(): 165 | for i, (image_input, audio_input, nframes) in enumerate(val_loader): 166 | image_input = image_input.to(device) 167 | audio_input = audio_input.to(device) 168 | 169 | # compute output 170 | image_output = image_model(image_input) 171 | audio_output = audio_model(audio_input) 172 | 173 | image_output = image_output.to('cpu').detach() 174 | audio_output = audio_output.to('cpu').detach() 175 | 176 | I_embeddings.append(image_output) 177 | A_embeddings.append(audio_output) 178 | 179 | pooling_ratio = round(audio_input.size(-1) / audio_output.size(-1)) 180 | nframes.div_(pooling_ratio) 181 | 182 | frame_counts.append(nframes.cpu()) 183 | 184 | batch_time.update(time.time() - end) 185 | end = time.time() 186 | 187 | image_output = torch.cat(I_embeddings) 188 | audio_output = torch.cat(A_embeddings) 189 | nframes = torch.cat(frame_counts) 190 | 191 | recalls = calc_recalls(image_output, audio_output, nframes, simtype=args.simtype) 192 | A_r10 = recalls['A_r10'] 193 | I_r10 = recalls['I_r10'] 194 | A_r5 = recalls['A_r5'] 195 | I_r5 = recalls['I_r5'] 196 | A_r1 = recalls['A_r1'] 197 | I_r1 = recalls['I_r1'] 198 | 199 | print(' * Audio R@10 {A_r10:.3f} Image R@10 {I_r10:.3f} over {N:d} validation pairs' 200 | .format(A_r10=A_r10, I_r10=I_r10, N=N_examples), flush=True) 201 | print(' * Audio R@5 {A_r5:.3f} Image R@5 {I_r5:.3f} over {N:d} validation pairs' 202 | .format(A_r5=A_r5, I_r5=I_r5, N=N_examples), flush=True) 203 | print(' * Audio R@1 {A_r1:.3f} Image R@1 {I_r1:.3f} over {N:d} validation pairs' 204 | .format(A_r1=A_r1, I_r1=I_r1, N=N_examples), flush=True) 205 | 206 | return recalls 207 | -------------------------------------------------------------------------------- /steps/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | def calc_recalls(image_outputs, audio_outputs, nframes, simtype='MISA'): 7 | """ 8 | Computes recall at 1, 5, and 10 given encoded image and audio outputs. 9 | """ 10 | S = compute_matchmap_similarity_matrix(image_outputs, audio_outputs, nframes, simtype=simtype) 11 | n = S.size(0) 12 | A2I_scores, A2I_ind = S.topk(10, 0) 13 | I2A_scores, I2A_ind = S.topk(10, 1) 14 | A_r1 = AverageMeter() 15 | A_r5 = AverageMeter() 16 | A_r10 = AverageMeter() 17 | I_r1 = AverageMeter() 18 | I_r5 = AverageMeter() 19 | I_r10 = AverageMeter() 20 | for i in range(n): 21 | A_foundind = -1 22 | I_foundind = -1 23 | for ind in range(10): 24 | if A2I_ind[ind, i] == i: 25 | I_foundind = ind 26 | if I2A_ind[i, ind] == i: 27 | A_foundind = ind 28 | # do r1s 29 | if A_foundind == 0: 30 | A_r1.update(1) 31 | else: 32 | A_r1.update(0) 33 | if I_foundind == 0: 34 | I_r1.update(1) 35 | else: 36 | I_r1.update(0) 37 | # do r5s 38 | if A_foundind >= 0 and A_foundind < 5: 39 | A_r5.update(1) 40 | else: 41 | A_r5.update(0) 42 | if I_foundind >= 0 and I_foundind < 5: 43 | I_r5.update(1) 44 | else: 45 | I_r5.update(0) 46 | # do r10s 47 | if A_foundind >= 0 and A_foundind < 10: 48 | A_r10.update(1) 49 | else: 50 | A_r10.update(0) 51 | if I_foundind >= 0 and I_foundind < 10: 52 | I_r10.update(1) 53 | else: 54 | I_r10.update(0) 55 | 56 | recalls = {'A_r1':A_r1.avg, 'A_r5':A_r5.avg, 'A_r10':A_r10.avg, 57 | 'I_r1':I_r1.avg, 'I_r5':I_r5.avg, 'I_r10':I_r10.avg} 58 | #'A_meanR':A_meanR.avg, 'I_meanR':I_meanR.avg} 59 | 60 | return recalls 61 | 62 | def computeMatchmap(I, A): 63 | assert(I.dim() == 3) 64 | assert(A.dim() == 2) 65 | D = I.size(0) 66 | H = I.size(1) 67 | W = I.size(2) 68 | T = A.size(1) 69 | Ir = I.view(D, -1).t() 70 | matchmap = torch.mm(Ir, A) 71 | matchmap = matchmap.view(H, W, T) 72 | return matchmap 73 | 74 | def matchmapSim(M, simtype): 75 | assert(M.dim() == 3) 76 | if simtype == 'SISA': 77 | return M.mean() 78 | elif simtype == 'MISA': 79 | M_maxH, _ = M.max(0) 80 | M_maxHW, _ = M_maxH.max(0) 81 | return M_maxHW.mean() 82 | elif simtype == 'SIMA': 83 | M_maxT, _ = M.max(2) 84 | return M_maxT.mean() 85 | else: 86 | raise ValueError 87 | 88 | def sampled_margin_rank_loss(image_outputs, audio_outputs, nframes, margin=1., simtype='MISA'): 89 | """ 90 | Computes the triplet margin ranking loss for each anchor image/caption pair 91 | The impostor image/caption is randomly sampled from the minibatch 92 | """ 93 | assert(image_outputs.dim() == 4) 94 | assert(audio_outputs.dim() == 3) 95 | n = image_outputs.size(0) 96 | loss = torch.zeros(1, device=image_outputs.device, requires_grad=True) 97 | for i in range(n): 98 | I_imp_ind = i 99 | A_imp_ind = i 100 | while I_imp_ind == i: 101 | I_imp_ind = np.random.randint(0, n) 102 | while A_imp_ind == i: 103 | A_imp_ind = np.random.randint(0, n) 104 | nF = nframes[i] 105 | nFimp = nframes[A_imp_ind] 106 | anchorsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[i][:, 0:nF]), simtype) 107 | Iimpsim = matchmapSim(computeMatchmap(image_outputs[I_imp_ind], audio_outputs[i][:, 0:nF]), simtype) 108 | Aimpsim = matchmapSim(computeMatchmap(image_outputs[i], audio_outputs[A_imp_ind][:, 0:nFimp]), simtype) 109 | A2I_simdif = margin + Iimpsim - anchorsim 110 | if (A2I_simdif.data > 0).all(): 111 | loss = loss + A2I_simdif 112 | I2A_simdif = margin + Aimpsim - anchorsim 113 | if (I2A_simdif.data > 0).all(): 114 | loss = loss + I2A_simdif 115 | loss = loss / n 116 | return loss 117 | 118 | def compute_matchmap_similarity_matrix(image_outputs, audio_outputs, nframes, simtype='MISA'): 119 | """ 120 | Assumes image_outputs is a (batchsize, embedding_dim, rows, height) tensor 121 | Assumes audio_outputs is a (batchsize, embedding_dim, 1, time) tensor 122 | Returns similarity matrix S where images are rows and audios are along the columns 123 | """ 124 | assert(image_outputs.dim() == 4) 125 | assert(audio_outputs.dim() == 3) 126 | n = image_outputs.size(0) 127 | S = torch.zeros(n, n, device=image_outputs.device) 128 | for image_idx in range(n): 129 | for audio_idx in range(n): 130 | nF = max(1, nframes[audio_idx]) 131 | S[image_idx, audio_idx] = matchmapSim(computeMatchmap(image_outputs[image_idx], audio_outputs[audio_idx][:, 0:nF]), simtype) 132 | return S 133 | 134 | class AverageMeter(object): 135 | """Computes and stores the average and current value""" 136 | def __init__(self): 137 | self.reset() 138 | 139 | def reset(self): 140 | self.val = 0 141 | self.avg = 0 142 | self.sum = 0 143 | self.count = 0 144 | 145 | def update(self, val, n=1): 146 | self.val = val 147 | self.sum += val * n 148 | self.count += n 149 | self.avg = self.sum / self.count 150 | 151 | def adjust_learning_rate(base_lr, lr_decay, optimizer, epoch): 152 | """Sets the learning rate to the initial LR decayed by 10 every lr_decay epochs""" 153 | lr = base_lr * (0.1 ** (epoch // lr_decay)) 154 | for param_group in optimizer.param_groups: 155 | param_group['lr'] = lr 156 | 157 | def load_progress(prog_pkl, quiet=False): 158 | """ 159 | load progress pkl file 160 | Args: 161 | prog_pkl(str): path to progress pkl file 162 | Return: 163 | progress(list): 164 | epoch(int): 165 | global_step(int): 166 | best_epoch(int): 167 | best_avg_r10(float): 168 | """ 169 | def _print(msg): 170 | if not quiet: 171 | print(msg) 172 | 173 | with open(prog_pkl, "rb") as f: 174 | prog = pickle.load(f) 175 | epoch, global_step, best_epoch, best_avg_r10, _ = prog[-1] 176 | 177 | _print("\nPrevious Progress:") 178 | msg = "[%5s %7s %5s %7s %6s]" % ("epoch", "step", "best_epoch", "best_avg_r10", "time") 179 | _print(msg) 180 | return prog, epoch, global_step, best_epoch, best_avg_r10 181 | --------------------------------------------------------------------------------