├── data ├── __init__.py ├── label_data.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── label_data.cpython-36.pyc │ ├── random_shuffle_dataset.cpython-36.pyc │ └── random_shuffle_dataset.cpython-37.pyc └── random_shuffle_dataset.py ├── models ├── __init__.py ├── __pycache__ │ ├── Model.cpython-36.pyc │ ├── Model.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── Model_speed.cpython-37.pyc │ ├── transformer.cpython-36.pyc │ └── transformer.cpython-37.pyc ├── Model.py ├── Model_speed.py └── transformer.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── base_options.cpython-36.pyc │ └── base_options.cpython-37.pyc └── base_options.py ├── datasets └── permutation_10.npy ├── load_materials.py ├── readme.md ├── requirements.txt ├── utils.py └── main.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/label_data.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/permutation_10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/datasets/permutation_10.npy -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/Model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/Model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/Model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/Model.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/label_data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/data/__pycache__/label_data.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/Model_speed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/Model_speed.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/transformer.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/models/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/random_shuffle_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/data/__pycache__/random_shuffle_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/random_shuffle_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamMr/EST/HEAD/data/__pycache__/random_shuffle_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /load_materials.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.utils.data 4 | import torchvision.transforms as transforms 5 | from data.random_shuffle_dataset import RandomShuffleDataset 6 | import utils 7 | 8 | 9 | def LoadDataset(opt): 10 | cate2label = utils.cate2label(opt.dataset_name) 11 | 12 | train_dataset = RandomShuffleDataset( 13 | video_root=opt.train_video_root, 14 | video_list=opt.train_list_root, 15 | rectify_label=cate2label, 16 | isTrain= True, 17 | transform=transforms.Compose([transforms.ToTensor()]), 18 | opt=opt 19 | ) 20 | 21 | val_dataset = RandomShuffleDataset( 22 | video_root=opt.test_video_root, 23 | video_list=opt.test_list_root, 24 | rectify_label=cate2label, 25 | isTrain = False, 26 | transform=transforms.Compose([transforms.ToTensor()]), 27 | opt=opt 28 | ) 29 | 30 | train_loader = torch.utils.data.DataLoader( 31 | train_dataset, 32 | batch_size=opt.batch_size, shuffle=True,num_workers=opt.num_threads, 33 | pin_memory=True, drop_last=True) #True若数据集大小不能被batch_size整除,则删除最后一个不完整的批处理。 34 | 35 | val_loader = torch.utils.data.DataLoader( 36 | val_dataset, 37 | batch_size=opt.batch_size, shuffle=False,num_workers=opt.num_threads, 38 | pin_memory=True) 39 | 40 | return train_loader, val_loader 41 | 42 | 43 | def LoadParameter(_structure, _parameterDir): 44 | 45 | checkpoint = torch.load(_parameterDir) 46 | pretrained_state_dict = checkpoint['state_dict'] 47 | model_state_dict = _structure.state_dict() 48 | 49 | for key in pretrained_state_dict: 50 | if ((key == 'module.fc.weight') | (key == 'module.fc.bias') | (key == 'module.feature.weight') | (key == 'module.feature.bias')): 51 | 52 | pass 53 | else: 54 | model_state_dict[key.replace('module.', '')] = pretrained_state_dict[key] 55 | 56 | _structure.load_state_dict(model_state_dict) 57 | model = torch.nn.DataParallel(_structure).cuda() 58 | 59 | return model 60 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Expression Snippet Transformer for Robust Video-based Facial Expression Recognition 2 | 3 | Pytorch implementation of paper: 4 | 5 | > **Expression Snippet Transformer for Robust Video-based Facial Expression Recognition** 6 | 7 | ## Content 8 | 9 | - [Dependencies](#dependencies) 10 | - [Code and Data Preparation](#code-and-data-preparation) 11 | - [Training](#training) 12 | - [Testing](#testing) 13 | 14 | ## Dependencies 15 | 16 | Python Version: 3.7.9 17 | 18 | Required packages are listed in requirements.txt. You can install them by running: 19 | 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Code and Data Preparation 25 | 26 | 1. Download the code from this repository and download the pre-trained ResNet-18 from [Baidu Drive](https://pan.baidu.com/s/1lnO1alaaP23NlZcPyNOhgg) (1req) 27 | 28 | 2. Prepare the dataset. 29 | 30 | You need to unified the input video length to 105 frames. Make sure the data structure is as below. 31 | 32 | ``` 33 | ├── DFEW 34 | └── videos 35 | └── 14400 36 | ├── 000.jpg 37 | ├── 001.jpg 38 | ├── 002.jpg 39 | ├── ... 40 | └── 14401 41 | ├── 000.jpg 42 | ├── 001.jpg 43 | ├── 002.jpg 44 | ├── ... 45 | └── data_list 46 | ├── Train_DFEW_all_clip.txt 47 | ├── Train_DFEW_all_clip_set_2.txt 48 | ├── Train_DFEW_all_clip_set_3.txt 49 | ├── Train_DFEW_all_clip_set_4.txt 50 | ├── Train_DFEW_all_clip_set_5.txt 51 | ``` 52 | 53 | ## Training 54 | 55 | You can use the following command to train: 56 | 57 | ``` 58 | python main.py --train_video_root /data/Your_Path/data_path/DFEW/videos --train_list_root /data/Your_Path/data_path/DFEW/data_list/Train_DFEW_all_clip_set_2.txt --test_video_root /data/Your_Path/data_path/DFEW/videos --test_list_root /data/Your_Path/data_path/DFEW/data_list/Test_DFEW_all_clip_set_2.txt --dataset_name DFEW --name dfew_transformer --gpu_ids 3 --batch 8 --epochs_count 160 59 | ``` 60 | 61 | ## Testing 62 | 63 | You can evaluate a trained model by running: 64 | 65 | ``` 66 | python main.py --train_video_root /data/Your_Path/data_path/DFEW/videos --train_list_root /data/Your_Path/data_path/DFEW/data_list/Train_DFEW_all_clip_set_1.txt --test_video_root /data/Your_Path/data_path/DFEW/videos --test_list_root /data/Your_Path/data_path/DFEW/data_list/Test_DFEW_all_clip_set_1.txt --dataset_name DFEW --name dfew_transformer --gpu_ids 3 --batch 8 --phase test --eval_model_path MODEL_PATH 67 | ``` 68 | 69 | Here, `MODEL_PATH` denotes for the path of the trained model. 70 | 71 | You can download our trained model on DFEW from [Baidu Drive](https://pan.baidu.com/s/1BkZnt5IP-xcXcSiTlcuKsA) (owu2) 72 | 73 | **IF YOU HAVE ANY PROBLEM, PLEASE CONTACT wangwenbin@cug.edu.cn OR COMMIT ISSUES** 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | albumentations==1.1.0 3 | appdirs==1.4.4 4 | argcomplete==1.12.3 5 | argon2-cffi==21.1.0 6 | astunparse==1.6.3 7 | attrs==21.2.0 8 | audioread==2.1.9 9 | av==8.0.3 10 | backcall==0.2.0 11 | bleach==4.1.0 12 | cached-property==1.5.2 13 | cachetools==4.2.1 14 | certifi==2021.5.30 15 | cffi==1.14.5 16 | chardet==4.0.0 17 | cycler==0.10.0 18 | debugpy==1.5.1 19 | decorator==4.4.2 20 | defusedxml==0.7.1 21 | docopt==0.6.2 22 | efficientnet-pytorch==0.6.3 23 | einops==0.3.2 24 | entrypoints==0.3 25 | et-xmlfile==1.0.1 26 | facenet-pytorch==2.5.2 27 | fastprogress==1.0.0 28 | flatbuffers==1.12 29 | ftfy==6.0.3 30 | fvcore==0.1.5.post20210515 31 | gast==0.3.3 32 | google-auth==1.25.0 33 | google-auth-oauthlib==0.4.2 34 | google-pasta==0.2.0 35 | graphviz==0.16 36 | grpcio==1.32.0 37 | h5py==2.10.0 38 | idna==2.10 39 | imageio==2.9.0 40 | imgaug==0.4.0 41 | importlib-metadata==3.4.0 42 | iopath==0.1.8 43 | ipykernel==6.5.0 44 | ipython==7.29.0 45 | ipython-genutils==0.2.0 46 | ipywidgets==7.6.5 47 | jdcal==1.4.1 48 | jedi==0.18.0 49 | Jinja2==3.0.2 50 | joblib==1.0.1 51 | jsonpatch==1.32 52 | jsonpointer==2.1 53 | jsonschema==4.1.2 54 | jupyter-client==7.0.6 55 | jupyter-core==4.9.1 56 | jupyterlab-pygments==0.1.2 57 | jupyterlab-widgets==1.0.2 58 | keras==2.8.0 59 | Keras-Preprocessing==1.1.2 60 | kiwisolver==1.3.1 61 | librosa==0.8.1 62 | linformer==0.2.1 63 | llvmlite==0.36.0 64 | Markdown==3.3.3 65 | MarkupSafe==2.0.1 66 | matplotlib==3.3.4 67 | matplotlib-inline==0.1.3 68 | mistune==0.8.4 69 | mkl-fft==1.2.0 70 | mkl-random==1.2.0 71 | mkl-service==2.3.0 72 | munch==2.5.0 73 | nbclient==0.5.4 74 | nbconvert==6.2.0 75 | nbformat==5.1.3 76 | nest-asyncio==1.5.1 77 | networkx==2.5 78 | notebook==6.4.5 79 | numba==0.53.1 80 | numpy==1.19.5 81 | oauthlib==3.1.0 82 | olefile==0.46 83 | opencv-python==4.5.1.48 84 | opencv-python-headless==4.5.3.56 85 | openpyxl==3.0.6 86 | opt-einsum==3.3.0 87 | packaging==21.0 88 | pandas==1.2.2 89 | pandocfilters==1.5.0 90 | parso==0.8.2 91 | pexpect==4.8.0 92 | pickleshare==0.7.5 93 | Pillow==8.1.0 94 | pipreqs==0.4.10 95 | pooch==1.4.0 96 | portalocker==2.3.0 97 | pretrainedmodels==0.7.4 98 | prometheus-client==0.12.0 99 | prompt-toolkit==3.0.21 100 | protobuf==3.14.0 101 | ptflops==0.6.5 102 | ptyprocess==0.7.0 103 | pyasn1==0.4.8 104 | pyasn1-modules==0.2.8 105 | pycparser==2.20 106 | Pygments==2.10.0 107 | pyparsing==2.4.7 108 | pyrsistent==0.18.0 109 | python-dateutil==2.8.1 110 | pytorch-warmup==0.0.4 111 | pytz==2021.1 112 | PyWavelets==1.1.1 113 | PyYAML==5.4.1 114 | pyzmq==22.1.0 115 | qudida==0.0.4 116 | regex==2022.1.18 117 | requests==2.25.1 118 | requests-oauthlib==1.3.0 119 | resampy==0.2.2 120 | rsa==4.7 121 | scikit-learn==1.0.1 122 | scipy==1.6.0 123 | seaborn==0.11.1 124 | segmentation-models-pytorch==0.2.0 125 | Send2Trash==1.8.0 126 | Shapely==1.7.1 127 | simplejson==3.17.2 128 | six==1.15.0 129 | SoundFile==0.10.3.post1 130 | tabulate==0.8.9 131 | tensorboard==2.4.1 132 | tensorboard-logger==0.1.0 133 | tensorboard-plugin-wit==1.8.0 134 | tensorboardX==2.4.1 135 | tensorflow==2.4.1 136 | tensorflow-estimator==2.4.0 137 | termcolor==1.1.0 138 | terminado==0.12.1 139 | testpath==0.5.0 140 | thop==0.0.31.post2005241907 141 | threadpoolctl==2.1.0 142 | tifffile==2021.2.1 143 | timer==0.1.2 144 | timm==0.4.12 145 | torch==1.7.1 146 | torchsummary==1.5.1 147 | torchvision==0.8.2 148 | tornado==6.1 149 | tqdm==4.62.2 150 | traitlets==5.1.1 151 | ttach==0.0.3 152 | typing-extensions 153 | urllib3==1.26.3 154 | visdom==0.1.8.9 155 | vit-pytorch==0.20.8 156 | wcwidth==0.2.5 157 | webencodings==0.5.1 158 | websocket-client==1.1.0 159 | Werkzeug==1.0.1 160 | widgetsnbextension==3.5.2 161 | wrapt==1.12.1 162 | yacs==0.1.8 163 | yarg==0.1.9 164 | zipp==3.4.0 165 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import numpy as np 5 | 6 | def save_checkpoint(state,opt): 7 | expr_dir = os.path.join(opt.checkpoints_dir,opt.name) 8 | 9 | if not os.path.exists(expr_dir): 10 | os.makedirs(expr_dir) 11 | 12 | epoch = state['epoch'] 13 | save_dir = os.path.join(expr_dir,str(epoch)+'_'+str(round(float(state['prec1']), 4))) 14 | torch.save(state, save_dir) 15 | print(save_dir) 16 | 17 | def cate2label(dataset_name): 18 | cate2label = {'MMI': {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 19 | 'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Sad': 4, 'Surprise': 5}, 20 | 'BU3D': {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 21 | 'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Sad': 4, 'Surprise': 5}, 22 | 'AFEW': {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Neutral', 5: 'Sad', 6: 'Surprise', 23 | 'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}, 24 | 'DFEW':{0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Neutral', 5: 'Sad', 6: 'Surprise', 25 | 'Angry': 0,'Disgust': 1,'Fear': 2,'Happy': 3,'Neutral': 4,'Sad': 5,'Surprise': 6}, 26 | 'Group': {0: 'sadness', 'sadness': 0, 1: 'happiness', 'happiness': 1, 2: 'helplessness', 'helplessness': 2, 3: 'anxiety', 'anxiety': 3, 4: 'disgust', 'disgust': 4, 5: 'contempt', 'contempt': 5, 6: 'disappointment', 'disappointment': 6, 7: 'surprise', 'surprise': 7, 8: 'fear', 'fear': 8, 9: 'neutral', 'neutral': 9, 10: 'anger', 'anger': 10}, 27 | 'Group_multi':{'helplessness_disappointment': 0, 0: 'helplessness_disappointment', 'disgust_helplessness': 1, 1: 'disgust_helplessness', 'fear_surprise': 2, 2: 'fear_surprise', 'anger_surprise': 3, 3: 'anger_surprise', 'disgust': 4, 4: 'disgust', 'anger_sadness': 5, 5: 'anger_sadness', 'disgust_contempt': 6, 6: 'disgust_contempt', 'neutral': 7, 7: 'neutral', 'anxiety_helplessness': 8, 8: 'anxiety_helplessness', 'fear_anxiety': 9, 9: 'fear_anxiety', 'surprise': 10, 10: 'surprise', 'disgust_anxiety': 11, 11: 'disgust_anxiety', 'anger_disgust_contempt': 12, 12: 'anger_disgust_contempt', 'fear_sadness': 13, 13: 'fear_sadness', 'sadness_disappointment': 14, 14: 'sadness_disappointment', 'anxiety': 15, 15: 'anxiety', 'fear': 16, 16: 'fear', 'happiness_contempt': 17, 17: 'happiness_contempt', 'sadness_helplessness_disappointment': 18, 18: 'sadness_helplessness_disappointment', 'disgust_anxiety_helplessness': 19, 19: 'disgust_anxiety_helplessness', 'anger': 20, 20: 'anger', 'disgust_disappointment': 21, 21: 'disgust_disappointment', 'surprise_anxiety': 22, 22: 'surprise_anxiety', 'happiness': 23, 23: 'happiness', 'sadness': 24, 24: 'sadness', 'anger_disgust': 25, 25: 'anger_disgust', 'anger_disgust_anxiety': 26, 26: 'anger_disgust_anxiety', 'fear_surprise_anxiety': 27, 27: 'fear_surprise_anxiety', 'helplessness': 28, 28: 'helplessness', 'sadness_anxiety': 29, 29: 'sadness_anxiety', 'disgust_surprise': 30, 30: 'disgust_surprise', 'disappointment': 31, 31: 'disappointment', 'sadness_anxiety_helplessness': 32, 32: 'sadness_anxiety_helplessness', 'disgust_sadness': 33, 33: 'disgust_sadness', 'anxiety_helplessness_disappointment': 34, 34: 'anxiety_helplessness_disappointment', 'disgust_helplessness_disappointment': 35, 35: 'disgust_helplessness_disappointment', 'happiness_surprise': 36, 36: 'happiness_surprise', 'sadness_surprise': 37, 37: 'sadness_surprise', 'anger_helplessness': 38, 38: 'anger_helplessness', 'fear_sadness_anxiety': 39, 39: 'fear_sadness_anxiety', 'contempt': 40, 40: 'contempt', 'anger_anxiety': 41, 41: 'anger_anxiety', 'sadness_helplessness': 42, 42: 'sadness_helplessness'} 28 | } 29 | 30 | return cate2label[dataset_name] 31 | 32 | def mkdirs(root): 33 | if not os.path.exists(root): 34 | os.makedirs(root) 35 | 36 | def get_time(): 37 | return str(time.strftime("%Y_%m_%d_%H_%M_%S",time.localtime())) 38 | 39 | def draw_weight(weight,file_path): 40 | np.save(file_path,weight) 41 | 42 | -------------------------------------------------------------------------------- /data/random_shuffle_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import torch 6 | import numpy as np 7 | 8 | class RandomShuffleDataset(data.Dataset): 9 | def __init__(self,video_root,video_list,isTrain,rectify_label,opt,transform=None): 10 | super(RandomShuffleDataset, self).__init__() 11 | ### param area#################### 12 | self.first_clips_number = 30 # 从前first_clips_number帧中选第一个clip的第一帧 13 | self.search_clips_number = 75 # 从第一个clip之后往后找search_clips_number帧做剩余clips-1的数据 14 | self.choose_num = opt.per_snippets # 每个clip选choose_num张图 15 | self.each_clips=15 # 每个clip的选取范围 16 | self.clips = opt.snippets # 多少个clip 17 | self.next_jump = 10 # 当前位置的下一跳 18 | self.isTrain = isTrain 19 | self.test_first = opt.test_first 20 | self.opt = opt 21 | 22 | #####data path ################### 23 | self.video_root = video_root 24 | self.video_list = video_list 25 | self.rectify_label = rectify_label 26 | self.transform = transform 27 | self.permutation = opt.permutation_root 28 | self.per_nums = opt.permutation_classes 29 | ############################# 30 | 31 | self.video_label = self.read_data(self.video_root,self.video_list,self.rectify_label) 32 | self.permutation = self.load_per(self.permutation).tolist() 33 | 34 | def load_per(self,path): 35 | return np.load(path) 36 | 37 | def read_data(self,video_root,video_list,rectify_label): 38 | video_label_list = [] 39 | #print(video_list) 40 | # 读取文件,获取所有视频数据 41 | with open(video_list,'r') as imf: 42 | 43 | for id, line in enumerate(imf): 44 | video_label = line.strip().split() 45 | 46 | video_name = video_label[0] 47 | label = rectify_label[video_label[1]] 48 | 49 | video_label_list.append((os.path.join(video_root,video_name),label)) 50 | 51 | return video_label_list 52 | 53 | def __getitem__(self, index): 54 | ''' 55 | :param index: 56 | :return: 57 | data = [clip_0,clip_1,clip_2,...,clip_7] 58 | clip_0 = [torch.tensor_0,torch.tensor_1,...,torch.tensor_4] torch.tensor_0.shape = [3,224,224] 59 | label = 0 or 1 or 2 or 3 or 4 or 5 .... 60 | path: 当前视频的路径 61 | ''' 62 | data_path,label = self.video_label[index] 63 | 64 | frame_path_list = sorted(os.listdir(data_path)) 65 | 66 | if self.isTrain: 67 | first_loc = random.randint(0, self.first_clips_number - 1) 68 | else: 69 | first_loc = self.test_first 70 | # 从first_loc开始往后选each_clips+search_clips_number张图片,each_clips为第一个clip的张数,后面的search_clips_number为后面的clip所需要的图片数 71 | sub_frames_list = frame_path_list[first_loc:first_loc + self.search_clips_number] 72 | 73 | data_clips = [] 74 | 75 | # clip 0~clips-1 76 | cur_loc = 0 77 | for i in range(0, self.clips): 78 | high_range = cur_loc + self.each_clips 79 | low_range = cur_loc 80 | frames_tmp = sub_frames_list[low_range:high_range] 81 | data_clips.append(self.get_image(frames_tmp, data_path)) 82 | 83 | cur_loc += self.next_jump 84 | 85 | per_loc = random.randint(0,self.per_nums - 1) 86 | per = self.permutation[per_loc] 87 | 88 | data_clips_shuffle = self.order_clip(data_clips,per) 89 | data_clips_normal = self.order_clip(data_clips,[x for x in range(self.opt.snippets)]) 90 | 91 | return {'data_normal':data_clips_normal,'label':label,'path':data_path,'data_shuffle':data_clips_shuffle,'per_label':per_loc,'per_shuffle':torch.tensor(per),'per_normal':torch.tensor([x for x in range(self.opt.snippets)])} 92 | 93 | def order_clip(self,data_clips,order): 94 | clip_list = [torch.stack(data_clips[order[i]],dim=0) for i in range(len(order))] 95 | return torch.stack(clip_list, dim=0) 96 | 97 | # 读取每个clip的图片数据 98 | def get_image(self,frames,data_path): 99 | # 随机采样 choose_num 个图片id 100 | if self.isTrain: 101 | indexs = self.sample(self.choose_num,0,self.each_clips-1) 102 | else: 103 | indexs = [x for x in range(0,self.each_clips,self.each_clips // self.choose_num)] 104 | # 读取图片 105 | result_list = [] 106 | for loc in indexs: 107 | assert len(frames)> loc, data_path 108 | img_path = os.path.join(data_path,frames[loc]) 109 | img = Image.open(img_path).convert("RGB") 110 | img = img.resize((224,224)) 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | result_list.append(img) 114 | 115 | return result_list 116 | 117 | def sample(self,num,min_index,max_index): 118 | s = set() 119 | while len(s) < num: 120 | tmp = random.randint(min_index,max_index) 121 | s.add(tmp) 122 | return list(s) 123 | 124 | def __len__(self): 125 | return len(self.video_label) 126 | 127 | 128 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import utils 4 | 5 | class BaseOptions(): 6 | 7 | def __init__(self): 8 | self.initialized = False 9 | 10 | def initialize(self,parser): 11 | parser.add_argument('--train_video_root', required=True, help='path to train videos') 12 | parser.add_argument('--train_list_root', required=True, help='path to train videos list') 13 | parser.add_argument('--test_video_root', required=True, help='path to test videos') 14 | parser.add_argument('--test_list_root', required=True, help='path to test videos list') 15 | parser.add_argument('--permutation_root',default='./datasets/permutation_10.npy') 16 | parser.add_argument('--dataset_name', required=True,type=str,default='MMI',help='BU3D, AFEW, MMI, DFEW') 17 | parser.add_argument('--name',type=str,default='experiment_name',help='name of the experiment. It decides where to store samples and models') 18 | parser.add_argument('--gpu_ids',type=str,default='0',help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 19 | parser.add_argument('--num_threads',default=4,type=int,help='# threads for loading data') 20 | parser.add_argument('--batch_size',type=int,default=8,help='input batch size') 21 | parser.add_argument('--seed',type=int,default=3456,help='random seed') 22 | parser.add_argument('--checkpoints_dir',type=str,default='./checkpoints',help='models are saved here') 23 | parser.add_argument('--phase',type=str,default='train',help='train,test') 24 | parser.add_argument('--epoch', default=0, type=int, help='start epoch count') 25 | parser.add_argument('--epochs_count', default=160, type=int) 26 | parser.add_argument('--lr', default=1e-4, type=float) 27 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 28 | help='momentum (default: 0.9)') 29 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 30 | metavar='W', help='weight decay (default: 1e-4)') 31 | parser.add_argument('--print-freq', '-p', default=20, type=int, 32 | metavar='N', help='print frequency (default: 10)') 33 | parser.add_argument('--lamb', default=0.1428, type=float, help='control permutation') 34 | parser.add_argument('--warm_up',default=10,type=int) 35 | parser.add_argument('--print_freq',default=20,type=int) 36 | 37 | parser.add_argument('--snippets', type=int, default=7, help='the number of snippets') 38 | parser.add_argument('--per_snippets', type=int, default=5, help='the number of per snippets') 39 | parser.add_argument('--use_norm', action='store_false') 40 | parser.add_argument('--d_model',type=int,default=512) 41 | parser.add_argument('--nhead',type=int,default=4) 42 | parser.add_argument('--encoder_nums',type=int,default=3) 43 | parser.add_argument('--decoder_nums',type=int,default=3) 44 | parser.add_argument('--permutation_classes',type=int,default=10) 45 | parser.add_argument('--parameterDir',type=str,default='./parameters/Resnet18_FER+_pytorch.pth.tar') 46 | 47 | ###########Continue###################### 48 | parser.add_argument('--continue_train',action='store_true') 49 | parser.add_argument('--pre_train_model_path',type=str) 50 | parser.add_argument('--heat_map_path',type=str,default='./heat_map') 51 | ###########Evaluation################### 52 | parser.add_argument('--eval_model_path',type=str) 53 | parser.add_argument('--draw_weight',action='store_true') 54 | parser.add_argument('--test_first',type=int,default=15) 55 | self.initialized = True 56 | return parser 57 | 58 | def gather_options(self): 59 | if not self.initialized: 60 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 61 | parser = self.initialize(parser) 62 | 63 | self.parser = parser 64 | return parser.parse_args() 65 | 66 | def print_options(self,opt): 67 | message = '' 68 | message += '----------------- Options ---------------\n' 69 | for k, v in sorted(vars(opt).items()): 70 | comment = '' 71 | default = self.parser.get_default(k) 72 | if v != default: 73 | comment = '\t[default: %s]' % str(default) 74 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 75 | message += '----------------- End -------------------' 76 | print(message) 77 | 78 | # save to the disk 79 | utils.mkdirs(opt.checkpoints_dir) 80 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 81 | utils.mkdirs(expr_dir) 82 | str_time = utils.get_time() 83 | file_name = os.path.join(expr_dir, '{}_{}_opt.txt'.format(opt.phase,str_time)) 84 | with open(file_name, 'wt') as opt_file: 85 | opt_file.write(message) 86 | opt_file.write('\n') 87 | 88 | def parse(self): 89 | opt = self.gather_options() 90 | opt.isTrain = True if opt.phase == 'train' else False 91 | 92 | self.print_options(opt) 93 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 94 | self.opt = opt 95 | return self.opt -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.optim 6 | import torch.utils.data 7 | import utils 8 | import load_materials 9 | from models.Model import resnet18_EST 10 | from tensorboardX import SummaryWriter 11 | import pytorch_warmup as warmup 12 | import random 13 | from options.base_options import BaseOptions 14 | from torch.backends import cudnn 15 | import os 16 | 17 | 18 | def train(train_loader, model, criterion, optimizer, epoch, opt, writer): 19 | running_loss, count, correct_count, running_cls_loss, running_per_loss, correct_per_count = 0., 0, 0., 0., 0., 0. 20 | model.train() 21 | for i, data in enumerate(train_loader): 22 | target_first = data['label'] 23 | input_var = torch.autograd.Variable(data['data_shuffle']) 24 | order_label = data['per_label'] 25 | target = target_first.cuda(non_blocking=True) 26 | order_label = order_label.cuda(non_blocking=True) 27 | 28 | target_var = torch.autograd.Variable(target) 29 | order_label = torch.autograd.Variable(order_label) 30 | pred_score, per_score = model(input_var, per=data['per_shuffle']) 31 | 32 | # compute gradient and do Adam step 33 | loss_cls = criterion(pred_score, target_var) 34 | loss_per = criterion(per_score, order_label) 35 | loss = loss_cls + loss_per * opt.lamb 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | # store loss 41 | running_loss += loss.item() 42 | running_cls_loss += loss_cls.item() 43 | running_per_loss += loss_per.item() 44 | correct_count += (torch.max(pred_score, dim=1)[1] == target_var).sum() 45 | correct_per_count += (torch.max(per_score, dim=1)[1] == order_label).sum() 46 | count += input_var.size(0) 47 | 48 | if i % opt.print_freq == 0: 49 | print( 50 | 'Epoch: [{0}][{1}/{2}]\t Loss {loss:.4f}\t Cls_Acc{acc:.4f}\t Per_Acc{per_acc:.4f}\t Loss cls {loss_cls:.4f}\t Loss per {loss_per:.4f}' 51 | .format(epoch, i, len(train_loader), loss=running_loss / count, acc=int(correct_count) / count, 52 | per_acc=int(correct_per_count) / count, loss_cls=running_cls_loss / count, 53 | loss_per=running_per_loss / count)) 54 | print( 55 | ' Train_Acc {train_Video:.4f}\t Train_Loss {Train_Loss:.4f}\t Per_Acc{per_acc:.4f}\t Loss cls {loss_cls:.4f}\t Loss per {loss_per:.4f}'. 56 | format(train_Video=int(correct_count) / count, Train_Loss=running_loss / count, 57 | per_acc=int(correct_per_count) / count, loss_cls=running_cls_loss / count, 58 | loss_per=running_per_loss / count)) 59 | 60 | writer.add_scalar('final_loss', running_loss / count, epoch) 61 | writer.add_scalar('final_cls_loss', running_cls_loss / count, epoch) 62 | writer.add_scalar('final_cls_acc', int(correct_count) / count, epoch) 63 | writer.add_scalar('final_per_loss', running_per_loss / count, epoch) 64 | writer.add_scalar('final_per_acc', int(correct_per_count) / count, epoch) 65 | 66 | 67 | def validate(val_loader, model,args): 68 | model.eval() 69 | test_correct_count, test_count, test_correct_per_count, test_per_acc = 0, 0, 0, 0. 70 | 71 | with torch.no_grad(): 72 | for i, data in enumerate(val_loader): 73 | ########################test shuffle 74 | input_var = torch.autograd.Variable(data['data_shuffle']) 75 | order_label = data['per_label'] 76 | order_label = order_label.cuda(non_blocking=True) 77 | order_label = torch.autograd.Variable(order_label) 78 | 79 | # compute output 80 | _, per_score = model(input_var, per=data['per_shuffle']) 81 | #################################################### 82 | 83 | #################### test cls 84 | input_var = torch.autograd.Variable(data['data_normal']) 85 | target_first = data['label'] 86 | # compute output 87 | target = target_first.cuda(non_blocking=True) 88 | 89 | target_var = torch.autograd.Variable(target) 90 | 91 | pred_score, _ = model(input_var, per=data['per_normal']) 92 | ##################################################### 93 | #if torch.max(pred_score, dim=1)[1] != target_var: 94 | # print(data['path'], ' ', utils.cate2label(args.dataset_name)[torch.max(pred_score, dim=1)[1].item()]) 95 | test_correct_count += (torch.max(pred_score, dim=1)[1] == target_var).sum() 96 | test_correct_per_count += (torch.max(per_score, dim=1)[1] == order_label).sum() 97 | 98 | test_count += input_var.size(0) 99 | 100 | if args.draw_weight: 101 | video_path_list = data['path'][0].split('/') 102 | video_path = video_path_list[-2]+'/'+video_path_list[-1] 103 | heat_map_path = os.path.join(args.heat_map_path,args.name,video_path) 104 | trans_path = os.path.join(heat_map_path,'trans') 105 | cos_path = os.path.join(heat_map_path,'cos') 106 | utils.mkdirs(trans_path) 107 | utils.mkdirs(cos_path) 108 | for t in range(len(weight_list)): 109 | heat = os.path.join(trans_path,str(t)+'.npy') 110 | utils.draw_weight(weight_list[t].squeeze().detach().cpu().numpy(),heat) 111 | for t in range(len(cos_weight)): 112 | heat = os.path.join(cos_path,str(t)+'.npy') 113 | utils.draw_weight(cos_weight[t].detach().cpu().numpy(),heat) 114 | 115 | test_acc = int(test_correct_count) / test_count 116 | test_per_acc = int(test_correct_per_count) / test_count 117 | print(' Test_Acc: {test_Video:.4f} '.format(test_Video=test_acc)) 118 | print(' Test_per_Acc: {test_per_Video:.4f} '.format(test_per_Video=test_per_acc)) 119 | 120 | return test_acc, test_per_acc 121 | 122 | 123 | def main(opt): 124 | train_loader, val_loader = load_materials.LoadDataset(opt) 125 | model = resnet18_EST(clips=opt.snippets, img_num_per_clip=opt.per_snippets, d_model=opt.d_model, nhead=opt.nhead, 126 | encoder_nums=opt.encoder_nums, decoder_nums=opt.decoder_nums, use_norm=opt.use_norm, 127 | per_classes=opt.permutation_classes,draw_weight=opt.draw_weight) 128 | if opt.isTrain and not opt.continue_train: 129 | model = load_materials.LoadParameter(model, opt.parameterDir) 130 | print('train !') 131 | elif opt.continue_train: 132 | model = torch.nn.DataParallel(model).cuda() 133 | model.load_state_dict(torch.load(opt.pre_train_model_path)['state_dict']) 134 | print('load eval model !') 135 | else: 136 | print('load eval model !') 137 | model = torch.nn.DataParallel(model).cuda() 138 | model.load_state_dict(torch.load(opt.eval_model_path)['state_dict']) 139 | 140 | criterion = nn.CrossEntropyLoss().cuda() 141 | cudnn.benchmark = True 142 | 143 | if not opt.isTrain: 144 | validate(val_loader, model,opt) 145 | return 146 | 147 | per_branch_params = list(map(id, model.module.per_branch.parameters())) 148 | base_params = filter(lambda p: id(p) not in per_branch_params and p.requires_grad, model.parameters()) 149 | optimizer = torch.optim.Adam([ 150 | {'params': base_params}, 151 | {'params': model.module.per_branch.parameters(), 'lr': opt.lr} 152 | ], lr=opt.lr, betas=(0.9, 0.999), weight_decay=opt.weight_decay) 153 | 154 | lr_schduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.epochs_count) 155 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=opt.warm_up) 156 | warmup_scheduler.last_step = -1 157 | best_prec1 = 0. 158 | for epoch in range(opt.epoch, opt.epochs_count): 159 | lr_schduler.step(epoch) 160 | warmup_scheduler.dampen() 161 | train(train_loader, model, criterion, optimizer, epoch, opt, writer) 162 | prec1, per_acc = validate(val_loader, model,opt) 163 | 164 | writer.add_scalar('final_test_acc', prec1, epoch) 165 | writer.add_scalar('final_test_per_acc', per_acc, epoch) 166 | is_best = prec1 > best_prec1 167 | if is_best: 168 | print('better model!') 169 | best_prec1 = max(prec1, best_prec1) 170 | utils.save_checkpoint({ 171 | 'epoch': epoch + 1, 172 | 'state_dict': model.state_dict(), 173 | 'prec1': prec1, 174 | }, opt) 175 | else: 176 | print('Model too bad & not save') 177 | 178 | 179 | if __name__ == '__main__': 180 | opt = BaseOptions().parse() 181 | 182 | cudnn.benchmark = False # if benchmark=True, deterministic will be False 183 | cudnn.deterministic = True 184 | torch.manual_seed(opt.seed) # 为CPU设置随机种子 185 | torch.cuda.manual_seed(opt.seed) # 为当前GPU设置随机种子 186 | torch.cuda.manual_seed_all(opt.seed) # 为所有GPU设置随机种子 187 | random.seed(opt.seed) 188 | 189 | writer = SummaryWriter(comment=opt.name) 190 | 191 | main(opt) 192 | 193 | writer.close() 194 | -------------------------------------------------------------------------------- /models/Model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | from models.transformer import Transformer 5 | import numpy as np 6 | from einops import rearrange, reduce, repeat 7 | import torch.nn.functional as F 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | class PosEncoding(torch.nn.Module): 15 | def __init__(self,max_seq_len,d_model=512): 16 | super(PosEncoding, self).__init__() 17 | pos_enc = np.array( 18 | [[pos / np.power(10000,2.0 * (j//2) / d_model) for j in range(d_model)] for pos in range(max_seq_len)] 19 | ) 20 | pos_enc[:,0::2] = np.sin(pos_enc[:,0::2]) 21 | pos_enc[:,1::2] = np.cos(pos_enc[:,1::2]) 22 | pos_enc = pos_enc.astype(np.float32) 23 | self.pos_enc = torch.nn.Embedding(max_seq_len,d_model) 24 | self.pos_enc.weight = torch.nn.Parameter(torch.from_numpy(pos_enc),requires_grad=False) 25 | 26 | def forward(self, input_len): 27 | ''' 28 | 29 | :param input_len: [7,7,7,7,...,7] shape=[batch_size] 30 | :return: 31 | ''' 32 | input_pos = torch.tensor([list(range(0,len)) for len in input_len]).cuda() 33 | return self.pos_enc(input_pos) 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU() 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | class Permutation(nn.Module): 68 | def __init__(self,d_model,classes,clips): 69 | super(Permutation, self).__init__() 70 | self.d_model = d_model 71 | self.classes = classes 72 | 73 | self.classifier = nn.Sequential(*[nn.Linear(clips*512,2048),nn.BatchNorm1d(2048),nn.ReLU(inplace=True), 74 | nn.Linear(2048,512),nn.BatchNorm1d(512),nn.ReLU(inplace=True), 75 | nn.Linear(512,self.classes)]) 76 | 77 | self.weights_init(self.classifier) 78 | 79 | def forward(self, input): 80 | output = self.classifier(input) 81 | return output 82 | 83 | def weights_init(self,model): 84 | if type(model) in [nn.ConvTranspose2d, nn.Linear]: 85 | nn.init.xavier_normal(model.weight.data) 86 | nn.init.constant(model.bias.data, 0.1) 87 | 88 | 89 | class EST(nn.Module): 90 | def __init__(self, block, layers,clips=7,img_num_per_clip=5,d_model=512,nhead=4,dropout=0.1,encoder_nums=4,decoder_nums=4, use_norm=True,per_classes=10,draw_weight=False): 91 | self.inplanes = 64 92 | self.key_clips = 1 93 | self.per_classes = per_classes 94 | self.d_model = d_model 95 | self.clips = clips 96 | self.img_num_per_clip = img_num_per_clip 97 | self.use_norm = use_norm 98 | self.draw_weight = draw_weight 99 | super(EST, self).__init__() 100 | ### First layer 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU() 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AdaptiveAvgPool2d(1) 111 | self.dropout = nn.Dropout(0.5) 112 | self.dropout2 = nn.Dropout(0.5) 113 | self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 114 | self.norm1 = nn.LayerNorm(d_model) 115 | self.dropout3 = nn.Dropout(0.1) 116 | self.cos = nn.CosineSimilarity(dim=1) 117 | self.maxpool2 = nn.AdaptiveMaxPool1d(1) 118 | 119 | ### Second layer 120 | self.transformer = Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=encoder_nums, 121 | num_decoder_layers=decoder_nums 122 | , dropout=dropout,draw_weight=self.draw_weight) 123 | self.query_embed = torch.nn.Embedding(self.key_clips, self.d_model) 124 | self.pos_enc = PosEncoding(max_seq_len=self.clips, d_model=self.d_model) 125 | 126 | #############Third layer 127 | self.per_branch = Permutation(self.d_model,self.per_classes,self.clips) 128 | 129 | ### Forth layer 130 | #self.final_classifier = nn.Sequential(*[nn.Linear(512, 256), nn.ReLU(True), 131 | # nn.Linear(256, 256), nn.ReLU(True), 132 | # nn.Dropout(0.4), 133 | # nn.Sequential(nn.Linear(256, 43))]) 134 | 135 | self.classifier = nn.Sequential(*[nn.Linear(512, 128), nn.ReLU(True), 136 | nn.Linear(128, 64), nn.ReLU(True), 137 | nn.Dropout(0.4), 138 | nn.Sequential(nn.Linear(64, 7))]) 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | 148 | def _make_layer(self, block, planes, blocks, stride=1): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | nn.Conv2d(self.inplanes, planes * block.expansion, 153 | kernel_size=1, stride=stride, bias=False), 154 | nn.BatchNorm2d(planes * block.expansion), 155 | ) 156 | 157 | layers = [] 158 | layers.append(block(self.inplanes, planes, stride, downsample)) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, blocks): 161 | layers.append(block(self.inplanes, planes)) 162 | 163 | return nn.Sequential(*layers) 164 | 165 | def forward(self, x,per=torch.tensor([0,1,2,3,4,5,6])): 166 | 167 | #print(x.size()) 168 | video_feature = [] 169 | cos_weight = [] 170 | 171 | b = x.size(0) 172 | 173 | tmp = rearrange(x, 'b s f c h w -> (b s f) c h w') 174 | tmp = self.conv1(tmp) 175 | tmp = self.bn1(tmp) 176 | tmp = self.relu(tmp) 177 | tmp = self.maxpool(tmp) 178 | 179 | tmp = self.layer1(tmp) 180 | tmp = self.layer2(tmp) 181 | tmp = self.layer3(tmp) 182 | tmp = self.layer4(tmp) 183 | tmp = self.avgpool(tmp) 184 | 185 | tmp = tmp.squeeze(3).squeeze(2) 186 | tmp = rearrange(tmp,'(b s f) c -> b s f c',s=self.clips,f=self.img_num_per_clip) 187 | for i in range(self.clips): 188 | vs_stack = tmp[:,i,:,:] 189 | vs_stack = vs_stack.permute(1,0,2) 190 | output,weight = self.attn(vs_stack,vs_stack,vs_stack) 191 | if self.use_norm: 192 | output = vs_stack + self.dropout3(output) 193 | output = self.norm1(output).permute(1,2,0) 194 | 195 | global_feature = self.maxpool2(output).squeeze(dim=2) 196 | local_feature = output.permute(2,0,1) # 5*b*512 197 | dis_list = [] 198 | for j in range(self.img_num_per_clip): 199 | dis = self.cos(local_feature[j],global_feature) 200 | dis_list.append(dis) 201 | dis_alpha = torch.stack(dis_list,dim=1).unsqueeze(dim=1) 202 | dis_alpha = torch.clamp(dis_alpha, min=1e-8) 203 | output = output.mul(dis_alpha).sum(2).div(dis_alpha.sum(2)) 204 | 205 | cos_weight.append(dis_alpha.squeeze()) 206 | video_feature.append(output) 207 | 208 | ori_video_feature = torch.stack(video_feature,dim=1) #video_feature = bacth_size * clip_num * clip_feature的维度 即 bacth_size * 7 * 512 b = video_feature.size(0) 209 | b = int(ori_video_feature.size(0)) 210 | pos = torch.tensor([self.clips] * b) 211 | pos_enc = self.pos_enc(pos) 212 | for i in range(b): 213 | pos_enc[i] = pos_enc[i][per[i]] 214 | x = ori_video_feature + pos_enc 215 | x = x.permute(1, 0, 2) 216 | tgt = self.query_embed.weight 217 | tgt = tgt.unsqueeze(1).repeat(1,b,1) 218 | emotion_clip_feature,weight_list = self.transformer(x, tgt) #.permute(1, 2, 0) 219 | emotion_clip_feature = emotion_clip_feature.permute(1,2,0) 220 | 221 | emotion_clip_feature = emotion_clip_feature.reshape(-1,self.d_model * self.key_clips) 222 | 223 | 224 | #####permutation 225 | ori_video_feature_detach = ori_video_feature.detach() 226 | per_feature = ori_video_feature_detach + emotion_clip_feature.unsqueeze(dim=1) 227 | per_output = self.per_branch(per_feature.view(b,-1)) 228 | 229 | output = self.classifier(emotion_clip_feature) 230 | 231 | return output,per_output 232 | 233 | 234 | def resnet18_EST(pretrained=False, **kwargs): 235 | # Constructs base a ResNet-18 model. 236 | model = EST(BasicBlock, [2, 2, 2, 2], **kwargs) 237 | return model -------------------------------------------------------------------------------- /models/Model_speed.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | from models.transformer import Transformer 5 | import numpy as np 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | class PosEncoding(torch.nn.Module): 13 | def __init__(self,max_seq_len,d_model=512): 14 | super(PosEncoding, self).__init__() 15 | pos_enc = np.array( 16 | [[pos / np.power(10000,2.0 * (j//2) / d_model) for j in range(d_model)] for pos in range(max_seq_len)] 17 | ) 18 | pos_enc[:,0::2] = np.sin(pos_enc[:,0::2]) 19 | pos_enc[:,1::2] = np.cos(pos_enc[:,1::2]) 20 | pos_enc = pos_enc.astype(np.float32) 21 | self.pos_enc = torch.nn.Embedding(max_seq_len,d_model) 22 | self.pos_enc.weight = torch.nn.Parameter(torch.from_numpy(pos_enc),requires_grad=False) 23 | 24 | def forward(self, input_len): 25 | ''' 26 | 27 | :param input_len: [7,7,7,7,...,7] shape=[batch_size] 28 | :return: 29 | ''' 30 | input_pos = torch.tensor([list(range(0,len)) for len in input_len]).cuda() 31 | return self.pos_enc(input_pos) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None): 38 | super(BasicBlock, self).__init__() 39 | self.conv1 = conv3x3(inplanes, planes, stride) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.relu = nn.ReLU() 42 | self.conv2 = conv3x3(planes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | 60 | out += residual 61 | out = self.relu(out) 62 | 63 | return out 64 | 65 | class Permutation(nn.Module): 66 | def __init__(self,d_model,classes): 67 | super(Permutation, self).__init__() 68 | self.d_model = d_model 69 | self.classes = classes 70 | 71 | self.classifier = nn.Sequential(*[nn.Linear(7*512,2048),nn.BatchNorm1d(2048),nn.ReLU(inplace=True), 72 | nn.Linear(2048,512),nn.BatchNorm1d(512),nn.ReLU(inplace=True), 73 | nn.Linear(512,self.classes)]) 74 | 75 | self.weights_init(self.classifier) 76 | 77 | def forward(self, input): 78 | output = self.classifier(input) 79 | return output 80 | 81 | def weights_init(self,model): 82 | if type(model) in [nn.ConvTranspose2d, nn.Linear]: 83 | nn.init.xavier_normal(model.weight.data) 84 | nn.init.constant(model.bias.data, 0.1) 85 | 86 | 87 | class EST(nn.Module): 88 | def __init__(self, block, layers,clips=7,img_num_per_clip=5,d_model=512,nhead=4,dropout=0.1,encoder_nums=3,decoder_nums=3, use_norm=True,per_classes=10): 89 | self.inplanes = 64 90 | self.key_clips = 1 91 | self.per_classes = per_classes 92 | self.d_model = d_model 93 | self.clips = clips 94 | self.img_num_per_clip = img_num_per_clip 95 | self.use_norm = use_norm 96 | super(EST, self).__init__() 97 | ### First layer 98 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 99 | bias=False) 100 | self.bn1 = nn.BatchNorm2d(64) 101 | self.relu = nn.ReLU() 102 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 103 | self.layer1 = self._make_layer(block, 64, layers[0]) 104 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 105 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 106 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 107 | self.avgpool = nn.AdaptiveAvgPool2d(1) 108 | self.dropout = nn.Dropout(0.5) 109 | self.dropout2 = nn.Dropout(0.5) 110 | self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 111 | self.norm1 = nn.LayerNorm(d_model) 112 | self.dropout3 = nn.Dropout(0.1) 113 | self.cos = nn.CosineSimilarity(dim=2) 114 | self.maxpool2 = nn.AdaptiveMaxPool1d(1) 115 | 116 | ### Second layer 117 | self.transformer = Transformer(d_model=d_model, nhead=nhead, num_encoder_layers=encoder_nums, 118 | num_decoder_layers=decoder_nums 119 | , dropout=dropout) 120 | self.query_embed = torch.nn.Embedding(self.key_clips, self.d_model) 121 | self.pos_enc = PosEncoding(max_seq_len=self.clips, d_model=self.d_model) 122 | 123 | #############Third layer 124 | self.per_branch = Permutation(self.d_model,self.per_classes) 125 | 126 | ### Forth layer 127 | self.classifier = nn.Sequential(*[nn.Linear(512, 128), nn.ReLU(True), 128 | nn.Linear(128, 64), nn.ReLU(True), 129 | nn.Dropout(0.4), 130 | nn.Sequential(nn.Linear(64, 7))]) 131 | 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 135 | m.weight.data.normal_(0, math.sqrt(2. / n)) 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | 140 | def _make_layer(self, block, planes, blocks, stride=1): 141 | downsample = None 142 | if stride != 1 or self.inplanes != planes * block.expansion: 143 | downsample = nn.Sequential( 144 | nn.Conv2d(self.inplanes, planes * block.expansion, 145 | kernel_size=1, stride=stride, bias=False), 146 | nn.BatchNorm2d(planes * block.expansion), 147 | ) 148 | 149 | layers = [] 150 | layers.append(block(self.inplanes, planes, stride, downsample)) 151 | self.inplanes = planes * block.expansion 152 | for i in range(1, blocks): 153 | layers.append(block(self.inplanes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def forward(self, x,per=torch.tensor([0,1,2,3,4,5,6])): 158 | 159 | video_feature = [] 160 | b = x.size(0) 161 | x = x.view(b*35,3,224,224) 162 | f = self.conv1(x) 163 | f = self.bn1(f) 164 | f = self.relu(f) 165 | f = self.maxpool(f) 166 | 167 | f = self.layer1(f) 168 | f = self.layer2(f) 169 | f = self.layer3(f) 170 | f = self.layer4(f) 171 | f = self.avgpool(f) 172 | f = f.squeeze() 173 | #print(f.size()) 174 | #f = f.view(self.img_num_per_clip,b*self.clips,self.d_model) 175 | #x = f.view(b,self.clips,self.img_num_per_clip,self.d_model) 176 | x = f.view(self.img_num_per_clip,b*self.clips,self.d_model) 177 | output,weight = self.attn(x,x,x) 178 | output = x + self.dropout3(output) 179 | output = self.norm1(output).permute(1,2,0) 180 | global_feature = self.maxpool2(output).view(1,self.clips*b,self.d_model) 181 | local_feature = output.permute(2,0,1) 182 | #dis_list = [] 183 | #local_feature = local_feature.view(b*self.img_num_per_clip*self.clips,self.d_model) 184 | #print(local_feature.size(),global_feature.size()) 185 | dis_alpha = self.cos(local_feature,global_feature).unsqueeze(dim=2) 186 | #print(dis_alpha.size()) 187 | dis_alpha = dis_alpha.permute(1,2,0) 188 | #for i in range(self.img_num_per_clip): 189 | # dis = self.cos(local_feature[i],global_feature) 190 | # dis_list.append(dis) 191 | #dis_alpha = torch.stack(dis_list,dim=1).unsqueeze(dim=1) 192 | #print(dis_alpha.size()) 193 | dis_alpha = torch.clamp(dis_alpha, min=1e-8) 194 | output = output.mul(dis_alpha).sum(2).div(dis_alpha.sum(2)) 195 | #print(output.size()) 196 | output.view(b,7,self.d_model) 197 | video_feature.append(output) 198 | #for i in range(self.clips): 199 | 200 | # vs = [] 201 | 202 | # ff = x[:, :, :, :, :, i] # x[batch_size,3,224,224, 5, 7] 203 | # for j in range(self.img_num_per_clip): 204 | # f = ff[:, :, :, :, j] # ff[batch_size,3,224,224, 3] 205 | # ff[batch_size,3,224,224] 206 | 207 | # f = self.conv1(f) 208 | # f = self.bn1(f) 209 | # f = self.relu(f) 210 | # f = self.maxpool(f) 211 | 212 | # f = self.layer1(f) 213 | # f = self.layer2(f) 214 | # f = self.layer3(f) 215 | # f = self.layer4(f) 216 | # f = self.avgpool(f) 217 | 218 | # f = f.squeeze(3).squeeze(2) # f[1, 512, 1, 1] ---> f[1, 512] 219 | 220 | # vs.append(f) 221 | 222 | # vs_stack = torch.stack(vs,dim=2) 223 | # vs_stack = vs_stack.permute(2,0,1) 224 | # output,weight = self.attn(vs_stack,vs_stack,vs_stack) 225 | # if self.use_norm: 226 | # output = vs_stack + self.dropout3(output) 227 | # output = self.norm1(output).permute(1,2,0) 228 | 229 | # global_feature = self.maxpool2(output).squeeze(dim=2) 230 | # local_feature = output.permute(2,0,1) # 5*b*512 231 | # dis_list = [] 232 | # for t in range(self.img_num_per_clip): 233 | # dis = self.cos(local_feature[i],global_feature) 234 | # dis_list.append(dis) 235 | # dis_alpha = torch.stack(dis_list,dim=1).unsqueeze(dim=1) 236 | # dis_alpha = torch.clamp(dis_alpha, min=1e-8) 237 | # output = output.mul(dis_alpha).sum(2).div(dis_alpha.sum(2)) 238 | 239 | # video_feature.append(output) 240 | 241 | 242 | ori_video_feature = torch.stack(video_feature,dim=1) #video_feature = bacth_size * clip_num * clip_feature的维度 即 bacth_size * 7 * 512 b = video_feature.size(0) 243 | b = int(ori_video_feature.size(0)) 244 | pos = torch.tensor([self.clips] * b) 245 | pos_enc = self.pos_enc(pos) 246 | for i in range(b): 247 | pos_enc[i] = pos_enc[i][per[i]] 248 | x = ori_video_feature + pos_enc 249 | x = x.permute(1, 0, 2) 250 | tgt = self.query_embed.weight 251 | tgt = tgt.unsqueeze(1).repeat(1,b,1) 252 | emotion_clip_feature = self.transformer(x, tgt).permute(1, 2, 0) 253 | 254 | emotion_clip_feature = emotion_clip_feature.reshape(-1,self.d_model * self.key_clips) 255 | 256 | 257 | #####permutation 258 | #ori_video_feature_detach = ori_video_feature.detach() 259 | #per_feature = ori_video_feature_detach + emotion_clip_feature.unsqueeze(dim=1) 260 | #per_output = self.per_branch(per_feature.view(b,-1)) 261 | 262 | output = self.classifier(emotion_clip_feature) 263 | 264 | return output 265 | 266 | 267 | def resnet18_EST(pretrained=False, **kwargs): 268 | # Constructs base a ResNet-18 model. 269 | model = EST(BasicBlock, [2, 2, 2, 2], **kwargs) 270 | return model -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from torch.nn import functional as F 4 | from torch.nn import Module 5 | from torch.nn import MultiheadAttention 6 | from torch.nn.modules import ModuleList 7 | from torch.nn.init import xavier_uniform_,kaiming_uniform_ 8 | from torch.nn import Dropout 9 | from torch.nn import Linear 10 | from torch.nn import LayerNorm 11 | 12 | 13 | class Transformer(Module): 14 | r"""A transformer model. User is able to modify the attributes as needed. The architecture 15 | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, 16 | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 17 | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 18 | Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) 19 | model with corresponding parameters. 20 | 21 | Args: 22 | d_model: the number of expected features in the encoder/decoder inputs (default=512). 23 | nhead: the number of heads in the multiheadattention models (default=8). 24 | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 25 | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 26 | dim_feedforward: the dimension of the feedforward network model (default=2048). 27 | dropout: the dropout value (default=0.1). 28 | activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). 29 | custom_encoder: custom encoder (default=None). 30 | custom_decoder: custom decoder (default=None). 31 | 32 | Examples:: 33 | >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) 34 | >>> src = torch.rand((10, 32, 512)) 35 | >>> tgt = torch.rand((20, 32, 512)) 36 | >>> out = transformer_model(src, tgt) 37 | 38 | Note: A full example to apply nn.Transformer module for the word language model is available in 39 | https://github.com/pytorch/examples/tree/master/word_language_model 40 | """ 41 | 42 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 43 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 44 | activation="relu", custom_encoder=None, custom_decoder=None,draw_weight=False): 45 | super(Transformer, self).__init__() 46 | 47 | if custom_encoder is not None: 48 | self.encoder = custom_encoder 49 | else: 50 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 51 | encoder_norm = LayerNorm(d_model) 52 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 53 | 54 | if custom_decoder is not None: 55 | self.decoder = custom_decoder 56 | else: 57 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation,draw_weight) 58 | decoder_norm = LayerNorm(d_model) 59 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,draw_weight) 60 | 61 | self._reset_parameters() 62 | 63 | self.d_model = d_model 64 | self.nhead = nhead 65 | self.draw_weight = draw_weight 66 | 67 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, 68 | memory_mask=None, src_key_padding_mask=None, 69 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 70 | r"""Take in and process masked source/target sequences. 71 | 72 | Args: 73 | src: the sequence to the encoder (required). 74 | tgt: the sequence to the decoder (required). 75 | src_mask: the additive mask for the src sequence (optional). 76 | tgt_mask: the additive mask for the tgt sequence (optional). 77 | memory_mask: the additive mask for the encoder output (optional). 78 | src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). 79 | tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). 80 | memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). 81 | 82 | Shape: 83 | - src: :math:`(S, N, E)`. 84 | - tgt: :math:`(T, N, E)`. 85 | - src_mask: :math:`(S, S)`. 86 | - tgt_mask: :math:`(T, T)`. 87 | - memory_mask: :math:`(T, S)`. 88 | - src_key_padding_mask: :math:`(N, S)`. 89 | - tgt_key_padding_mask: :math:`(N, T)`. 90 | - memory_key_padding_mask: :math:`(N, S)`. 91 | 92 | Note: [src/tgt/memory]_mask should be filled with 93 | float('-inf') for the masked positions and float(0.0) else. These masks 94 | ensure that predictions for position i depend only on the unmasked positions 95 | j and are applied identically for each sequence in a batch. 96 | [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 97 | that should be masked with float('-inf') and False values will be unchanged. 98 | This mask ensures that no information will be taken from position i if 99 | it is masked, and has a separate mask for each sequence in a batch. 100 | 101 | - output: :math:`(T, N, E)`. 102 | 103 | Note: Due to the multi-head attention architecture in the transformer model, 104 | the output sequence length of a transformer is same as the input sequence 105 | (i.e. target) length of the decode. 106 | 107 | where S is the source sequence length, T is the target sequence length, N is the 108 | batch size, E is the feature number 109 | 110 | Examples: 111 | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 112 | """ 113 | 114 | if src.size(1) != tgt.size(1): 115 | raise RuntimeError("the batch number of src and tgt must be equal") 116 | 117 | if src.size(2) != self.d_model or tgt.size(2) != self.d_model: 118 | raise RuntimeError("the feature number of src and tgt must be equal to d_model") 119 | 120 | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) 121 | if self.draw_weight: 122 | output,weight_list = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 123 | tgt_key_padding_mask=tgt_key_padding_mask, 124 | memory_key_padding_mask=memory_key_padding_mask) 125 | return output,weight_list 126 | else: 127 | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 128 | tgt_key_padding_mask=tgt_key_padding_mask, 129 | memory_key_padding_mask=memory_key_padding_mask) 130 | return output,None 131 | 132 | def generate_square_subsequent_mask(self, sz): 133 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 134 | Unmasked positions are filled with float(0.0). 135 | """ 136 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 137 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 138 | return mask 139 | 140 | def _reset_parameters(self): 141 | r"""Initiate parameters in the transformer model.""" 142 | 143 | for p in self.parameters(): 144 | if p.dim() > 1: 145 | #xavier_uniform_(p) 146 | kaiming_uniform_(p) 147 | 148 | 149 | class TransformerEncoder(Module): 150 | r"""TransformerEncoder is a stack of N encoder layers 151 | 152 | Args: 153 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 154 | num_layers: the number of sub-encoder-layers in the encoder (required). 155 | norm: the layer normalization component (optional). 156 | 157 | Examples:: 158 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 159 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 160 | >>> src = torch.rand(10, 32, 512) 161 | >>> out = transformer_encoder(src) 162 | """ 163 | 164 | def __init__(self, encoder_layer, num_layers, norm=None): 165 | super(TransformerEncoder, self).__init__() 166 | self.layers = _get_clones(encoder_layer, num_layers) 167 | self.num_layers = num_layers 168 | self.norm = norm 169 | 170 | def forward(self, src, mask=None, src_key_padding_mask=None): 171 | r"""Pass the input through the endocder layers in turn. 172 | 173 | Args: 174 | src: the sequnce to the encoder (required). 175 | mask: the mask for the src sequence (optional). 176 | src_key_padding_mask: the mask for the src keys per batch (optional). 177 | 178 | Shape: 179 | see the docs in Transformer class. 180 | """ 181 | output = src 182 | 183 | for i in range(self.num_layers): 184 | output = self.layers[i](output, src_mask=mask, 185 | src_key_padding_mask=src_key_padding_mask) 186 | 187 | if self.norm: 188 | output = self.norm(output) 189 | 190 | return output 191 | 192 | 193 | class TransformerDecoder(Module): 194 | r"""TransformerDecoder is a stack of N decoder layers 195 | 196 | Args: 197 | decoder_layer: an instance of the TransformerDecoderLayer() class (required). 198 | num_layers: the number of sub-decoder-layers in the decoder (required). 199 | norm: the layer normalization component (optional). 200 | 201 | Examples:: 202 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 203 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 204 | >>> memory = torch.rand(10, 32, 512) 205 | >>> tgt = torch.rand(20, 32, 512) 206 | >>> out = transformer_decoder(tgt, memory) 207 | """ 208 | 209 | def __init__(self, decoder_layer, num_layers, norm=None,draw_weight=False): 210 | super(TransformerDecoder, self).__init__() 211 | self.layers = _get_clones(decoder_layer, num_layers) 212 | self.num_layers = num_layers 213 | self.norm = norm 214 | self.draw_weight = draw_weight 215 | 216 | def forward(self, tgt, memory, tgt_mask=None, 217 | memory_mask=None, tgt_key_padding_mask=None, 218 | memory_key_padding_mask=None): 219 | r"""Pass the inputs (and mask) through the decoder layer in turn. 220 | 221 | Args: 222 | tgt: the sequence to the decoder (required). 223 | memory: the sequnce from the last layer of the encoder (required). 224 | tgt_mask: the mask for the tgt sequence (optional). 225 | memory_mask: the mask for the memory sequence (optional). 226 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 227 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 228 | 229 | Shape: 230 | see the docs in Transformer class. 231 | """ 232 | output = tgt 233 | weight_list = [] 234 | for i in range(self.num_layers): 235 | if self.draw_weight: 236 | output,weight = self.layers[i](output, memory, tgt_mask=tgt_mask, 237 | memory_mask=memory_mask, 238 | tgt_key_padding_mask=tgt_key_padding_mask, 239 | memory_key_padding_mask=memory_key_padding_mask) 240 | weight_list.append(weight) 241 | else: 242 | output = self.layers[i](output, memory, tgt_mask=tgt_mask, 243 | memory_mask=memory_mask, 244 | tgt_key_padding_mask=tgt_key_padding_mask, 245 | memory_key_padding_mask=memory_key_padding_mask) 246 | 247 | if self.norm: 248 | output = self.norm(output) 249 | 250 | if self.draw_weight: 251 | return output,weight_list 252 | else: 253 | return output 254 | 255 | class TransformerEncoderLayer(Module): 256 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 257 | This standard encoder layer is based on the paper "Attention Is All You Need". 258 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 259 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 260 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 261 | in a different way during application. 262 | 263 | Args: 264 | d_model: the number of expected features in the input (required). 265 | nhead: the number of heads in the multiheadattention models (required). 266 | dim_feedforward: the dimension of the feedforward network model (default=2048). 267 | dropout: the dropout value (default=0.1). 268 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 269 | 270 | Examples:: 271 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 272 | >>> src = torch.rand(10, 32, 512) 273 | >>> out = encoder_layer(src) 274 | """ 275 | 276 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 277 | super(TransformerEncoderLayer, self).__init__() 278 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 279 | # Implementation of Feedforward model 280 | self.linear1 = Linear(d_model, dim_feedforward) 281 | self.dropout = Dropout(dropout) 282 | self.linear2 = Linear(dim_feedforward, d_model) 283 | 284 | self.norm1 = LayerNorm(d_model) 285 | self.norm2 = LayerNorm(d_model) 286 | self.dropout1 = Dropout(dropout) 287 | self.dropout2 = Dropout(dropout) 288 | 289 | self.activation = _get_activation_fn(activation) 290 | 291 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 292 | r"""Pass the input through the endocder layer. 293 | 294 | Args: 295 | src: the sequnce to the encoder layer (required). 296 | src_mask: the mask for the src sequence (optional). 297 | src_key_padding_mask: the mask for the src keys per batch (optional). 298 | 299 | Shape: 300 | see the docs in Transformer class. 301 | """ 302 | src2,weight = self.self_attn(src, src, src, attn_mask=src_mask, 303 | key_padding_mask=src_key_padding_mask) 304 | src = src + self.dropout1(src2) 305 | src = self.norm1(src) 306 | if hasattr(self, "activation"): 307 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 308 | else: # for backward compatibility 309 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 310 | src = src + self.dropout2(src2) 311 | src = self.norm2(src) 312 | return src 313 | 314 | 315 | class TransformerDecoderLayer(Module): 316 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 317 | This standard decoder layer is based on the paper "Attention Is All You Need". 318 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 319 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 320 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 321 | in a different way during application. 322 | 323 | Args: 324 | d_model: the number of expected features in the input (required). 325 | nhead: the number of heads in the multiheadattention models (required). 326 | dim_feedforward: the dimension of the feedforward network model (default=2048). 327 | dropout: the dropout value (default=0.1). 328 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 329 | 330 | Examples:: 331 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 332 | >>> memory = torch.rand(10, 32, 512) 333 | >>> tgt = torch.rand(20, 32, 512) 334 | >>> out = decoder_layer(tgt, memory) 335 | """ 336 | 337 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",draw_weight=False): 338 | super(TransformerDecoderLayer, self).__init__() 339 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 340 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 341 | # Implementation of Feedforward model 342 | self.linear1 = Linear(d_model, dim_feedforward) 343 | self.dropout = Dropout(dropout) 344 | self.linear2 = Linear(dim_feedforward, d_model) 345 | 346 | self.norm1 = LayerNorm(d_model) 347 | self.norm2 = LayerNorm(d_model) 348 | self.norm3 = LayerNorm(d_model) 349 | self.dropout1 = Dropout(dropout) 350 | self.dropout2 = Dropout(dropout) 351 | self.dropout3 = Dropout(dropout) 352 | 353 | self.activation = _get_activation_fn(activation) 354 | 355 | self.draw_weight = draw_weight 356 | 357 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 358 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 359 | r"""Pass the inputs (and mask) through the decoder layer. 360 | 361 | Args: 362 | tgt: the sequence to the decoder layer (required). 363 | memory: the sequnce from the last layer of the encoder (required). 364 | tgt_mask: the mask for the tgt sequence (optional). 365 | memory_mask: the mask for the memory sequence (optional). 366 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 367 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 368 | 369 | Shape: 370 | see the docs in Transformer class. 371 | """ 372 | tgt2,weight_mask = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 373 | key_padding_mask=tgt_key_padding_mask) 374 | tgt = tgt + self.dropout1(tgt2) 375 | tgt = self.norm1(tgt) 376 | tgt2,weight = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 377 | key_padding_mask=memory_key_padding_mask) 378 | tgt = tgt + self.dropout2(tgt2) 379 | tgt = self.norm2(tgt) 380 | if hasattr(self, "activation"): 381 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 382 | else: # for backward compatibility 383 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 384 | tgt = tgt + self.dropout3(tgt2) 385 | tgt = self.norm3(tgt) 386 | if self.draw_weight: 387 | return tgt,weight 388 | else: 389 | return tgt 390 | 391 | 392 | def _get_clones(module, N): 393 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 394 | 395 | 396 | def _get_activation_fn(activation): 397 | if activation == "relu": 398 | return F.relu 399 | elif activation == "gelu": 400 | return F.gelu 401 | else: 402 | raise RuntimeError("activation should be relu/gelu, not %s." % activation) 403 | --------------------------------------------------------------------------------