├── dataset └── dataset.txt ├── pretrained_model └── sentilare_model │ └── sentilare.txt ├── networks ├── SentiLARE │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── modeling_sentilr.cpython-38.pyc │ │ └── modeling_sentilr_roberta.cpython-38.pyc │ ├── modeling_sentilr_roberta.py │ └── modeling_sentilr.py └── subnet │ ├── __pycache__ │ └── CEmodule.cpython-38.pyc │ └── CEmodule.py ├── requirement.txt ├── config └── global_configs.py ├── utils ├── metric.py ├── set_seed.py └── databuilder.py ├── README.md └── train.py /dataset/dataset.txt: -------------------------------------------------------------------------------- 1 | Dataset pkl file is placed here. 2 | -------------------------------------------------------------------------------- /pretrained_model/sentilare_model/sentilare.txt: -------------------------------------------------------------------------------- 1 | The SentiLare model file is placed here. 2 | 3 | -------------------------------------------------------------------------------- /networks/SentiLARE/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_sentilr_roberta import RobertaForSequenceClassification -------------------------------------------------------------------------------- /networks/subnet/__pycache__/CEmodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Say2L/CENet/HEAD/networks/subnet/__pycache__/CEmodule.cpython-38.pyc -------------------------------------------------------------------------------- /networks/SentiLARE/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Say2L/CENet/HEAD/networks/SentiLARE/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/SentiLARE/__pycache__/modeling_sentilr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Say2L/CENet/HEAD/networks/SentiLARE/__pycache__/modeling_sentilr.cpython-38.pyc -------------------------------------------------------------------------------- /networks/SentiLARE/__pycache__/modeling_sentilr_roberta.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Say2L/CENet/HEAD/networks/SentiLARE/__pycache__/modeling_sentilr_roberta.cpython-38.pyc -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | pytorch-transformers==1.2.0 2 | numpy==1.20.2 3 | torch==1.8.1 4 | torchvision==0.9.1 5 | tqdm==4.58.0 6 | transformers==3.0.2 7 | scikit-learn==0.24.2 8 | six==1.15.0 9 | scikit-learn==0.24.2 10 | wandb==0.10.25 -------------------------------------------------------------------------------- /config/global_configs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEVICE = torch.device("cuda:0") 4 | 5 | """ 6 | # MOSEI SETTING 7 | ACOUSTIC_DIM = 74 8 | VISUAL_DIM = 35 9 | TEXT_DIM = 768 10 | """ 11 | 12 | # MOSI SETTING 13 | ACOUSTIC_DIM = 74 14 | VISUAL_DIM = 27 15 | TEXT_DIM = 768 16 | 17 | ROBERTA_INJECTION_INDEX = 1 18 | 19 | input_size = VISUAL_DIM # ACOUSTIC_DIM VISUAL_DIM 20 | hidden_size = 768 21 | ffn_num_hiddens = 1024 22 | max_sequence_len = 50 23 | num_head = 8 24 | label_size = 16 25 | 26 | head_dropout = 0.1, 27 | head_hidden_dim = 64 28 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import accuracy_score, f1_score 3 | 4 | def score_model(preds, labels, use_zero=False): 5 | mae = np.mean(np.absolute(preds - labels)) 6 | corr = np.corrcoef(preds, labels)[0][1] 7 | non_zeros = np.array( 8 | [i for i, e in enumerate(labels) if e != 0 or use_zero]) 9 | preds = preds[non_zeros] 10 | labels = labels[non_zeros] 11 | preds = preds >= 0 12 | labels = labels >= 0 13 | f_score = f1_score(labels, preds, average="weighted") 14 | acc = accuracy_score(labels, preds) 15 | 16 | return acc, mae, corr, f_score -------------------------------------------------------------------------------- /utils/set_seed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import os 5 | import argparse 6 | 7 | def seed(s): 8 | if isinstance(s, int): 9 | if 0 <= s <= 9999: 10 | return s 11 | else: 12 | raise argparse.ArgumentTypeError( 13 | "Seed must be between 0 and 2**32 - 1. Received {0}".format(s) 14 | ) 15 | elif s == "random": 16 | return np.random.randint(0, 9999) 17 | else: 18 | raise argparse.ArgumentTypeError( 19 | "Integer value is expected. Recieved {0}".format(s) 20 | ) 21 | 22 | def set_random_seed(seed: int): 23 | """ 24 | Helper function to seed experiment for reproducibility. 25 | If -1 is provided as seed, experiment uses random seed from 0~9999 26 | 27 | Args: 28 | seed (int): integer to be used as seed, use -1 to randomly seed experiment 29 | """ 30 | print("Seed: {}".format(seed)) 31 | 32 | torch.backends.cudnn.benchmark = False 33 | torch.backends.cudnn.enabled = False 34 | torch.backends.cudnn.deterministic = True 35 | random.seed(seed) 36 | os.environ["PYTHONHASHSEED"] = str(seed) 37 | np.random.seed(seed) 38 | torch.manual_seed(seed) 39 | torch.cuda.manual_seed(seed) 40 | torch.cuda.manual_seed_all(seed) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CENet 2 | > Pytorch implementation for codes in "Cross-modal Enhancement Network for Multimodal Sentiment Analysis (TMM 2022)"(https://ieeexplore.ieee.org/document/9797846) 3 | # Prepare 4 | ## Dataset 5 | Download the MOSI pkl file (https://drive.google.com/drive/folders/1_u1Vt0_4g0RLoQbdslBwAdMslEdW1avI?usp=sharing). Put it under the "./dataset" directory. 6 | 7 | ## Pre-trained language model 8 | Download the SentiLARE language model files (https://drive.google.com/file/d/1onz0ds0CchBRFcSc_AkTLH_AZX_iNTjO/view?usp=share_link), and then put them into the "./pretrained-model/sentilare_model" directory. 9 | 10 | # Run 11 | ''' 12 | python train.py 13 | ''' 14 | 15 | Note: the scale of MOSI dataset is small, so the training process is not stable. To get results close to those in CENet paper, you can set the seed in args to 6758. The experimental results of this paper are obtained on the Windows system. 16 | 17 | # Paper 18 | 19 | Please cite our paper if you find our work useful for your research: 20 | 21 | ``` 22 | @ARTICLE{9797846, 23 | author={Wang, Di and Liu, Shuai and Wang, Quan and Tian, Yumin and He, Lihuo and Gao, Xinbo}, 24 | journal={IEEE Transactions on Multimedia}, 25 | title={Cross-modal Enhancement Network for Multimodal Sentiment Analysis}, 26 | year={2022}, 27 | pages={1-13}, 28 | doi={10.1109/TMM.2022.3183830} 29 | } 30 | ``` 31 | -------------------------------------------------------------------------------- /networks/subnet/CEmodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from config.global_configs import * 5 | import math 6 | 7 | class CE(nn.Module): 8 | def __init__(self, beta_shift_a=0.5, beta_shift_v=0.5, dropout_prob=0.2): 9 | super(CE, self).__init__() 10 | self.visual_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size) 11 | self.acoustic_embedding = nn.Embedding(label_size + 1, TEXT_DIM, padding_idx=label_size) 12 | self.hv = SelfAttention(TEXT_DIM) 13 | self.ha = SelfAttention(TEXT_DIM) 14 | self.cat_connect = nn.Linear(2 * TEXT_DIM, TEXT_DIM) 15 | 16 | 17 | def forward(self, text_embedding, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None): 18 | visual_ = self.visual_embedding(visual_ids) 19 | acoustic_ = self.acoustic_embedding(acoustic_ids) 20 | visual_ = self.hv(text_embedding, visual_) 21 | acoustic_ = self.ha(text_embedding, acoustic_) 22 | visual_acoustic = torch.cat((visual_, acoustic_), dim=-1) 23 | shift = self.cat_connect(visual_acoustic) 24 | embedding_shift = shift + text_embedding 25 | 26 | return embedding_shift 27 | 28 | class Attention(nn.Module): 29 | def __init__(self, text_dim): 30 | super(Attention, self).__init__() 31 | self.text_dim = text_dim 32 | self.dim = text_dim 33 | self.Wq = nn.Linear(text_dim, text_dim) 34 | self.Wk = nn.Linear(self.dim, text_dim) 35 | self.Wv = nn.Linear(self.dim, text_dim) 36 | 37 | def forward(self, text_embedding, embedding): 38 | Q = self.Wq(text_embedding) 39 | K = self.Wk(embedding) 40 | V = self.Wv(embedding) 41 | tmp = torch.matmul(Q, K.transpose(-1, -2) * math.sqrt(self.text_dim))[0] 42 | weight_matrix = F.softmax(torch.matmul(Q, K.transpose(-1, -2) * math.sqrt(self.text_dim)), dim=-1) 43 | 44 | return torch.matmul(weight_matrix, V) 45 | 46 | 47 | class SelfAttention(nn.Module): 48 | def __init__(self, hidden_size, head_num=1): 49 | super(SelfAttention, self).__init__() 50 | self.head_num = head_num 51 | self.s_d = hidden_size // self.head_num 52 | self.all_head_size = self.head_num * self.s_d 53 | self.Wq = nn.Linear(hidden_size, hidden_size) 54 | self.Wk = nn.Linear(hidden_size, hidden_size) 55 | self.Wv = nn.Linear(hidden_size, hidden_size) 56 | 57 | def transpose_for_scores(self, x): 58 | x = x.view(x.size(0), x.size(1), self.head_num, -1) 59 | return x.permute(0, 2, 1, 3) 60 | 61 | def forward(self, text_embedding, embedding): 62 | Q = self.Wq(text_embedding) 63 | K = self.Wk(embedding) 64 | V = self.Wv(embedding) 65 | Q = self.transpose_for_scores(Q) 66 | K = self.transpose_for_scores(K) 67 | V = self.transpose_for_scores(V) 68 | weight_score = torch.matmul(Q, K.transpose(-1, -2)) 69 | weight_prob = nn.Softmax(dim=-1)(weight_score * 8) 70 | 71 | context_layer = torch.matmul(weight_prob, V) 72 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 73 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 74 | context_layer = context_layer.view(*new_context_layer_shape) 75 | return context_layer 76 | -------------------------------------------------------------------------------- /utils/databuilder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader, TensorDataset 6 | from pytorch_transformers import BertTokenizer, XLNetTokenizer, RobertaTokenizer 7 | 8 | class InputFeatures(object): 9 | 10 | def __init__(self, input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_id): 11 | self.input_ids = input_ids 12 | self.visual_ids = visual_ids 13 | self.acoustic_ids = acoustic_ids 14 | self.pos_ids = pos_ids 15 | self.senti_ids = senti_ids 16 | self.polarity_ids = polarity_ids 17 | self.visual = visual 18 | self.acoustic = acoustic 19 | self.input_mask = input_mask 20 | self.segment_ids = segment_ids 21 | self.label_id = label_id 22 | 23 | class MultimodalConfig(object): 24 | def __init__(self, beta_shift, dropout_prob): 25 | self.beta_shift = beta_shift 26 | self.dropout_prob = dropout_prob 27 | 28 | def convert_to_features(args, examples, max_seq_length, tokenizer): 29 | features = [] 30 | 31 | for (ex_index, example) in enumerate(examples): 32 | 33 | (words, visual, acoustic, pos_ids, senti_ids, visual_ids, acoustic_ids), label_id, segment = example 34 | 35 | tokens, inversions, = [], [] 36 | for idx, word in enumerate(words): 37 | tokenized = tokenizer.tokenize(word) 38 | tokens.extend(tokenized) 39 | inversions.extend([idx] * len(tokenized)) 40 | 41 | # Check inversion 42 | assert len(tokens) == len(inversions) 43 | 44 | aligned_pos_ids = [] 45 | aligned_senti_ids = [] 46 | 47 | for inv_idx in inversions: 48 | 49 | aligned_pos_ids.append(pos_ids[inv_idx]) 50 | aligned_senti_ids.append(senti_ids[inv_idx]) 51 | 52 | 53 | #visual = np.array(aligned_visual) 54 | visual = np.array(visual) 55 | visual_ids = np.array(visual_ids) 56 | acoustic = np.array(acoustic) 57 | acoustic_ids = np.array(acoustic_ids) 58 | pos_ids = aligned_pos_ids 59 | senti_ids = aligned_senti_ids 60 | 61 | # Truncate input if necessary 62 | 63 | if len(tokens) > max_seq_length - 3: 64 | tokens = tokens[: max_seq_length - 3] 65 | words = words[: max_seq_length - 3] 66 | pos_ids = pos_ids[: max_seq_length - 3] 67 | senti_ids = senti_ids[: max_seq_length - 3] 68 | 69 | 70 | input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids = prepare_sentilare_input( 71 | args, tokens, visual_ids, acoustic_ids, pos_ids, senti_ids, visual, acoustic, tokenizer 72 | ) 73 | # Check input length 74 | assert len(input_ids) == args.max_seq_length 75 | assert len(input_mask) == args.max_seq_length 76 | assert len(segment_ids) == args.max_seq_length 77 | 78 | features.append( 79 | InputFeatures( 80 | input_ids=input_ids, 81 | visual_ids=visual_ids, 82 | acoustic_ids=acoustic_ids, 83 | pos_ids=pos_ids, 84 | senti_ids = senti_ids, 85 | polarity_ids=polarity_ids, 86 | input_mask=input_mask, 87 | segment_ids=segment_ids, 88 | visual=visual, 89 | acoustic=acoustic, 90 | label_id=label_id, 91 | ) 92 | ) 93 | return features 94 | 95 | def prepare_sentilare_input(args, tokens, visual_ids, acoustic_ids, pos_ids, senti_ids, visual, acoustic, tokenizer): 96 | CLS = tokenizer.cls_token 97 | SEP = tokenizer.sep_token 98 | tokens = [CLS] + tokens + [SEP] + [SEP] 99 | pos_ids = [4] + pos_ids + [4] + [4] 100 | senti_ids = [2] + senti_ids + [2] + [2] 101 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 102 | segment_ids = [0] * len(input_ids) 103 | input_mask = [1] * len(input_ids) 104 | 105 | pad_length = args.max_seq_length - len(input_ids) 106 | padding = [0] * pad_length 107 | 108 | # Pad inputs 109 | input_ids += padding 110 | pos_ids += [4] * pad_length 111 | senti_ids += [2] * pad_length 112 | polarity_ids = [5] * len(input_ids) 113 | input_mask += padding 114 | segment_ids += padding 115 | 116 | return input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids 117 | 118 | def get_tokenizer(args): 119 | return RobertaTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=False) 120 | 121 | def get_appropriate_dataset(args, data): 122 | 123 | tokenizer = get_tokenizer(args) 124 | 125 | features = convert_to_features(args, data, args.max_seq_length, tokenizer) 126 | all_input_ids = torch.tensor( 127 | [f.input_ids for f in features], dtype=torch.long) 128 | all_visual_ids = torch.tensor( 129 | [f.visual_ids for f in features], dtype=torch.long) 130 | all_acoustic_ids = torch.tensor( 131 | [f.acoustic_ids for f in features], dtype=torch.long) 132 | all_pos_ids = torch.tensor( 133 | [f.pos_ids for f in features], dtype=torch.long) 134 | all_senti_ids = torch.tensor( 135 | [f.senti_ids for f in features], dtype=torch.long) 136 | all_polarity_ids = torch.tensor( 137 | [f.polarity_ids for f in features], dtype=torch.long) 138 | all_input_mask = torch.tensor( 139 | [f.input_mask for f in features], dtype=torch.long) 140 | all_segment_ids = torch.tensor( 141 | [f.segment_ids for f in features], dtype=torch.long) 142 | all_visual = torch.tensor([f.visual for f in features], dtype=torch.float) 143 | all_acoustic = torch.tensor([f.acoustic for f in features], dtype=torch.float) 144 | all_label_ids = torch.tensor( 145 | [f.label_id for f in features], dtype=torch.float) 146 | 147 | dataset = TensorDataset( 148 | all_input_ids, 149 | all_visual_ids, 150 | all_acoustic_ids, 151 | all_pos_ids, 152 | all_senti_ids, 153 | all_polarity_ids, 154 | all_visual, 155 | all_acoustic, 156 | all_input_mask, 157 | all_segment_ids, 158 | all_label_ids, 159 | ) 160 | return dataset 161 | 162 | 163 | def set_up_data_loader(args): 164 | with open(args.data_path, "rb") as handle: 165 | data = pickle.load(handle) 166 | 167 | train_data = data["train"] 168 | dev_data = data["dev"] 169 | test_data = data["test"] 170 | 171 | train_dataset = get_appropriate_dataset(args, train_data) 172 | dev_dataset = get_appropriate_dataset(args, dev_data) 173 | test_dataset = get_appropriate_dataset(args, test_data) 174 | 175 | num_train_optimization_steps = ( 176 | int( 177 | len(train_dataset) / args.train_batch_size / 178 | args.gradient_accumulation_step 179 | ) 180 | * args.n_epochs 181 | ) 182 | 183 | train_dataloader = DataLoader( 184 | train_dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True 185 | ) 186 | 187 | dev_dataloader = DataLoader( 188 | dev_dataset, batch_size=args.dev_batch_size, shuffle=True 189 | ) 190 | 191 | test_dataloader = DataLoader( 192 | test_dataset, batch_size=args.test_batch_size, shuffle=True, 193 | ) 194 | 195 | return ( 196 | train_dataloader, 197 | dev_dataloader, 198 | test_dataloader, 199 | num_train_optimization_steps, 200 | ) 201 | 202 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import random 5 | from pytorch_transformers.modeling_roberta import RobertaConfig 6 | import torch 7 | import numpy as np 8 | import wandb 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm, trange 13 | 14 | from torch.nn import L1Loss, MSELoss 15 | from pytorch_transformers import WarmupLinearSchedule , AdamW 16 | from networks.SentiLARE import RobertaForSequenceClassification 17 | from utils.databuilder import set_up_data_loader 18 | from utils.set_seed import set_random_seed, seed 19 | from utils.metric import score_model 20 | from config.global_configs import DEVICE 21 | 22 | def parser_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--dataset", type=str, 25 | choices=["mosi", "mosei"], default="mosi") 26 | parser.add_argument("--data_path", type=str, default='./dataset/MOSI_16_sentilare_unaligned_data.pkl') 27 | parser.add_argument("--max_seq_length", type=int, default=50) 28 | parser.add_argument("--train_batch_size", type=int, default=64) 29 | parser.add_argument("--dev_batch_size", type=int, default=128) 30 | parser.add_argument("--test_batch_size", type=int, default=128) 31 | parser.add_argument("--n_epochs", type=int, default=40) 32 | parser.add_argument("--beta_shift", type=float, default=1.0) 33 | parser.add_argument("--dropout_prob", type=float, default=0.5) 34 | parser.add_argument( 35 | "--model", 36 | type=str, 37 | choices=["bert-base-uncased", "xlnet-base-cased", "roberta-base"], 38 | default="roberta-base") 39 | parser.add_argument("--model_name_or_path", default='./pretrained_model/sentilare_model/', type=str, 40 | help="Path to pre-trained model or shortcut name") 41 | parser.add_argument("--learning_rate", type=float, default=6e-5) 42 | parser.add_argument("--weight_decay", type=float, default=0) 43 | parser.add_argument("--gradient_accumulation_step", type=int, default=1) 44 | parser.add_argument("--test_step", type=int, default=20) 45 | parser.add_argument("--max_grad_norm", type=int, default=2) 46 | parser.add_argument("--warmup_proportion", type=float, default=0.4) 47 | parser.add_argument("--seed", type=seed, default=6758, help="integer or 'random'") 48 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 49 | help="Epsilon for Adam optimizer.") 50 | return parser.parse_args() 51 | 52 | def prep_for_training(args, num_train_optimization_steps: int): 53 | config = RobertaConfig.from_pretrained(args.model_name_or_path, num_labels=1, finetuning_task='sst') 54 | model = RobertaForSequenceClassification.from_pretrained( 55 | args.model_name_or_path, config=config, pos_tag_embedding=True, senti_embedding=True, polarity_embedding=True) 56 | model.to(DEVICE) 57 | #Prepare optimizer 58 | param_optimizer = list(model.named_parameters()) 59 | no_decay = ["bias", "LayerNorm.weight"] 60 | CE_params = ['CE'] 61 | optimizer_grouped_parameters = [ 62 | { 63 | "params": [ 64 | p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and not any(nd in n for nd in CE_params) 65 | ], 66 | "weight_decay": args.weight_decay, 67 | }, 68 | {"params": model.roberta.encoder.CE.parameters(), 'lr':args.learning_rate, "weight_decay": args.weight_decay}, 69 | { 70 | "params": [ 71 | p for n, p in param_optimizer if any(nd in n for nd in no_decay) and not any(nd in n for nd in CE_params) 72 | ], 73 | "weight_decay": 0.0, 74 | }, 75 | ] 76 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 77 | scheduler = WarmupLinearSchedule( 78 | optimizer, 79 | warmup_steps=args.warmup_proportion * num_train_optimization_steps, 80 | t_total=num_train_optimization_steps, 81 | ) 82 | return model, optimizer, scheduler 83 | 84 | def train_epoch(args, model: nn.Module, train_dataloader: DataLoader, optimizer, scheduler): 85 | model.train() 86 | preds = [] 87 | labels = [] 88 | tr_loss = 0 89 | 90 | nb_tr_steps = 0 91 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 92 | batch = tuple(t.to(DEVICE) for t in batch) 93 | input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch 94 | visual = torch.squeeze(visual, 1) 95 | outputs = model( 96 | input_ids, 97 | visual, 98 | acoustic, 99 | visual_ids, 100 | acoustic_ids, 101 | pos_ids, senti_ids, polarity_ids, 102 | attention_mask=input_mask, 103 | token_type_ids=segment_ids, 104 | ) 105 | logits = outputs[0] 106 | loss_fct = MSELoss() 107 | loss = loss_fct(logits.view(-1), label_ids.view(-1)) 108 | if args.gradient_accumulation_step > 1: 109 | loss = loss / args.gradient_accumulation_step 110 | loss.backward() 111 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 112 | tr_loss += loss.item() 113 | nb_tr_steps += 1 114 | if (step + 1) % args.gradient_accumulation_step == 0: 115 | optimizer.step() 116 | scheduler.step() 117 | optimizer.zero_grad() 118 | logits = logits.detach().cpu().numpy() 119 | label_ids = label_ids.detach().cpu().numpy() 120 | logits = np.squeeze(logits).tolist() 121 | label_ids = np.squeeze(label_ids).tolist() 122 | preds.extend(logits) 123 | labels.extend(label_ids) 124 | 125 | preds = np.array(preds) 126 | labels = np.array(labels) 127 | 128 | return tr_loss / nb_tr_steps, preds, labels 129 | 130 | def evaluate_epoch(args, model: nn.Module, dataloader: DataLoader): 131 | model.eval() 132 | preds = [] 133 | labels = [] 134 | loss = 0 135 | nb_dev_examples, nb_steps = 0, 0 136 | with torch.no_grad(): 137 | for step, batch in enumerate(dataloader): 138 | batch = tuple(t.to(DEVICE) for t in batch) 139 | input_ids, visual_ids, acoustic_ids, pos_ids, senti_ids, polarity_ids, visual, acoustic, input_mask, segment_ids, label_ids = batch 140 | visual = torch.squeeze(visual, 1) 141 | outputs = model( 142 | input_ids, 143 | visual, 144 | acoustic, 145 | visual_ids, 146 | acoustic_ids, 147 | pos_ids, senti_ids, polarity_ids, 148 | token_type_ids=segment_ids, 149 | attention_mask=input_mask, 150 | labels=None, 151 | ) 152 | logits = outputs[0] 153 | loss_fct = MSELoss() 154 | loss = loss_fct(logits.view(-1), label_ids.view(-1)) 155 | if args.gradient_accumulation_step > 1: 156 | loss = loss / args.gradient_accumulation_step 157 | loss += loss.item() 158 | nb_steps += 1 159 | logits = logits.detach().cpu().numpy() 160 | label_ids = label_ids.detach().cpu().numpy() 161 | logits = np.squeeze(logits).tolist() 162 | label_ids = np.squeeze(label_ids).tolist() 163 | preds.extend(logits) 164 | labels.extend(label_ids) 165 | 166 | preds = np.array(preds) 167 | labels = np.array(labels) 168 | 169 | return loss / nb_steps, preds, labels 170 | 171 | def train( 172 | args, 173 | model, 174 | train_dataloader, 175 | validation_dataloader, 176 | test_data_loader, 177 | optimizer, 178 | scheduler,): 179 | valid_losses = [] 180 | test_accuracies = [] 181 | for epoch_i in range(int(args.n_epochs)): 182 | train_loss, train_pre, train_label = train_epoch(args, model, train_dataloader, optimizer, scheduler) 183 | valid_loss, valid_pre, valid_label = evaluate_epoch(args, model, validation_dataloader) 184 | test_loss, test_pre, test_label = evaluate_epoch(args, model, test_data_loader) 185 | train_acc, train_mae, train_corr, train_f_score = score_model(train_pre, train_label) 186 | test_acc, test_mae, test_corr, test_f_score = score_model(test_pre, test_label) 187 | non0_test_acc, _, _, non0_test_f_score = score_model(test_pre, test_label, use_zero=True) 188 | valid_acc, valid_mae, valid_corr, valid_f_score = score_model(valid_pre, valid_label) 189 | print( 190 | "epoch:{}, train_loss:{}, train_acc:{}, valid_loss:{}, valid_acc:{}, test_loss:{}, test_acc:{}".format( 191 | epoch_i, train_loss, train_acc, valid_loss, valid_acc, test_loss, test_acc 192 | ) 193 | ) 194 | valid_losses.append(valid_loss) 195 | test_accuracies.append(test_acc) 196 | wandb.log( 197 | ( 198 | { 199 | "train_loss": train_loss, 200 | "valid_loss": valid_loss, 201 | "train_acc": train_acc, 202 | "train_corr": train_corr, 203 | "valid_acc":valid_acc, 204 | "valid_corr":valid_corr, 205 | "test_loss":test_loss, 206 | "test_acc": test_acc, 207 | "test_mae": test_mae, 208 | "test_corr": test_corr, 209 | "test_f_score": test_f_score, 210 | "non0_test_acc": non0_test_acc, 211 | "non0_test_f_score": non0_test_f_score, 212 | "best_valid_loss": min(valid_losses), 213 | "best_test_acc": max(test_accuracies), 214 | } 215 | ) 216 | ) 217 | 218 | def main(): 219 | args = parser_args() 220 | wandb.init(project="CENet", reinit=True) 221 | 222 | set_random_seed(args.seed) 223 | wandb.config.update(args) 224 | 225 | (train_data_loader, 226 | dev_data_loader, 227 | test_data_loader, 228 | num_train_optimization_steps, 229 | ) = set_up_data_loader(args) 230 | 231 | model, optimizer, scheduler = prep_for_training(args, num_train_optimization_steps) 232 | 233 | train( 234 | args, 235 | model, 236 | train_data_loader, 237 | dev_data_loader, 238 | test_data_loader, 239 | optimizer, 240 | scheduler, 241 | ) 242 | 243 | if __name__ == "__main__": 244 | main() -------------------------------------------------------------------------------- /networks/SentiLARE/modeling_sentilr_roberta.py: -------------------------------------------------------------------------------- 1 | from __future__ import (absolute_import, division, print_function, 2 | unicode_literals) 3 | 4 | import logging 5 | from config.global_configs import hidden_size 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import CrossEntropyLoss, MSELoss, BCELoss 9 | 10 | from .modeling_sentilr import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu 11 | from pytorch_transformers import RobertaConfig 12 | from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 13 | logger = logging.getLogger(__name__) 14 | 15 | class RobertaEmbeddings(BertEmbeddings): 16 | """ 17 | Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. 18 | """ 19 | def __init__(self, config, pos_tag_embedding = False, senti_embedding = False, polarity_embedding = False): 20 | super(RobertaEmbeddings, self).__init__(config, pos_tag_embedding=pos_tag_embedding, senti_embedding=senti_embedding, polarity_embedding=polarity_embedding) 21 | self.padding_idx = 1 22 | 23 | def forward(self, input_ids, token_type_ids=None, position_ids=None, pos_tag_ids=None, senti_word_ids=None, polarity_ids=None): 24 | seq_length = input_ids.size(1) 25 | if position_ids is None: 26 | # Position numbers begin at padding_idx+1. Padding symbols are ignored. 27 | # cf. fairseq's `utils.make_positions` 28 | position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device) 29 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 30 | return super(RobertaEmbeddings, self).forward(input_ids, 31 | token_type_ids=token_type_ids, 32 | position_ids=position_ids, 33 | pos_tag_ids=pos_tag_ids, 34 | senti_word_ids=senti_word_ids, 35 | polarity_ids=polarity_ids) 36 | 37 | 38 | class RobertaModel(BertModel): 39 | 40 | config_class = RobertaConfig 41 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 42 | base_model_prefix = "roberta" 43 | 44 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 45 | super(RobertaModel, self).__init__(config, pos_tag_embedding=pos_tag_embedding, senti_embedding=senti_embedding, polarity_embedding=polarity_embedding) 46 | 47 | self.embeddings = RobertaEmbeddings(config, pos_tag_embedding=pos_tag_embedding, senti_embedding=senti_embedding, polarity_embedding=polarity_embedding) 48 | self.init_weights() 49 | 50 | def forward(self, input_ids, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, pos_ids=None, senti_word_ids=None, polarity_ids=None, attention_mask=None, 51 | token_type_ids=None, position_ids=None, head_mask=None): 52 | if input_ids[:, 0].sum().item() != 0: 53 | logger.warning("A sequence with no special tokens has been passed to the RoBERTa model. " 54 | "This model requires special tokens in order to work. " 55 | "Please specify add_special_tokens=True in your encoding.") 56 | return super(RobertaModel, self).forward(input_ids, 57 | visual=visual, acoustic=acoustic, 58 | visual_ids=visual_ids, acoustic_ids=acoustic_ids, 59 | pos_ids=pos_ids, 60 | senti_word_ids=senti_word_ids, 61 | polarity_ids=polarity_ids, 62 | attention_mask=attention_mask, 63 | token_type_ids=token_type_ids, 64 | position_ids=position_ids, 65 | head_mask=head_mask, 66 | ) 67 | 68 | 69 | 70 | class RobertaLMHead(nn.Module): 71 | """Roberta Head for masked language modeling.""" 72 | 73 | def __init__(self, config): 74 | super(RobertaLMHead, self).__init__() 75 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 76 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 77 | 78 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 79 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 80 | 81 | def forward(self, features, **kwargs): 82 | x = self.dense(features) 83 | x = gelu(x) 84 | x = self.layer_norm(x) 85 | 86 | # project back to size of vocabulary with bias 87 | x = self.decoder(x) + self.bias 88 | 89 | return x 90 | 91 | 92 | class RobertaForSequenceClassification(BertPreTrainedModel): 93 | 94 | config_class = RobertaConfig 95 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 96 | base_model_prefix = "roberta" 97 | 98 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 99 | super(RobertaForSequenceClassification, self).__init__(config) 100 | self.num_labels = config.num_labels 101 | 102 | self.roberta = RobertaModel(config, pos_tag_embedding=pos_tag_embedding, 103 | senti_embedding=senti_embedding, 104 | polarity_embedding=polarity_embedding) 105 | self.classifier = RobertaClassificationHead(config) 106 | 107 | def forward(self, input_ids, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, pos_tag_ids=None, 108 | senti_word_ids=None, polarity_ids=None, attention_mask=None, token_type_ids=None, 109 | position_ids=None, head_mask=None, labels=None): 110 | outputs = self.roberta(input_ids, 111 | visual=visual, 112 | acoustic=acoustic, 113 | visual_ids=visual_ids, 114 | acoustic_ids=acoustic_ids, 115 | pos_ids=pos_tag_ids, 116 | senti_word_ids=senti_word_ids, 117 | polarity_ids=polarity_ids, 118 | attention_mask=attention_mask, 119 | token_type_ids=token_type_ids, 120 | position_ids=position_ids, 121 | head_mask=head_mask, 122 | ) 123 | sequence_output = outputs[0] 124 | logits = self.classifier(sequence_output, visual, visual_ids) 125 | 126 | outputs = (logits,) + outputs[2:] 127 | if labels is not None: 128 | if self.num_labels == 1: 129 | # We are doing regression 130 | loss_fct = MSELoss() 131 | loss = loss_fct(logits.view(-1), labels.view(-1)) 132 | else: 133 | loss_fct = CrossEntropyLoss() 134 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 135 | outputs = (loss,) + outputs 136 | 137 | return outputs # (loss), logits, (hidden_states), (attentions) 138 | 139 | 140 | class RobertaForMultiLabelClassification(BertPreTrainedModel): 141 | 142 | config_class = RobertaConfig 143 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 144 | base_model_prefix = "roberta" 145 | 146 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 147 | super(RobertaForMultiLabelClassification, self).__init__(config) 148 | self.num_labels = config.num_labels 149 | 150 | self.roberta = RobertaModel(config, pos_tag_embedding=pos_tag_embedding, 151 | senti_embedding=senti_embedding, 152 | polarity_embedding=polarity_embedding) 153 | self.classifier = RobertaClassificationHead(config) 154 | self.sigmoid_layer = nn.Sigmoid() 155 | 156 | 157 | def forward(self, input_ids, pos_tag_ids, senti_word_ids, polarity_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 158 | labels=None): 159 | outputs = self.roberta(input_ids, 160 | attention_mask=attention_mask, 161 | token_type_ids=token_type_ids, 162 | position_ids=position_ids, 163 | head_mask=head_mask, 164 | pos_ids=pos_tag_ids, 165 | senti_word_ids=senti_word_ids, 166 | polarity_ids=polarity_ids 167 | ) 168 | sequence_output = outputs[0] 169 | logits = self.classifier(sequence_output) 170 | logits = self.sigmoid_layer(logits) 171 | 172 | outputs = (logits,) + outputs[2:] 173 | if labels is not None: 174 | if self.num_labels == 1: 175 | # We are doing regression 176 | loss_fct = MSELoss() 177 | loss = loss_fct(logits.view(-1), labels.view(-1)) 178 | else: 179 | loss_fct = BCELoss() 180 | loss = loss_fct(logits.view(-1), labels.view(-1)) 181 | outputs = (loss,) + outputs 182 | 183 | return outputs # (loss), logits, (hidden_states), (attentions) 184 | 185 | 186 | class RobertaForMultipleChoice(BertPreTrainedModel): 187 | config_class = RobertaConfig 188 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 189 | base_model_prefix = "roberta" 190 | 191 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 192 | super(RobertaForMultipleChoice, self).__init__(config) 193 | 194 | self.roberta = RobertaModel(config, pos_tag_embedding=pos_tag_embedding, 195 | senti_embedding=senti_embedding, 196 | polarity_embedding=polarity_embedding) 197 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 198 | self.classifier = nn.Linear(config.hidden_size, 1) 199 | 200 | self.init_weights() 201 | 202 | def forward( 203 | self, 204 | input_ids, pos_tag_ids, senti_word_ids, polarity_ids, 205 | token_type_ids=None, 206 | attention_mask=None, 207 | labels=None, 208 | position_ids=None, 209 | head_mask=None, 210 | ): 211 | 212 | num_choices = input_ids.shape[1] 213 | 214 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) 215 | flat_pos_tag_ids = pos_tag_ids.view(-1, pos_tag_ids.size(-1)) 216 | flat_senti_word_ids = senti_word_ids.view(-1, senti_word_ids.size(-1)) 217 | flat_polarity_ids = polarity_ids.view(-1, polarity_ids.size(-1)) 218 | flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 219 | flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 220 | flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 221 | outputs = self.roberta( 222 | flat_input_ids, 223 | position_ids=flat_position_ids, 224 | token_type_ids=flat_token_type_ids, 225 | attention_mask=flat_attention_mask, 226 | head_mask=head_mask, 227 | pos_ids=flat_pos_tag_ids, 228 | senti_word_ids=flat_senti_word_ids, 229 | polarity_ids=flat_polarity_ids 230 | ) 231 | pooled_output = outputs[1] 232 | 233 | pooled_output = self.dropout(pooled_output) 234 | logits = self.classifier(pooled_output) 235 | reshaped_logits = logits.view(-1, num_choices) 236 | 237 | outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here 238 | 239 | if labels is not None: 240 | loss_fct = CrossEntropyLoss() 241 | loss = loss_fct(reshaped_logits, labels) 242 | outputs = (loss,) + outputs 243 | 244 | return outputs # (loss), reshaped_logits, (hidden_states), (attentions) 245 | 246 | 247 | class RobertaClassificationHead(nn.Module): 248 | """Head for sentence-level classification tasks.""" 249 | 250 | def __init__(self, config): 251 | super(RobertaClassificationHead, self).__init__() 252 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 253 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 254 | self.out_proj = nn.Linear(config.hidden_size, config.num_labels) 255 | 256 | def forward(self, features, visual=None, visual_ids=None, **kwargs): 257 | x = features[:, 0, :] # take token (equiv. to [CLS]) 258 | x = self.dropout(x) 259 | x = self.dense(x) 260 | x = torch.tanh(x) 261 | x = self.dropout(x) 262 | x = self.out_proj(x) 263 | return x 264 | 265 | 266 | class RobertaForTokenClassification(BertPreTrainedModel): 267 | 268 | config_class = RobertaConfig 269 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 270 | base_model_prefix = "roberta" 271 | 272 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 273 | super(RobertaForTokenClassification, self).__init__(config) 274 | self.num_labels = config.num_labels 275 | 276 | self.roberta = RobertaModel(config, pos_tag_embedding=pos_tag_embedding, 277 | senti_embedding=senti_embedding, 278 | polarity_embedding=polarity_embedding) 279 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 280 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 281 | 282 | self.init_weights() 283 | 284 | def forward( 285 | self, 286 | input_ids=None, 287 | pos_tag_ids=None, 288 | senti_word_ids=None, 289 | polarity_ids=None, 290 | attention_mask=None, 291 | token_type_ids=None, 292 | position_ids=None, 293 | head_mask=None, 294 | labels=None, 295 | ): 296 | 297 | outputs = self.roberta( 298 | input_ids, 299 | attention_mask=attention_mask, 300 | token_type_ids=token_type_ids, 301 | position_ids=position_ids, 302 | head_mask=head_mask, 303 | pos_ids=pos_tag_ids, 304 | senti_word_ids=senti_word_ids, 305 | polarity_ids=polarity_ids 306 | ) 307 | 308 | sequence_output = outputs[0] 309 | 310 | sequence_output = self.dropout(sequence_output) 311 | logits = self.classifier(sequence_output) 312 | 313 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 314 | if labels is not None: 315 | loss_fct = CrossEntropyLoss(ignore_index=-1) 316 | # Only keep active parts of the loss 317 | if attention_mask is not None: 318 | active_loss = attention_mask.view(-1) == 1 319 | active_logits = logits.view(-1, self.num_labels)[active_loss] 320 | active_labels = labels.view(-1)[active_loss] 321 | loss = loss_fct(active_logits, active_labels) 322 | else: 323 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 324 | outputs = (loss,) + outputs 325 | 326 | return outputs # (loss), scores, (hidden_states), (attentions) -------------------------------------------------------------------------------- /networks/SentiLARE/modeling_sentilr.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import json 4 | import logging 5 | import math 6 | import os 7 | import sys 8 | from io import open 9 | 10 | import torch 11 | from torch import nn 12 | from torch.nn import CrossEntropyLoss, MSELoss, BCELoss 13 | from networks.subnet.CEmodule import CE 14 | from pytorch_transformers import BertConfig 15 | from pytorch_transformers.modeling_utils import PreTrainedModel, prune_linear_layer 16 | from pytorch_transformers import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 17 | from config.global_configs import * 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | def load_tf_weights_in_bert(model, config, tf_checkpoint_path): 22 | """ Load tf checkpoints in a pytorch model. 23 | """ 24 | try: 25 | import re 26 | import numpy as np 27 | import tensorflow as tf 28 | except ImportError: 29 | raise 30 | tf_path = os.path.abspath(tf_checkpoint_path) 31 | logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) 32 | # Load weights from TF model 33 | init_vars = tf.train.list_variables(tf_path) 34 | names = [] 35 | arrays = [] 36 | for name, shape in init_vars: 37 | logger.info("Loading TF weight {} with shape {}".format(name, shape)) 38 | array = tf.train.load_variable(tf_path, name) 39 | names.append(name) 40 | arrays.append(array) 41 | 42 | for name, array in zip(names, arrays): 43 | name = name.split('/') 44 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 45 | # which are not required for using pretrained model 46 | if any(n in ["adam_v", "adam_m", "global_step"] for n in name): 47 | logger.info("Skipping {}".format("/".join(name))) 48 | continue 49 | pointer = model 50 | for m_name in name: 51 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 52 | l = re.split(r'_(\d+)', m_name) 53 | else: 54 | l = [m_name] 55 | if l[0] == 'kernel' or l[0] == 'gamma': 56 | pointer = getattr(pointer, 'weight') 57 | elif l[0] == 'output_bias' or l[0] == 'beta': 58 | pointer = getattr(pointer, 'bias') 59 | elif l[0] == 'output_weights': 60 | pointer = getattr(pointer, 'weight') 61 | elif l[0] == 'squad': 62 | pointer = getattr(pointer, 'classifier') 63 | else: 64 | try: 65 | pointer = getattr(pointer, l[0]) 66 | except AttributeError: 67 | logger.info("Skipping {}".format("/".join(name))) 68 | continue 69 | if len(l) >= 2: 70 | num = int(l[1]) 71 | pointer = pointer[num] 72 | if m_name[-11:] == '_embeddings': 73 | pointer = getattr(pointer, 'weight') 74 | elif m_name == 'kernel': 75 | array = np.transpose(array) 76 | try: 77 | assert pointer.shape == array.shape 78 | except AssertionError as e: 79 | e.args += (pointer.shape, array.shape) 80 | raise 81 | logger.info("Initialize PyTorch weight {}".format(name)) 82 | pointer.data = torch.from_numpy(array) 83 | return model 84 | 85 | 86 | def gelu(x): 87 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 88 | 89 | 90 | def swish(x): 91 | return x * torch.sigmoid(x) 92 | 93 | 94 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 95 | 96 | 97 | try: 98 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 99 | except (ImportError, AttributeError) as e: 100 | BertLayerNorm = torch.nn.LayerNorm 101 | 102 | class BertEmbeddings(nn.Module): 103 | """Construct the embeddings from word, position, token_type, POS, word-level and sentence-level sentiment embeddings. 104 | """ 105 | def __init__(self, config, pos_tag_embedding = False, senti_embedding = False, polarity_embedding = False): 106 | super(BertEmbeddings, self).__init__() 107 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 108 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 109 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 110 | 111 | if senti_embedding: 112 | self.senti_embeddings = nn.Embedding(3, config.hidden_size, padding_idx=2) 113 | else: 114 | self.register_parameter('senti_embeddings', None) 115 | if pos_tag_embedding: 116 | self.pos_tag_embeddings = nn.Embedding(5, config.hidden_size, padding_idx=4) 117 | else: 118 | self.register_parameter('pos_tag_embeddings', None) 119 | if polarity_embedding: 120 | self.polarity_embeddings = nn.Embedding(6, config.hidden_size, padding_idx=5) 121 | else: 122 | self.register_parameter('polarity_embeddings', None) 123 | 124 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 125 | # any TensorFlow checkpoint file 126 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 127 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 128 | 129 | def forward(self, input_ids, token_type_ids=None, position_ids=None, pos_tag_ids=None, senti_word_ids=None, polarity_ids=None): 130 | seq_length = input_ids.size(1) 131 | if position_ids is None: 132 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 133 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 134 | if token_type_ids is None: 135 | token_type_ids = torch.zeros_like(input_ids) 136 | 137 | if senti_word_ids is not None and self.senti_embeddings is not None: 138 | senti_word_embeddings = self.senti_embeddings(senti_word_ids) 139 | else: 140 | senti_word_embeddings = 0 141 | 142 | if pos_tag_ids is not None and self.pos_tag_embeddings is not None: 143 | pos_tag_embeddings = self.pos_tag_embeddings(pos_tag_ids) 144 | else: 145 | pos_tag_embeddings = 0 146 | 147 | if polarity_ids is not None and self.polarity_embeddings is not None: 148 | polarity_embeddings = self.polarity_embeddings(polarity_ids) 149 | else: 150 | polarity_embeddings = 0 151 | 152 | words_embeddings = self.word_embeddings(input_ids) 153 | position_embeddings = self.position_embeddings(position_ids) 154 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 155 | 156 | embeddings = words_embeddings + position_embeddings + token_type_embeddings + senti_word_embeddings + pos_tag_embeddings + polarity_embeddings 157 | embeddings = self.LayerNorm(embeddings) 158 | embeddings = self.dropout(embeddings) 159 | return embeddings 160 | 161 | 162 | class BertSelfAttention(nn.Module): 163 | def __init__(self, config): 164 | super(BertSelfAttention, self).__init__() 165 | if config.hidden_size % config.num_attention_heads != 0: 166 | raise ValueError( 167 | "The hidden size (%d) is not a multiple of the number of attention " 168 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 169 | self.output_attentions = config.output_attentions 170 | 171 | self.num_attention_heads = config.num_attention_heads 172 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 173 | self.all_head_size = self.num_attention_heads * self.attention_head_size 174 | 175 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 176 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 177 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 178 | 179 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 180 | 181 | def transpose_for_scores(self, x): 182 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 183 | x = x.view(*new_x_shape) 184 | return x.permute(0, 2, 1, 3) 185 | 186 | def forward(self, hidden_states, attention_mask, head_mask=None): 187 | mixed_query_layer = self.query(hidden_states) 188 | mixed_key_layer = self.key(hidden_states) 189 | mixed_value_layer = self.value(hidden_states) 190 | 191 | query_layer = self.transpose_for_scores(mixed_query_layer) 192 | key_layer = self.transpose_for_scores(mixed_key_layer) 193 | value_layer = self.transpose_for_scores(mixed_value_layer) 194 | 195 | # Take the dot product between "query" and "key" to get the raw attention scores. 196 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 197 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 198 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 199 | attention_scores = attention_scores + attention_mask 200 | 201 | # Normalize the attention scores to probabilities. 202 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 203 | 204 | # This is actually dropping out entire tokens to attend to, which might 205 | # seem a bit unusual, but is taken from the original Transformer paper. 206 | attention_probs = self.dropout(attention_probs) 207 | 208 | # Mask heads if we want to 209 | if head_mask is not None: 210 | attention_probs = attention_probs * head_mask 211 | 212 | context_layer = torch.matmul(attention_probs, value_layer) 213 | 214 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 215 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 216 | context_layer = context_layer.view(*new_context_layer_shape) 217 | 218 | outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) 219 | return outputs 220 | 221 | 222 | class BertSelfOutput(nn.Module): 223 | def __init__(self, config): 224 | super(BertSelfOutput, self).__init__() 225 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 226 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 227 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 228 | 229 | def forward(self, hidden_states, input_tensor): 230 | hidden_states = self.dense(hidden_states) 231 | hidden_states = self.dropout(hidden_states) 232 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 233 | return hidden_states 234 | 235 | 236 | class BertAttention(nn.Module): 237 | def __init__(self, config): 238 | super(BertAttention, self).__init__() 239 | self.self = BertSelfAttention(config) 240 | self.output = BertSelfOutput(config) 241 | self.pruned_heads = set() 242 | 243 | def prune_heads(self, heads): 244 | if len(heads) == 0: 245 | return 246 | mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) 247 | heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads 248 | for head in heads: 249 | # Compute how many pruned heads are before the head and move the index accordingly 250 | head = head - sum(1 if h < head else 0 for h in self.pruned_heads) 251 | mask[head] = 0 252 | mask = mask.view(-1).contiguous().eq(1) 253 | index = torch.arange(len(mask))[mask].long() 254 | 255 | # Prune linear layers 256 | self.self.query = prune_linear_layer(self.self.query, index) 257 | self.self.key = prune_linear_layer(self.self.key, index) 258 | self.self.value = prune_linear_layer(self.self.value, index) 259 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 260 | 261 | # Update hyper params and store pruned heads 262 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 263 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 264 | self.pruned_heads = self.pruned_heads.union(heads) 265 | 266 | def forward(self, input_tensor, attention_mask, head_mask=None): 267 | self_outputs = self.self(input_tensor, attention_mask, head_mask) 268 | attention_output = self.output(self_outputs[0], input_tensor) 269 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 270 | return outputs 271 | 272 | 273 | class BertIntermediate(nn.Module): 274 | def __init__(self, config): 275 | super(BertIntermediate, self).__init__() 276 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 277 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 278 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 279 | else: 280 | self.intermediate_act_fn = config.hidden_act 281 | 282 | def forward(self, hidden_states): 283 | hidden_states = self.dense(hidden_states) 284 | hidden_states = self.intermediate_act_fn(hidden_states) 285 | return hidden_states 286 | 287 | 288 | class BertOutput(nn.Module): 289 | def __init__(self, config): 290 | super(BertOutput, self).__init__() 291 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 292 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 293 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 294 | 295 | def forward(self, hidden_states, input_tensor): 296 | hidden_states = self.dense(hidden_states) 297 | hidden_states = self.dropout(hidden_states) 298 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 299 | return hidden_states 300 | 301 | 302 | class BertLayer(nn.Module): 303 | def __init__(self, config): 304 | super(BertLayer, self).__init__() 305 | self.attention = BertAttention(config) 306 | self.intermediate = BertIntermediate(config) 307 | self.output = BertOutput(config) 308 | 309 | def forward(self, hidden_states, attention_mask, head_mask=None): 310 | attention_outputs = self.attention(hidden_states, attention_mask, head_mask) 311 | attention_output = attention_outputs[0] 312 | intermediate_output = self.intermediate(attention_output) 313 | layer_output = self.output(intermediate_output, attention_output) 314 | outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them 315 | return outputs 316 | 317 | 318 | class BertEncoder(nn.Module): 319 | def __init__(self, config): 320 | super(BertEncoder, self).__init__() 321 | self.output_attentions = config.output_attentions 322 | self.output_hidden_states = config.output_hidden_states 323 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 324 | self.CE = CE() 325 | 326 | def forward(self, hidden_states, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, attention_mask=None, head_mask=None): 327 | all_hidden_states = () 328 | all_attentions = () 329 | for i, layer_module in enumerate(self.layer): 330 | if self.output_hidden_states: 331 | all_hidden_states = all_hidden_states + (hidden_states,) 332 | 333 | if i == ROBERTA_INJECTION_INDEX: 334 | hidden_states = self.CE(hidden_states, visual=visual, acoustic=acoustic, visual_ids=visual_ids, acoustic_ids=acoustic_ids) 335 | 336 | layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) 337 | hidden_states = layer_outputs[0] 338 | 339 | if self.output_attentions: 340 | all_attentions = all_attentions + (layer_outputs[1],) 341 | 342 | # Add last layer 343 | if self.output_hidden_states: 344 | all_hidden_states = all_hidden_states + (hidden_states,) 345 | 346 | outputs = (hidden_states,) 347 | if self.output_hidden_states: 348 | outputs = outputs + (all_hidden_states,) 349 | if self.output_attentions: 350 | outputs = outputs + (all_attentions,) 351 | return outputs # last-layer hidden state, (all hidden states), (all attentions) 352 | 353 | 354 | class BertPooler(nn.Module): 355 | def __init__(self, config): 356 | super(BertPooler, self).__init__() 357 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 358 | self.activation = nn.Tanh() 359 | 360 | def forward(self, hidden_states): 361 | # We "pool" the model by simply taking the hidden state corresponding 362 | # to the first token. 363 | first_token_tensor = hidden_states[:, 0] 364 | pooled_output = self.dense(first_token_tensor) 365 | pooled_output = self.activation(pooled_output) 366 | return pooled_output 367 | 368 | 369 | class BertPredictionHeadTransform(nn.Module): 370 | def __init__(self, config): 371 | super(BertPredictionHeadTransform, self).__init__() 372 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 373 | if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)): 374 | self.transform_act_fn = ACT2FN[config.hidden_act] 375 | else: 376 | self.transform_act_fn = config.hidden_act 377 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 378 | 379 | def forward(self, hidden_states): 380 | hidden_states = self.dense(hidden_states) 381 | hidden_states = self.transform_act_fn(hidden_states) 382 | hidden_states = self.LayerNorm(hidden_states) 383 | return hidden_states 384 | 385 | 386 | class BertLMPredictionHead(nn.Module): 387 | def __init__(self, config): 388 | super(BertLMPredictionHead, self).__init__() 389 | self.transform = BertPredictionHeadTransform(config) 390 | 391 | # The output weights are the same as the input embeddings, but there is 392 | # an output-only bias for each token. 393 | self.decoder = nn.Linear(config.hidden_size, 394 | config.vocab_size, 395 | bias=False) 396 | 397 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 398 | 399 | def forward(self, hidden_states): 400 | hidden_states = self.transform(hidden_states) 401 | hidden_states = self.decoder(hidden_states) + self.bias 402 | return hidden_states 403 | 404 | 405 | class BertPreTrainingHeads(nn.Module): 406 | def __init__(self, config): 407 | super(BertPreTrainingHeads, self).__init__() 408 | self.predictions = BertLMPredictionHead(config) 409 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 410 | 411 | def forward(self, sequence_output, pooled_output): 412 | prediction_scores = self.predictions(sequence_output) 413 | seq_relationship_score = self.seq_relationship(pooled_output) 414 | return prediction_scores, seq_relationship_score 415 | 416 | 417 | class BertPreTrainedModel(PreTrainedModel): 418 | """ An abstract class to handle weights initialization and 419 | a simple interface for dowloading and loading pretrained models. 420 | """ 421 | config_class = BertConfig 422 | pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP 423 | load_tf_weights = load_tf_weights_in_bert 424 | base_model_prefix = "bert" 425 | 426 | def _init_weights(self, module): 427 | """ Initialize the weights """ 428 | if isinstance(module, (nn.Linear, nn.Embedding)): 429 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 430 | elif isinstance(module, BertLayerNorm): 431 | module.bias.data.zero_() 432 | module.weight.data.fill_(1.0) 433 | if isinstance(module, nn.Linear) and module.bias is not None: 434 | module.bias.data.zero_() 435 | 436 | 437 | class BertModel(BertPreTrainedModel): 438 | 439 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 440 | super(BertModel, self).__init__(config) 441 | 442 | self.embeddings = BertEmbeddings(config, pos_tag_embedding=pos_tag_embedding, senti_embedding=senti_embedding, polarity_embedding=polarity_embedding) 443 | self.encoder = BertEncoder(config) 444 | self.pooler = BertPooler(config) 445 | 446 | self.init_weights() 447 | 448 | def _resize_token_embeddings(self, new_num_tokens): 449 | old_embeddings = self.embeddings.word_embeddings 450 | new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) 451 | self.embeddings.word_embeddings = new_embeddings 452 | return self.embeddings.word_embeddings 453 | 454 | def _prune_heads(self, heads_to_prune): 455 | """ Prunes heads of the model. 456 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 457 | See base class PreTrainedModel 458 | """ 459 | for layer, heads in heads_to_prune.items(): 460 | self.encoder.layer[layer].attention.prune_heads(heads) 461 | 462 | def forward(self, input_ids, visual=None, acoustic=None, visual_ids=None, acoustic_ids=None, pos_ids=None, senti_word_ids=None, polarity_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): 463 | if attention_mask is None: 464 | attention_mask = torch.ones_like(input_ids) 465 | if token_type_ids is None: 466 | token_type_ids = torch.zeros_like(input_ids) 467 | 468 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 469 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) 470 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 471 | 472 | if head_mask is not None: 473 | if head_mask.dim() == 1: 474 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 475 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 476 | elif head_mask.dim() == 2: 477 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) 478 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) 479 | else: 480 | head_mask = [None] * self.config.num_hidden_layers 481 | 482 | embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, pos_tag_ids=pos_ids, 483 | senti_word_ids=senti_word_ids, polarity_ids=polarity_ids) 484 | encoder_outputs = self.encoder(embedding_output, visual, acoustic, visual_ids, acoustic_ids, 485 | attention_mask=extended_attention_mask, 486 | head_mask=head_mask) 487 | sequence_output = encoder_outputs[0] 488 | pooled_output = self.pooler(sequence_output) 489 | 490 | outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] 491 | return outputs 492 | 493 | 494 | 495 | class BertForSequenceClassification(BertPreTrainedModel): 496 | 497 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 498 | super(BertForSequenceClassification, self).__init__(config) 499 | self.num_labels = config.num_labels 500 | 501 | self.bert = BertModel(config, pos_tag_embedding=pos_tag_embedding, 502 | senti_embedding=senti_embedding, 503 | polarity_embedding=polarity_embedding) 504 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 505 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 506 | 507 | self.init_weights() 508 | 509 | def forward(self, input_ids, pos_tag_ids, senti_word_ids, polarity_ids, attention_mask=None, token_type_ids=None, 510 | position_ids=None, head_mask=None, labels=None): 511 | 512 | outputs = self.bert(input_ids, 513 | attention_mask=attention_mask, 514 | token_type_ids=token_type_ids, 515 | position_ids=position_ids, 516 | head_mask=head_mask, 517 | pos_ids=pos_tag_ids, 518 | senti_word_ids=senti_word_ids, 519 | polarity_ids=polarity_ids) 520 | 521 | pooled_output = outputs[1] 522 | 523 | pooled_output = self.dropout(pooled_output) 524 | logits = self.classifier(pooled_output) 525 | 526 | outputs = (logits,) + outputs[2:] 527 | 528 | if labels is not None: 529 | if self.num_labels == 1: 530 | loss_fct = MSELoss() 531 | loss = loss_fct(logits.view(-1), labels.view(-1)) 532 | else: 533 | loss_fct = CrossEntropyLoss() 534 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 535 | outputs = (loss,) + outputs 536 | 537 | return outputs # (loss), logits, (hidden_states), (attentions) 538 | 539 | 540 | class BertForMultilabelClassification(BertPreTrainedModel): 541 | 542 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 543 | super(BertForMultilabelClassification, self).__init__(config) 544 | self.num_labels = config.num_labels 545 | 546 | self.bert = BertModel(config, pos_tag_embedding=pos_tag_embedding, 547 | senti_embedding=senti_embedding, 548 | polarity_embedding=polarity_embedding) 549 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 550 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 551 | self.sigmoid_layer = nn.Sigmoid() 552 | 553 | self.init_weights() 554 | 555 | def forward(self, input_ids, pos_tag_ids, senti_word_ids, polarity_ids, attention_mask=None, token_type_ids=None, 556 | position_ids=None, head_mask=None, labels=None): 557 | 558 | outputs = self.bert(input_ids, 559 | attention_mask=attention_mask, 560 | token_type_ids=token_type_ids, 561 | position_ids=position_ids, 562 | head_mask=head_mask, 563 | pos_ids=pos_tag_ids, 564 | senti_word_ids=senti_word_ids, 565 | polarity_ids=polarity_ids) 566 | 567 | pooled_output = outputs[1] 568 | 569 | pooled_output = self.dropout(pooled_output) 570 | logits = self.classifier(pooled_output) 571 | logits = self.sigmoid_layer(logits) 572 | 573 | outputs = (logits,) + outputs[2:] 574 | 575 | if labels is not None: 576 | if self.num_labels == 1: 577 | loss_fct = MSELoss() 578 | loss = loss_fct(logits.view(-1), labels.view(-1)) 579 | else: 580 | loss_fct = BCELoss() 581 | loss = loss_fct(logits.view(-1), labels.view(-1)) 582 | outputs = (loss,) + outputs 583 | 584 | return outputs # (loss), logits, (hidden_states), (attentions) 585 | 586 | 587 | class BertForMultipleChoice(BertPreTrainedModel): 588 | 589 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 590 | super(BertForMultipleChoice, self).__init__(config) 591 | 592 | self.bert = BertModel(config, pos_tag_embedding=pos_tag_embedding, 593 | senti_embedding=senti_embedding, 594 | polarity_embedding=polarity_embedding) 595 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 596 | self.classifier = nn.Linear(config.hidden_size, 1) 597 | 598 | self.init_weights() 599 | 600 | def forward(self, input_ids, pos_tag_ids, senti_word_ids, polarity_ids, attention_mask=None, token_type_ids=None, 601 | position_ids=None, head_mask=None, labels=None): 602 | num_choices = input_ids.shape[1] 603 | 604 | input_ids = input_ids.view(-1, input_ids.size(-1)) 605 | pos_tag_ids = pos_tag_ids.view(-1, pos_tag_ids.size(-1)) 606 | senti_word_ids = senti_word_ids.view(-1, senti_word_ids.size(-1)) 607 | polarity_ids = polarity_ids.view(-1, polarity_ids.size(-1)) 608 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 609 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 610 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 611 | 612 | outputs = self.bert(input_ids, 613 | attention_mask=attention_mask, 614 | token_type_ids=token_type_ids, 615 | position_ids=position_ids, 616 | head_mask=head_mask, 617 | pos_ids=pos_tag_ids, 618 | senti_word_ids=senti_word_ids, 619 | polarity_ids=polarity_ids) 620 | 621 | pooled_output = outputs[1] 622 | 623 | pooled_output = self.dropout(pooled_output) 624 | logits = self.classifier(pooled_output) 625 | reshaped_logits = logits.view(-1, num_choices) 626 | 627 | outputs = (reshaped_logits,) + outputs[2:] 628 | 629 | if labels is not None: 630 | loss_fct = CrossEntropyLoss() 631 | loss = loss_fct(reshaped_logits, labels) 632 | outputs = (loss,) + outputs 633 | 634 | return outputs # (loss), reshaped_logits, (hidden_states), (attentions) 635 | 636 | 637 | class BertForTokenClassification(BertPreTrainedModel): 638 | 639 | def __init__(self, config, pos_tag_embedding=False, senti_embedding=False, polarity_embedding=False): 640 | super(BertForTokenClassification, self).__init__(config) 641 | self.num_labels = config.num_labels 642 | 643 | self.bert = BertModel(config, pos_tag_embedding=pos_tag_embedding, 644 | senti_embedding=senti_embedding, 645 | polarity_embedding=polarity_embedding) 646 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 647 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 648 | 649 | self.init_weights() 650 | 651 | def forward(self, input_ids, pos_tag_ids, senti_word_ids, polarity_ids, attention_mask=None, token_type_ids=None, 652 | position_ids=None, head_mask=None, labels=None): 653 | 654 | outputs = self.bert(input_ids, 655 | attention_mask=attention_mask, 656 | token_type_ids=token_type_ids, 657 | position_ids=position_ids, 658 | head_mask=head_mask, 659 | pos_ids=pos_tag_ids, 660 | senti_word_ids=senti_word_ids, 661 | polarity_ids=polarity_ids) 662 | 663 | sequence_output = outputs[0] 664 | 665 | sequence_output = self.dropout(sequence_output) 666 | logits = self.classifier(sequence_output) 667 | 668 | outputs = (logits,) + outputs[2:] 669 | if labels is not None: 670 | loss_fct = CrossEntropyLoss(ignore_index=-1) 671 | if attention_mask is not None: 672 | active_loss = attention_mask.view(-1) == 1 673 | active_logits = logits.view(-1, self.num_labels)[active_loss] 674 | active_labels = labels.view(-1)[active_loss] 675 | loss = loss_fct(active_logits, active_labels) 676 | else: 677 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 678 | outputs = (loss,) + outputs 679 | 680 | return outputs # (loss), scores, (hidden_states), (attentions) 681 | --------------------------------------------------------------------------------