├── utils ├── __init__.py └── util.py ├── model ├── ops │ ├── __init__.py │ ├── utils.py │ └── basic_ops.py ├── loss.py ├── metric.py └── models.py ├── trainer ├── __init__.py └── trainer.py ├── logger ├── __init__.py ├── logger.py ├── logger_config.json └── visualization.py ├── .gitignore ├── glove_840B_embeddings.npy ├── base ├── __init__.py ├── base_model.py ├── base_data_loader.py └── base_trainer.py ├── config_tsn.json ├── README.md ├── data_loader └── data_loaders.py ├── test_tsn.py ├── parse_config.py ├── train_tsn.py ├── dataset.py └── transforms.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import * 2 | -------------------------------------------------------------------------------- /model/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_ops import * -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | -------------------------------------------------------------------------------- /logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | from .visualization import * -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | *.pyc 4 | *.png 5 | *.txt 6 | *.avi 7 | *.tar 8 | log/ 9 | -------------------------------------------------------------------------------- /glove_840B_embeddings.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/filby89/NTUA-BEEU-eccv2020/HEAD/glove_840B_embeddings.npy -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_loader import * 2 | from .base_model import * 3 | from .base_trainer import * 4 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | from abc import abstractmethod 4 | 5 | 6 | class BaseModel(nn.Module): 7 | """ 8 | Base class for all models 9 | """ 10 | @abstractmethod 11 | def forward(self, *inputs): 12 | """ 13 | Forward pass logic 14 | 15 | :return: Model output 16 | """ 17 | raise NotImplementedError 18 | 19 | def __str__(self): 20 | """ 21 | Model prints with number of trainable parameters 22 | """ 23 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 24 | params = sum([np.prod(p.size()) for p in model_parameters]) 25 | return super().__str__() + '\nTrainable parameters: {}'.format(params) 26 | -------------------------------------------------------------------------------- /config_tsn.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "pycharm debug", 3 | "n_gpu": 4, 4 | 5 | "loss": "combined_loss", 6 | 7 | "loss_continuous": "mse_loss", 8 | "metrics": [ 9 | "average_precision", "roc_auc" 10 | ], 11 | 12 | "metrics_continuous": [ 13 | "r2", "mean_squared_error" 14 | ], 15 | 16 | "lr_scheduler": { 17 | "type": "MultiStepLR", 18 | "args": { 19 | "milestones": [20], 20 | "gamma": 0.1 21 | } 22 | }, 23 | 24 | "trainer": { 25 | "epochs": 50, 26 | 27 | "save_dir": "log", 28 | "save_period": 1, 29 | "verbosity": 2, 30 | 31 | "monitor": "min val_loss", 32 | "early_stop": 100, 33 | 34 | "tensorboard": true 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /logger/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config 3 | from pathlib import Path 4 | from utils import read_json 5 | 6 | 7 | def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO): 8 | """ 9 | Setup logging configuration 10 | """ 11 | log_config = Path(log_config) 12 | if log_config.is_file(): 13 | config = read_json(log_config) 14 | # modify logging paths based on run config 15 | for _, handler in config['handlers'].items(): 16 | if 'filename' in handler: 17 | handler['filename'] = str(save_dir / handler['filename']) 18 | 19 | logging.config.dictConfig(config) 20 | else: 21 | print("Warning: logging configuration file is not found in {}.".format(log_config)) 22 | logging.basicConfig(level=default_level) 23 | -------------------------------------------------------------------------------- /model/ops/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix 4 | 5 | def get_grad_hook(name): 6 | def hook(m, grad_in, grad_out): 7 | print((name, grad_out[0].data.abs().mean(), grad_in[0].data.abs().mean())) 8 | print((grad_out[0].size())) 9 | print((grad_in[0].size())) 10 | 11 | print((grad_out[0])) 12 | print((grad_in[0])) 13 | 14 | return hook 15 | 16 | 17 | def softmax(scores): 18 | es = np.exp(scores - scores.max(axis=-1)[..., None]) 19 | return es / es.sum(axis=-1)[..., None] 20 | 21 | 22 | def log_add(log_a, log_b): 23 | return log_a + np.log(1 + np.exp(log_b - log_a)) 24 | 25 | 26 | def class_accuracy(prediction, label): 27 | cf = confusion_matrix(prediction, label) 28 | cls_cnt = cf.sum(axis=1) 29 | cls_hit = np.diag(cf) 30 | 31 | cls_acc = cls_hit / cls_cnt.astype(float) 32 | 33 | mean_cls_acc = cls_acc.mean() 34 | 35 | return cls_acc, mean_cls_acc -------------------------------------------------------------------------------- /logger/logger_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "version": 1, 4 | "disable_existing_loggers": false, 5 | "formatters": { 6 | "simple": {"format": "%(message)s"}, 7 | "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"} 8 | }, 9 | "handlers": { 10 | "console": { 11 | "class": "logging.StreamHandler", 12 | "level": "DEBUG", 13 | "formatter": "simple", 14 | "stream": "ext://sys.stdout" 15 | }, 16 | "info_file_handler": { 17 | "class": "logging.handlers.RotatingFileHandler", 18 | "level": "INFO", 19 | "formatter": "datetime", 20 | "filename": "info.log", 21 | "maxBytes": 10485760, 22 | "backupCount": 20, "encoding": "utf8" 23 | } 24 | }, 25 | "root": { 26 | "level": "INFO", 27 | "handlers": [ 28 | "console", 29 | "info_file_handler" 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import Variable, Function 5 | 6 | def nll_loss(output, target): 7 | return F.nll_loss(output, target) 8 | 9 | 10 | def bce_loss(output, target): 11 | 12 | t = target.clone().detach() 13 | 14 | t[t >= 0.5] = 1 # threshold to get binary labels 15 | t[t < 0.5] = 0 16 | 17 | loss = F.binary_cross_entropy_with_logits(output, t) 18 | return loss 19 | 20 | def combined_loss(output, target): 21 | l = F.mse_loss(output, target) 22 | 23 | l += bce_loss(output, target) 24 | 25 | return l 26 | 27 | def mse_loss(output, target): 28 | return F.mse_loss(output, target) 29 | 30 | 31 | def mse_center_loss(output, target, labels): 32 | t = labels.clone().detach() 33 | t[t >= 0.5] = 1 # threshold to get binary labels 34 | t[t < 0.5] = 0 35 | 36 | target = target[0,:26] 37 | 38 | positive_centers = [] 39 | for i in range(output.size(0)): 40 | p = target[t[i, :] == 1] 41 | if p.size(0) == 0: 42 | positive_center = torch.zeros(300).cuda() 43 | else: 44 | positive_center = torch.mean(p, dim=0) 45 | 46 | positive_centers.append(positive_center) 47 | 48 | positive_centers = torch.stack(positive_centers,dim=0) 49 | loss = F.mse_loss(output, positive_centers) 50 | 51 | return loss -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sklearn.metrics 3 | 4 | import warnings 5 | 6 | def accuracy(output, target): 7 | with torch.no_grad(): 8 | pred = torch.argmax(output, dim=1) 9 | assert pred.shape[0] == len(target) 10 | correct = 0 11 | correct += torch.sum(pred == target).item() 12 | return correct / len(target) 13 | 14 | 15 | def top_k_acc(output, target, k=3): 16 | with torch.no_grad(): 17 | pred = torch.topk(output, k, dim=1)[1] 18 | assert pred.shape[0] == len(target) 19 | correct = 0 20 | for i in range(k): 21 | correct += torch.sum(pred[:, i] == target).item() 22 | return correct / len(target) 23 | 24 | 25 | def average_precision(output, target): 26 | return sklearn.metrics.average_precision_score(target, output, average=None) 27 | 28 | def multilabel_confusion_matrix(output, target): 29 | # with warnings.catch_warnings(): 30 | # warnings.simplefilter("ignore") 31 | return sklearn.metrics.multilabel_confusion_matrix(target, output) 32 | 33 | 34 | def roc_auc(output, target): 35 | # print(np.sum(target.cpu().detach().numpy(),axis=1),np.sum(target.cpu().detach().numpy(),axis=0)) 36 | # print(output.size()) 37 | return sklearn.metrics.roc_auc_score(target, output, average=None) 38 | 39 | 40 | def mean_squared_error(output, target): 41 | return sklearn.metrics.mean_squared_error(target, output, multioutput='raw_values') 42 | 43 | def r2(output, target): 44 | return sklearn.metrics.r2_score(target, output, multioutput='raw_values') 45 | 46 | def ERS(mR2, mAP, mRA): 47 | return 1/2 * (mR2 + 1/2 * (mAP + mRA)) -------------------------------------------------------------------------------- /model/ops/basic_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | warnings.filterwarnings("ignore") 5 | 6 | class Identity(torch.nn.Module): 7 | def forward(self, input): 8 | return input 9 | 10 | 11 | class SegmentConsensus(torch.autograd.Function): 12 | 13 | def __init__(self, consensus_type, dim=1): 14 | self.consensus_type = consensus_type 15 | self.dim = dim 16 | self.shape = None 17 | 18 | def forward(self, input_tensor): 19 | self.shape = input_tensor.size() 20 | if self.consensus_type == 'avg': 21 | output = input_tensor.mean(dim=self.dim, keepdim=True) 22 | elif self.consensus_type == 'max': 23 | output = input_tensor.max(dim=self.dim, keepdim=True)[0] 24 | print(output) 25 | elif self.consensus_type == 'identity': 26 | output = input_tensor 27 | else: 28 | output = None 29 | 30 | return output 31 | 32 | def backward(self, grad_output): 33 | if self.consensus_type == 'avg': 34 | grad_in = grad_output.expand(self.shape) / float(self.shape[self.dim]) 35 | elif self.consensus_type == 'identity': 36 | grad_in = grad_output 37 | else: 38 | grad_in = None 39 | 40 | return grad_in 41 | 42 | 43 | class ConsensusModule(torch.nn.Module): 44 | 45 | def __init__(self, consensus_type, dim=1): 46 | super(ConsensusModule, self).__init__() 47 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' 48 | self.dim = dim 49 | 50 | def forward(self, input): 51 | return SegmentConsensus(self.consensus_type, self.dim)(input) 52 | -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super().__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | if isinstance(split, int): 39 | assert split > 0 40 | assert split < self.n_samples, "validation set size is configured to be larger than entire dataset." 41 | len_valid = split 42 | else: 43 | len_valid = int(self.n_samples * split) 44 | 45 | valid_idx = idx_full[0:len_valid] 46 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 47 | 48 | train_sampler = SubsetRandomSampler(train_idx) 49 | valid_sampler = SubsetRandomSampler(valid_idx) 50 | 51 | # turn off shuffle option which is mutually exclusive with sampler 52 | self.shuffle = False 53 | self.n_samples = len(train_idx) 54 | 55 | return train_sampler, valid_sampler 56 | 57 | def split_validation(self): 58 | if self.valid_sampler is None: 59 | return None 60 | else: 61 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NTUA-BEEU-ECCV 2 | 3 | Winning solution for the BEEU (First International Workshop on Bodily Expressed Emotion Understanding) challenge organized at ECCV2020. Please read the accompanied paper for more details. 4 | 5 | \[Update\] 6 | For an updated version please check [this extension](https://github.com/GiannisPikoulis/FG2021-BoLD) achieving significantly better results with lighter models. 7 | 8 | ### Preparation 9 | * Download the [BoLD dataset](https://cydar.ist.psu.edu/emotionchallenge/index.php). 10 | * Use [https://github.com/yjxiong/temporal-segment-networks](https://github.com/yjxiong/temporal-segment-networks) in order to extract rgb and optical flow for the dataset. 11 | * Change the directories in "dataset.py" file. 12 | 13 | 14 | ### Training 15 | 16 | Train an RGB Temporal Segment Network on BoLD dataset: 17 | 18 | > python train_tsn.py -c config_tsn.json --modality "RGB" -b 32 --lr 1e-3 --arch resnet101 --workers 4 --num_segments 3 --exp_name "rgb tsn" -d 0,1,2,3 19 | 20 | Add context branch: 21 | 22 | > python train_tsn.py -c config_tsn.json --modality "RGB" -b 32 --lr 1e-3 --arch resnet101 --workers 4 --num_segments 3 --exp_name "rgb with context tsn" -d 0,1,2,3 --context 23 | 24 | Add visual embedding loss: 25 | 26 | > python train_tsn.py -c config_tsn.json --modality "RGB" -b 32 --lr 1e-3 --arch resnet101 --workers 4 --num_segments 3 --exp_name "rgb with context tsn" -d 0,1,2,3 --context --embed 27 | 28 | Change modality to Flow: 29 | 30 | > python train_tsn.py -c config_tsn.json --modality "Flow" -b 32 --lr 1e-3 --arch resnet101 --workers 4 --num_segments 3 --exp_name "rgb tsn" -d 0,1,2,3 31 | 32 | 33 | ### Pretrained Models 34 | We also offer weights of an RGB with context model with 0.2213 validation ERS and a Flow model with 0.2157 validation ERS. Their fusion achieves an ERS of 0.2613 on the test set. You can download the pretrained models [here](https://ntuagr-my.sharepoint.com/:f:/g/personal/filby_ntua_gr/EkFAi_QSn9NDsFTylvoAJrQBuvh6eQWkbgTuZcyMWWPR2w?e=xxw6h9). An example on how to use them is shown in test_tsn.py script: 35 | 36 | > python test_tsn.py --modality "RGB" --arch resnet101 --workers 4 --context 37 | > python test_tsn.py --modality "Flow" --arch resnet101 --workers 4 38 | 39 | 40 | ## Citation 41 | If you use this code for your research, consider citing our paper. 42 | ``` 43 | @inproceedings{NTUA_BEEU, 44 | title={Emotion Understanding in Videos Through Body, Context, and Visual-Semantic Embedding Loss}, 45 | author={Filntisis, Panagiotis Paraskevas and Efthymiou, Niki and Potamianos, Gerasimos and Maragos, Petros}, 46 | booktitle={ECCV Workshop on Bodily Expressed Emotion Understanding}, 47 | year={2020} 48 | } 49 | ``` 50 | 51 | ### Acknowlegements 52 | 53 | * [https://github.com/yjxiong/tsn-pytorch](https://github.com/yjxiong/tsn-pytorch) 54 | * [https://github.com/victoresque/pytorch-template](https://github.com/victoresque/pytorch-template) 55 | 56 | 57 | ### Contact 58 | For questions feel free to open an issue. 59 | -------------------------------------------------------------------------------- /data_loader/data_loaders.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from base import BaseDataLoader 3 | import dataset 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data.dataloader import default_collate 7 | import image_body_dataset 8 | import video_dataset 9 | 10 | class MnistDataLoader(BaseDataLoader): 11 | """ 12 | MNIST data loading demo using BaseDataLoader 13 | """ 14 | def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 15 | trsfm = transforms.Compose([ 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.1307,), (0.3081,)) 18 | ]) 19 | self.data_dir = data_dir 20 | self.dataset = datasets.MNIST(self.data_dir, train=training, download=True, transform=trsfm) 21 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 22 | 23 | 24 | 25 | def my_collate(batch): 26 | batch = filter (lambda x:x is not None, batch) 27 | return torch.utils.data.dataloader.default_collate(batch) 28 | 29 | def pad_collate(batch): 30 | batch = [x for x in batch if x is not None] 31 | # xx, face, hands_left, hands_right paths, targets, targets_continuous = zip(*batch) 32 | # xx, laban_body_component, embedding, places_features, paths, targets, targets_continuous = zip(*batch) 33 | xx, laban_body_component, targets, targets_continuous = zip(*batch) 34 | 35 | x_lens = [len(x) for x in xx] 36 | xx_pad = pad_sequence(xx, batch_first=True, padding_value=0) 37 | ll_pad = pad_sequence(laban_body_component, batch_first=True, padding_value=0) 38 | 39 | # xx_pad_face = pad_sequence(face, batch_first=True, padding_value=0) 40 | # xx_pad_hands_left = pad_sequence(hands_left, batch_first=True, padding_value=0) 41 | # xx_pad_hands_right = pad_sequence(hands_right, batch_first=True, padding_value=0) 42 | 43 | return xx_pad, ll_pad, default_collate(targets), default_collate(targets_continuous), torch.tensor(x_lens) 44 | 45 | class BoLDDataLoader(BaseDataLoader): 46 | """ 47 | BoLD data loading demo using BaseDataLoader 48 | """ 49 | def __init__(self, mode, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 50 | self.dataset = dataset.BoLD(mode=mode) 51 | if mode == "val": 52 | shuffle = False 53 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=pad_collate) 54 | 55 | class BoLDDataLoaderImage(BaseDataLoader): 56 | """ 57 | BoLD data loading demo using BaseDataLoader 58 | """ 59 | def __init__(self, mode, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True): 60 | self.dataset = image_body_dataset.BoLD(mode=mode) 61 | if mode == "val": 62 | shuffle = False 63 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 64 | 65 | 66 | class BoLDDataLoaderVideo(BaseDataLoader): 67 | """ 68 | BoLD data loading demo using BaseDataLoader 69 | """ 70 | def __init__(self, mode, batch_size, shuffle=True, validation_split=0.0, num_workers=1, **kwargs): 71 | self.dataset = video_dataset.VideoDataset(mode, **kwargs) 72 | if mode == "val": 73 | shuffle = False 74 | super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 75 | 76 | -------------------------------------------------------------------------------- /logger/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from datetime import datetime 3 | 4 | 5 | class TensorboardWriter(): 6 | def __init__(self, log_dir, logger, enabled): 7 | self.writer = None 8 | self.selected_module = "" 9 | self.log_dir = log_dir 10 | if enabled: 11 | log_dir = str(log_dir) 12 | 13 | # Retrieve vizualization writer. 14 | succeeded = False 15 | for module in ["torch.utils.tensorboard", "tensorboardX"]: 16 | try: 17 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 18 | succeeded = True 19 | break 20 | except ImportError: 21 | succeeded = False 22 | self.selected_module = module 23 | 24 | if not succeeded: 25 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 26 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 27 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 28 | logger.warning(message) 29 | 30 | self.step = 0 31 | self.mode = '' 32 | 33 | self.tb_writer_ftns = { 34 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 35 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding', 'add_figure' 36 | } 37 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 38 | self.timer = datetime.now() 39 | 40 | def save_results(self, output, name): 41 | import numpy as np 42 | import os 43 | np.save(os.path.join(self.log_dir,"%s_%d"%(name, self.step)), output) 44 | # np.save(os.path.join(self.log_dir,"output_continuous_%d"%self.step), output_continuous) 45 | 46 | def set_step(self, step, mode='train'): 47 | self.mode = mode 48 | self.step = step 49 | if step == 0: 50 | self.timer = datetime.now() 51 | else: 52 | duration = datetime.now() - self.timer 53 | self.add_scalar('steps_per_sec', 1 / duration.total_seconds()) 54 | self.timer = datetime.now() 55 | 56 | def __getattr__(self, name): 57 | """ 58 | If visualization is configured to use: 59 | return add_data() methods of tensorboard with additional information (step, tag) added. 60 | Otherwise: 61 | return a blank function handle that does nothing 62 | """ 63 | if name in self.tb_writer_ftns: 64 | add_data = getattr(self.writer, name, None) 65 | 66 | def wrapper(tag, data, *args, **kwargs): 67 | if add_data is not None: 68 | # add mode(train/valid) tag 69 | if name not in self.tag_mode_exceptions: 70 | tag = '{}/{}'.format(tag, self.mode) 71 | add_data(tag, data, self.step, *args, **kwargs) 72 | return wrapper 73 | else: 74 | # default action for returning methods defined in this class, set_step() for instance. 75 | try: 76 | attr = object.__getattr__(name) 77 | except AttributeError: 78 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 79 | return attr 80 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from pathlib import Path 4 | from itertools import repeat 5 | from collections import OrderedDict 6 | 7 | import matplotlib as mpl 8 | 9 | mpl.use('Agg') 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | import matplotlib.animation as animation 14 | def ensure_dir(dirname): 15 | dirname = Path(dirname) 16 | if not dirname.is_dir(): 17 | dirname.mkdir(parents=True, exist_ok=False) 18 | 19 | def read_json(fname): 20 | fname = Path(fname) 21 | with fname.open('rt') as handle: 22 | return json.load(handle, object_hook=OrderedDict) 23 | 24 | def write_json(content, fname): 25 | fname = Path(fname) 26 | with fname.open('wt') as handle: 27 | json.dump(content, handle, indent=4, sort_keys=False) 28 | 29 | def inf_loop(data_loader): 30 | ''' wrapper function for endless data loader. ''' 31 | for loader in repeat(data_loader): 32 | yield from loader 33 | 34 | class MetricTracker: 35 | def __init__(self, *keys, writer=None): 36 | self.writer = writer 37 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 38 | self.reset() 39 | 40 | def reset(self): 41 | for col in self._data.columns: 42 | self._data[col].values[:] = 0 43 | 44 | def update(self, key, value, n=1): 45 | if self.writer is not None: 46 | self.writer.add_scalar(key, value) 47 | self._data.total[key] += value * n 48 | self._data.counts[key] += n 49 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 50 | 51 | def avg(self, key): 52 | return self._data.average[key] 53 | 54 | def result(self): 55 | return dict(self._data.average) 56 | 57 | import matplotlib.pyplot as plt 58 | import matplotlib.animation as animation 59 | import numpy as np 60 | 61 | def visualize_skeleton_openpose_18(joints, filename="fig.png"): 62 | joints_edges = [[0,1], [1,2], [2,3], [3,4], [1,5], [5,6], [6,7], [1,8], [8,9], 63 | [9,10], [1,11], [11,12], [12,13], [0,14],[14,16], [0,15], [15,17]] 64 | 65 | 66 | joints[joints[:,:,2]<0.1] = np.nan 67 | joints[np.isnan(joints[:,:,2])] = np.nan 68 | 69 | 70 | # ani = animation.FuncAnimation(fig, update_plot, frames=range(len(sequence)), 71 | # fargs=(sequence, scat)) 72 | 73 | 74 | # plt.show() 75 | 76 | 77 | from celluloid import Camera 78 | 79 | fig = plt.figure() 80 | ax = fig.add_subplot(111) 81 | plt.gca().invert_yaxis() 82 | 83 | camera = Camera(fig) 84 | for frame in range(0,joints.shape[0]): 85 | 86 | scat = ax.scatter(joints[frame, :, 0], joints[frame, :, 1]) 87 | for edge in joints_edges: 88 | ax.plot((joints[frame, edge[0], 0], joints[frame, edge[1], 0]), 89 | (joints[frame, edge[0], 1], joints[frame, edge[1], 1])) 90 | 91 | camera.snap() 92 | 93 | animation = camera.animate(interval=30) 94 | plt.close() 95 | return animation 96 | 97 | 98 | 99 | 100 | def make_barplot(y, c, label): 101 | 102 | def autolabel(rects): 103 | """Attach a text label above each bar in *rects*, displaying its height.""" 104 | for rect in rects: 105 | height = rect.get_height() 106 | ax.annotate('{:.02f}'.format(height), 107 | xy=(rect.get_x() + rect.get_width() / 2, height), 108 | xytext=(0, 3), # 3 points vertical offset 109 | textcoords="offset points", 110 | ha='center', va='bottom') 111 | 112 | x = np.arange(len(c)) # the label locations 113 | 114 | width = 0.35 115 | fig, ax = plt.subplots(figsize=(8,6)) 116 | rects1 = ax.bar(x, y, width, label=label) 117 | 118 | autolabel(rects1) 119 | 120 | plt.xticks(rotation=90) 121 | 122 | ax.set_xticks(x) 123 | ax.set_xticklabels(c) 124 | plt.tight_layout() 125 | plt.close() 126 | 127 | return fig 128 | 129 | 130 | import cv2 131 | 132 | features_blobs = None 133 | 134 | def hook_feature(module, input, output): 135 | global features_blobs 136 | features_blobs = np.squeeze(output.data.cpu().numpy()) 137 | 138 | def returnCAM(feature_conv, weight_softmax, class_idx): 139 | # generate the class activation maps upsample to 256x256 140 | size_upsample = (256, 256) 141 | nc, h, w = feature_conv.shape 142 | output_cam = [] 143 | for idx in class_idx: 144 | cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc, h * w))) 145 | cam = cam.reshape(h, w) 146 | cam = cam - np.min(cam) 147 | cam_img = cam / np.max(cam) 148 | cam_img = np.uint8(255 * cam_img) 149 | output_cam.append(cv2.resize(cam_img, size_upsample)) 150 | return output_cam 151 | 152 | 153 | def setup_cam(model): 154 | features_names = ['layer4'] # this is the last conv layer of the resnet 155 | for name in features_names: 156 | model.module.base_model._modules.get(name).register_forward_hook(hook_feature) 157 | -------------------------------------------------------------------------------- /test_tsn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | import torch.nn.parallel 5 | import torch.optim 6 | from sklearn.metrics import confusion_matrix 7 | import torchvision 8 | from dataset import * 9 | from transforms import * 10 | from model.models import * 11 | from model.ops import ConsensusModule 12 | 13 | # options 14 | parser = argparse.ArgumentParser( 15 | description="Standard video-level testing") 16 | parser.add_argument('--modality', type=str, choices=['RGB', 'Flow', 'RGBDiff'], default="RGB") 17 | parser.add_argument('--weights', type=str) 18 | parser.add_argument('--arch', type=str, default="resnet101") 19 | parser.add_argument('--save_scores', type=str, default=None) 20 | parser.add_argument('--test_segments', type=int, default=25) 21 | parser.add_argument('--max_num', type=int, default=-1) 22 | parser.add_argument('--test_crops', type=int, default=10) 23 | parser.add_argument('--input_size', type=int, default=224) 24 | parser.add_argument('--crop_fusion_type', type=str, default='avg', 25 | choices=['avg', 'max', 'topk']) 26 | parser.add_argument('--k', type=int, default=3) 27 | parser.add_argument('--dropout', type=float, default=0.7) 28 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 29 | help='number of data loading workers (default: 4)') 30 | # parser.add_argument('--gpus', nargs='+', type=int, default=None) 31 | parser.add_argument('--flow_prefix', type=str, default='') 32 | parser.add_argument('--context', default=False, action="store_true") 33 | parser.add_argument('--categorical', default=True, action="store_true") 34 | parser.add_argument('--continuous', default=True, action="store_true") 35 | 36 | args = parser.parse_args() 37 | 38 | model = 'rgb_with_context_tsn.pth.tar' # 0.2157 39 | model = 'flow_tsn.pth.tar' # 0.2213 40 | 41 | 42 | if args.modality == 'RGB': 43 | data_length = 1 44 | elif args.modality == 'Flow': 45 | data_length = 5 46 | 47 | args.weights = model 48 | args.test_crops = 1 49 | args.test_segments = 25 50 | 51 | print(args) 52 | 53 | net = TSN(26, 1, args.modality, 54 | base_model=args.arch, new_length=data_length, 55 | consensus_type=args.crop_fusion_type, embed=True, context=args.context, 56 | dropout=args.dropout) 57 | 58 | features_blobs = [] 59 | 60 | checkpoint = torch.load(args.weights) 61 | 62 | base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())} 63 | a = net.load_state_dict(base_dict, strict=True) 64 | if args.test_crops == 1: 65 | cropping = torchvision.transforms.Compose([ 66 | GroupScale((224,224)), 67 | ]) 68 | elif args.test_crops == 10: 69 | cropping = torchvision.transforms.Compose([ 70 | GroupOverSample(224, 224) 71 | ]) 72 | else: 73 | raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops)) 74 | 75 | 76 | 77 | data_loader = torch.utils.data.DataLoader( 78 | TSNDataSet("test", num_segments=args.test_segments, context=args.context, 79 | new_length=1 if args.modality == "RGB" else 5, 80 | modality=args.modality, 81 | image_tmpl="img_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg", 82 | test_mode=True, 83 | transform=torchvision.transforms.Compose([ 84 | cropping, 85 | Stack(roll=args.arch == 'BNInception'), 86 | ToTorchFormatTensor(div=args.arch != 'BNInception'), 87 | GroupNormalize(net.input_mean, net.input_std), 88 | ])), 89 | batch_size=1, shuffle=False, 90 | num_workers=args.workers * 2, pin_memory=True) 91 | 92 | devices = [0] 93 | 94 | net = torch.nn.DataParallel(net.cuda(), device_ids=devices) 95 | net.eval() 96 | 97 | data_gen = enumerate(data_loader) 98 | 99 | total_num = len(data_loader.dataset) 100 | output = [] 101 | 102 | def eval_video(video_data): 103 | i, data, label, label_cont = video_data 104 | num_crop = args.test_crops 105 | 106 | if args.modality == 'RGB': 107 | length = 3 108 | elif args.modality == 'Flow': 109 | length = 10 110 | elif args.modality == 'RGBDiff': 111 | length = 18 112 | else: 113 | raise ValueError("Unknown modality "+args.modality) 114 | 115 | input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)), 116 | volatile=True) 117 | 118 | out = net(input_var, None) 119 | rst = torch.sigmoid(out['categorical']).data.cpu().numpy().copy() 120 | rst_cont = torch.sigmoid(out['continuous']).data.cpu().numpy().copy() 121 | 122 | return i, rst.reshape((num_crop, args.test_segments, 26)).mean(axis=0).reshape( 123 | (args.test_segments, 1, 26) 124 | ), rst_cont.reshape((num_crop, args.test_segments, 3)).mean(axis=0).reshape( 125 | (args.test_segments, 1, 3) 126 | ) 127 | 128 | proc_start_time = time.time() 129 | max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset) 130 | 131 | import random 132 | 133 | for i, batch in data_gen: 134 | # print(batch) 135 | data, embeddings = batch 136 | 137 | rst = eval_video((i, data, None, None)) 138 | output.append(rst[1:]) 139 | cnt_time = time.time() - proc_start_time 140 | 141 | print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1, 142 | total_num, 143 | float(cnt_time) / (i+1))) 144 | 145 | video_pred = np.squeeze(np.array([np.mean(x[0], axis=0) for x in output])) 146 | video_pred_cont = np.squeeze(np.array([np.mean(x[1], axis=0) for x in output])) 147 | print(video_pred.shape, video_pred_cont.shape) 148 | -------------------------------------------------------------------------------- /parse_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from pathlib import Path 4 | from functools import reduce, partial 5 | from operator import getitem 6 | from datetime import datetime 7 | from logger import setup_logging 8 | from utils import read_json, write_json 9 | 10 | 11 | class ConfigParser: 12 | def __init__(self, config, resume=None, modification=None, run_id=None): 13 | """ 14 | class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving 15 | and logging module. 16 | :param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example. 17 | :param resume: String, path to the checkpoint being loaded. 18 | :param modification: Dict keychain:value, specifying position values to be replaced from config dict. 19 | :param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default 20 | """ 21 | # load config file and apply modification 22 | self._config = _update_config(config, modification) 23 | self.resume = resume 24 | 25 | # set save_dir where trained model and log will be saved. 26 | save_dir = Path(self.config['trainer']['save_dir']) 27 | 28 | exper_name = self.config['name'] 29 | if run_id is None: # use timestamp as default run-id 30 | run_id = datetime.now().strftime(r'%m%d_%H%M%S') 31 | self._save_dir = save_dir / 'models' / exper_name / run_id 32 | self._log_dir = save_dir / 'log' / exper_name / run_id 33 | 34 | # make directory for saving checkpoints and log. 35 | exist_ok = run_id == '' 36 | self.save_dir.mkdir(parents=True, exist_ok=exist_ok) 37 | self.log_dir.mkdir(parents=True, exist_ok=exist_ok) 38 | 39 | # save updated config file to the checkpoint dir 40 | write_json(self.config, self.save_dir / 'config.json') 41 | 42 | # configure logging module 43 | setup_logging(self.log_dir) 44 | self.log_levels = { 45 | 0: logging.WARNING, 46 | 1: logging.INFO, 47 | 2: logging.DEBUG 48 | } 49 | 50 | @classmethod 51 | def from_args(cls, args, options=''): 52 | """ 53 | Initialize this class from some cli arguments. Used in train, test. 54 | """ 55 | for opt in options: 56 | args.add_argument(*opt.flags, default=None, type=opt.type) 57 | if not isinstance(args, tuple): 58 | args = args.parse_args() 59 | 60 | if args.device is not None: 61 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 62 | if args.resume is not None: 63 | resume = Path(args.resume) 64 | cfg_fname = resume.parent / 'config.json' 65 | else: 66 | msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example." 67 | assert args.config is not None, msg_no_cfg 68 | resume = None 69 | cfg_fname = Path(args.config) 70 | 71 | config = read_json(cfg_fname) 72 | if args.config and resume: 73 | # update new config for fine-tuning 74 | config.update(read_json(args.config)) 75 | 76 | # parse custom cli options into dictionary 77 | modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options} 78 | return cls(config, resume, modification) 79 | 80 | def init_obj(self, name, module, *args, **kwargs): 81 | """ 82 | Finds a function handle with the name given as 'type' in config, and returns the 83 | instance initialized with corresponding arguments given. 84 | 85 | `object = config.init_obj('name', module, a, b=1)` 86 | is equivalent to 87 | `object = module.name(a, b=1)` 88 | """ 89 | module_name = self[name]['type'] 90 | module_args = dict(self[name]['args']) 91 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 92 | module_args.update(kwargs) 93 | return getattr(module, module_name)(*args, **module_args) 94 | 95 | def init_ftn(self, name, module, *args, **kwargs): 96 | """ 97 | Finds a function handle with the name given as 'type' in config, and returns the 98 | function with given arguments fixed with functools.partial. 99 | 100 | `function = config.init_ftn('name', module, a, b=1)` 101 | is equivalent to 102 | `function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. 103 | """ 104 | module_name = self[name]['type'] 105 | module_args = dict(self[name]['args']) 106 | assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed' 107 | module_args.update(kwargs) 108 | return partial(getattr(module, module_name), *args, **module_args) 109 | 110 | def __getitem__(self, name): 111 | """Access items like ordinary dict.""" 112 | return self.config[name] 113 | 114 | def get_logger(self, name, verbosity=2): 115 | msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys()) 116 | assert verbosity in self.log_levels, msg_verbosity 117 | logger = logging.getLogger(name) 118 | logger.setLevel(self.log_levels[verbosity]) 119 | return logger 120 | 121 | # setting read-only attributes 122 | @property 123 | def config(self): 124 | return self._config 125 | 126 | @property 127 | def save_dir(self): 128 | return self._save_dir 129 | 130 | @property 131 | def log_dir(self): 132 | return self._log_dir 133 | 134 | # helper functions to update config dict with custom cli options 135 | def _update_config(config, modification): 136 | if modification is None: 137 | return config 138 | 139 | for k, v in modification.items(): 140 | if v is not None: 141 | _set_by_path(config, k, v) 142 | return config 143 | 144 | def _get_opt_name(flags): 145 | for flg in flags: 146 | if flg.startswith('--'): 147 | return flg.replace('--', '') 148 | return flags[0].replace('--', '') 149 | 150 | def _set_by_path(tree, keys, value): 151 | """Set a value in a nested object in tree by sequence of keys.""" 152 | keys = keys.split(';') 153 | _get_by_path(tree, keys[:-1])[keys[-1]] = value 154 | 155 | def _get_by_path(tree, keys): 156 | """Access a nested object in tree by sequence of keys.""" 157 | return reduce(getitem, keys, tree) 158 | -------------------------------------------------------------------------------- /train_tsn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import torch 4 | import numpy as np 5 | import model.loss as module_loss 6 | import model.metric as module_metric 7 | from parse_config import ConfigParser 8 | from transforms import * 9 | from logger import setup_logging 10 | from model import loss 11 | from trainer.trainer import Trainer 12 | from dataset import TSNDataSet 13 | from model.models import TSN 14 | 15 | # fix random seeds for reproducibility 16 | SEED = 123 17 | torch.manual_seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | np.random.seed(SEED) 21 | 22 | def main(args, config): 23 | if args.modality == 'RGB': 24 | data_length = 1 25 | elif args.modality == 'Flow': 26 | data_length = 5 27 | 28 | 29 | model = TSN(26, args.num_segments, args.modality, 30 | base_model=args.arch, new_length=data_length, embed=args.embed, 31 | consensus_type=args.consensus_type, dropout=args.dropout, partial_bn=not args.no_partialbn, context=args.context) 32 | 33 | input_mean = model.input_mean 34 | input_std = model.input_std 35 | policies = model.get_optim_policies() 36 | 37 | normalize = GroupNormalize(input_mean, input_std) 38 | 39 | dataset = TSNDataSet("train", num_segments=args.num_segments, 40 | context=args.context, 41 | new_length=data_length, 42 | modality=args.modality, 43 | image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB"] else args.flow_prefix+"{}_{:05d}.jpg", 44 | transform=torchvision.transforms.Compose([ 45 | GroupScale((224,224)), 46 | Stack(roll=args.arch == 'BNInception'), 47 | ToTorchFormatTensor(div=args.arch != 'BNInception'), 48 | normalize, 49 | ])) 50 | 51 | 52 | 53 | train_loader = torch.utils.data.DataLoader( 54 | dataset, 55 | batch_size=args.batch_size, shuffle=True, 56 | num_workers=args.workers, pin_memory=True, drop_last=False) 57 | 58 | val_loader = torch.utils.data.DataLoader( 59 | TSNDataSet("val", num_segments=args.num_segments, 60 | context=args.context, 61 | new_length=data_length, 62 | modality=args.modality, 63 | image_tmpl="img_{:05d}.jpg" if args.modality in ["RGB"] else args.flow_prefix+"{}_{:05d}.jpg", 64 | random_shift=False, 65 | transform=torchvision.transforms.Compose([ 66 | GroupScale((int(224),int(224))), 67 | Stack(roll=args.arch == 'BNInception'), 68 | ToTorchFormatTensor(div=args.arch != 'BNInception'), 69 | normalize, 70 | ])), 71 | batch_size=args.batch_size, shuffle=False, 72 | num_workers=args.workers, pin_memory=True) 73 | 74 | 75 | logger = config.get_logger('train') 76 | logger.info(model) 77 | 78 | 79 | # get function handles of loss and metrics 80 | criterion_categorical = getattr(module_loss, config['loss']) 81 | criterion_continuous = getattr(module_loss, config['loss_continuous']) 82 | 83 | metrics = [getattr(module_metric, met) for met in config['metrics']] 84 | metrics_continuous = [getattr(module_metric, met) for met in config['metrics_continuous']] 85 | 86 | optimizer = torch.optim.SGD(policies, 87 | args.lr, 88 | momentum=args.momentum, 89 | weight_decay=args.weight_decay) 90 | 91 | lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer) 92 | 93 | for param_group in optimizer.param_groups: 94 | print(param_group['lr']) 95 | 96 | trainer = Trainer(model, criterion_categorical, criterion_continuous, metrics, metrics_continuous, optimizer, 97 | config=config, 98 | data_loader=train_loader, 99 | valid_data_loader=val_loader, 100 | lr_scheduler=lr_scheduler, embed=args.embed) 101 | 102 | trainer.train() 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser(description='PyTorch Template') 107 | parser.add_argument('-c', '--config', default=None, type=str, 108 | help='config file path (default: None)') 109 | parser.add_argument('-r', '--resume', default=None, type=str, 110 | help='path to latest checkpoint (default: None)') 111 | parser.add_argument('-d', '--device', default=None, type=str, 112 | help='indices of GPUs to enable (default: all)') 113 | 114 | parser.add_argument('--modality', type=str, choices=['RGB', 'Flow', 'RGBDiff', 'depth']) 115 | 116 | # ========================= Model Configs ========================== 117 | parser.add_argument('--arch', type=str, default="resnet101") 118 | parser.add_argument('--num_segments', type=int, default=3) 119 | parser.add_argument('--consensus_type', type=str, default='avg', 120 | choices=['avg', 'max', 'topk', 'identity', 'rnn', 'cnn']) 121 | parser.add_argument('--k', type=int, default=3) 122 | 123 | parser.add_argument('--dropout', '--do', default=0.5, type=float, 124 | metavar='DO', help='dropout ratio (default: 0.5)') 125 | 126 | # ========================= Learning Configs ========================== 127 | parser.add_argument('-b', '--batch-size', default=32, type=int, 128 | metavar='N', help='mini-batch size (default: 256)') 129 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 130 | metavar='LR', help='initial learning rate') 131 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 132 | help='momentum') 133 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 134 | metavar='W', help='weight decay (default: 5e-4)') 135 | parser.add_argument('--clip-gradient', '--gd', default=None, type=float, 136 | metavar='W', help='gradient norm clipping (default: disabled)') 137 | parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") 138 | parser.add_argument('--context', default=False, action="store_true") 139 | parser.add_argument('--embed', default=False, action="store_true") 140 | 141 | # ========================= Monitor Configs ========================== 142 | parser.add_argument('--print-freq', '-p', default=20, type=int, 143 | metavar='N', help='print frequency (default: 10)') 144 | parser.add_argument('--eval-freq', '-ef', default=5, type=int, 145 | metavar='N', help='evaluation frequency (default: 5)') 146 | 147 | 148 | # ========================= Runtime Configs ========================== 149 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 150 | help='number of data loading workers (default: 4)') 151 | 152 | parser.add_argument('--flow_prefix', default="", type=str) 153 | 154 | 155 | # custom cli options to modify configuration from default values given in json file. 156 | CustomArgs = collections.namedtuple('CustomArgs', 'flags type target') 157 | options = [ 158 | CustomArgs(['--exp_name'], type=str, target='name'), 159 | ] 160 | config = ConfigParser.from_args(parser, options) 161 | print(config) 162 | 163 | args = parser.parse_args() 164 | 165 | main(args, config) 166 | -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.utils import make_grid 4 | from base import BaseTrainer 5 | from utils import inf_loop, MetricTracker, make_barplot 6 | import matplotlib as mpl 7 | import random 8 | 9 | mpl.use('Agg') 10 | import matplotlib.pyplot as plt 11 | import model.metric 12 | import model.loss 13 | 14 | 15 | class Trainer(BaseTrainer): 16 | """ 17 | Trainer class 18 | """ 19 | def __init__(self, model, criterion, criterion_continuous, metric_ftns, metric_ftns_continuous, optimizer, config, data_loader, categorical=True, continuous=True, 20 | valid_data_loader=None, lr_scheduler=None, len_epoch=None, embed=False): 21 | super().__init__(model, criterion, metric_ftns, optimizer, config) 22 | self.data_loader = data_loader 23 | self.categorical = categorical 24 | self.continuous = continuous 25 | 26 | if len_epoch is None: 27 | # epoch-based training 28 | self.len_epoch = len(self.data_loader) 29 | else: 30 | # iteration-based training 31 | self.data_loader = inf_loop(data_loader) 32 | self.len_epoch = len_epoch 33 | 34 | self.valid_data_loader = valid_data_loader 35 | self.do_validation = self.valid_data_loader is not None 36 | self.lr_scheduler = lr_scheduler 37 | self.log_step = int(np.sqrt(data_loader.batch_size)) 38 | 39 | self.metric_ftns_continuous = metric_ftns_continuous 40 | 41 | self.criterion_continuous = criterion_continuous 42 | self.criterion_categorical = criterion 43 | 44 | self.categorical_class_metrics = [_class + "_" + m.__name__ for _class in valid_data_loader.dataset.categorical_emotions for m in self.metric_ftns] 45 | 46 | self.continuous_class_metrics = [_class + "_" + m.__name__ for _class in valid_data_loader.dataset.continuous_emotions for m in self.metric_ftns_continuous] 47 | 48 | self.train_metrics = MetricTracker('mre', 'loss', 'loss_categorical', 'loss_continuous', 'loss_embed', 49 | 'map', 'mse', 'r2', 'roc_auc', writer=self.writer) 50 | self.valid_metrics = MetricTracker('mre', 'loss', 'loss_categorical', 'loss_continuous', 'loss_embed', 51 | 'map', 'mse', 'r2', 'roc_auc', writer=self.writer) 52 | 53 | self.embed = embed 54 | 55 | def _train_epoch(self, epoch, phase="train"): 56 | """ 57 | Training logic for an epoch 58 | 59 | :param epoch: Integer, current training epoch. 60 | :return: A log that contains average loss and metric in this epoch. 61 | """ 62 | import torch.nn.functional as F 63 | import model.loss 64 | print("Finding LR") 65 | for param_group in self.optimizer.param_groups: 66 | print(param_group['lr']) 67 | 68 | if phase == "train": 69 | self.model.train() 70 | self.train_metrics.reset() 71 | torch.set_grad_enabled(True) 72 | metrics = self.train_metrics 73 | elif phase == "val": 74 | self.model.eval() 75 | self.valid_metrics.reset() 76 | torch.set_grad_enabled(False) 77 | metrics = self.valid_metrics 78 | 79 | outputs = [] 80 | outputs_continuous = [] 81 | targets = [] 82 | targets_continuous = [] 83 | 84 | data_loader = self.data_loader if phase == "train" else self.valid_data_loader 85 | 86 | for batch_idx, (data, embeddings, target, target_continuous, lengths) in enumerate(data_loader): 87 | 88 | data, target, target_continuous = data.to(self.device), target.to(self.device), target_continuous.to(self.device) 89 | embeddings = embeddings.to(self.device) 90 | 91 | if phase == "train": 92 | self.optimizer.zero_grad() 93 | 94 | out = self.model(data, embeddings) 95 | 96 | loss = 0 97 | 98 | loss_categorical = self.criterion_categorical(out['categorical'], target) 99 | loss += loss_categorical 100 | 101 | loss_continuous = self.criterion_continuous(torch.sigmoid(out['continuous']), target_continuous) 102 | loss += loss_continuous 103 | 104 | if self.embed: 105 | loss_embed = model.loss.mse_center_loss(out['embed'], embeddings, target) 106 | loss += loss_embed 107 | 108 | if phase == "train": 109 | loss.backward() 110 | self.optimizer.step() 111 | 112 | output = out['categorical'].cpu().detach().numpy() 113 | target = target.cpu().detach().numpy() 114 | outputs.append(output) 115 | targets.append(target) 116 | 117 | output_continuous = torch.sigmoid(out['continuous']).cpu().detach().numpy() 118 | target_continuous = target_continuous.cpu().detach().numpy() 119 | outputs_continuous.append(output_continuous) 120 | targets_continuous.append(target_continuous) 121 | 122 | if batch_idx % self.log_step == 0: 123 | self.logger.debug('{} Epoch: {} {} Loss: {:.6f} Loss categorical: {:.6f} Loss continuous: {:.6f}'.format( 124 | phase, 125 | epoch, 126 | self._progress(batch_idx), 127 | loss.item(),loss_categorical.item(), loss_continuous.item())) 128 | 129 | if batch_idx == self.len_epoch: 130 | break 131 | 132 | if phase == "train": 133 | self.writer.set_step(epoch) 134 | else: 135 | self.writer.set_step(epoch, "valid") 136 | 137 | metrics.update('loss', loss.item()) 138 | 139 | metrics.update('loss_categorical', loss_categorical.item()) 140 | if self.embed: 141 | metrics.update('loss_embed', loss_embed.item()) 142 | 143 | output = np.vstack(outputs) 144 | target = np.vstack(targets) 145 | target[target>=0.5] = 1 # threshold to get binary labels 146 | target[target<0.5] = 0 147 | 148 | ap = model.metric.average_precision(output, target) 149 | roc_auc = model.metric.roc_auc(output, target) 150 | metrics.update("map", np.mean(ap)) 151 | metrics.update("roc_auc", np.mean(roc_auc)) 152 | 153 | self.writer.add_figure('%s ap per class' % phase, make_barplot(ap, self.valid_data_loader.dataset.categorical_emotions, 'average_precision')) 154 | self.writer.add_figure('%s roc auc per class' % phase, make_barplot(roc_auc, self.valid_data_loader.dataset.categorical_emotions, 'roc auc')) 155 | 156 | metrics.update('loss_continuous', loss_continuous.item()) 157 | output_continuous = np.vstack(outputs_continuous) 158 | target_continuous = np.vstack(targets_continuous) 159 | 160 | mse = model.metric.mean_squared_error(output_continuous, target_continuous) 161 | r2 = model.metric.r2(output_continuous, target_continuous) 162 | 163 | metrics.update("r2", np.mean(r2)) 164 | metrics.update("mse", np.mean(mse)) 165 | 166 | self.writer.add_figure('%s r2 per class' % phase, make_barplot(r2, self.valid_data_loader.dataset.continuous_emotions, 'r2')) 167 | self.writer.add_figure('%s mse auc per class' % phase, make_barplot(mse, self.valid_data_loader.dataset.continuous_emotions, 'mse')) 168 | 169 | metrics.update("mre", model.metric.ERS(np.mean(r2), np.mean(ap), np.mean(roc_auc))) 170 | 171 | log = metrics.result() 172 | 173 | if phase == "train": 174 | if self.lr_scheduler is not None: 175 | self.lr_scheduler.step() 176 | 177 | if self.do_validation: 178 | val_log = self._train_epoch(epoch, phase="val") 179 | log.update(**{'val_' + k: v for k, v in val_log.items()}) 180 | 181 | return log 182 | 183 | elif phase == "val": 184 | if self.categorical: 185 | self.writer.save_results(output, "output") 186 | if self.continuous: 187 | self.writer.save_results(output_continuous, "output_continuous") 188 | 189 | return metrics.result() 190 | 191 | 192 | def _progress(self, batch_idx): 193 | base = '[{}/{} ({:.0f}%)]' 194 | if hasattr(self.data_loader, 'n_samples'): 195 | current = batch_idx * self.data_loader.batch_size 196 | total = self.data_loader.n_samples 197 | else: 198 | current = batch_idx 199 | total = self.len_epoch 200 | return base.format(current, total, 100.0 * current / total) 201 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod 3 | from numpy import inf 4 | from logger import TensorboardWriter 5 | 6 | class BaseTrainer: 7 | """ 8 | Base class for all trainers 9 | """ 10 | def __init__(self, model, criterion, metric_ftns, optimizer, config): 11 | self.config = config 12 | self.logger = config.get_logger('trainer', config['trainer']['verbosity']) 13 | 14 | # setup GPU device if available, move model into configured device 15 | self.device, device_ids = self._prepare_device(config['n_gpu']) 16 | self.model = model.to(self.device) 17 | if len(device_ids) > 1: 18 | self.model = torch.nn.DataParallel(model, device_ids=device_ids) 19 | 20 | self.criterion = criterion 21 | self.metric_ftns = metric_ftns 22 | self.optimizer = optimizer 23 | 24 | cfg_trainer = config['trainer'] 25 | self.epochs = cfg_trainer['epochs'] 26 | self.save_period = cfg_trainer['save_period'] 27 | self.monitor = cfg_trainer.get('monitor', 'off') 28 | 29 | # configuration to monitor model performance and save best 30 | if self.monitor == 'off': 31 | self.mnt_mode = 'off' 32 | self.mnt_best = 0 33 | else: 34 | self.mnt_mode, self.mnt_metric = self.monitor.split() 35 | assert self.mnt_mode in ['min', 'max'] 36 | 37 | self.mnt_best = inf if self.mnt_mode == 'min' else -inf 38 | self.early_stop = cfg_trainer.get('early_stop', inf) 39 | 40 | self.start_epoch = 1 41 | 42 | self.checkpoint_dir = config.save_dir 43 | 44 | # setup visualization writer instance 45 | self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) 46 | 47 | if config.resume is not None: 48 | self._resume_checkpoint(config.resume) 49 | 50 | @abstractmethod 51 | def _train_epoch(self, epoch): 52 | """ 53 | Training logic for an epoch 54 | 55 | :param epoch: Current epoch number 56 | """ 57 | raise NotImplementedError 58 | 59 | def train(self): 60 | """ 61 | Full training logic 62 | """ 63 | not_improved_count = 0 64 | for epoch in range(self.start_epoch, self.epochs + 1): 65 | result = self._train_epoch(epoch) 66 | 67 | # save logged informations into log dict 68 | log = {'epoch': epoch} 69 | log.update(result) 70 | 71 | # print logged informations to the screen 72 | for key, value in log.items(): 73 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 74 | 75 | # evaluate model performance according to configured metric, save best checkpoint as model_best 76 | best = False 77 | if self.mnt_mode != 'off': 78 | try: 79 | # check whether model performance improved or not, according to specified metric(mnt_metric) 80 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ 81 | (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) 82 | except KeyError: 83 | self.logger.warning("Warning: Metric '{}' is not found. " 84 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 85 | self.mnt_mode = 'off' 86 | improved = False 87 | 88 | if improved: 89 | self.mnt_best = log[self.mnt_metric] 90 | not_improved_count = 0 91 | best = True 92 | else: 93 | not_improved_count += 1 94 | 95 | if not_improved_count > self.early_stop: 96 | self.logger.info("Validation performance didn\'t improve for {} epochs. " 97 | "Training stops.".format(self.early_stop)) 98 | break 99 | 100 | if epoch % self.save_period == 0: 101 | self._save_checkpoint(epoch, save_best=best) 102 | 103 | def _prepare_device(self, n_gpu_use): 104 | """ 105 | setup GPU device if available, move model into configured device 106 | """ 107 | n_gpu = torch.cuda.device_count() 108 | if n_gpu_use > 0 and n_gpu == 0: 109 | self.logger.warning("Warning: There\'s no GPU available on this machine," 110 | "training will be performed on CPU.") 111 | n_gpu_use = 0 112 | if n_gpu_use > n_gpu: 113 | self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available " 114 | "on this machine.".format(n_gpu_use, n_gpu)) 115 | n_gpu_use = n_gpu 116 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 117 | list_ids = list(range(n_gpu_use)) 118 | return device, list_ids 119 | 120 | def _save_checkpoint(self, epoch, save_best=False): 121 | """ 122 | Saving checkpoints 123 | 124 | :param epoch: current epoch number 125 | :param log: logging information of the epoch 126 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 127 | """ 128 | arch = type(self.model).__name__ 129 | state = { 130 | 'arch': arch, 131 | 'epoch': epoch, 132 | 'state_dict': self.model.state_dict(), 133 | 'optimizer': self.optimizer.state_dict(), 134 | 'monitor_best': self.mnt_best, 135 | 'config': self.config 136 | } 137 | filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch)) 138 | torch.save(state, filename) 139 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 140 | if save_best: 141 | best_path = str(self.checkpoint_dir / 'model_best.pth') 142 | torch.save(state, best_path) 143 | self.logger.info("Saving current best: model_best.pth ...") 144 | 145 | def _resume_checkpoint(self, resume_path): 146 | """ 147 | Resume from saved checkpoints 148 | 149 | :param resume_path: Checkpoint path to be resumed 150 | """ 151 | resume_path = str(resume_path) 152 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 153 | checkpoint = torch.load(resume_path) 154 | self.start_epoch = checkpoint['epoch'] + 1 155 | self.mnt_best = checkpoint['monitor_best'] 156 | 157 | # load architecture params from checkpoint. 158 | if checkpoint['config']['arch'] != self.config['arch']: 159 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 160 | "checkpoint. This may yield an exception while state_dict is being loaded.") 161 | self.model.load_state_dict(checkpoint['state_dict']) 162 | 163 | # load optimizer state from checkpoint only when optimizer type is not changed. 164 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 165 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 166 | "Optimizer parameters not being resumed.") 167 | else: 168 | self.optimizer.load_state_dict(checkpoint['optimizer']) 169 | 170 | self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)) 171 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional 3 | import torch 4 | from .ops.basic_ops import ConsensusModule, Identity 5 | from torch.nn.init import normal, constant 6 | from torch.nn import Parameter 7 | import torchvision 8 | import numpy as np 9 | 10 | class TSN(nn.Module): 11 | def __init__(self, num_class, num_segments, modality, 12 | base_model='resnet18', new_length=None, 13 | consensus_type='avg', before_softmax=True, 14 | dropout=0.8, modalities_fusion='cat', 15 | crop_num=1, partial_bn=True, context=False, embed=False): 16 | super(TSN, self).__init__() 17 | self.modality = modality 18 | self.num_segments = num_segments 19 | self.reshape = True 20 | self.before_softmax = before_softmax 21 | self.dropout = dropout 22 | self.crop_num = crop_num 23 | self.consensus_type = consensus_type 24 | self.embed = embed 25 | 26 | self.name_base = base_model 27 | if not before_softmax and consensus_type != 'avg': 28 | raise ValueError("Only avg consensus can be used after Softmax") 29 | 30 | if new_length is None: 31 | self.new_length = 1 if modality == "RGB" else 5 32 | else: 33 | self.new_length = new_length 34 | 35 | print((""" 36 | Initializing TSN with base model: {}. 37 | TSN Configurations: 38 | input_modality: {} 39 | num_segments: {} 40 | new_length: {} 41 | consensus_module: {} 42 | dropout_ratio: {} 43 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout))) 44 | self.embed = embed 45 | 46 | self._prepare_base_model(base_model) 47 | 48 | self.context = context 49 | 50 | if context: 51 | self._prepare_context_model() 52 | 53 | feature_dim = self._prepare_tsn(num_class) 54 | 55 | if self.modality == 'Flow': 56 | print("Converting the ImageNet model to a flow init model") 57 | self.base_model = self._construct_flow_model(self.base_model) 58 | print("Done. Flow model ready...") 59 | 60 | if self.context: 61 | print("Converting the context model to a flow init model") 62 | self.context_model = self._construct_flow_model(self.context_model) 63 | print("Done. Flow model ready...") 64 | 65 | 66 | self.consensus = ConsensusModule(consensus_type) 67 | self.consensus_cont = ConsensusModule(consensus_type) 68 | 69 | if self.embed: 70 | self.consensus_embed = ConsensusModule(consensus_type) 71 | 72 | if not self.before_softmax: 73 | self.softmax = nn.Softmax() 74 | 75 | self._enable_pbn = partial_bn 76 | if partial_bn: 77 | self.partialBN(True) 78 | 79 | def _prepare_tsn(self, num_class): 80 | std = 0.001 81 | 82 | if isinstance(self.base_model, torch.nn.modules.container.Sequential): 83 | feature_dim = 2048 84 | else: 85 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features 86 | if self.dropout == 0: 87 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) 88 | self.new_fc = None 89 | else: 90 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) 91 | 92 | if self.context: 93 | num_feats = 4096 94 | else: 95 | num_feats = 2048 96 | 97 | if self.embed: 98 | self.embed_fc = nn.Linear(num_feats,300) 99 | normal(self.embed_fc.weight, 0, std) 100 | constant(self.embed_fc.bias, 0) 101 | 102 | self.new_fc = nn.Linear(num_feats, num_class) 103 | normal(self.new_fc.weight, 0, std) 104 | constant(self.new_fc.bias, 0) 105 | 106 | self.new_fc_1 = nn.Linear(num_feats, 3) 107 | normal(self.new_fc_1.weight, 0, std) 108 | constant(self.new_fc_1.bias, 0) 109 | 110 | 111 | return num_feats 112 | 113 | 114 | def _prepare_context_model(self): 115 | self.context_model = getattr(torchvision.models, "resnet50")(True) 116 | modules = list(self.context_model.children())[:-1] # delete the last fc layer. 117 | self.context_model = nn.Sequential(*modules) 118 | 119 | def _prepare_base_model(self, base_model): 120 | import torchvision, torchvision.models 121 | 122 | if 'resnet' in base_model or 'vgg' in base_model or 'resnext' in base_model or 'densenet' in base_model: 123 | self.base_model = getattr(torchvision.models, base_model)(True) 124 | self.base_model.last_layer_name = 'fc' 125 | self.input_size = 224 126 | self.input_mean = [0.485, 0.456, 0.406] 127 | self.input_std = [0.229, 0.224, 0.225] 128 | 129 | if self.modality == 'Flow': 130 | self.input_mean = [0.5] 131 | self.input_std = [np.mean(self.input_std)] 132 | 133 | else: 134 | raise ValueError('Unknown base model: {}'.format(base_model)) 135 | 136 | def train(self, mode=True): 137 | """ 138 | Override the default train() to freeze the BN parameters 139 | :return: 140 | """ 141 | super(TSN, self).train(mode) 142 | count = 0 143 | if self._enable_pbn: 144 | print("Freezing BatchNorm2D except the first one.") 145 | for m in self.base_model.modules(): 146 | if isinstance(m, nn.BatchNorm2d): 147 | count += 1 148 | if count >= (2 if self._enable_pbn else 1): 149 | m.eval() 150 | 151 | # shutdown update in frozen mode 152 | m.weight.requires_grad = False 153 | m.bias.requires_grad = False 154 | count = 0 155 | if self.context: 156 | print("Freezing BatchNorm2D except the first one.") 157 | for m in self.context_model.modules(): 158 | if isinstance(m, nn.BatchNorm2d): 159 | count += 1 160 | if count >= (2 if self._enable_pbn else 1): 161 | m.eval() 162 | 163 | # shutdown update in frozen mode 164 | m.weight.requires_grad = False 165 | m.bias.requires_grad = False 166 | 167 | 168 | 169 | def partialBN(self, enable): 170 | self._enable_pbn = enable 171 | 172 | def get_optim_policies(self): 173 | params = [{'params': self.parameters()}] 174 | 175 | return params 176 | 177 | 178 | 179 | def forward(self, input, embeddings): 180 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length 181 | 182 | if self.context: 183 | inp = input.view((-1, sample_len) + input.size()[-2:]) 184 | 185 | body_indices = list(range(0,inp.size(0),2)) 186 | context_indices = list(range(1,inp.size(0),2)) 187 | 188 | body = inp[body_indices] 189 | context = inp[context_indices] 190 | else: 191 | body = input.view((-1, sample_len) + input.size()[-2:]) 192 | 193 | base_out = self.base_model(body).squeeze(-1).squeeze(-1) 194 | 195 | if self.context: 196 | context_out = self.context_model(context).squeeze(-1).squeeze(-1) 197 | base_out = torch.cat((base_out, context_out),dim=1) 198 | 199 | outputs = {} 200 | 201 | if self.embed: 202 | embed_segm = self.embed_fc(base_out) 203 | embed = embed_segm.view((-1, self.num_segments) + embed_segm.size()[1:]) 204 | embed = self.consensus_embed(embed).squeeze(1) 205 | outputs['embed'] = embed 206 | 207 | 208 | base_out_cat = self.new_fc(base_out) 209 | base_out_cont = self.new_fc_1(base_out) 210 | 211 | base_out_cat = base_out_cat.view((-1, self.num_segments) + base_out_cat.size()[1:]) 212 | base_out_cont = base_out_cont.view((-1, self.num_segments) + base_out_cont.size()[1:]) 213 | 214 | output = self.consensus(base_out_cat) 215 | outputs['categorical'] = output.squeeze(1) 216 | 217 | output_cont = self.consensus_cont(base_out_cont) 218 | outputs['continuous'] = output_cont.squeeze(1) 219 | 220 | return outputs 221 | 222 | 223 | def _construct_flow_model(self, base_model): 224 | # modify the convolution layers 225 | # Torch models are usually defined in a hierarchical way. 226 | # nn.modules.children() return all sub modules in a DFS manner 227 | modules = list(base_model.modules()) 228 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] 229 | conv_layer = modules[first_conv_idx] 230 | container = modules[first_conv_idx - 1] 231 | 232 | # modify parameters, assume the first blob contains the convolution kernels 233 | params = [x.clone() for x in conv_layer.parameters()] 234 | kernel_size = params[0].size() 235 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] 236 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() 237 | 238 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, 239 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, 240 | bias=True if len(params) == 2 else False) 241 | new_conv.weight.data = new_kernels 242 | if len(params) == 2: 243 | new_conv.bias.data = params[1].data # add bias if neccessary 244 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name 245 | 246 | # replace the first convlution layer 247 | setattr(container, layer_name, new_conv) 248 | return base_model 249 | 250 | 251 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import cv2 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | from numpy.random import randint 8 | import pandas as pd 9 | import torch 10 | import torchvision.transforms.functional as tF 11 | 12 | def rreplace(s, old, new, occurrence): 13 | li = s.rsplit(old, occurrence) 14 | return new.join(li) 15 | 16 | 17 | class VideoRecord(object): 18 | def __init__(self, row): 19 | self._data = row 20 | 21 | @property 22 | def path(self): 23 | return self._data[0] 24 | 25 | @property 26 | def num_frames(self): 27 | return int(self._data[1]) 28 | 29 | @property 30 | def min_frame(self): 31 | return int(self._data[2]) 32 | 33 | @property 34 | def max_frame(self): 35 | return int(self._data[3]) 36 | 37 | 38 | class TSNDataSet(data.Dataset): 39 | def __init__(self, mode, 40 | num_segments=3, new_length=1, modality='RGB', 41 | image_tmpl='img_{:05d}.jpg', transform=None, 42 | force_grayscale=False, random_shift=True, test_mode=False, context=False): 43 | 44 | self.num_segments = num_segments 45 | self.new_length = new_length 46 | self.modality = modality 47 | self.image_tmpl = image_tmpl 48 | self.transform = transform 49 | self.random_shift = random_shift 50 | self.test_mode = test_mode 51 | 52 | self.bold_path = "/gpu-data/filby/BoLD/BOLD_public" 53 | 54 | self.context = context 55 | 56 | self.categorical_emotions = ["Peace", "Affection", "Esteem", "Anticipation", "Engagement", "Confidence", "Happiness", 57 | "Pleasure", "Excitement", "Surprise", "Sympathy", "Doubt/Confusion", "Disconnect", 58 | "Fatigue", "Embarrassment", "Yearning", "Disapproval", "Aversion", "Annoyance", "Anger", 59 | "Sensitivity", "Sadness", "Disquietment", "Fear", "Pain", "Suffering"] 60 | 61 | self.continuous_emotions = ["Valence", "Arousal", "Dominance"] 62 | 63 | self.attributes = ["Gender", "Age", "Ethnicity"] 64 | 65 | header = ["video", "person_id", "min_frame", "max_frame"] + self.categorical_emotions + self.continuous_emotions + self.attributes + ["annotation_confidence"] 66 | 67 | # self.df = pd.read_csv(os.path.join(self.bold_path, "annotations/{}_extra.csv".format(mode))) 68 | self.df = pd.read_csv(os.path.join(self.bold_path, "annotations/{}.csv".format(mode)), names=header) 69 | self.df["joints_path"] = self.df["video"].apply(rreplace,args=[".mp4",".npy",1]) 70 | 71 | self.video_list = self.df["video"] 72 | self.mode = mode 73 | 74 | self.embeddings = np.load("glove_840B_embeddings.npy") 75 | 76 | def get_context(self, image, joints, format="cv2"): 77 | joints = joints.reshape((18,3)) 78 | joints[joints[:,2]<0.1] = np.nan 79 | joints[np.isnan(joints[:,2])] = np.nan 80 | 81 | joint_min_x = int(round(np.nanmin(joints[:,0]))) 82 | joint_min_y = int(round(np.nanmin(joints[:,1]))) 83 | 84 | joint_max_x = int(round(np.nanmax(joints[:,0]))) 85 | joint_max_y = int(round(np.nanmax(joints[:,1]))) 86 | 87 | expand_x = int(round(10/100 * (joint_max_x-joint_min_x))) 88 | expand_y = int(round(10/100 * (joint_max_y-joint_min_y))) 89 | 90 | if format == "cv2": 91 | image[max(0, joint_min_x - expand_x):min(joint_max_x + expand_x, image.shape[1])] = [0,0,0] 92 | elif format == "PIL": 93 | bottom = min(joint_max_y+expand_y, image.height) 94 | right = min(joint_max_x+expand_x,image.width) 95 | top = max(0,joint_min_y-expand_y) 96 | left = max(0,joint_min_x-expand_x) 97 | image = np.array(image) 98 | if len(image.shape) == 3: 99 | image[top:bottom,left:right] = [0,0,0] 100 | else: 101 | image[top:bottom,left:right] = np.min(image) 102 | return Image.fromarray(image) 103 | 104 | 105 | def get_bounding_box(self, image, joints, format="cv2"): 106 | joints = joints.reshape((18,3)) 107 | joints[joints[:,2]<0.1] = np.nan 108 | joints[np.isnan(joints[:,2])] = np.nan 109 | 110 | joint_min_x = int(round(np.nanmin(joints[:,0]))) 111 | joint_min_y = int(round(np.nanmin(joints[:,1]))) 112 | 113 | joint_max_x = int(round(np.nanmax(joints[:,0]))) 114 | joint_max_y = int(round(np.nanmax(joints[:,1]))) 115 | 116 | expand_x = int(round(100/100 * (joint_max_x-joint_min_x))) 117 | expand_y = int(round(100/100 * (joint_max_y-joint_min_y))) 118 | 119 | if format == "cv2": 120 | return image[max(0,joint_min_y-expand_y):min(joint_max_y+expand_y, image.shape[0]), max(0,joint_min_x-expand_x):min(joint_max_x+expand_x,image.shape[1])] 121 | elif format == "PIL": 122 | bottom = min(joint_max_y+expand_y, image.height) 123 | right = min(joint_max_x+expand_x,image.width) 124 | top = max(0,joint_min_y-expand_y) 125 | left = max(0,joint_min_x-expand_x) 126 | return tF.crop(image, top, left, bottom-top ,right-left) 127 | 128 | 129 | def joints(self, index): 130 | sample = self.df.iloc[index] 131 | 132 | joints_path = os.path.join(self.bold_path, "joints", sample["joints_path"]) 133 | 134 | joints18 = np.load(joints_path) 135 | joints18[:,0] -= joints18[0,0] 136 | 137 | return joints18 138 | 139 | def _load_image(self, directory, idx, index, mode="body"): 140 | joints = self.joints(index) 141 | 142 | poi_joints = joints[joints[:, 0] + 1 == idx] 143 | sample = self.df.iloc[index] 144 | poi_joints = poi_joints[(poi_joints[:, 1] == sample["person_id"]), 2:] 145 | if self.modality == 'RGB' or self.modality == 'RGBDiff': 146 | 147 | frame = Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert("RGB") 148 | 149 | if mode == "context": 150 | if poi_joints.size == 0: 151 | return [frame] 152 | context = self.get_context(frame, poi_joints, format="PIL") 153 | return [context] 154 | 155 | if poi_joints.size == 0: 156 | body = frame 157 | pass #just do the whole frame 158 | else: 159 | body = self.get_bounding_box(frame, poi_joints, format="PIL") 160 | 161 | if body.size == 0: 162 | print(poi_joints) 163 | body = frame 164 | 165 | return [body] 166 | 167 | # return [Image.open(os.path.join(directory, self.image_tmpl.format(idx))).convert('RGB')] 168 | elif self.modality == 'Flow': 169 | frame_x = Image.open(os.path.join(directory, self.image_tmpl.format('flow_x', idx))).convert('L') 170 | frame_y = Image.open(os.path.join(directory, self.image_tmpl.format('flow_y', idx))).convert('L') 171 | # frame = cv2.imread(os.path.join(directory, 'img_{:05d}.jpg'.format(idx))) 172 | # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 173 | 174 | if mode == "context": 175 | if poi_joints.size == 0: 176 | return [frame_x, frame_y] 177 | context_x = self.get_context(frame_x, poi_joints, format="PIL") 178 | context_y = self.get_context(frame_y, poi_joints, format="PIL") 179 | return [context_x, context_y] 180 | 181 | if poi_joints.size == 0: 182 | body_x = frame_x 183 | body_y = frame_y 184 | pass #just do the whole frame 185 | else: 186 | body_x = self.get_bounding_box(frame_x, poi_joints, format="PIL") 187 | body_y = self.get_bounding_box(frame_y, poi_joints, format="PIL") 188 | 189 | if body_x.size == 0: 190 | body_x = frame_x 191 | body_y = frame_y 192 | 193 | 194 | return [body_x, body_y] 195 | 196 | 197 | def _sample_indices(self, record): 198 | """ 199 | 200 | :param record: VideoRecord 201 | :return: list 202 | """ 203 | 204 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments 205 | if average_duration > 0: 206 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments) # + (record.min_frame+1) 207 | # print(record.num_frames, record.min_frame, record.max_frame) 208 | elif record.num_frames > self.num_segments: 209 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) 210 | else: 211 | offsets = np.zeros((self.num_segments,)) 212 | return offsets + 1 213 | 214 | def _get_val_indices(self, record): 215 | if record.num_frames > self.num_segments + self.new_length - 1: 216 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 217 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 218 | else: 219 | offsets = np.zeros((self.num_segments,)) 220 | return offsets + 1 221 | 222 | def _get_test_indices(self, record): 223 | 224 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) 225 | 226 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) 227 | 228 | return offsets + 1 229 | 230 | def __getitem__(self, index): 231 | sample = self.df.iloc[index] 232 | 233 | fname = os.path.join(self.bold_path,"videos",self.df.iloc[index]["video"]) 234 | 235 | capture = cv2.VideoCapture(fname) 236 | frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))-1 237 | 238 | capture.release() 239 | 240 | record_path = os.path.join(self.bold_path,"test_raw",sample["video"][4:-4]) 241 | 242 | record = VideoRecord([record_path, frame_count, sample["min_frame"], sample["max_frame"]]) 243 | 244 | if not self.test_mode: 245 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) 246 | else: 247 | segment_indices = self._get_test_indices(record) 248 | 249 | return self.get(record, segment_indices, index) 250 | 251 | def get(self, record, indices, index): 252 | 253 | images = list() 254 | # print(indices) 255 | for seg_ind in indices: 256 | p = int(seg_ind) 257 | for i in range(self.new_length): 258 | 259 | seg_imgs = self._load_image(record.path, p, index, mode="body") 260 | 261 | images.extend(seg_imgs) 262 | 263 | if self.context: 264 | seg_imgs = self._load_image(record.path, p, index, mode="context") 265 | images.extend(seg_imgs) 266 | 267 | 268 | if p < record.num_frames: 269 | p += 1 270 | 271 | 272 | if not self.test_mode: 273 | categorical = self.df.iloc[index][self.categorical_emotions] 274 | 275 | continuous = self.df.iloc[index][self.continuous_emotions] 276 | continuous = continuous/10.0 # normalize to 0 - 1 277 | 278 | if self.transform is None: 279 | process_data = images 280 | else: 281 | process_data = self.transform(images) 282 | 283 | return process_data, torch.tensor(self.embeddings).float(), torch.tensor(categorical).float(), torch.tensor(continuous).float(), self.df.iloc[index]["video"] 284 | else: 285 | process_data = self.transform(images) 286 | return process_data, torch.tensor(self.embeddings).float() 287 | 288 | def __len__(self): 289 | return len(self.df) 290 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | import math 7 | import torch 8 | 9 | 10 | class GroupRandomCrop(object): 11 | def __init__(self, size): 12 | if isinstance(size, numbers.Number): 13 | self.size = (int(size), int(size)) 14 | else: 15 | self.size = size 16 | 17 | def __call__(self, img_group): 18 | 19 | w, h = img_group[0].size 20 | th, tw = self.size 21 | 22 | out_images = list() 23 | 24 | x1 = random.randint(0, w - tw) 25 | y1 = random.randint(0, h - th) 26 | 27 | for img in img_group: 28 | # assert(img.size[0] == w and img.size[1] == h) 29 | if w == tw and h == th: 30 | out_images.append(img) 31 | else: 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) 33 | 34 | return out_images 35 | 36 | 37 | class GroupCenterCrop(object): 38 | def __init__(self, size): 39 | self.worker = torchvision.transforms.CenterCrop(size) 40 | 41 | def __call__(self, img_group): 42 | return [self.worker(img) for img in img_group] 43 | # 44 | # class GroupTopCrop(object): 45 | # def __init__(self, size): 46 | # self.worker = torchvision.transforms.CenterCrop(size) 47 | # 48 | # def __call__(self, img_group): 49 | # return [self.worker(img) for img in img_group] 50 | 51 | 52 | 53 | class GroupRandomHorizontalFlip(object): 54 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 55 | """ 56 | def __init__(self, is_flow=False): 57 | self.is_flow = is_flow 58 | 59 | def __call__(self, img_group, is_flow=False): 60 | v = random.random() 61 | if v < 0.5: 62 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] 63 | if self.is_flow: 64 | for i in range(0, len(ret), 2): 65 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping 66 | return ret 67 | else: 68 | return img_group 69 | 70 | 71 | class GroupNormalize(object): 72 | def __init__(self, mean, std): 73 | self.mean = mean 74 | self.std = std 75 | 76 | def __call__(self, tensor): 77 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) 78 | rep_std = self.std * (tensor.size()[0]//len(self.std)) 79 | 80 | # TODO: make efficient 81 | for t, m, s in zip(tensor, rep_mean, rep_std): 82 | t.sub_(m).div_(s) 83 | 84 | return tensor 85 | 86 | 87 | class GroupScale(object): 88 | """ Rescales the input PIL.Image to the given 'size'. 89 | 'size' will be the size of the smaller edge. 90 | For example, if height > width, then image will be 91 | rescaled to (size * height / width, size) 92 | size: size of the smaller edge 93 | interpolation: Default: PIL.Image.BILINEAR 94 | """ 95 | 96 | def __init__(self, size, interpolation=Image.BILINEAR): 97 | self.worker = torchvision.transforms.Scale(size, interpolation) 98 | 99 | def __call__(self, img_group): 100 | return [self.worker(img) for img in img_group] 101 | 102 | 103 | class GroupOverSample(object): 104 | def __init__(self, crop_size, scale_size=None): 105 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) 106 | 107 | if scale_size is not None: 108 | self.scale_worker = GroupScale(scale_size) 109 | else: 110 | self.scale_worker = None 111 | 112 | def __call__(self, img_group): 113 | 114 | if self.scale_worker is not None: 115 | img_group = self.scale_worker(img_group) 116 | 117 | image_w, image_h = img_group[0].size 118 | crop_w, crop_h = self.crop_size 119 | 120 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) 121 | oversample_group = list() 122 | for o_w, o_h in offsets: 123 | normal_group = list() 124 | flip_group = list() 125 | for i, img in enumerate(img_group): 126 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) 127 | normal_group.append(crop) 128 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) 129 | 130 | if img.mode == 'L' and i % 2 == 0: 131 | flip_group.append(ImageOps.invert(flip_crop)) 132 | else: 133 | flip_group.append(flip_crop) 134 | 135 | oversample_group.extend(normal_group) 136 | oversample_group.extend(flip_group) 137 | return oversample_group 138 | 139 | 140 | class GroupMultiScaleCrop(object): 141 | 142 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): 143 | self.scales = scales if scales is not None else [1, .875, .75, .66] 144 | self.max_distort = max_distort 145 | self.fix_crop = fix_crop 146 | self.more_fix_crop = more_fix_crop 147 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] 148 | self.interpolation = Image.BILINEAR 149 | 150 | def __call__(self, img_group): 151 | 152 | im_size = img_group[0].size 153 | 154 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) 155 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] 156 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) 157 | for img in crop_img_group] 158 | return ret_img_group 159 | 160 | def _sample_crop_size(self, im_size): 161 | image_w, image_h = im_size[0], im_size[1] 162 | 163 | # find a crop size 164 | base_size = min(image_w, image_h) 165 | crop_sizes = [int(base_size * x) for x in self.scales] 166 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] 167 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] 168 | 169 | pairs = [] 170 | for i, h in enumerate(crop_h): 171 | for j, w in enumerate(crop_w): 172 | if abs(i - j) <= self.max_distort: 173 | pairs.append((w, h)) 174 | 175 | crop_pair = random.choice(pairs) 176 | if not self.fix_crop: 177 | w_offset = random.randint(0, image_w - crop_pair[0]) 178 | h_offset = random.randint(0, image_h - crop_pair[1]) 179 | else: 180 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) 181 | 182 | return crop_pair[0], crop_pair[1], w_offset, h_offset 183 | 184 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): 185 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) 186 | return random.choice(offsets) 187 | 188 | @staticmethod 189 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): 190 | w_step = (image_w - crop_w) // 4 191 | h_step = (image_h - crop_h) // 4 192 | 193 | ret = list() 194 | ret.append((0, 0)) # upper left 195 | ret.append((4 * w_step, 0)) # upper right 196 | ret.append((0, 4 * h_step)) # lower left 197 | ret.append((4 * w_step, 4 * h_step)) # lower right 198 | ret.append((2 * w_step, 2 * h_step)) # center 199 | 200 | if more_fix_crop: 201 | ret.append((0, 2 * h_step)) # center left 202 | ret.append((4 * w_step, 2 * h_step)) # center right 203 | ret.append((2 * w_step, 4 * h_step)) # lower center 204 | ret.append((2 * w_step, 0 * h_step)) # upper center 205 | 206 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter 207 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter 208 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter 209 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter 210 | 211 | return ret 212 | 213 | 214 | class GroupRandomSizedCrop(object): 215 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size 216 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio 217 | This is popularly used to train the Inception networks 218 | size: size of the smaller edge 219 | interpolation: Default: PIL.Image.BILINEAR 220 | """ 221 | def __init__(self, size, interpolation=Image.BILINEAR): 222 | self.size = size 223 | self.interpolation = interpolation 224 | 225 | def __call__(self, img_group): 226 | for attempt in range(10): 227 | area = img_group[0].size[0] * img_group[0].size[1] 228 | target_area = random.uniform(0.08, 1.0) * area 229 | aspect_ratio = random.uniform(3. / 4, 4. / 3) 230 | 231 | w = int(round(math.sqrt(target_area * aspect_ratio))) 232 | h = int(round(math.sqrt(target_area / aspect_ratio))) 233 | 234 | if random.random() < 0.5: 235 | w, h = h, w 236 | 237 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: 238 | x1 = random.randint(0, img_group[0].size[0] - w) 239 | y1 = random.randint(0, img_group[0].size[1] - h) 240 | found = True 241 | break 242 | else: 243 | found = False 244 | x1 = 0 245 | y1 = 0 246 | 247 | if found: 248 | out_group = list() 249 | for img in img_group: 250 | img = img.crop((x1, y1, x1 + w, y1 + h)) 251 | assert(img.size == (w, h)) 252 | out_group.append(img.resize((self.size, self.size), self.interpolation)) 253 | return out_group 254 | else: 255 | # Fallback 256 | scale = GroupScale(self.size, interpolation=self.interpolation) 257 | crop = GroupRandomCrop(self.size) 258 | return crop(scale(img_group)) 259 | 260 | 261 | class Stack(object): 262 | 263 | def __init__(self, roll=False): 264 | self.roll = roll 265 | 266 | def __call__(self, img_group): 267 | if img_group[0].mode == 'L': 268 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) 269 | elif img_group[0].mode == 'RGB': 270 | if self.roll: 271 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) 272 | else: 273 | return np.concatenate(img_group, axis=2) 274 | 275 | 276 | class ToTorchFormatTensor(object): 277 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] 278 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ 279 | def __init__(self, div=True): 280 | self.div = div 281 | 282 | def __call__(self, pic): 283 | if isinstance(pic, np.ndarray): 284 | # handle numpy array 285 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() 286 | else: 287 | # handle PIL Image 288 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 289 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) 290 | # put it from HWC to CHW format 291 | # yikes, this transpose takes 80% of the loading time/CPU 292 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 293 | return img.float().div(255) if self.div else img.float() 294 | 295 | 296 | class IdentityTransform(object): 297 | 298 | def __call__(self, data): 299 | return data 300 | 301 | 302 | if __name__ == "__main__": 303 | trans = torchvision.transforms.Compose([ 304 | GroupScale(256), 305 | GroupRandomCrop(224), 306 | Stack(), 307 | ToTorchFormatTensor(), 308 | GroupNormalize( 309 | mean=[.485, .456, .406], 310 | std=[.229, .224, .225] 311 | )] 312 | ) 313 | 314 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') 315 | 316 | color_group = [im] * 3 317 | rst = trans(color_group) 318 | 319 | gray_group = [im.convert('L')] * 9 320 | gray_rst = trans(gray_group) 321 | 322 | trans2 = torchvision.transforms.Compose([ 323 | GroupRandomSizedCrop(256), 324 | Stack(), 325 | ToTorchFormatTensor(), 326 | GroupNormalize( 327 | mean=[.485, .456, .406], 328 | std=[.229, .224, .225]) 329 | ]) 330 | print(trans2(color_group)) 331 | --------------------------------------------------------------------------------