├── figure └── TILOSS.png ├── src ├── __pycache__ │ ├── config.cpython-39.pyc │ ├── models.cpython-39.pyc │ ├── solver.cpython-39.pyc │ ├── solverr.cpython-39.pyc │ ├── testsolver.cpython-39.pyc │ ├── data_loader.cpython-39.pyc │ └── create_dataset.cpython-39.pyc ├── utils │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── convert.cpython-39.pyc │ │ ├── functions.cpython-39.pyc │ │ └── time_track.cpython-39.pyc │ ├── convert.py │ ├── time_track.py │ └── functions.py ├── test.py ├── train.py ├── data_loader.py ├── config.py ├── testsolver.py ├── create_dataset.py ├── models.py └── solver.py ├── README.md └── requirements.txt /figure/TILOSS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/figure/TILOSS.png -------------------------------------------------------------------------------- /src/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/models.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/solver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/solver.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/solverr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/solverr.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import * 2 | from .time_track import time_desc_decorator 3 | from .functions import * -------------------------------------------------------------------------------- /src/__pycache__/testsolver.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/testsolver.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/data_loader.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/data_loader.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/create_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/__pycache__/create_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/convert.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/utils/__pycache__/convert.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/functions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/utils/__pycache__/functions.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/time_track.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-G-Y/SATI/HEAD/src/utils/__pycache__/time_track.cpython-39.pyc -------------------------------------------------------------------------------- /src/utils/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def to_gpu(x, on_cpu=False, gpu_id=None): 4 | """Tensor => Variable""" 5 | if torch.cuda.is_available() and not on_cpu: 6 | x = x.cuda(gpu_id) 7 | return x 8 | 9 | def to_cpu(x): 10 | """Variable => Tensor""" 11 | if torch.cuda.is_available(): 12 | x = x.cpu() 13 | return x.data -------------------------------------------------------------------------------- /src/utils/time_track.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | 5 | def base_time_desc_decorator(method, desc='test_description'): 6 | def timed(*args, **kwargs): 7 | 8 | # Print Description 9 | # print('#' * 50) 10 | print(desc) 11 | # print('#' * 50 + '\n') 12 | 13 | # Calculation Runtime 14 | start = time.time() 15 | 16 | # Run Method 17 | try: 18 | result = method(*args, **kwargs) 19 | except TypeError: 20 | result = method(**kwargs) 21 | 22 | # Print Runtime 23 | print('Done! It took {:.2} secs\n'.format(time.time() - start)) 24 | 25 | if result is not None: 26 | return result 27 | 28 | return timed 29 | 30 | 31 | def time_desc_decorator(desc): return partial(base_time_desc_decorator, desc=desc) 32 | 33 | 34 | @time_desc_decorator('this is description') 35 | def time_test(arg, kwarg='this is kwarg'): 36 | time.sleep(3) 37 | print('Inside of time_test') 38 | print('printing arg: ', arg) 39 | print('printing kwarg: ', kwarg) 40 | 41 | 42 | @time_desc_decorator('this is second description') 43 | def no_arg_method(): 44 | print('this method has no argument') 45 | 46 | 47 | if __name__ == '__main__': 48 | time_test('hello', kwarg=3) 49 | time_test(3) 50 | no_arg_method() 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic-Guided Multimodal Sentiment Decoding with Adversarial Temporal-Invariant Learning 2 | ----------------------------------------------------------------------------------------------------------------------------------- 3 | Code for Semantic-Guided Multimodal Sentiment Decoding with Adversarial Temporal-Invariant Learning (SATI). 4 | Due to the double-blind review, we will not be providing the checkpoint link at this time. 5 | Our checkpints can be download from [here](https://drive.google.com/drive/folders/11umrB8wphhYgMyBPAU7q5MXQ1yOepd0s?usp=drive_link).' 6 | 7 | ## **Data Download** 8 | - Install [CMU Multimodal SDK](https://github.com/CMU-MultiComp-Lab/CMU-MultimodalSDK). Ensure, you can perform from mmsdk import mmdatasdk. 9 | - Option 1: Download [pre-computed splits](https://drive.google.com/drive/folders/1IBwWNH0XjPnZWaAlP1U2tIJH6Rb3noMI) provided by MOSI and place the contents inside datasets folder. 10 | - Option 2: Re-create splits by downloading data from MMSDK. For this, simply run the code as detailed next. 11 | 12 | ## **Running the code** 13 | - cd src 14 | -- Set word_emb_path in config.py to [glove file](https://drive.google.com/file/d/1dOCXST8Lxj_WgZJmTC1owiF5Vm_I0iEA/view?usp=sharing) provided by MOSI and [roberta](https://drive.google.com/file/d/1KsZGuAP_s68WyU3wOZ2hZ7vcf7HB0zj3/view?usp=drive_link) path 15 | - Set sdk_dir to the path of CMU-MultimodalSDK. 16 | - python train.py --data mosi. Replace mosi with mosei or ur_funny for other datasets. 17 | 18 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from random import random 5 | 6 | from config import get_config, activation_dict 7 | from data_loader import get_loader 8 | from testsolver import Solver 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | # Setting random seed 18 | random_name = str(random()) 19 | random_seed = 336 20 | torch.manual_seed(random_seed) 21 | torch.cuda.manual_seed_all(random_seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | np.random.seed(random_seed) 25 | 26 | # Setting the config for each stage 27 | train_config = get_config(mode='train') 28 | dev_config = get_config(mode='dev') 29 | test_config = get_config(mode='test') 30 | 31 | #print(train_config) 32 | 33 | # Creating pytorch dataloaders 34 | train_data_loader = get_loader(train_config, shuffle = True) 35 | dev_data_loader = get_loader(dev_config, shuffle = False) 36 | test_data_loader = get_loader(test_config, shuffle = False) 37 | 38 | 39 | folder = "/home/s22xjq/SATI/src/checkpoints_无敌!" 40 | solver = Solver 41 | solver = solver(train_config, dev_config, test_config, train_data_loader, dev_data_loader, test_data_loader, is_train=False) 42 | 43 | # Build the model 44 | solver.build() 45 | 46 | solver.eval(folder, mode="test", to_print=True) 47 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from random import random 5 | 6 | from config import get_config, activation_dict 7 | from data_loader import get_loader 8 | from solver import Solver 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | # Setting random seed 18 | random_name = str(random()) 19 | random_seed = 336 20 | torch.manual_seed(random_seed) 21 | torch.cuda.manual_seed_all(random_seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = False 24 | np.random.seed(random_seed) 25 | 26 | # Setting the config for each stage 27 | train_config = get_config(mode='train') 28 | dev_config = get_config(mode='dev') 29 | test_config = get_config(mode='test') 30 | 31 | print(train_config) 32 | 33 | # Creating pytorch dataloaders 34 | train_data_loader = get_loader(train_config, shuffle = True) 35 | dev_data_loader = get_loader(dev_config, shuffle = False) 36 | test_data_loader = get_loader(test_config, shuffle = False) 37 | 38 | 39 | 40 | 41 | #train_config.learning_rate = 32.0e-6+i*1e-6 42 | print("train_config.learning_rate", train_config.learning_rate) 43 | solver = Solver 44 | solver = solver(train_config, dev_config, test_config, train_data_loader, dev_data_loader, test_data_loader, is_train=True) 45 | # Build the modelq 46 | solver.build() 47 | # Train the model (test scores will be returned based on dev performance) 48 | solver.train() 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | anyio==4.4.0 3 | argon2-cffi==23.1.0 4 | argon2-cffi-bindings==21.2.0 5 | arrow==1.3.0 6 | asttokens==2.4.1 7 | async-lru==2.0.4 8 | attrs==23.2.0 9 | Babel==2.15.0 10 | beautifulsoup4==4.12.3 11 | bleach==6.1.0 12 | certifi==2024.7.4 13 | cffi==1.16.0 14 | charset-normalizer==3.3.2 15 | colorama==0.4.6 16 | comm==0.2.2 17 | contourpy==1.2.1 18 | cycler==0.12.1 19 | debugpy==1.8.2 20 | decorator==5.1.1 21 | defusedxml==0.7.1 22 | exceptiongroup==1.2.2 23 | executing==2.0.1 24 | fastjsonschema==2.20.0 25 | filelock==3.15.4 26 | fonttools==4.53.1 27 | fqdn==1.5.1 28 | fsspec==2024.6.1 29 | gensim==4.3.2 30 | graphviz==0.20.3 31 | grpcio==1.65.1 32 | h11==0.14.0 33 | h5py==3.11.0 34 | httpcore==1.0.5 35 | httpx==0.27.0 36 | huggingface-hub==0.23.4 37 | idna==3.7 38 | importlib_metadata==8.0.0 39 | importlib_resources==6.4.0 40 | ipykernel==6.29.5 41 | ipython==8.18.1 42 | ipywidgets==8.1.3 43 | isoduration==20.11.0 44 | jedi==0.19.1 45 | Jinja2==3.1.4 46 | joblib==1.4.2 47 | json5==0.9.25 48 | jsonpointer==3.0.0 49 | jsonschema==4.23.0 50 | jsonschema-specifications==2023.12.1 51 | jupyter==1.0.0 52 | jupyter-console==6.6.3 53 | jupyter-events==0.10.0 54 | jupyter-lsp==2.2.5 55 | jupyter_client==8.6.2 56 | jupyter_core==5.7.2 57 | jupyter_server==2.14.2 58 | jupyter_server_terminals==0.5.3 59 | jupyterlab==4.2.3 60 | jupyterlab_pygments==0.3.0 61 | jupyterlab_server==2.27.2 62 | jupyterlab_widgets==3.0.11 63 | kiwisolver==1.4.5 64 | Markdown==3.6 65 | MarkupSafe==2.1.5 66 | matplotlib==3.9.2 67 | matplotlib-inline==0.1.7 68 | mistune==3.0.2 69 | mmsdk @ file:///home/xgy/CMU-MultimodalSDK 70 | mpmath==1.3.0 71 | nbclient==0.10.0 72 | nbconvert==7.16.4 73 | nbformat==5.10.4 74 | nest-asyncio==1.6.0 75 | networkx==3.2.1 76 | notebook==7.2.1 77 | notebook_shim==0.2.4 78 | numpy==1.24.4 79 | nvidia-cublas-cu12==12.1.3.1 80 | nvidia-cuda-cupti-cu12==12.1.105 81 | nvidia-cuda-nvrtc-cu12==12.1.105 82 | nvidia-cuda-runtime-cu12==12.1.105 83 | nvidia-cudnn-cu12==8.9.2.26 84 | nvidia-cufft-cu12==11.0.2.54 85 | nvidia-curand-cu12==10.3.2.106 86 | nvidia-cusolver-cu12==11.4.5.107 87 | nvidia-cusparse-cu12==12.1.0.106 88 | nvidia-nccl-cu12==2.20.5 89 | nvidia-nvjitlink-cu12==12.5.82 90 | nvidia-nvtx-cu12==12.1.105 91 | overrides==7.7.0 92 | packaging==24.1 93 | pandocfilters==1.5.1 94 | parso==0.8.4 95 | pexpect==4.9.0 96 | pillow==10.4.0 97 | platformdirs==4.2.2 98 | prometheus_client==0.20.0 99 | prompt_toolkit==3.0.47 100 | protobuf==4.25.4 101 | psutil==6.0.0 102 | ptyprocess==0.7.0 103 | pure-eval==0.2.2 104 | pycparser==2.22 105 | Pygments==2.18.0 106 | pyparsing==3.1.2 107 | python-dateutil==2.9.0.post0 108 | python-json-logger==2.0.7 109 | PyYAML==6.0.1 110 | pyzmq==26.0.3 111 | qtconsole==5.5.2 112 | QtPy==2.4.1 113 | referencing==0.35.1 114 | regex==2024.5.15 115 | requests==2.32.3 116 | rfc3339-validator==0.1.4 117 | rfc3986-validator==0.1.1 118 | rpds-py==0.19.0 119 | safetensors==0.4.3 120 | scikit-learn==1.5.1 121 | scipy==1.9.0 122 | Send2Trash==1.8.3 123 | six==1.16.0 124 | smart-open==7.0.4 125 | sniffio==1.3.1 126 | soupsieve==2.5 127 | stack-data==0.6.3 128 | sympy==1.13.0 129 | tensorboard==2.17.0 130 | tensorboard-data-server==0.7.2 131 | terminado==0.18.1 132 | threadpoolctl==3.5.0 133 | tinycss2==1.3.0 134 | tokenizers==0.19.1 135 | tomli==2.0.1 136 | torch==2.3.1 137 | torchviz==0.0.2 138 | tornado==6.4.1 139 | tqdm==4.66.4 140 | traitlets==5.14.3 141 | transformers==4.42.4 142 | triton==2.3.1 143 | types-python-dateutil==2.9.0.20240316 144 | typing_extensions==4.12.2 145 | uri-template==1.3.0 146 | urllib3==2.2.2 147 | validators==0.32.0 148 | wcwidth==0.2.13 149 | webcolors==24.6.0 150 | webencodings==0.5.1 151 | websocket-client==1.8.0 152 | Werkzeug==3.0.3 153 | widgetsnbextension==4.0.11 154 | wrapt==1.16.0 155 | zipp==3.19.2 156 | -------------------------------------------------------------------------------- /src/utils/functions.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | import torch.nn as nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Adapted from https://github.com/fungtion/DSN/blob/master/functions.py 8 | """ 9 | 10 | class ReverseLayerF(Function): 11 | 12 | @staticmethod 13 | def forward(ctx, x, p): 14 | ctx.p = p 15 | 16 | return x.view_as(x) 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | output = grad_output.neg() * ctx.p 21 | 22 | return output, None 23 | 24 | 25 | class MSE(nn.Module): 26 | def __init__(self): 27 | super(MSE, self).__init__() 28 | 29 | def forward(self, pred, real): 30 | diffs = torch.add(real, -pred) 31 | n = torch.numel(diffs.data) 32 | mse = torch.sum(diffs.pow(2)) / n 33 | 34 | return mse 35 | 36 | 37 | class SIMSE(nn.Module): 38 | 39 | def __init__(self): 40 | super(SIMSE, self).__init__() 41 | 42 | def forward(self, pred, real): 43 | diffs = torch.add(real, - pred) 44 | n = torch.numel(diffs.data) 45 | simse = torch.sum(diffs).pow(2) / (n ** 2) 46 | 47 | return simse 48 | 49 | import torch.nn as nn 50 | import torch.nn.functional as F 51 | class FocalLoss(nn.Module): 52 | def __init__(self, alpha=0.25, gamma=2.0): 53 | super(FocalLoss, self).__init__() 54 | self.alpha = alpha 55 | self.gamma = gamma 56 | def forward(self, inputs, targets): 57 | y_tilde_binary = torch.sign(inputs) 58 | y_tilde_binary[y_tilde_binary == -1] = 0 # 将 -1 映射为 0 59 | #targets = torch.sign(targets) 60 | #targets[targets == -1] = 0 # 将 -1 映射为 0 61 | y_binary = torch.sign(targets) 62 | y_binary[y_binary == -1] = 0 # 将 -1 映射为 0 63 | BCE_loss = F.binary_cross_entropy_with_logits(y_tilde_binary, y_binary, reduction='none') 64 | pt = torch.exp(-BCE_loss) # prevents nans when probability 0 65 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 66 | return F_loss.mean() 67 | 68 | class WeightedMSELoss(nn.Module): 69 | def __init__(self, weights): 70 | super(WeightedMSELoss, self).__init__() 71 | self.weights = weights 72 | 73 | def forward(self, inputs, targets): 74 | # 根据目标值类别选择权重 75 | weights = torch.where(targets > 0, self.weights[1], self.weights[0]) 76 | loss = weights * (inputs - targets) ** 2 77 | return loss.mean() 78 | 79 | class DiffLoss(nn.Module): 80 | 81 | def __init__(self): 82 | super(DiffLoss, self).__init__() 83 | 84 | def forward(self, input1, input2): 85 | 86 | batch_size = input1.size(0) 87 | input1 = input1.view(batch_size, -1) 88 | input2 = input2.view(batch_size, -1) 89 | 90 | # Zero mean 91 | input1_mean = torch.mean(input1, dim=0, keepdims=True) 92 | input2_mean = torch.mean(input2, dim=0, keepdims=True) 93 | input1 = input1 - input1_mean 94 | input2 = input2 - input2_mean 95 | 96 | input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach() 97 | input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6) 98 | 99 | 100 | input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach() 101 | input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6) 102 | 103 | diff_loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2)) 104 | 105 | return diff_loss 106 | 107 | class CMD(nn.Module): 108 | """ 109 | Adapted from https://github.com/wzell/cmd/blob/master/models/domain_regularizer.py 110 | """ 111 | 112 | def __init__(self): 113 | super(CMD, self).__init__() 114 | 115 | def forward(self, x1, x2, n_moments): 116 | mx1 = torch.mean(x1, 0) 117 | mx2 = torch.mean(x2, 0) 118 | sx1 = x1-mx1 119 | sx2 = x2-mx2 120 | dm = self.matchnorm(mx1, mx2) 121 | scms = dm 122 | for i in range(n_moments - 1): 123 | scms += self.scm(sx1, sx2, i + 2) 124 | return scms 125 | 126 | def matchnorm(self, x1, x2): 127 | power = torch.pow(x1-x2,2) 128 | summed = torch.sum(power) 129 | sqrt = summed**(0.5) 130 | return sqrt 131 | # return ((x1-x2)**2).sum().sqrt() 132 | 133 | def scm(self, sx1, sx2, k): 134 | ss1 = torch.mean(torch.pow(sx1, k), 0) 135 | ss2 = torch.mean(torch.pow(sx2, k), 0) 136 | return self.matchnorm(ss1, ss2) 137 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm_notebook 6 | from collections import defaultdict 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 11 | from torch.utils.data import DataLoader, Dataset 12 | from transformers import * 13 | 14 | 15 | 16 | from create_dataset import MOSI, MOSEI, UR_FUNNY, PAD, UNK 17 | 18 | import warnings 19 | warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.tokenization_utils_base") 20 | 21 | 22 | vocab_file = '/home/s22xjq/SATI/model/vocab.json' 23 | merges_file = '/home/s22xjq/SATI/model/merges.txt' 24 | roberta_tokenizer = RobertaTokenizer(vocab_file, merges_file) 25 | 26 | 27 | class MSADataset(Dataset): 28 | def __init__(self, config): 29 | 30 | ## Fetch dataset 31 | if "mosi" in str(config.data_dir).lower(): 32 | dataset = MOSI(config) 33 | elif "mosei" in str(config.data_dir).lower(): 34 | dataset = MOSEI(config) 35 | elif "ur_funny" in str(config.data_dir).lower(): 36 | dataset = UR_FUNNY(config) 37 | else: 38 | print("Dataset not defined correctly") 39 | exit() 40 | 41 | self.data, self.word2id, self.pretrained_emb = dataset.get_data(config.mode) 42 | #print(self.data) 43 | self.len = len(self.data) 44 | 45 | config.visual_size = self.data[0][0][1].shape[1] 46 | config.acoustic_size = self.data[0][0][2].shape[1] 47 | 48 | config.word2id = self.word2id 49 | config.pretrained_emb = self.pretrained_emb 50 | 51 | 52 | def __getitem__(self, index): 53 | return self.data[index] 54 | 55 | def __len__(self): 56 | return self.len 57 | 58 | 59 | 60 | def get_loader(config, shuffle=True): 61 | """Load DataLoader of given DialogDataset""" 62 | 63 | dataset = MSADataset(config) 64 | 65 | #print(config.mode) 66 | config.data_len = len(dataset) 67 | 68 | 69 | def collate_fn(batch): 70 | ''' 71 | Collate functions assume batch = [Dataset[i] for i in index_set] 72 | ''' 73 | # for later use we sort the batch in descending order of length 74 | batch = sorted(batch, key=lambda x: x[0][0].shape[0], reverse=True) 75 | 76 | # get the data out of the batch - use pad sequence util functions from PyTorch to pad things 77 | 78 | 79 | labels = torch.cat([torch.from_numpy(sample[1]) for sample in batch], dim=0) 80 | sentences = pad_sequence([torch.LongTensor(sample[0][0]) for sample in batch], padding_value=PAD)#[44, 64] 81 | #print(sentences.shape) 82 | visual = pad_sequence([torch.FloatTensor(sample[0][1]) for sample in batch]) #[44, 64, 47] 83 | #print(visual) 84 | acoustic = pad_sequence([torch.FloatTensor(sample[0][2]) for sample in batch]) #[44, 64, 47] 85 | #print(sentences.shape, visual.shape) 86 | 87 | ## BERT-based features input prep 88 | 89 | SENT_LEN = sentences.size(0) 90 | # Create bert indices using tokenizer 91 | roberta_details = [] 92 | for sample in batch: 93 | text = " ".join(sample[0][3]) # 将文本合并成一个字符串 94 | encoded_roberta_sent = roberta_tokenizer.encode_plus( 95 | text, 96 | max_length=SENT_LEN, # RoBERTa不需要+2,因为它只需要[CLS]和[SEP]标记 97 | add_special_tokens=True, 98 | padding='max_length', # 填充到最大长度 99 | truncation=True, # 截断超过最大长度的部分 100 | return_tensors='pt' # 返回PyTorch tensors(如果你使用的是PyTorch) 101 | ) 102 | roberta_details.append(encoded_roberta_sent) 103 | 104 | bert_sentences = torch.LongTensor([sample["input_ids"].squeeze(0).tolist() for sample in roberta_details]) 105 | bert_sentence_att_mask = torch.LongTensor([sample["attention_mask"].squeeze(0).tolist() for sample in roberta_details]) 106 | bert_sentence_types = torch.randn(0) 107 | # lengths are useful later in using RNNs 108 | lengths = torch.LongTensor([sample[0][0].shape[0] for sample in batch]) 109 | 110 | return sentences, visual, acoustic, labels, lengths, bert_sentences, bert_sentence_types, bert_sentence_att_mask 111 | 112 | 113 | data_loader = DataLoader( 114 | dataset=dataset, 115 | batch_size=config.batch_size, 116 | shuffle=shuffle, 117 | collate_fn=collate_fn) 118 | 119 | return data_loader 120 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | from collections import defaultdict 5 | from datetime import datetime 6 | from pathlib import Path 7 | import pprint 8 | from torch import optim 9 | import torch.nn as nn 10 | 11 | # path to a pretrained word embedding file 12 | word_emb_path = '/home/s22xjq/SATI/glove.840B.d.txt' 13 | assert(word_emb_path is not None) 14 | 15 | 16 | username = Path.home().name 17 | project_dir = Path(__file__).resolve().parent.parent 18 | sdk_dir = project_dir.joinpath('/home/s22xjq/CMU-MultimodalSDK') 19 | data_dir = project_dir.joinpath('datasets') 20 | data_dict = {'mosi': data_dir.joinpath('MOSI'), 'mosei': data_dir.joinpath( 21 | 'MOSEI'), 'ur_funny': data_dir.joinpath('UR_FUNNY')} 22 | optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam, 'NAdam':optim.NAdam} 23 | activation_dict = {'elu': nn.ELU, "hardshrink": nn.Hardshrink, "hardtanh": nn.Hardtanh, 24 | "leakyrelu": nn.LeakyReLU, "prelu": nn.PReLU, "relu": nn.ReLU, "rrelu": nn.RReLU, 25 | "tanh": nn.Tanh, "sigmoid" :nn.Sigmoid} 26 | 27 | 28 | def str2bool(v): 29 | """string to boolean""" 30 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 31 | return True 32 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 33 | return False 34 | else: 35 | raise argparse.ArgumentTypeError('Boolean value expected.') 36 | 37 | 38 | class Config(object): 39 | def __init__(self, **kwargs): 40 | """Configuration Class: set kwargs as class attributes with setattr""" 41 | if kwargs is not None: 42 | for key, value in kwargs.items(): 43 | if key == 'optimizer': 44 | value = optimizer_dict[value] 45 | if key == 'activation': 46 | value = activation_dict[value] 47 | setattr(self, key, value) 48 | 49 | # Dataset directory: ex) ./datasets/cornell/ 50 | self.dataset_dir = data_dict[self.data.lower()] 51 | self.sdk_dir = sdk_dir 52 | # Glove path 53 | self.word_emb_path = word_emb_path 54 | 55 | # Data Split ex) 'train', 'valid', 'test' 56 | # self.data_dir = self.dataset_dir.joinpath(self.mode) 57 | self.data_dir = self.dataset_dir 58 | 59 | def __str__(self): 60 | """Pretty-print configurations in alphabetical order""" 61 | config_str = 'Configurations\n' 62 | config_str += pprint.pformat(self.__dict__) 63 | return config_str 64 | 65 | 66 | def get_config(parse=True, **optional_kwargs): 67 | """ 68 | Get configurations as attributes of class 69 | 1. Parse configurations with argparse. 70 | 2. Create Config class initilized with parsed kwargs. 71 | 3. Return Config class. 72 | """ 73 | parser = argparse.ArgumentParser() 74 | 75 | # Mode 76 | parser.add_argument('--mode', type=str, default='train') 77 | parser.add_argument('--runs', type=int, default=5) 78 | 79 | # Bert 80 | parser.add_argument('--use_cmd_sim', type=str2bool, default=True) 81 | parser.add_argument('--use_domain', type=str2bool, default=True) 82 | # Train 83 | time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') 84 | #parser.add_argument('--name', type=str, default=f"{time_now}") 85 | parser.add_argument('--name', type=str, default=f"bestF1") 86 | parser.add_argument('--num_classes', type=int, default=0) 87 | parser.add_argument('--batch_size', type=int, default=128) 88 | 89 | parser.add_argument('--eval_batch_size', type=int, default=10) 90 | parser.add_argument('--n_epoch', type=int, default=50) 91 | #try: use early stop in mosei 92 | parser.add_argument('--patience', type=int, default=5) 93 | parser.add_argument('--start_saving', type=float, default=0.4) 94 | 95 | parser.add_argument('--diff_weight', type=float, default=0.4) 96 | parser.add_argument('--sim_weight', type=float, default=1.0) 97 | parser.add_argument('--sp_weight', type=float, default=0.0) 98 | parser.add_argument('--recon_weight', type=float, default=1.0) 99 | parser.add_argument('--jsd_weight', type=float, default=1) 100 | 101 | 102 | 103 | parser.add_argument('--learning_rate', type=float, default=1.7e-05) 104 | parser.add_argument('--optimizer', type=str, default='Adam') 105 | parser.add_argument('--clip', type=float, default=1.0) 106 | 107 | 108 | parser.add_argument('--rnncell', type=str, default='lstm') 109 | parser.add_argument('--embedding_size', type=int, default=300) 110 | parser.add_argument('--hidden_size', type=int, default=128) 111 | parser.add_argument('--dropout', type=float, default=0.6) 112 | parser.add_argument('--reverse_grad_weight', type=float, default=1.0) 113 | # Selectin activation from 'elu', "hardshrink", "hardtanh", "leakyrelu", "prelu", "relu", "rrelu", "tanh" 114 | parser.add_argument('--activation', type=str, default='relu') 115 | 116 | 117 | #my 118 | parser.add_argument('--queries', type=int, default='64') 119 | #layers of self-attention 120 | parser.add_argument('--layers', type=int, default=2) 121 | 122 | # Model 123 | parser.add_argument('--model', type=str, 124 | default='SATI', help='one of {SATI, }') 125 | 126 | # Data 127 | parser.add_argument('--data', type=str, default='mosi') 128 | 129 | # Parse arguments 130 | if parse: 131 | kwargs = parser.parse_args() 132 | else: 133 | kwargs = parser.parse_known_args()[0] 134 | 135 | print(kwargs.data) 136 | if kwargs.data == "mosi": 137 | kwargs.num_classes = 1 138 | kwargs.batch_size = 64 139 | kwargs.use_domain = True 140 | kwargs.start_saving = 0.4 141 | kwargs.patience = 100 142 | elif kwargs.data == "mosei": 143 | kwargs.num_classes = 1 144 | kwargs.batch_size = 16 145 | kwargs.use_domain = True 146 | kwargs.start_saving = 0.0 147 | kwargs.patience = 5 148 | elif kwargs.data == "ur_funny": 149 | kwargs.num_classes = 2 150 | kwargs.batch_size = 32 151 | else: 152 | print("No dataset mentioned") 153 | exit() 154 | 155 | # Namespace => Dictionary 156 | kwargs = vars(kwargs) 157 | kwargs.update(optional_kwargs) 158 | 159 | return Config(**kwargs) 160 | -------------------------------------------------------------------------------- /src/testsolver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from math import isnan 4 | import re 5 | import pickle 6 | import gensim 7 | import numpy as np 8 | from tqdm import tqdm 9 | from tqdm import tqdm_notebook 10 | from sklearn.metrics import classification_report, accuracy_score, f1_score 11 | from sklearn.metrics import confusion_matrix 12 | from sklearn.metrics import precision_recall_fscore_support 13 | from scipy.special import expit 14 | from torchviz import make_dot 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn import functional as F 18 | torch.manual_seed(123) 19 | torch.cuda.manual_seed_all(123) 20 | 21 | from utils import to_gpu, time_desc_decorator, DiffLoss, MSE, SIMSE, CMD 22 | import models 23 | 24 | 25 | class Solver(object): 26 | def __init__(self, train_config, dev_config, test_config, train_data_loader, dev_data_loader, test_data_loader, is_train=True, model=None): 27 | self.train_accuracies = [] 28 | self.valid_accuracies = [] 29 | self.test_accuracies = [] 30 | self.train_losses = [] 31 | self.valid_losses = [] 32 | self.test_losses = [] 33 | self.train_maes = [] 34 | self.valid_maes = [] 35 | self.test_maes = [] 36 | self.train_f1_scores = [] 37 | self.valid_f1_scores = [] 38 | self.test_f1_scores = [] 39 | 40 | 41 | self.train_config = train_config 42 | self.epoch_i = 0 43 | self.train_data_loader = train_data_loader 44 | self.dev_data_loader = dev_data_loader 45 | self.test_data_loader = test_data_loader 46 | self.is_train = is_train 47 | self.model = model 48 | self.criterion = criterion = nn.L1Loss(reduction="mean") 49 | @time_desc_decorator('Build Graph') 50 | def build(self, cuda=True): 51 | 52 | if self.model is None: 53 | self.model = getattr(models, self.train_config.model)(self.train_config) 54 | 55 | # Final list 56 | for name, param in self.model.named_parameters(): 57 | 58 | # Bert freezing customizations 59 | if self.train_config.data == "mosei": 60 | if "bertmodel.encoder.layer" in name: 61 | layer_num = int(name.split("encoder.layer.")[-1].split(".")[0]) 62 | if layer_num <= (8): 63 | param.requires_grad = False 64 | elif self.train_config.data == "ur_funny": 65 | if "bert" in name: 66 | param.requires_grad = False 67 | 68 | if 'weight_hh' in name: 69 | nn.init.orthogonal_(param) 70 | #print('\t' + name, param.requires_grad) 71 | 72 | # Initialize weight of Embedding matrix with Glove embeddings 73 | 74 | if torch.cuda.is_available() and cuda: 75 | self.model.cuda() 76 | 77 | if self.is_train: 78 | self.optimizer = self.train_config.optimizer( 79 | filter(lambda p: p.requires_grad, self.model.parameters()), 80 | lr=self.train_config.learning_rate) 81 | 82 | 83 | def eval( self,folder, mode=None, to_print=False): 84 | assert(mode is not None) 85 | self.model.eval() 86 | 87 | y_true, y_pred = [], [] 88 | eval_loss, eval_loss_diff = [], [] 89 | 90 | if mode == "dev": 91 | dataloader = self.dev_data_loader 92 | elif mode == "test": 93 | dataloader = self.test_data_loader 94 | 95 | if to_print: 96 | self.model.load_state_dict(torch.load(folder+'/model_best_acc.std')) 97 | optimizer = self.train_config.optimizer( 98 | filter(lambda p: p.requires_grad, self.model.parameters()), 99 | lr=self.train_config.learning_rate) 100 | 101 | 102 | 103 | with torch.no_grad(): 104 | for batch in dataloader: 105 | self.model.zero_grad() 106 | t, v, a, y, l, bert_sent, bert_sent_type, bert_sent_mask = batch 107 | mean = 0.0 108 | std = 0.5 109 | 110 | gaussian_noise = torch.normal(mean=mean, std=std, size=v.shape) 111 | #v = v + 2.00*gaussian_noise 112 | gaussian_noise = torch.normal(mean=mean, std=std, size=t.shape) 113 | #t = t + 2.00*gaussian_noise 114 | gaussian_noise = torch.normal(mean=mean, std=std, size=a.shape) 115 | #a = a + 2.00*gaussian_noise 116 | t = to_gpu(t) 117 | v = to_gpu(v) 118 | a = to_gpu(a) 119 | y = to_gpu(y) 120 | l = to_gpu(l) 121 | bert_sent = to_gpu(bert_sent) 122 | bert_sent_type = to_gpu(bert_sent_type) 123 | bert_sent_mask = to_gpu(bert_sent_mask) 124 | y_tilde = self.model(t, v, a, l, bert_sent, bert_sent_type, bert_sent_mask) 125 | 126 | if self.train_config.data == "ur_funny": 127 | y = y.squeeze() 128 | 129 | cls_loss = self.criterion(y_tilde, y) 130 | loss = cls_loss 131 | 132 | eval_loss.append(loss.item()) 133 | y_pred.append(y_tilde.detach().cpu().numpy()) 134 | y_true.append(y.detach().cpu().numpy()) 135 | 136 | eval_loss = np.mean(eval_loss) 137 | y_true = np.concatenate(y_true, axis=0).squeeze() 138 | y_pred = np.concatenate(y_pred, axis=0).squeeze() 139 | 140 | accuracy = self.calc_metrics(y_true, y_pred, mode, to_print) 141 | mae = np.mean(np.abs(y_pred - y_true)) 142 | f1 = f1_score((y_pred >= 0), (y_true >=0), average='weighted') 143 | 144 | if mode == "dev": 145 | self.valid_losses.append(eval_loss) 146 | self.valid_accuracies.append(accuracy) 147 | self.valid_maes.append(mae) 148 | self.valid_f1_scores.append(f1) 149 | elif mode == "test": 150 | self.test_losses.append(eval_loss) 151 | self.test_accuracies.append(accuracy) 152 | self.test_maes.append(mae) 153 | self.test_f1_scores.append(f1) 154 | if to_print: 155 | print(f"Eval {mode} loss: {round(eval_loss, 4)}, Accuracy: {round(accuracy, 4)}, MAE: {round(mae, 4)}, F1-score: {round(f1, 4)}") 156 | 157 | 158 | return eval_loss, accuracy, mae, f1 159 | 160 | def multiclass_acc(self, preds, truths): 161 | """ 162 | Compute the multiclass accuracy w.r.t. groundtruth 163 | :param preds: Float array representing the predictions, dimension (N,) 164 | :param truths: Float/int array representing the groundtruth classes, dimension (N,) 165 | :return: Classification accuracy 166 | """ 167 | return np.sum(np.round(preds) == np.round(truths)) / float(len(truths)) 168 | 169 | def calc_metrics(self, y_true, y_pred, mode=None, to_print=False): 170 | """ 171 | Metric scheme adapted from: 172 | https://github.com/yaohungt/Multimodal-Transformer/blob/master/src/eval_metrics.py 173 | """ 174 | 175 | 176 | if self.train_config.data == "ur_funny": 177 | test_preds = np.argmax(y_pred, 1) 178 | test_truth = y_true 179 | 180 | if to_print: 181 | print("Confusion Matrix (pos/neg) :") 182 | print(confusion_matrix(test_truth, test_preds)) 183 | print("Classification Report (pos/neg) :") 184 | print(classification_report(test_truth, test_preds, digits=5)) 185 | print("Accuracy (pos/neg) ", accuracy_score(test_truth, test_preds)) 186 | 187 | return accuracy_score(test_truth, test_preds) 188 | 189 | else: 190 | test_preds = y_pred 191 | test_truth = y_true 192 | 193 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0]) 194 | 195 | test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.) 196 | test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.) 197 | test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.) 198 | test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.) 199 | 200 | mae = np.mean(np.absolute(test_preds - test_truth)) # Average L1 distance between preds and truths 201 | corr = np.corrcoef(test_preds, test_truth)[0][1] 202 | mult_a7 = self.multiclass_acc(test_preds_a7, test_truth_a7) 203 | mult_a5 = self.multiclass_acc(test_preds_a5, test_truth_a5) 204 | 205 | pos_neg_f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted') 206 | 207 | # pos - neg 208 | binary_truth = (test_truth[non_zeros] > 0) 209 | binary_preds = (test_preds[non_zeros] > 0) 210 | 211 | if to_print: 212 | print("mae: ", mae) 213 | print("corr: ", corr) 214 | print("mult_acc: ", mult_a7) 215 | print("Classification Report (pos/neg) :") 216 | print(classification_report(binary_truth, binary_preds, digits=5)) 217 | print("Accuracy (pos/neg) ", accuracy_score(binary_truth, binary_preds)) 218 | print("F1 (pos/neg) ", pos_neg_f_score) 219 | 220 | # non-neg - neg 221 | binary_truth = (test_truth >= 0) 222 | binary_preds = (test_preds >= 0) 223 | non_neg_f_score = f1_score(binary_truth, binary_preds, average='weighted') 224 | if to_print: 225 | print("Classification Report (non-neg/neg) :") 226 | print(classification_report(binary_truth, binary_preds, digits=5)) 227 | print("Accuracy (non-neg/neg) ", accuracy_score(binary_truth, binary_preds)) 228 | 229 | print("F1 (non-neg/neg) ", non_neg_f_score) 230 | 231 | return accuracy_score(binary_truth, binary_preds) 232 | 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /src/create_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import mmsdk 3 | import os 4 | import re 5 | import pickle 6 | import numpy as np 7 | from tqdm.auto import tqdm 8 | from collections import defaultdict 9 | from mmsdk import mmdatasdk as md 10 | from subprocess import check_call, CalledProcessError 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | 16 | def to_pickle(obj, path): 17 | with open(path, 'wb') as f: 18 | pickle.dump(obj, f) 19 | def load_pickle(path): 20 | with open(path, 'rb') as f: 21 | return pickle.load(f) 22 | 23 | 24 | # construct a word2id mapping that automatically takes increment when new words are encountered 25 | word2id = defaultdict(lambda: len(word2id)) 26 | UNK = word2id[''] 27 | PAD = word2id[''] 28 | 29 | 30 | # turn off the word2id - define a named function here to allow for pickling 31 | def return_unk(): 32 | return UNK 33 | 34 | 35 | def load_emb(w2i, path_to_embedding, embedding_size=300, embedding_vocab=2196017, init_emb=None): 36 | if init_emb is None: 37 | emb_mat = np.random.randn(len(w2i), embedding_size) 38 | else: 39 | emb_mat = init_emb 40 | f = open(path_to_embedding, 'r') 41 | found = 0 42 | for line in tqdm(f, total=embedding_vocab): 43 | content = line.strip().split() 44 | vector = np.asarray(list(map(lambda x: float(x), content[-300:]))) 45 | word = ' '.join(content[:-300]) 46 | if word in w2i: 47 | idx = w2i[word] 48 | emb_mat[idx, :] = vector 49 | found += 1 50 | print(f"Found {found} words in the embedding file.") 51 | return torch.tensor(emb_mat).float() 52 | 53 | 54 | 55 | 56 | 57 | class MOSI: 58 | def __init__(self, config): 59 | 60 | if config.sdk_dir is None: 61 | print("SDK path is not specified! Please specify first in constants/paths.py") 62 | exit(0) 63 | else: 64 | sys.path.append(str(config.sdk_dir)) 65 | 66 | DATA_PATH = str(config.dataset_dir) 67 | CACHE_PATH = DATA_PATH + '/embedding_and_mapping.pt' 68 | 69 | # If cached data if already exists 70 | try: 71 | self.train = load_pickle(DATA_PATH + '/train.pkl') 72 | self.dev = load_pickle(DATA_PATH + '/dev.pkl') 73 | self.test = load_pickle(DATA_PATH + '/test.pkl') 74 | self.pretrained_emb, self.word2id = torch.load(CACHE_PATH) 75 | 76 | except: 77 | 78 | # create folders for storing the data 79 | if not os.path.exists(DATA_PATH): 80 | check_call(' '.join(['mkdir', '-p', DATA_PATH]), shell=True) 81 | 82 | 83 | # download highlevel features, low-level (raw) data and labels for the dataset MOSI 84 | # if the files are already present, instead of downloading it you just load it yourself. 85 | # here we use CMU_MOSI dataset as example. 86 | DATASET = md.cmu_mosi 87 | try: 88 | md.mmdataset(DATASET.highlevel, DATA_PATH) 89 | except RuntimeError: 90 | print("High-level features have been downloaded previously.") 91 | 92 | try: 93 | md.mmdataset(DATASET.raw, DATA_PATH) 94 | except RuntimeError: 95 | print("Raw data have been downloaded previously.") 96 | 97 | try: 98 | md.mmdataset(DATASET.labels, DATA_PATH) 99 | except RuntimeError: 100 | print("Labels have been downloaded previously.") 101 | 102 | # define your different modalities - refer to the filenames of the CSD files 103 | #visual_field = 'CMU_MOSI_VisualFacet_4.1' 104 | #acoustic_field = 'CMU_MOSI_COVAREP' 105 | #text_field = 'CMU_MOSI_TimestampedWords' 106 | visual_field = 'CMU_MOSI_Visual_OpenFace_1' 107 | acoustic_field = 'CMU_MOSI_openSMILE_IS09' 108 | text_field = 'CMU_MOSI_TimestampedWords' 109 | 110 | features = [ 111 | text_field, 112 | visual_field, 113 | acoustic_field 114 | ] 115 | 116 | recipe = {feat: os.path.join(DATA_PATH, feat) + '.csd' for feat in features} 117 | print(recipe) 118 | dataset = md.mmdataset(recipe) 119 | 120 | # we define a simple averaging function that does not depend on intervals 121 | def avg(intervals: np.array, features: np.array) -> np.array: 122 | try: 123 | return np.average(features, axis=0) 124 | except: 125 | return features 126 | 127 | # first we align to words with averaging, collapse_function receives a list of functions 128 | dataset.align(text_field, collapse_functions=[avg]) 129 | 130 | label_field = 'CMU_MOSI_Opinion_Labels' 131 | 132 | # we add and align to lables to obtain labeled segments 133 | # this time we don't apply collapse functions so that the temporal sequences are preserved 134 | label_recipe = {label_field: os.path.join(DATA_PATH, label_field + '.csd')} 135 | dataset.add_computational_sequences(label_recipe, destination=None) 136 | dataset.align(label_field) 137 | 138 | # obtain the train/dev/test splits - these splits are based on video IDs 139 | train_split = DATASET.standard_folds.standard_train_fold 140 | dev_split = DATASET.standard_folds.standard_valid_fold 141 | test_split = DATASET.standard_folds.standard_test_fold 142 | 143 | 144 | # a sentinel epsilon for safe division, without it we will replace illegal values with a constant 145 | EPS = 1e-6 146 | 147 | 148 | 149 | # place holders for the final train/dev/test dataset 150 | self.train = train = [] 151 | self.dev = dev = [] 152 | self.test = test = [] 153 | self.word2id = word2id 154 | 155 | # define a regular expression to extract the video ID out of the keys 156 | pattern = re.compile('(.*)\[.*\]') 157 | num_drop = 0 # a counter to count how many data points went into some processing issues 158 | 159 | for segment in dataset[label_field].keys(): 160 | 161 | # get the video ID and the features out of the aligned dataset 162 | vid = re.search(pattern, segment).group(1) 163 | label = dataset[label_field][segment]['features'] 164 | _words = dataset[text_field][segment]['features'] 165 | _visual = dataset[visual_field][segment]['features'] 166 | _acoustic = dataset[acoustic_field][segment]['features'] 167 | 168 | # if the sequences are not same length after alignment, there must be some problem with some modalities 169 | # we should drop it or inspect the data again 170 | if not _words.shape[0] == _visual.shape[0] == _acoustic.shape[0]: 171 | print(f"Encountered datapoint {vid} with text shape {_words.shape}, visual shape {_visual.shape}, acoustic shape {_acoustic.shape}") 172 | num_drop += 1 173 | continue 174 | 175 | # remove nan values 176 | label = np.nan_to_num(label) 177 | _visual = np.nan_to_num(_visual) 178 | _acoustic = np.nan_to_num(_acoustic) 179 | 180 | # remove speech pause tokens - this is in general helpful 181 | # we should remove speech pauses and corresponding visual/acoustic features together 182 | # otherwise modalities would no longer be aligned 183 | actual_words = [] 184 | words = [] 185 | visual = [] 186 | acoustic = [] 187 | for i, word in enumerate(_words): 188 | if word[0] != b'sp': 189 | actual_words.append(word[0].decode('utf-8')) 190 | words.append(word2id[word[0].decode('utf-8')]) # SDK stores strings as bytes, decode into strings here 191 | visual.append(_visual[i, :]) 192 | acoustic.append(_acoustic[i, :]) 193 | 194 | words = np.asarray(words) 195 | visual = np.asarray(visual) 196 | acoustic = np.asarray(acoustic) 197 | 198 | 199 | # z-normalization per instance and remove nan/infs 200 | visual = np.nan_to_num((visual - visual.mean(0, keepdims=True)) / (EPS + np.std(visual, axis=0, keepdims=True))) 201 | acoustic = np.nan_to_num((acoustic - acoustic.mean(0, keepdims=True)) / (EPS + np.std(acoustic, axis=0, keepdims=True))) 202 | 203 | if vid in train_split: 204 | train.append(((words, visual, acoustic, actual_words), label, segment)) 205 | elif vid in dev_split: 206 | dev.append(((words, visual, acoustic, actual_words), label, segment)) 207 | elif vid in test_split: 208 | test.append(((words, visual, acoustic, actual_words), label, segment)) 209 | else: 210 | print(f"Found video that doesn't belong to any splits: {vid}") 211 | 212 | print(f"Total number of {num_drop} datapoints have been dropped.") 213 | 214 | word2id.default_factory = return_unk 215 | 216 | # Save glove embeddings cache too 217 | self.pretrained_emb = pretrained_emb = load_emb(word2id, config.word_emb_path) 218 | torch.save((pretrained_emb, word2id), CACHE_PATH) 219 | 220 | # Save pickles 221 | to_pickle(train, DATA_PATH + '/train.pkl') 222 | to_pickle(dev, DATA_PATH + '/dev.pkl') 223 | to_pickle(test, DATA_PATH + '/test.pkl') 224 | 225 | def get_data(self, mode): 226 | 227 | if mode == "train": 228 | return self.train, self.word2id, self.pretrained_emb 229 | elif mode == "dev": 230 | return self.dev, self.word2id, self.pretrained_emb 231 | elif mode == "test": 232 | return self.test, self.word2id, self.pretrained_emb 233 | else: 234 | print("Mode is not set properly (train/dev/test)") 235 | exit() 236 | 237 | 238 | 239 | 240 | class MOSEI: 241 | def __init__(self, config): 242 | 243 | if config.sdk_dir is None: 244 | print("SDK path is not specified! Please specify first in constants/paths.py") 245 | exit(0) 246 | else: 247 | sys.path.append(str(config.sdk_dir)) 248 | 249 | DATA_PATH = str(config.dataset_dir) 250 | CACHE_PATH = DATA_PATH + '/embedding_and_mapping.pt' 251 | 252 | # If cached data if already exists 253 | try: 254 | self.train = load_pickle(DATA_PATH + '/train.pkl') 255 | self.dev = load_pickle(DATA_PATH + '/dev.pkl') 256 | self.test = load_pickle(DATA_PATH + '/test.pkl') 257 | self.pretrained_emb, self.word2id = torch.load(CACHE_PATH) 258 | 259 | except: 260 | 261 | # create folders for storing the data 262 | if not os.path.exists(DATA_PATH): 263 | check_call(' '.join(['mkdir', '-p', DATA_PATH]), shell=True) 264 | 265 | 266 | # download highlevel features, low-level (raw) data and labels for the dataset MOSEI 267 | # if the files are already present, instead of downloading it you just load it yourself. 268 | DATASET = md.cmu_mosei 269 | """try: 270 | md.mmdataset(DATASET.highlevel, DATA_PATH) 271 | except RuntimeError: 272 | print("High-level features have been downloaded previously.") 273 | 274 | try: 275 | md.mmdataset(DATASET.raw, DATA_PATH) 276 | except RuntimeError: 277 | print("Raw data have been downloaded previously.")""" 278 | 279 | try: 280 | md.mmdataset(DATASET.labels, DATA_PATH) 281 | except RuntimeError: 282 | print("Labels have been downloaded previously.") 283 | 284 | # define your different modalities - refer to the filenames of the CSD files 285 | visual_field = 'CMU_MOSEI_VisualFacet42' 286 | acoustic_field = 'CMU_MOSEI_COVAREP' 287 | text_field = 'CMU_MOSEI_TimestampedWords' 288 | #text_field = "CMU_MOSEI_TimestampedWordVectors" 289 | 290 | 291 | features = [ 292 | text_field, 293 | visual_field, 294 | acoustic_field 295 | ] 296 | 297 | recipe = {feat: os.path.join(DATA_PATH, feat) + '.csd' for feat in features} 298 | print(recipe) 299 | dataset = md.mmdataset(recipe) 300 | 301 | # we define a simple averaging function that does not depend on intervals 302 | def avg(intervals: np.array, features: np.array) -> np.array: 303 | try: 304 | return np.average(features, axis=0) 305 | except: 306 | return features 307 | 308 | # first we align to words with averaging, collapse_function receives a list of functions 309 | dataset.align(text_field, collapse_functions=[avg]) 310 | 311 | label_field = 'CMU_MOSEI_Labels' 312 | 313 | # we add and align to lables to obtain labeled segments 314 | # this time we don't apply collapse functions so that the temporal sequences are preserved 315 | label_recipe = {label_field: os.path.join(DATA_PATH, label_field + '.csd')} 316 | dataset.add_computational_sequences(label_recipe, destination=None) 317 | dataset.align(label_field) 318 | 319 | # obtain the train/dev/test splits - these splits are based on video IDs 320 | train_split = DATASET.standard_folds.standard_train_fold 321 | dev_split = DATASET.standard_folds.standard_valid_fold 322 | test_split = DATASET.standard_folds.standard_test_fold 323 | 324 | 325 | # a sentinel epsilon for safe division, without it we will replace illegal values with a constant 326 | EPS = 1e-6 327 | 328 | 329 | 330 | # place holders for the final train/dev/test dataset 331 | self.train = train = [] 332 | self.dev = dev = [] 333 | self.test = test = [] 334 | self.word2id = word2id 335 | 336 | # define a regular expression to extract the video ID out of the keys 337 | pattern = re.compile('(.*)\[.*\]') 338 | num_drop = 0 # a counter to count how many data points went into some processing issues 339 | 340 | for segment in dataset[label_field].keys(): 341 | 342 | # get the video ID and the features out of the aligned dataset 343 | try: 344 | vid = re.search(pattern, segment).group(1) 345 | label = dataset[label_field][segment]['features'] 346 | _words = dataset[text_field][segment]['features'] 347 | _visual = dataset[visual_field][segment]['features'] 348 | _acoustic = dataset[acoustic_field][segment]['features'] 349 | except: 350 | continue 351 | 352 | # if the sequences are not same length after alignment, there must be some problem with some modalities 353 | # we should drop it or inspect the data again 354 | if not _words.shape[0] == _visual.shape[0] == _acoustic.shape[0]: 355 | print(f"Encountered datapoint {vid} with text shape {_words.shape}, visual shape {_visual.shape}, acoustic shape {_acoustic.shape}") 356 | num_drop += 1 357 | continue 358 | 359 | # remove nan values 360 | label = np.nan_to_num(label) 361 | _visual = np.nan_to_num(_visual) 362 | _acoustic = np.nan_to_num(_acoustic) 363 | 364 | # remove speech pause tokens - this is in general helpful 365 | # we should remove speech pauses and corresponding visual/acoustic features together 366 | # otherwise modalities would no longer be aligned 367 | actual_words = [] 368 | words = [] 369 | visual = [] 370 | acoustic = [] 371 | for i, word in enumerate(_words): 372 | if word[0] != b'sp': 373 | actual_words.append(word[0].decode('utf-8')) 374 | words.append(word2id[word[0].decode('utf-8')]) # SDK stores strings as bytes, decode into strings here 375 | visual.append(_visual[i, :]) 376 | acoustic.append(_acoustic[i, :]) 377 | 378 | words = np.asarray(words) 379 | visual = np.asarray(visual) 380 | acoustic = np.asarray(acoustic) 381 | 382 | # z-normalization per instance and remove nan/infs 383 | visual = np.nan_to_num((visual - visual.mean(0, keepdims=True)) / (EPS + np.std(visual, axis=0, keepdims=True))) 384 | acoustic = np.nan_to_num((acoustic - acoustic.mean(0, keepdims=True)) / (EPS + np.std(acoustic, axis=0, keepdims=True))) 385 | 386 | if vid in train_split: 387 | train.append(((words, visual, acoustic, actual_words), label, segment)) 388 | elif vid in dev_split: 389 | dev.append(((words, visual, acoustic, actual_words), label, segment)) 390 | elif vid in test_split: 391 | test.append(((words, visual, acoustic, actual_words), label, segment)) 392 | else: 393 | print(f"Found video that doesn't belong to any splits: {vid}") 394 | 395 | 396 | print(f"Total number of {num_drop} datapoints have been dropped.") 397 | 398 | word2id.default_factory = return_unk 399 | 400 | # Save glove embeddings cache too 401 | self.pretrained_emb = pretrained_emb = load_emb(word2id, config.word_emb_path) 402 | torch.save((pretrained_emb, word2id), CACHE_PATH) 403 | 404 | # Save pickles 405 | to_pickle(train, DATA_PATH + '/train.pkl') 406 | to_pickle(dev, DATA_PATH + '/dev.pkl') 407 | to_pickle(test, DATA_PATH + '/test.pkl') 408 | 409 | def get_data(self, mode): 410 | 411 | if mode == "train": 412 | return self.train, self.word2id, self.pretrained_emb 413 | elif mode == "dev": 414 | return self.dev, self.word2id, self.pretrained_emb 415 | elif mode == "test": 416 | return self.test, self.word2id, self.pretrained_emb 417 | else: 418 | print("Mode is not set properly (train/dev/test)") 419 | exit() 420 | 421 | 422 | 423 | 424 | class UR_FUNNY: 425 | def __init__(self, config): 426 | 427 | 428 | DATA_PATH = str(config.dataset_dir) 429 | CACHE_PATH = DATA_PATH + '/embedding_and_mapping.pt' 430 | 431 | # If cached data if already exists 432 | try: 433 | self.train = load_pickle(DATA_PATH + '/train.pkl') 434 | self.dev = load_pickle(DATA_PATH + '/dev.pkl') 435 | self.test = load_pickle(DATA_PATH + '/test.pkl') 436 | self.pretrained_emb, self.word2id = torch.load(CACHE_PATH) 437 | 438 | except: 439 | 440 | 441 | # create folders for storing the data 442 | if not os.path.exists(DATA_PATH): 443 | check_call(' '.join(['mkdir', '-p', DATA_PATH]), shell=True) 444 | 445 | 446 | data_folds=load_pickle(DATA_PATH + '/data_folds.pkl') 447 | train_split=data_folds['train'] 448 | dev_split=data_folds['dev'] 449 | test_split=data_folds['test'] 450 | 451 | 452 | 453 | word_aligned_openface_sdk=load_pickle(DATA_PATH + "/openface_features_sdk.pkl") 454 | word_aligned_covarep_sdk=load_pickle(DATA_PATH + "/covarep_features_sdk.pkl") 455 | word_embedding_idx_sdk=load_pickle(DATA_PATH + "/word_embedding_indexes_sdk.pkl") 456 | word_list_sdk=load_pickle(DATA_PATH + "/word_list.pkl") 457 | humor_label_sdk = load_pickle(DATA_PATH + "/humor_label_sdk.pkl") 458 | 459 | # a sentinel epsilon for safe division, without it we will replace illegal values with a constant 460 | EPS = 1e-6 461 | 462 | # place holders for the final train/dev/test dataset 463 | self.train = train = [] 464 | self.dev = dev = [] 465 | self.test = test = [] 466 | self.word2id = word2id 467 | 468 | num_drop = 0 # a counter to count how many data points went into some processing issues 469 | 470 | # Iterate over all possible utterances 471 | for key in humor_label_sdk.keys(): 472 | 473 | label = np.array(humor_label_sdk[key], dtype=int) 474 | _word_id = np.array(word_embedding_idx_sdk[key]['punchline_embedding_indexes']) 475 | _acoustic = np.array(word_aligned_covarep_sdk[key]['punchline_features']) 476 | _visual = np.array(word_aligned_openface_sdk[key]['punchline_features']) 477 | 478 | 479 | if not _word_id.shape[0] == _acoustic.shape[0] == _visual.shape[0]: 480 | num_drop += 1 481 | continue 482 | 483 | # remove nan values 484 | label = np.array([np.nan_to_num(label)])[:, np.newaxis] 485 | _visual = np.nan_to_num(_visual) 486 | _acoustic = np.nan_to_num(_acoustic) 487 | 488 | 489 | actual_words = [] 490 | words = [] 491 | visual = [] 492 | acoustic = [] 493 | for i, word_id in enumerate(_word_id): 494 | word = word_list_sdk[word_id] 495 | actual_words.append(word) 496 | words.append(word2id[word]) 497 | visual.append(_visual[i, :]) 498 | acoustic.append(_acoustic[i, :]) 499 | 500 | words = np.asarray(words) 501 | visual = np.asarray(visual) 502 | acoustic = np.asarray(acoustic) 503 | 504 | # z-normalization per instance and remove nan/infs 505 | visual = np.nan_to_num((visual - visual.mean(0, keepdims=True)) / (EPS + np.std(visual, axis=0, keepdims=True))) 506 | acoustic = np.nan_to_num((acoustic - acoustic.mean(0, keepdims=True)) / (EPS + np.std(acoustic, axis=0, keepdims=True))) 507 | 508 | if key in train_split: 509 | train.append(((words, visual, acoustic, actual_words), label)) 510 | elif key in dev_split: 511 | dev.append(((words, visual, acoustic, actual_words), label)) 512 | elif key in test_split: 513 | test.append(((words, visual, acoustic, actual_words), label)) 514 | else: 515 | print(f"Found video that doesn't belong to any splits: {key}") 516 | 517 | print(f"Total number of {num_drop} datapoints have been dropped.") 518 | word2id.default_factory = return_unk 519 | 520 | # Save glove embeddings cache too 521 | self.pretrained_emb = pretrained_emb = load_emb(word2id, config.word_emb_path) 522 | torch.save((pretrained_emb, word2id), CACHE_PATH) 523 | 524 | # Save pickles 525 | to_pickle(train, DATA_PATH + '/train.pkl') 526 | to_pickle(dev, DATA_PATH + '/dev.pkl') 527 | to_pickle(test, DATA_PATH + '/test.pkl') 528 | 529 | def get_data(self, mode): 530 | 531 | if mode == "train": 532 | return self.train, self.word2id, self.pretrained_emb 533 | elif mode == "dev": 534 | return self.dev, self.word2id, self.pretrained_emb 535 | elif mode == "test": 536 | return self.test, self.word2id, self.pretrained_emb 537 | else: 538 | print("Mode is not set properly (train/dev/test)") 539 | exit() -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 8 | from transformers import BertModel, BertConfig 9 | from typing import Optional 10 | from utils import to_gpu 11 | from utils import ReverseLayerF 12 | import warnings 13 | from transformers import RobertaTokenizer, RobertaModel 14 | 15 | import math 16 | import torch.nn.functional as F 17 | import torch.nn as disable_weight_init 18 | 19 | class AdversarialDiscriminator(nn.Module): 20 | def __init__(self, feature_dim): 21 | super(AdversarialDiscriminator, self).__init__() 22 | self.fc1 = nn.Linear(feature_dim, 64) 23 | self.fc2 = nn.Linear(64, 32) 24 | self.fc3 = nn.Linear(32, 1) 25 | self.sigmoid = nn.Sigmoid() 26 | 27 | def forward(self, x): 28 | x = F.relu(self.fc1(x)) 29 | x = F.relu(self.fc2(x)) 30 | return self.sigmoid(self.fc3(x)) 31 | 32 | class RNNPoolingModel(nn.Module): 33 | def __init__(self, input_dim, hidden_dim, output_dim): 34 | super(RNNPoolingModel, self).__init__() 35 | self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=False) 36 | self.fc = nn.Linear(hidden_dim, output_dim) 37 | 38 | def forward(self, x): 39 | # x: [frame, batch, embedding] 40 | _, (h_n, _) = self.rnn(x) # h_n: [1, batch, hidden_dim] 41 | x = h_n.squeeze(0) # [batch, hidden_dim] 42 | x = self.fc(x) # [batch, output_dim] 43 | return x 44 | 45 | class MaskedSelfAttention(nn.Module): 46 | def __init__(self, embed_dim, num_heads): 47 | super(MaskedSelfAttention, self).__init__() 48 | self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 49 | 50 | def forward(self, x, attn_mask=None): 51 | attn_output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask) 52 | return attn_output 53 | 54 | 55 | def masked_mean(tensor, mask, dim): 56 | """Finding the mean along dim""" 57 | masked = torch.mul(tensor, mask) 58 | return masked.sum(dim=dim) / mask.sum(dim=dim) 59 | 60 | def masked_max(tensor, mask, dim): 61 | """Finding the max along dim""" 62 | masked = torch.mul(tensor, mask) 63 | neg_inf = torch.zeros_like(tensor) 64 | neg_inf[~mask] = -math.inf 65 | return (masked + neg_inf).max(dim=dim) 66 | 67 | 68 | class CrossAttention(nn.Module): 69 | def __init__(self, dim_q, dim_kv, heads=8): 70 | super().__init__() 71 | self.dim_q = dim_q 72 | self.heads = heads 73 | self.wq = nn.Linear(dim_q, dim_q, bias=False) 74 | self.wk = nn.Linear(dim_kv, dim_q, bias=False) 75 | self.wv = nn.Linear(dim_kv, dim_q, bias=False) 76 | self.out_proj = nn.Linear(dim_q, dim_q) 77 | self.dropout = nn.Dropout(0.2) 78 | self.act = nn.ReLU 79 | def multihead_reshape(self, x): 80 | b, lens, dim = x.shape 81 | x = x.reshape(b, lens, self.heads, dim // self.heads) 82 | x = x.transpose(1, 2) 83 | x = x.reshape(b * self.heads, lens, dim // self.heads) 84 | return x 85 | 86 | def multihead_reshape_inverse(self, x): 87 | b, lens, dim = x.shape 88 | x = x.reshape(b // self.heads, self.heads, lens, dim) 89 | x = x.transpose(1, 2) 90 | x = x.reshape(b // self.heads, lens, dim * self.heads) 91 | return x 92 | def forward(self, q, kv): 93 | q = self.wq(q) 94 | k = self.wk(kv) 95 | v = self.wv(kv) 96 | 97 | q = self.multihead_reshape(q) 98 | k = self.multihead_reshape(k) 99 | v = self.multihead_reshape(v) 100 | 101 | atten = q.bmm(k.transpose(1, 2)) * (self.dim_q // self.heads)**-0.5 102 | atten = atten.softmax(dim=-1) 103 | atten = atten.bmm(v) 104 | 105 | atten = self.multihead_reshape_inverse(atten) 106 | atten = self.out_proj(atten) 107 | atten = self.dropout(atten) 108 | return atten 109 | 110 | class FBP(nn.Module): 111 | def __init__(self, d_emb_1, d_emb_2, fbp_hid, fbp_k, dropout): 112 | super(FBP, self).__init__() 113 | self.fusion_1_matrix = nn.Linear(d_emb_1, fbp_hid*fbp_k, bias=False) 114 | self.fusion_2_matrix = nn.Linear(d_emb_2, fbp_hid*fbp_k, bias=False) 115 | self.fusion_dropout = nn.Dropout(dropout) 116 | self.fusion_pooling = nn.AvgPool1d(kernel_size=fbp_k) 117 | self.fbp_k = fbp_k 118 | 119 | def forward(self, seq1, seq2): 120 | seq1 = self.fusion_1_matrix(seq1) 121 | seq2 = self.fusion_2_matrix(seq2) 122 | fused_feature = torch.mul(seq1, seq2) 123 | if len(fused_feature.shape) == 2: 124 | fused_feature = fused_feature.unsqueeze(0) 125 | fused_feature = self.fusion_dropout(fused_feature) 126 | fused_feature = self.fusion_pooling(fused_feature).squeeze(0) * self.fbp_k # (bs, 512) 127 | fused_feature = F.normalize(fused_feature, dim=-1, p=2) 128 | return fused_feature 129 | 130 | 131 | class PositionalEncoding(nn.Module): 132 | def __init__(self, d_model, max_len=5000): 133 | super(PositionalEncoding, self).__init__() 134 | pe = torch.zeros(max_len, d_model) 135 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 136 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 137 | pe[:, 0::2] = torch.sin(position * div_term) 138 | pe[:, 1::2] = torch.cos(position * div_term) 139 | pe = pe.unsqueeze(1) # Add batch dimension 140 | self.register_buffer('pe', pe) 141 | 142 | def forward(self, x): 143 | return x + self.pe[:x.size(0), :] 144 | class SelfGating(nn.Module): 145 | def __init__(self, input_dim): 146 | super(SelfGating, self).__init__() 147 | self.gate = nn.Sequential( 148 | nn.Linear(input_dim, input_dim), 149 | nn.Sigmoid(), 150 | nn.Dropout(0.6) 151 | ) 152 | 153 | def forward(self, x): 154 | gate_value = self.gate(x) 155 | return x * gate_value 156 | def compute_frame_attention(self, features_m1, features_m2): 157 | S_g_list = [] 158 | for frame in range(features_m1.size(0)): 159 | Q = features_m1[frame] # [batch, ebd] 160 | K = features_m2[frame] # [batch, ebd] 161 | elementwise_mul = Q * K # [batch, ebd] 162 | 163 | # SumPooling 164 | sum_pooled = torch.sum(elementwise_mul, dim=1) # [batch] 165 | 166 | # L2 Normalization 167 | l2_normalized = F.normalize(sum_pooled, p=2, dim=1) # [batch] 168 | 169 | # Linear Projection 170 | S_g = torch.matmul(l2_normalized) # [batch] 171 | 172 | S_g_list.append(S_g) 173 | S_g_final = torch.stack(S_g_list, dim=0) # [frame, batch] 174 | return S_g_final 175 | 176 | class MyRNNModel(nn.Module): 177 | def __init__(self, input_dim, hidden_dim, num_classes): 178 | super(MyRNNModel, self).__init__() 179 | self.rnn = nn.LSTM(input_dim, hidden_dim, batch_first=True) 180 | self.dropout = nn.Dropout(0.6) 181 | 182 | def forward(self, x): 183 | x = x.permute(1, 0, 2) 184 | _, (hn, _) = self.rnn(x) 185 | hn = hn[-1] 186 | return hn 187 | 188 | # let's define a simple model that can deal with multimodal variable length sequence 189 | class SATI(nn.Module): 190 | def __init__(self, config): 191 | super(SATI, self).__init__() 192 | 193 | self.config = config 194 | self.text_size = config.embedding_size 195 | self.visual_size = config.visual_size 196 | self.acoustic_size = config.acoustic_size 197 | self.input_sizes = input_sizes = [self.text_size, self.visual_size, self.acoustic_size] 198 | self.hidden_sizes = hidden_sizes = [int(self.text_size), int(self.visual_size), int(self.acoustic_size)] 199 | self.output_size = output_size = config.num_classes 200 | self.dropout_rate = dropout_rate = config.dropout 201 | self.activation = self.config.activation() 202 | self.softmax = nn.Softmax() 203 | self.tanh = nn.Tanh() 204 | self.domain_private = [] 205 | self.domain_shared = [] 206 | 207 | rnn = nn.LSTM if self.config.rnncell == "lstm" else nn.GRU 208 | # defining modules - two layer bidirectional LSTM with layer norm in between 209 | 210 | vocab_file = '/home/s22xjq/SATI/model/vocab.json' 211 | merges_file = '/home/s22xjq/SATI/model/merges.txt' 212 | print("using roberta") 213 | self.robertatokenizer = RobertaTokenizer(vocab_file, merges_file) 214 | self.roberta = RobertaModel.from_pretrained('/home/s22xjq/SATI/model/roberta-base/') 215 | 216 | 217 | self.vrnn1 = rnn(input_sizes[1], hidden_sizes[1], bidirectional=True) 218 | self.vrnn2 = rnn(2*hidden_sizes[1], hidden_sizes[1], bidirectional=True) 219 | 220 | self.arnn1 = rnn(input_sizes[2], hidden_sizes[2], bidirectional=True) 221 | self.arnn2 = rnn(2*hidden_sizes[2], hidden_sizes[2], bidirectional=True) 222 | 223 | self.poolingrnn = MyRNNModel(self.config.hidden_size, self.config.hidden_size, self.output_size) 224 | 225 | ########################################## 226 | # mapping modalities to same sized space 227 | ########################################## 228 | 229 | self.project_t = nn.Sequential() 230 | self.project_t.add_module('project_t', nn.Linear(in_features=768, out_features=config.hidden_size)) 231 | self.project_t.add_module('project_t_activation', self.activation) 232 | self.project_t.add_module('project_t_layer_norm', nn.LayerNorm(config.hidden_size)) 233 | 234 | self.project_v = nn.Sequential() 235 | self.project_v.add_module('project_v', nn.Linear(in_features=hidden_sizes[1], out_features=config.hidden_size)) 236 | self.project_v.add_module('project_v_activation', self.activation) 237 | self.project_v.add_module('project_v_layer_norm', nn.LayerNorm(config.hidden_size)) 238 | 239 | self.project_a = nn.Sequential() 240 | self.project_a.add_module('project_a', nn.Linear(in_features=hidden_sizes[2], out_features=config.hidden_size)) 241 | self.project_a.add_module('project_a_activation', self.activation) 242 | self.project_a.add_module('project_a_layer_norm', nn.LayerNorm(config.hidden_size)) 243 | 244 | self.project_h = nn.Sequential() 245 | self.project_h.add_module('project_h', nn.Linear(in_features=hidden_sizes[2]*4, out_features=config.hidden_size)) 246 | self.project_h.add_module('project_h_activation', self.activation) 247 | self.project_h.add_module('project_h_layer_norm', nn.LayerNorm(config.hidden_size)) 248 | 249 | ########################################## 250 | # private encoders 251 | ########################################## 252 | self.private_t = nn.Sequential() 253 | self.private_t.add_module('private_t_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 254 | self.private_t.add_module('private_t_activation_1', nn.Sigmoid()) 255 | 256 | self.private_v = nn.Sequential() 257 | self.private_v.add_module('private_v_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 258 | self.private_v.add_module('private_v_activation_1', nn.Sigmoid()) 259 | 260 | self.private_a = nn.Sequential() 261 | self.private_a.add_module('private_a_3', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 262 | self.private_a.add_module('private_a_activation_3', nn.Sigmoid()) 263 | 264 | 265 | ########################################## 266 | # shared encoder 267 | ########################################## 268 | self.shared = nn.Sequential() 269 | self.shared.add_module('shared_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 270 | self.shared.add_module('shared_activation_1', nn.Sigmoid()) 271 | 272 | 273 | ########################################## 274 | # reconstruct 275 | ########################################## 276 | self.recon_t = nn.Sequential() 277 | self.recon_t.add_module('recon_t_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 278 | self.recon_v = nn.Sequential() 279 | self.recon_v.add_module('recon_v_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 280 | self.recon_a = nn.Sequential() 281 | self.recon_a.add_module('recon_a_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 282 | 283 | 284 | 285 | ########################################## 286 | # shared space adversarial discriminator 287 | ########################################## 288 | 289 | self.discriminator = nn.Sequential() 290 | self.discriminator.add_module('discriminator_layer_1', nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size)) 291 | self.discriminator.add_module('discriminator_layer_1_activation', self.activation) 292 | self.discriminator.add_module('discriminator_layer_1_dropout', nn.Dropout(dropout_rate)) 293 | 294 | self.W = nn.Sequential() 295 | self.W.add_module('discriminator_layer_2', nn.Linear(in_features=config.hidden_size, out_features=3, bias=False)) 296 | #self.discriminator.add_module('discriminator_layer_2_activation', self.softmax) 297 | ########################################## 298 | # shared-private collaborative discriminator 299 | ########################################## 300 | 301 | self.sp_discriminator = nn.Sequential() 302 | self.sp_discriminator.add_module('sp_discriminator_layer_1', nn.Linear(in_features=config.hidden_size, out_features=4)) 303 | 304 | 305 | 306 | self.fusion = nn.Sequential() 307 | self.fusion.add_module('fusion_layer_1', nn.Linear(in_features=self.config.hidden_size*2, out_features=self.config.hidden_size*1)) 308 | self.fusion.add_module('fusion_layer_1_dropout', nn.Dropout(dropout_rate)) 309 | self.fusion.add_module('fusion_layer_1_activation', self.activation) 310 | self.fusion.add_module('fusion_layer_3', nn.Linear(in_features=self.config.hidden_size*1, out_features= output_size)) 311 | 312 | self.my_fusion = nn.Sequential() 313 | self.my_fusion.add_module('fusion_layer_2', nn.Linear(in_features=self.config.hidden_size*2, out_features= output_size)) 314 | #self.my_fusion.add_module('fusion_layer_2_dropout', nn.Dropout(0.2)) 315 | 316 | 317 | 318 | self.tlayer_norm = nn.LayerNorm((hidden_sizes[0]*2,)) 319 | self.vlayer_norm = nn.LayerNorm((hidden_sizes[1]*2,)) 320 | self.alayer_norm = nn.LayerNorm((hidden_sizes[2]*2,)) 321 | self.hlayer_norm = nn.LayerNorm(self.config.hidden_size) 322 | self.player_norm = nn.LayerNorm(self.config.hidden_size) 323 | self.slayer_norm = nn.LayerNorm(self.config.hidden_size) 324 | 325 | 326 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.config.hidden_size*2, nhead=1) 327 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1) 328 | 329 | encoder_layer_v = nn.TransformerEncoderLayer(d_model=hidden_sizes[1], nhead=1) 330 | self.transformer_encoder_v = nn.TransformerEncoder(encoder_layer_v, num_layers=1) 331 | 332 | encoder_layer_a = nn.TransformerEncoderLayer(d_model=hidden_sizes[2], nhead=2) 333 | self.transformer_encoder_a = nn.TransformerEncoder(encoder_layer_a, num_layers=1) 334 | 335 | 336 | self.cross_attn_at = CrossAttention(dim_q=self.config.hidden_size, dim_kv=self.config.hidden_size, heads=4) 337 | self.norm_atten1 = nn.LayerNorm(normalized_shape=self.config.hidden_size, elementwise_affine=True) 338 | self.cross_attn_vt = CrossAttention(dim_q=self.config.hidden_size, dim_kv=self.config.hidden_size, heads=4) 339 | self.norm_atten2 = nn.LayerNorm(normalized_shape=self.config.hidden_size, elementwise_affine=True) 340 | 341 | #self.cross_encoder_av = TransformerEncoder(self.config.hidden_size, self.config.hidden_size, n_layers=4, d_inner=512, n_head=4, d_k=None, d_v=None, dropout=0.1, n_position=self.config.hidden_size, add_sa=False) 342 | #self.cross_encoder_at = TransformerEncoder(self.config.hidden_size, self.config.hidden_size, n_layers=4, d_inner=512, n_head=4, d_k=None, d_v=None, dropout=0.1, n_position=self.config.hidden_size, add_sa=False) 343 | 344 | 345 | ########################################## 346 | # PositionalEncoder 347 | ########################################## 348 | self.position_encoding = PositionalEncoding(d_model=self.config.hidden_size, max_len=1024) 349 | 350 | self.masked_attn_layer = MaskedSelfAttention(embed_dim = self.config.hidden_size, num_heads=1) 351 | 352 | self.pooling = RNNPoolingModel(input_dim=3*self.config.hidden_size, hidden_dim=self.config.hidden_size*3, output_dim=self.config.hidden_size*3) 353 | 354 | 355 | ########################################## 356 | # FBP Gate 357 | ########################################## 358 | self.fbp_at = FBP(self.config.hidden_size, self.config.hidden_size, fbp_hid=32, fbp_k=2, dropout=0.4) 359 | self.fc_gate_at = nn.Linear(32, 1) 360 | 361 | 362 | 363 | self.fbp_vt = FBP(self.config.hidden_size, self.config.hidden_size, fbp_hid=32, fbp_k=2, dropout=0.4) 364 | self.fc_gate_vt = nn.Linear(32, 1) 365 | self.gate_activate = nn.Tanh() 366 | 367 | ########################################## 368 | # self Gating 369 | ########################################## 370 | #self.self_gating = SelfGating(self.config.hidden_size) 371 | 372 | def extract_features(self, sequence, lengths, rnn1, rnn2, layer_norm): 373 | lengths = lengths.cpu().long() 374 | packed_sequence = pack_padded_sequence(sequence, lengths) 375 | 376 | if self.config.rnncell == "lstm": 377 | packed_h1, (final_h1, _) = rnn1(packed_sequence) 378 | else: 379 | packed_h1, final_h1 = rnn1(packed_sequence) 380 | 381 | padded_h1, _ = pad_packed_sequence(packed_h1) 382 | normed_h1 = layer_norm(padded_h1) 383 | packed_normed_h1 = pack_padded_sequence(normed_h1, lengths) 384 | 385 | if self.config.rnncell == "lstm": 386 | packed_h2, (final_h2, _) = rnn2(packed_normed_h1) 387 | else: 388 | packed_h2, final_h2 = rnn2(packed_normed_h1) 389 | padded_h2, _ = pad_packed_sequence(packed_h2) 390 | return final_h1, final_h2, padded_h1, padded_h2 391 | 392 | def alignment(self, sentences, visual, acoustic, lengths, bert_sent, bert_sent_type, bert_sent_mask): 393 | 394 | batch_size = lengths.size(0) 395 | 396 | ########################################## 397 | # extract avt features 398 | ########################################## 399 | 400 | 401 | roberta_output = self.roberta(input_ids=bert_sent, 402 | attention_mask=bert_sent_mask) 403 | roberta_out = roberta_output.last_hidden_state 404 | roberta_out = roberta_out.transpose(0,1) # [46, 64, 768] 405 | utterance_text = roberta_out 406 | 407 | 408 | 409 | 410 | 411 | # extract features from visual modality 412 | utterance_video = self.transformer_encoder_v(visual) 413 | utterance_audio = self.transformer_encoder_a(acoustic) 414 | #print( utterance_audio.shape) 415 | 416 | 417 | 418 | # Shared-private encoders 419 | self.shared_private(utterance_text, utterance_video, utterance_audio) 420 | 421 | ########################################## 422 | # discriminator 423 | ########################################## 424 | private_t = torch.sum(self.utt_private_t, dim=0)/self.utt_private_t.size(0) 425 | shared_t = torch.sum(self.utt_shared_t, dim=0)/self.utt_shared_t.size(0) 426 | private_a = torch.sum(self.utt_private_a, dim=0)/self.utt_private_a.size(0) 427 | shared_a = torch.sum(self.utt_shared_a, dim=0)/self.utt_shared_a.size(0) 428 | private_v = torch.sum(self.utt_private_v, dim=0)/self.utt_private_v.size(0) 429 | shared_v = torch.sum(self.utt_shared_v, dim=0)/self.utt_shared_v.size(0) 430 | shared = (shared_v+shared_t+shared_a)/3 431 | private = (private_v+private_t+private_a)/3 432 | private = self.player_norm(private) 433 | shared = self.slayer_norm(shared) 434 | self.domain_private = private 435 | self.domain_shared = shared 436 | #print(private) 437 | 438 | reversed_shared_code_t = ReverseLayerF.apply(shared_t, self.config.reverse_grad_weight) 439 | reversed_shared_code_v = ReverseLayerF.apply(shared_v, self.config.reverse_grad_weight) 440 | reversed_shared_code_a = ReverseLayerF.apply(shared_a, self.config.reverse_grad_weight) 441 | reversed_private_code_t = ReverseLayerF.apply(private_t, self.config.reverse_grad_weight) 442 | reversed_private_code_v = ReverseLayerF.apply(private_v, self.config.reverse_grad_weight) 443 | reversed_private_code_a = ReverseLayerF.apply(private_a, self.config.reverse_grad_weight) 444 | self.domain_shared_t = self.discriminator(reversed_shared_code_t) 445 | self.domain_shared_v = self.discriminator(reversed_shared_code_v) 446 | self.domain_shared_a = self.discriminator(reversed_shared_code_a) 447 | self.domain_private_t = self.discriminator(reversed_private_code_t) 448 | self.domain_private_v = self.discriminator(reversed_private_code_v) 449 | self.domain_private_a = self.discriminator(reversed_private_code_a) 450 | 451 | # For reconstruction 452 | self.reconstruct() 453 | 454 | A = self.position_encoding(self.utt_private_a) 455 | V = self.position_encoding(self.utt_private_v) 456 | T = self.position_encoding(self.utt_private_t) 457 | 458 | 459 | at = self.cross_attn_at(T.transpose(0, 1), A.transpose(0, 1)) 460 | at = at.transpose(0, 1) 461 | gate_ = self.fbp_at(self.utt_shared_a, self.utt_shared_t) 462 | gate_ = self.gate_activate(self.fc_gate_at(gate_))#.double() 463 | gate_sign = gate_ / torch.abs(gate_) 464 | gate_ = (gate_sign + torch.abs(gate_sign)) / 2.0 465 | at = at*gate_ + T 466 | at = self.norm_atten1(at) 467 | 468 | 469 | vt = self.cross_attn_vt(T.transpose(0, 1), V.transpose(0, 1)) 470 | vt = vt.transpose(0, 1) 471 | gate_ = self.fbp_vt(self.utt_shared_v, self.utt_shared_t) 472 | gate_ = self.gate_activate(self.fc_gate_at(gate_))#.double() 473 | gate_sign = gate_ / torch.abs(gate_) 474 | gate_ = (gate_sign + torch.abs(gate_sign)) / 2.0 475 | vt = vt*gate_ + T 476 | vt = self.norm_atten2(vt) 477 | h = torch.cat((at , vt), dim=2) 478 | h = self.transformer_encoder(h) 479 | h = torch.sum(h, dim=0)/h.size(0) 480 | o = self.fusion(h) 481 | #print(h) 482 | return o 483 | 484 | def reconstruct(self,): 485 | 486 | self.utt_t = (self.utt_private_t + self.utt_shared_t) 487 | self.utt_v = (self.utt_private_v + self.utt_shared_v) 488 | self.utt_a = (self.utt_private_a + self.utt_shared_a) 489 | self.utt_t = torch.sum(self.utt_t, dim=0)/self.utt_t.size(0) 490 | self.utt_v = torch.sum(self.utt_v, dim=0)/self.utt_v.size(0) 491 | self.utt_a = torch.sum(self.utt_a, dim=0)/self.utt_a.size(0) 492 | self.utt_t_recon = self.recon_t(self.utt_t) 493 | self.utt_v_recon = self.recon_v(self.utt_v) 494 | self.utt_a_recon = self.recon_a(self.utt_a) 495 | 496 | 497 | def shared_private(self, utterance_t, utterance_v, utterance_a): 498 | 499 | ########################################## 500 | # for recon_loss 501 | ##########################################3 502 | # Projecting to same sized space 503 | self.utt_t_orig = utterance_t = self.project_t(utterance_t) 504 | self.utt_v_orig = utterance_v = self.project_v(utterance_v) 505 | self.utt_a_orig = utterance_a = self.project_a(utterance_a) 506 | 507 | 508 | # Private-shared components 509 | self.utt_private_t = self.private_t(utterance_t) 510 | self.utt_private_v = self.private_v(utterance_v) 511 | self.utt_private_a = self.private_a(utterance_a) #[framec, batch, hidden_Size] 512 | 513 | 514 | self.utt_shared_t = self.shared(utterance_t) 515 | self.utt_shared_v = self.shared(utterance_v) 516 | self.utt_shared_a = self.shared(utterance_a) 517 | 518 | 519 | def forward(self, sentences, video, acoustic, lengths, bert_sent, bert_sent_type, bert_sent_mask): 520 | batch_size = lengths.size(0) 521 | o = self.alignment(sentences, video, acoustic, lengths, bert_sent, bert_sent_type, bert_sent_mask) 522 | return o 523 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from math import isnan 4 | import re 5 | import pickle 6 | import gensim 7 | import numpy as np 8 | from tqdm import tqdm 9 | from tqdm import tqdm_notebook 10 | from sklearn.metrics import classification_report, accuracy_score, f1_score 11 | from sklearn.metrics import confusion_matrix 12 | from sklearn.metrics import precision_recall_fscore_support 13 | from scipy.special import expit 14 | from torchviz import make_dot 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn import functional as F 18 | torch.manual_seed(123) 19 | torch.cuda.manual_seed_all(123) 20 | 21 | from utils import to_gpu, time_desc_decorator, DiffLoss, MSE, SIMSE, CMD, FocalLoss, WeightedMSELoss 22 | import models 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | class Solver(object): 26 | def __init__(self, train_config, dev_config, test_config, train_data_loader, dev_data_loader, test_data_loader, is_train=True, model=None): 27 | self.train_accuracies = [] 28 | self.valid_accuracies = [] 29 | self.test_accuracies = [] 30 | self.train_losses = [] 31 | self.valid_losses = [] 32 | self.test_losses = [] 33 | self.train_maes = [] 34 | self.valid_maes = [] 35 | self.test_maes = [] 36 | self.train_f1_scores = [] 37 | self.valid_f1_scores = [] 38 | self.test_f1_scores = [] 39 | 40 | 41 | self.train_config = train_config 42 | self.epoch_i = 0 43 | self.train_data_loader = train_data_loader 44 | self.dev_data_loader = dev_data_loader 45 | self.test_data_loader = test_data_loader 46 | self.is_train = is_train 47 | self.model = model 48 | 49 | @time_desc_decorator('Build Graph') 50 | def build(self, cuda=True): 51 | 52 | if self.model is None: 53 | self.model = getattr(models, self.train_config.model)(self.train_config) 54 | 55 | # Final list 56 | for name, param in self.model.named_parameters(): 57 | 58 | # Bert freezing customizations 59 | if self.train_config.data == "mosei": 60 | if "bertmodel.encoder.layer" in name: 61 | layer_num = int(name.split("encoder.layer.")[-1].split(".")[0]) 62 | if layer_num <= (8): 63 | param.requires_grad = False 64 | elif self.train_config.data == "ur_funny": 65 | if "bert" in name: 66 | param.requires_grad = False 67 | 68 | if 'weight_hh' in name: 69 | nn.init.orthogonal_(param) 70 | #print('\t' + name, param.requires_grad) 71 | 72 | # Initialize weight of Embedding matrix with Glove embeddings 73 | 74 | if torch.cuda.is_available() and cuda: 75 | self.model.cuda() 76 | 77 | if self.is_train: 78 | self.optimizer = self.train_config.optimizer( 79 | filter(lambda p: p.requires_grad, self.model.parameters()), 80 | lr=self.train_config.learning_rate) 81 | 82 | 83 | @time_desc_decorator('Training Start!') 84 | def train(self): 85 | curr_patience = patience = self.train_config.patience 86 | num_trials = 3 87 | writer = SummaryWriter() 88 | # self.criterion = criterion = nn.L1Loss(reduction="mean") 89 | if self.train_config.data == "ur_funny": 90 | self.criterion = criterion = nn.CrossEntropyLoss(reduction="mean") 91 | else: # mosi and mosei are regression datasets 92 | self.criterion = criterion = nn.MSELoss(reduction="mean") 93 | 94 | 95 | self.domain_loss_criterion = nn.CrossEntropyLoss(reduction="mean") 96 | self.sp_loss_criterion = nn.CrossEntropyLoss(reduction="mean") 97 | self.loss_diff = DiffLoss() 98 | self.loss_recon = MSE() 99 | self.loss_cmd = CMD() 100 | 101 | best_test_acc = float('-inf') 102 | best_valid_f1 = float('-inf') 103 | best_test_f1 = float('-inf') 104 | best_valid_acc = float('-inf') 105 | 106 | 107 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.3) 108 | 109 | #train_losses = [] 110 | #valid_losses = []F 111 | 112 | start_saving_epoch = int(self.train_config.n_epoch * self.train_config.start_saving) 113 | 114 | 115 | for e in range(self.train_config.n_epoch): 116 | self.model.train() 117 | 118 | train_loss_cls, train_loss_sim, train_loss_diff = [], [], [] 119 | train_loss_recon = [] 120 | train_loss_sp = [] 121 | train_loss = [] 122 | train_y_true, train_y_pred = [], [] 123 | train_loss_jsd = [] 124 | pos = 0 125 | nos = 0 126 | nig = 0 127 | 128 | for batch in self.train_data_loader: 129 | self.model.zero_grad() 130 | t, v, a, y, l, bert_sent, bert_sent_type, bert_sent_mask = batch 131 | 132 | batch_size = t.size(0) 133 | t = to_gpu(t) 134 | v = to_gpu(v) 135 | a = to_gpu(a) 136 | y = to_gpu(y) 137 | l = to_gpu(l) 138 | bert_sent = to_gpu(bert_sent) 139 | bert_sent_type = to_gpu(bert_sent_type) 140 | bert_sent_mask = to_gpu(bert_sent_mask) 141 | #for name, param in self.model.named_parameters(): 142 | #print(name, param.shape) 143 | y_tilde = self.model(t, v, a, l, bert_sent, bert_sent_type, bert_sent_mask) 144 | 145 | #print(y) 146 | #print(y_tilde) 147 | if self.train_config.data == "ur_funny": 148 | y = y.squeeze() 149 | 150 | cls_loss = criterion(y_tilde, y) 151 | dif_loss = self.get_diff_loss() 152 | domain_loss = self.get_domain_loss() 153 | #print((cls_loss.item())) 154 | recon_loss = self.get_recon_loss() 155 | cmd_loss = self.get_cmd_loss() 156 | jsd_loss = self.get_jsd_loss() 157 | #print(y) 158 | if self.train_config.use_domain: 159 | diff_loss = domain_loss #+ soft_loss 160 | else: 161 | diff_loss = dif_loss 162 | #print(type(domain_loss)) 163 | 164 | if self.train_config.use_cmd_sim: 165 | similarity_loss = cmd_loss 166 | else: 167 | similarity_loss = domain_loss 168 | 169 | loss = cls_loss + \ 170 | self.train_config.diff_weight * domain_loss + \ 171 | self.train_config.sim_weight * similarity_loss + \ 172 | self.train_config.recon_weight * recon_loss + \ 173 | self.train_config.jsd_weight * jsd_loss 174 | loss.backward() 175 | 176 | #print("backward") 177 | #print("latent_queries gradient:", self.model.latent_queries.grad) 178 | torch.nn.utils.clip_grad_value_([param for param in self.model.parameters() if param.requires_grad], self.train_config.clip) 179 | self.optimizer.step() 180 | 181 | train_loss_cls.append(cls_loss.item()) 182 | train_loss_diff.append(diff_loss.item()) 183 | train_loss_recon.append(recon_loss.item()) 184 | train_loss.append(loss.item()) 185 | train_loss_sim.append(similarity_loss.item()) 186 | train_loss_jsd.append(jsd_loss.item()) 187 | train_y_true.append(y.detach().cpu().numpy()) 188 | train_y_pred.append(y_tilde.detach().cpu().numpy()) 189 | #print(train_y_true) 190 | print("\n") 191 | writer.add_scalar('Loss/Train_cls', np.mean(train_loss_cls), e) 192 | writer.add_scalar('Loss/Train_diff', np.mean(train_loss_diff),e ) 193 | writer.add_scalar('Loss/Train_recon', np.mean(train_loss_recon), e) 194 | writer.add_scalar('Loss/Train_similarity', np.mean(train_loss_sim), e) 195 | #writer.add_scalar('Loss/Train_jsd', np.mean(train_loss_jsd), e) 196 | writer.add_scalar('Loss/Train_total', np.mean(train_loss), e) 197 | train_loss = np.mean(train_loss) 198 | #train_losses.append(train_loss) 199 | self.train_losses.append(train_loss) 200 | train_y_true = np.concatenate(train_y_true, axis=0).squeeze() 201 | train_y = np.array(train_y_true) 202 | pos +=len([x for x in train_y if x > 0]) 203 | nos +=len([x for x in train_y if x < 0]) 204 | nig +=len([x for x in train_y if x == 0]) 205 | train_y_pred = np.concatenate(train_y_pred, axis=0).squeeze() 206 | train_acc = self.calc_metrics(train_y_true, train_y_pred, mode="train") 207 | train_mae = np.mean(np.abs(train_y_pred - train_y_true)) 208 | train_f1 = f1_score((train_y_pred > 0), (train_y_true > 0), average='weighted') 209 | writer.add_scalar('Accuracy/Train', train_acc, e) 210 | writer.add_scalar('F1/Train', train_f1, e) 211 | print(f"Epoch {e+1} - Training loss: {round(train_loss, 4)}, Accuracy: {round(train_acc, 4)}, MAE: {round(train_mae, 4)}, F1-score: {round(train_f1, 4)}") 212 | 213 | self.train_accuracies.append(train_acc) 214 | self.train_maes.append(train_mae) 215 | self.train_f1_scores.append(train_f1) 216 | #print(pos, nos, nig) 217 | pos = 0 218 | nos = 0 219 | nig = 0 220 | #print(f"Training loss: {round(np.mean(train_loss), 4)}") 221 | 222 | 223 | valid_loss, valid_acc, valid_mae, valid_f1 = self.eval(mode="dev") 224 | test_loss, test_acc, test_mae, test_f1 = self.eval(mode="test") 225 | self.valid_losses.append(valid_loss) 226 | self.valid_accuracies.append(valid_acc) 227 | self.valid_maes.append(valid_mae) 228 | self.valid_f1_scores.append(valid_f1) 229 | print(f"Epoch {e+1} - Validation loss: {round(valid_loss, 4)}, Accuracy: {round(valid_acc, 4)}, MAE: {round(valid_mae, 4)}, F1-score: {round(valid_f1, 4)}") 230 | writer.add_scalar('Accuracy/Valid', valid_acc, e) 231 | writer.add_scalar('F1/Valid', valid_f1, e) 232 | 233 | #valid_loss, valid_acc = self.eval(mode="dev") 234 | 235 | if e >= start_saving_epoch and valid_f1 >= best_valid_f1: 236 | print(best_valid_f1) 237 | best_valid_f1 = valid_f1 238 | print("Found new best model on dev set! f1") 239 | if not os.path.exists(f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}'): 240 | os.makedirs(f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}') 241 | torch.save(self.model.state_dict(), f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/model_{self.train_config.name}.std') 242 | torch.save(self.optimizer.state_dict(), f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/optim_{self.train_config.name}.std') 243 | curr_patience = patience 244 | 245 | if e >= start_saving_epoch and valid_acc >= best_valid_acc: 246 | best_valid_acc = valid_acc 247 | print("Found new best model on dev set! acc") 248 | if not os.path.exists(f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}'): 249 | os.makedirs(f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}') 250 | torch.save(self.model.state_dict(), f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/model_best_acc.std') 251 | torch.save(self.optimizer.state_dict(), f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/optim_best_acc.std') 252 | curr_patience = patience 253 | elif e >= start_saving_epoch: 254 | curr_patience -= 1 255 | if curr_patience <= -1: 256 | print("Running out of patience, loading previous best model.") 257 | num_trials -= 1 258 | curr_patience = patience 259 | self.model.load_state_dict(torch.load(f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/model_{self.train_config.name}.std')) 260 | self.optimizer.load_state_dict(torch.load(f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/optim_{self.train_config.name}.std')) 261 | lr_scheduler.step() 262 | print(f"Current learning rate: {self.optimizer.state_dict()['param_groups'][0]['lr']}") 263 | 264 | if num_trials <= 0: 265 | print("Running out of patience, early stopping.") 266 | break 267 | 268 | self.eval(mode="test", to_print=True) 269 | print("best_valid_f1", best_valid_f1) 270 | print("best_test_acc", best_test_acc) 271 | 272 | 273 | 274 | 275 | def eval(self,mode=None, to_print=False): 276 | assert(mode is not None) 277 | self.model.eval() 278 | 279 | y_true, y_pred = [], [] 280 | eval_loss, eval_loss_diff = [], [] 281 | 282 | y_true2, y_pred2 = [], [] 283 | eval_loss2, eval_loss_diff2 = [], [] 284 | 285 | if mode == "dev": 286 | dataloader = self.dev_data_loader 287 | elif mode == "test": 288 | dataloader = self.test_data_loader 289 | 290 | if to_print: 291 | self.model.load_state_dict(torch.load( 292 | f'checkpoints/checkpoints/checkpoints_{self.train_config.learning_rate}/model_best_acc.std')) 293 | 294 | 295 | with torch.no_grad(): 296 | for batch in dataloader: 297 | self.model.zero_grad() 298 | t, v, a, y, l, bert_sent, bert_sent_type, bert_sent_mask = batch 299 | 300 | t = to_gpu(t) 301 | v = to_gpu(v) 302 | a = to_gpu(a) 303 | y = to_gpu(y) 304 | l = to_gpu(l) 305 | bert_sent = to_gpu(bert_sent) 306 | bert_sent_type = to_gpu(bert_sent_type) 307 | bert_sent_mask = to_gpu(bert_sent_mask) 308 | 309 | y_tilde = self.model(t, v, a, l, bert_sent, bert_sent_type, bert_sent_mask) 310 | train_y_pred_np = np.array(y_tilde.cpu()) 311 | binary_pred = (train_y_pred_np > 0).astype(int) 312 | #print(binary_pred) 313 | if self.train_config.data == "ur_funny": 314 | y = y.squeeze() 315 | cls_loss = self.criterion(y_tilde, y) 316 | loss = cls_loss 317 | 318 | eval_loss.append(loss.item()) 319 | y_pred.append(y_tilde.detach().cpu().numpy()) 320 | y_true.append(y.detach().cpu().numpy()) 321 | 322 | 323 | 324 | eval_loss = np.mean(eval_loss) 325 | y_true = np.concatenate(y_true, axis=0).squeeze() 326 | y_pred = np.concatenate(y_pred, axis=0).squeeze() 327 | 328 | accuracy = self.calc_metrics(y_true, y_pred, mode, to_print) 329 | mae = np.mean(np.abs(y_pred - y_true)) 330 | f1 = f1_score((y_pred >= 0), (y_true >= 0), average='weighted') 331 | 332 | if mode == "dev": 333 | self.valid_losses.append(eval_loss) 334 | self.valid_accuracies.append(accuracy) 335 | self.valid_maes.append(mae) 336 | self.valid_f1_scores.append(f1) 337 | elif mode == "test": 338 | self.test_losses.append(eval_loss) 339 | self.test_accuracies.append(accuracy) 340 | self.test_maes.append(mae) 341 | self.test_f1_scores.append(f1) 342 | if to_print: 343 | print(f"Eval {mode} loss: {round(eval_loss, 4)}, Accuracy: {round(accuracy, 4)}, MAE: {round(mae, 4)}, F1-score: {round(f1, 4)}") 344 | 345 | return eval_loss, accuracy, mae, f1 346 | 347 | def multiclass_acc(self, preds, truths): 348 | """ 349 | Compute the multiclass accuracy w.r.t. groundtruth 350 | :param preds: Float array representing the predictions, dimension (N,) 351 | :param truths: Float/int array representing the groundtruth classes, dimension (N,) 352 | :return: Classification accuracy 353 | """ 354 | return np.sum(np.round(preds) == np.round(truths)) / float(len(truths)) 355 | 356 | def calc_metrics(self, y_true, y_pred, mode=None, to_print=False): 357 | """ 358 | Metric scheme adapted from: 359 | https://github.com/yaohungt/Multimodal-Transformer/blob/master/src/eval_metrics.py 360 | """ 361 | 362 | 363 | if self.train_config.data == "ur_funny": 364 | test_preds = np.argmax(y_pred, 1) 365 | test_truth = y_true 366 | 367 | if to_print: 368 | print("Confusion Matrix (pos/neg) :") 369 | print(confusion_matrix(test_truth, test_preds)) 370 | print("Classification Report (pos/neg) :") 371 | print(classification_report(test_truth, test_preds, digits=5)) 372 | print("Accuracy (pos/neg) ", accuracy_score(test_truth, test_preds)) 373 | 374 | return accuracy_score(test_truth, test_preds) 375 | 376 | else: 377 | test_preds = y_pred 378 | test_truth = y_true 379 | 380 | non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0]) 381 | 382 | test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.) 383 | test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.) 384 | test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.) 385 | test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.) 386 | 387 | mae = np.mean(np.absolute(test_preds - test_truth)) # Average L1 distance between preds and truths 388 | corr = np.corrcoef(test_preds, test_truth)[0][1] 389 | mult_a7 = self.multiclass_acc(test_preds_a7, test_truth_a7) 390 | mult_a5 = self.multiclass_acc(test_preds_a5, test_truth_a5) 391 | 392 | f_score = f1_score((test_preds[non_zeros] > 0), (test_truth[non_zeros] > 0), average='weighted') 393 | 394 | # pos - neg 395 | binary_truth = (test_truth[non_zeros] > 0) 396 | binary_preds = (test_preds[non_zeros] > 0) 397 | 398 | if to_print: 399 | print("mae: ", mae) 400 | print("corr: ", corr) 401 | print("mult_acc: ", mult_a7) 402 | print("Classification Report (pos/neg) :") 403 | print(classification_report(binary_truth, binary_preds, digits=5)) 404 | print("Accuracy (pos/neg) ", accuracy_score(binary_truth, binary_preds)) 405 | print("F1 (pos/neg) ", f1_score(binary_truth, binary_preds,average='weighted')) 406 | 407 | # non-neg - neg 408 | binary_truth = (test_truth >= 0) 409 | binary_preds = (test_preds >= 0) 410 | 411 | if to_print: 412 | print("Classification Report (non-neg/neg) :") 413 | print(classification_report(binary_truth, binary_preds, digits=5)) 414 | print("Accuracy (non-neg/neg) ", accuracy_score(binary_truth, binary_preds)) 415 | print("F1 (non-neg/neg) ", f1_score(binary_truth, binary_preds, average='weighted')) 416 | 417 | return accuracy_score(binary_truth, binary_preds) 418 | 419 | 420 | def angular_margin_loss(self, feature, label, weight, scale_factor=30.0, margin=0.5, lambda_l2=0.01): 421 | 422 | labels = [0, 1, 2] 423 | normalized_feature = F.normalize(feature, p=2, dim=0) 424 | normalized_weight = F.normalize(weight, p=2, dim=1) 425 | 426 | 427 | cos_theta = torch.matmul(normalized_feature, normalized_weight.t()) # 形状为 (num_classes,) 428 | 429 | correct_class_cos_theta = cos_theta[label] 430 | 431 | cos_theta_m = correct_class_cos_theta - margin 432 | 433 | cos_theta_with_margin = cos_theta.clone() 434 | cos_theta_with_margin[label] = cos_theta_m 435 | 436 | exp_cos_theta = torch.exp(scale_factor * cos_theta_with_margin) 437 | softmax_output = exp_cos_theta / exp_cos_theta.sum() 438 | 439 | loss = -torch.log(softmax_output[label]) 440 | 441 | l2_reg = lambda_l2 * (weight ** 2).sum() 442 | 443 | total_loss = loss + l2_reg 444 | 445 | return total_loss 446 | 447 | def discriminator_loss(real_output, fake_output): 448 | real_loss = F.binary_cross_entropy(real_output, torch.ones_like(real_output)) 449 | 450 | def get_domain_loss(self,): 451 | 452 | #if self.train_config.use_cmd_sim: 453 | # return 0.0 454 | pred_shared_t = self.model.domain_shared_t 455 | pred_shared_a = self.model.domain_shared_a 456 | pred_shared_v = self.model.domain_shared_v 457 | pred_private_t = self.model.domain_private_t 458 | pred_private_a = self.model.domain_private_a 459 | pred_private_v = self.model.domain_private_v 460 | 461 | W = self.model.W.discriminator_layer_2.weight 462 | 463 | Lami = torch.zeros(1, device="cuda") 464 | Lams = torch.zeros(1, device="cuda") 465 | for i in range (pred_private_v.size(0)): 466 | Lami += (self.angular_margin_loss(feature=pred_shared_t[i], label=0, weight=W) + self.angular_margin_loss(feature=pred_shared_a[i], label=1, weight=W) \ 467 | + self.angular_margin_loss(feature=pred_shared_v[i], label=2, weight=W))/3.0 468 | Lams += (self.angular_margin_loss(feature=pred_private_t[i], label=0, weight=W) + self.angular_margin_loss(feature=pred_private_v[i], label=2, weight=W) \ 469 | + self.angular_margin_loss(feature=pred_private_a[i], label=1, weight=W))/3.0 470 | Lami = Lami/pred_private_v.size(0) 471 | Lams = Lams/pred_private_v.size(0) 472 | return Lami + Lams 473 | 474 | def get_cmd_loss(self,): 475 | 476 | if not self.train_config.use_cmd_sim: 477 | return 0.0 478 | shared_t = torch.sum(self.model.utt_shared_t, dim=0)/self.model.utt_shared_t.size(0) 479 | shared_v = torch.sum(self.model.utt_shared_v, dim=0)/self.model.utt_shared_v.size(0) 480 | shared_a = torch.sum(self.model.utt_shared_a, dim=0)/self.model.utt_shared_a.size(0) 481 | # losses between shared states 482 | loss = self.loss_cmd(shared_t, shared_v, 5) 483 | loss += self.loss_cmd(shared_t, shared_a, 5) 484 | loss += self.loss_cmd(shared_a, shared_v, 5) 485 | loss = loss/3.0 486 | 487 | return loss 488 | 489 | def get_diff_loss(self): 490 | 491 | shared_t = self.model.utt_shared_t 492 | shared_v = self.model.utt_shared_v 493 | shared_a = self.model.utt_shared_a 494 | private_t = self.model.utt_private_t 495 | private_v = self.model.utt_private_v 496 | private_a = self.model.utt_private_a 497 | #print(self.model.utt_private_a.shape) 498 | shared_t = torch.sum(shared_t, dim=0)/shared_t.size(0) 499 | #print(shared_t.shape) 500 | shared_v = torch.sum(shared_v, dim=0)/shared_t.size(0) 501 | shared_a = torch.sum(shared_a, dim=0)/shared_t.size(0) 502 | private_t = torch.sum(private_t, dim=0)/shared_t.size(0) 503 | private_v = torch.sum(private_v, dim=0)/shared_t.size(0) 504 | private_a = torch.sum(private_a, dim=0)/shared_t.size(0) 505 | 506 | # Between private and shared 507 | loss = self.loss_diff(private_t, shared_t) 508 | loss += self.loss_diff(private_v, shared_v) 509 | loss += self.loss_diff(private_a, shared_a) 510 | 511 | # Across privates 512 | loss += self.loss_diff(private_a, private_t) 513 | loss += self.loss_diff(private_a, private_v) 514 | loss += self.loss_diff(private_t, private_v) 515 | 516 | return loss 517 | 518 | def get_recon_loss(self, ): 519 | 520 | loss = self.loss_recon(self.model.utt_t_recon, self.model.utt_t_orig) 521 | loss += self.loss_recon(self.model.utt_v_recon, self.model.utt_v_orig) 522 | loss += self.loss_recon(self.model.utt_a_recon, self.model.utt_a_orig) 523 | loss = loss/3.0 524 | return loss 525 | 526 | 527 | def kl_divergence(self, p, q): 528 | p = p + 1e-10 529 | q = q + 1e-10 530 | return torch.sum(p * torch.log(p / q)) 531 | 532 | def jsd(self, p, q): 533 | m = (p + q ) / 2 534 | return (self.kl_divergence(p, m) + self.kl_divergence(q, m) ) / 2 535 | 536 | 537 | def get_jsd_loss(self, ): 538 | jsd_loss = 0 539 | features_v = torch.cat((self.model.utt_private_v,self.model.utt_shared_v), dim=2) 540 | #features_v = self.model.utt_private_v 541 | seq, batch, ebd = features_v.shape 542 | for i in range(seq - 1): 543 | p = F.softmax(features_v[i], dim=-1).mean(dim=0) 544 | q = F.softmax(features_v[i + 1], dim=-1).mean(dim=0) 545 | jsd = self.jsd(p, q) 546 | jsd_loss = jsd_loss + jsd 547 | jsd_loss = jsd_loss/(seq-1) 548 | return jsd_loss 549 | 550 | 551 | 552 | 553 | --------------------------------------------------------------------------------