├── text_data └── sentiment.txt ├── models ├── __pycache__ │ ├── text_models.cpython-37.pyc │ ├── text_models.cpython-39.pyc │ ├── vision_text.cpython-37.pyc │ ├── vision_models.cpython-37.pyc │ └── self_attention.cpython-37.pyc ├── self_attention.py ├── vision_models.py ├── text_models.py └── vision_text.py ├── utility ├── __pycache__ │ └── text_sentiment.cpython-37.pyc ├── sarcasm_image.py └── text_sentiment.py ├── README.md └── main.py /text_data/sentiment.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__pycache__/text_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downdric/MSD/HEAD/models/__pycache__/text_models.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/text_models.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downdric/MSD/HEAD/models/__pycache__/text_models.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/vision_text.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downdric/MSD/HEAD/models/__pycache__/vision_text.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/vision_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downdric/MSD/HEAD/models/__pycache__/vision_models.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/self_attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downdric/MSD/HEAD/models/__pycache__/self_attention.cpython-37.pyc -------------------------------------------------------------------------------- /utility/__pycache__/text_sentiment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/downdric/MSD/HEAD/utility/__pycache__/text_sentiment.cpython-37.pyc -------------------------------------------------------------------------------- /utility/sarcasm_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | 5 | def get_label(text_path): 6 | res = {} 7 | for line in open(text_path, 'r').readlines(): 8 | content = eval(line) 9 | skip_words = ['exgag', 'sarcasm', 'sarcastic', '', 'reposting', 'joke', 'humor', 'humour', 'jokes', 'irony', 'ironic'] 10 | flag = False 11 | for skip_word in skip_words: 12 | if skip_word in content[1]: flag = True 13 | if flag: continue 14 | res[content[0]] = content[2] 15 | return res 16 | 17 | 18 | def mv_img(labels, root_dir, target_dir): 19 | anno = ['non-sarcasm', 'sarcasm'] 20 | for cur_anno in anno: 21 | cur_dir = os.path.join(target_dir, cur_anno) 22 | if not os.path.exists(cur_dir): os.makedirs(cur_dir) 23 | 24 | for cur_key in labels.keys(): 25 | cur_img = cur_key + '.jpg' 26 | target_path = os.path.join(target_dir, anno[labels[cur_key]], cur_img) 27 | origin_path = os.path.join(root_dir, cur_img) 28 | shutil.copy(origin_path, target_path) 29 | 30 | -------------------------------------------------------------------------------- /utility/text_sentiment.py: -------------------------------------------------------------------------------- 1 | from senticnet.senticnet import SenticNet 2 | import numpy as np 3 | import torch 4 | 5 | 6 | sn = SenticNet() 7 | def get_word_level_sentiment(texts, tokenizer, device): 8 | res = [] 9 | for text in texts: 10 | if tokenizer is not None: word_list = tokenizer.tokenize(text) 11 | else: word_list = text.split() 12 | text_res = [] 13 | for word in word_list: 14 | try: 15 | word_polarity_value = float(sn.concept(word)['polarity_value']) 16 | except: 17 | word_polarity_value = float(0) 18 | text_res.append(word_polarity_value) 19 | res.append(torch.tensor(text_res).to(device)) 20 | return res 21 | 22 | 23 | def get_text_sentiment(texts, tokenizer, device): 24 | res = [] 25 | for text in texts: 26 | if tokenizer is not None: word_list = tokenizer.tokenize(text) 27 | else: word_list = text.split() 28 | text_res, cnt = 0, 0 29 | for word in word_list: 30 | try: 31 | text_res += float(sn.concept(word)['polarity_value']) 32 | cnt += 1 33 | except: 34 | text_res += float(0) 35 | if cnt != 0: text_res = text_res / cnt 36 | else: cnt = 0 37 | res.append(text_res) 38 | res = torch.tensor(res).to(device) 39 | return res 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is the official implementation of the paper "DIP: Dual Incongruity Perceiving Network for Sarcasm Detection", which is accepted by CVPR 2023. 4 | 5 | ## Abstract 6 | 7 | Sarcasm indicates the literal meaning is contrary to the actual attitude. Considering the popularity and complementarity of image-text data, we investigate multi-modal sarcasm detection. Different from other multi-modal tasks, for the sarcastic data, there exists intrinsic incongruity between a pair of image and text as demonstrated in psychological theories. 8 | 9 | To tackle this issue, we propose a Dual Incongruity Perceiving (DIP) network consisting of two branches to mine the sarcastic information from factual and affective levels. For the factual aspect, we introduce a channel-wise reweighting strategy to obtain semantically discriminative embeddings, and leverage gaussian distribution to model the uncertain correlation caused by the incongruity. The distribution is generated from the latest data stored in the memory bank, which can adaptively model the difference of semantic similarity between sarcastic and non-sarcastic data. For the affective aspect, we utilize siamese layers with shared parameters to learn cross-modal sentiment information. Furthermore, we use the polarity value to construct a relation graph for the mini-batch, which forms the continuous contrastive loss to acquire affective embeddings. Extensive experiments demonstrate that our proposed method performs favorably against state-of-the-art approaches. 10 | 11 | ## Installation 12 | Step 1: download data from ["Multi-Modal Sarcasm Detection in Twitter with Hierarchical Fusion Model"](https://github.com/ZLJ2015106/pytorch-multimodal_sarcasm_detection.git) 13 | 14 | Step 2: Please install the following packages before running the code: 15 |  torch == 1.13.0 16 |  torchtext == 0.14.0 17 |  torchvision == 0.14.0 18 |  transformers == 4.23.1 19 |  tokenizers == 0.13.1 20 |  senticnet == 1.6 21 | 22 | ## Usage 23 | 24 | ```bash 25 | python main.py 26 | ``` 27 | -------------------------------------------------------------------------------- /models/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | 6 | 7 | class FeedForward(nn.Module): 8 | def __init__(self, dim, hidden_dim): 9 | super().__init__() 10 | self.net = nn.Sequential( 11 | nn.LayerNorm(dim), 12 | nn.Linear(dim, hidden_dim), 13 | nn.GELU(), 14 | nn.Linear(hidden_dim, dim), 15 | ) 16 | def forward(self, x): 17 | return self.net(x) 18 | 19 | class Attention(nn.Module): 20 | def __init__(self, dim, heads = 8, dim_head = 64): 21 | super().__init__() 22 | inner_dim = dim_head * heads 23 | self.heads = heads 24 | self.scale = dim_head ** -0.5 25 | self.norm = nn.LayerNorm(dim) 26 | 27 | self.attend = nn.Softmax(dim = -1) 28 | 29 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 30 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 31 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 32 | 33 | def forward(self, x): 34 | x = self.norm(x) 35 | m1=x[:,0,:].unsqueeze(1) 36 | m2=x[:,1,:].unsqueeze(1) 37 | q= self.to_q(m1) 38 | kv = self.to_kv(m2).chunk(2, dim = -1) 39 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) #(12,1,512)->(12,8,1,64) 40 | k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), kv) 41 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 42 | 43 | attn = self.attend(dots) 44 | 45 | out = torch.matmul(attn, v) 46 | out = rearrange(out, 'b h n d -> b n (h d)') 47 | return self.to_out(out) 48 | 49 | 50 | class Transformer(nn.Module): 51 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 52 | super().__init__() 53 | self.layers = nn.ModuleList([]) 54 | for _ in range(depth): 55 | self.layers.append(nn.ModuleList([ 56 | Attention(dim, heads = heads, dim_head = dim_head), 57 | FeedForward(dim, mlp_dim) 58 | ])) 59 | def forward(self, x): 60 | for attn, ff in self.layers: 61 | x = attn(x) + x[:,1,:].unsqueeze(1) 62 | x = ff(x) + x 63 | return x 64 | 65 | transformer = Transformer(dim=768, depth=1, heads=8, dim_head=64, mlp_dim=768) 66 | input = torch.rand(12, 2, 768) 67 | y=transformer(input) 68 | pass 69 | 70 | -------------------------------------------------------------------------------- /models/vision_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import torchvision 3 | from torchvision.datasets import ImageFolder 4 | from torchvision import transforms, models 5 | from torch.utils import data 6 | from torch import optim, nn 7 | import torch.nn.functional as F 8 | from transformers import ViTModel, get_linear_schedule_with_warmup 9 | 10 | 11 | def initialize_transforms(args): 12 | input_size = 224 13 | img_mean, img_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 14 | image_transforms = {} 15 | image_transforms['train'] = transforms.Compose([ 16 | transforms.RandomResizedCrop(input_size), 17 | transforms.RandomHorizontalFlip(), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=img_mean, std=img_std) 20 | ]) 21 | image_transforms['test'] = transforms.Compose([ 22 | transforms.Resize(input_size), 23 | transforms.CenterCrop(input_size), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=img_mean, std=img_std) 26 | ]) 27 | return image_transforms 28 | 29 | 30 | def generate_vision_loader(args, image_transforms): 31 | train_set = ImageFolder(args.train_data_dir, image_transforms['train']) 32 | test_set = ImageFolder(args.test_data_dir, image_transforms['test']) 33 | train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 34 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) 35 | return train_loader, test_loader 36 | 37 | 38 | class VIT_MODEL(nn.Module): 39 | def __init__(self, vit, output_dim, alg): 40 | super().__init__() 41 | self.vit = vit 42 | if alg == 'base': self.vit.classifier = nn.Linear(768, output_dim) 43 | else: self.vit.classifier = nn.Linear(1024, output_dim) 44 | 45 | def forward(self, image): 46 | output = self.vit(image, return_dict=True) 47 | res = {'embeddings': output['last_hidden_state'][:, 1:, :], 'cls': output['pooler_output']} 48 | return res 49 | 50 | 51 | class ResNet_MODEL(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | self.model = models.resnet50(pretrained=True) 55 | self.model = nn.Sequential(*list(self.model.children())[:-2]) 56 | self.fc = nn.Linear(2048, 768, bias=True) 57 | 58 | def forward(self, image): 59 | output = self.model(image) 60 | batch_size, channles, w, h = output.size() 61 | output = output.view(batch_size, channles, -1) 62 | output = output.transpose(1, 2) 63 | output = self.fc(output) 64 | 65 | cls = torch.sum(output, dim=1) / (w*h) 66 | res = {'embeddings': output, 'cls': cls} 67 | return res 68 | 69 | 70 | def get_vision_model(args): 71 | if args.vision_backbone == 'vit': 72 | if args.vision_model == 'base': vit = ViTModel.from_pretrained('google/vit-base-patch16-224') 73 | elif args.vision_model == 'large' : vit = ViTModel.from_pretrained('google/vit-large-patch16-224') 74 | else: 75 | print('Only support base and large models') 76 | exit(0) 77 | base_model = VIT_MODEL(vit, args.output_dim, args.vision_model) 78 | else: 79 | base_model = ResNet_MODEL() 80 | return base_model 81 | 82 | 83 | def get_vision_configuration(args, model): 84 | optimizer = optim.Adam(model.parameters(), lr=args.vision_lr, weight_decay=args.vision_weight_decay) 85 | num_training_steps = int(args.train_set_len / args.batch_size * args.epoch) 86 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) 87 | criterion = nn.BCEWithLogitsLoss() 88 | return optimizer, scheduler, criterion 89 | -------------------------------------------------------------------------------- /models/text_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BertTokenizer 3 | from torchtext import data, datasets 4 | # from torchtext.data import Example 5 | from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import time 9 | import torch.nn.functional as F 10 | import tqdm 11 | from torchcrf import CRF 12 | import numpy as np 13 | from transformers import AutoTokenizer 14 | import torchtext.vocab as vocab 15 | 16 | 17 | def initialize_tokenizer(args): 18 | if args.text_model == 'base': 19 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 20 | max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased'] 21 | elif args.text_model == 'large': 22 | tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 23 | max_input_length = tokenizer.max_model_input_sizes['bert-large-uncased'] 24 | else: 25 | print('Only support base and large') 26 | exit(0) 27 | print(len(tokenizer.vocab)) 28 | 29 | print(max_input_length) 30 | 31 | init_token = tokenizer.cls_token 32 | eos_token = tokenizer.sep_token 33 | pad_token = tokenizer.pad_token 34 | unk_token = tokenizer.unk_token 35 | print(init_token, eos_token, pad_token, unk_token) 36 | 37 | return tokenizer 38 | 39 | 40 | def generate_text_loader(args, device, tokenizer, max_input_length, init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx): 41 | def tokenize_and_cut(sentence): 42 | tokens = tokenizer.tokenize(sentence) 43 | tokens = tokens[:max_input_length-2] 44 | return tokens 45 | 46 | TEXT = data.Field(batch_first = True, 47 | use_vocab = False, 48 | tokenize = tokenize_and_cut, 49 | preprocessing = tokenizer.convert_tokens_to_ids, 50 | init_token = init_token_idx, 51 | eos_token = eos_token_idx, 52 | pad_token = pad_token_idx, 53 | unk_token = unk_token_idx) 54 | 55 | LABEL = data.LabelField(dtype = torch.float) 56 | 57 | train_fields = [('text', TEXT), ('label', LABEL)] 58 | test_fields = [('text', TEXT), ('label', LABEL)] 59 | train_examples = [] 60 | test_examples = [] 61 | 62 | skip_words = ['exgag', 'sarcasm', 'sarcastic', '', 'reposting', 'joke', 'humor', 'humour', 'jokes', 'irony', 'ironic'] 63 | 64 | for line in tqdm.tqdm(open(args.train_text_path, 'r').readlines()): 65 | content = eval(line) 66 | flag = False 67 | for skip_word in skip_words: 68 | if skip_word in content[1]: flag = True 69 | if flag: continue 70 | text, label = content[1], content[2] 71 | # train_examples.append(Example.fromlist([text, label], train_fields)) 72 | 73 | for line in tqdm.tqdm(open(args.test_text_path, 'r').readlines()): 74 | content = eval(line) 75 | # flag = False 76 | # for skip_word in skip_words: 77 | # if skip_word in content[1]: flag = True 78 | # if flag: continue 79 | text, label = content[1], content[2] 80 | # test_examples.append(Example.fromlist([text, label], test_fields)) 81 | 82 | train_set = datasets(train_examples, train_fields) 83 | test_set = datasets(test_examples, test_fields) 84 | 85 | LABEL.build_vocab(train_set) 86 | 87 | train_iterator, test_iterator = data.BucketIterator.splits((train_set, test_set), batch_size=args.batch_size, device=device, sort=False) 88 | 89 | return train_iterator, test_iterator 90 | 91 | 92 | class BERT_MODEL(nn.Module): 93 | def __init__(self, bert, output_dim, alg, embedding=False): 94 | super().__init__() 95 | self.bert = bert 96 | self.embedding = embedding 97 | if alg == 'base': self.fc = nn.Linear(768, output_dim) 98 | else: self.fc = nn.Linear(1024, output_dim) 99 | 100 | def forward(self, text): 101 | output = self.bert(text) 102 | res = {'embeddings': output['last_hidden_state'][:, 1:, :], 'cls': output['pooler_output']} 103 | return res 104 | 105 | 106 | class LSTM_MODEL(nn.Module): 107 | def __init__(self, hidden_size): 108 | super().__init__() 109 | self.hidden_size = hidden_size 110 | self.embedding = nn.Embedding.from_pretrained(vocab.GloVe(name='42B', dim=300).vectors) 111 | 112 | self.biLSTM = nn.LSTM(input_size=300, hidden_size=hidden_size, num_layers=1, bidirectional=True, batch_first=True) 113 | self.fc = nn.Linear(2*hidden_size, 768, bias=True) 114 | # self.biLSTM = nn.LSTM(input_size=300, hidden_size=hidden_size, num_layers=1, batch_first=True) 115 | # self.fc = nn.Linear(hidden_size, 768, bias=True) 116 | 117 | def forward(self, text): 118 | embedded = self.embedding(text) 119 | output, _ = self.biLSTM(embedded) 120 | output = self.fc(output) 121 | # output = F.relu(self.fc(output)) 122 | 123 | return {'embeddings': output} 124 | 125 | 126 | def get_text_model(args, embedding=False): 127 | if args.text_backbone == 'bert': 128 | if args.text_model == 'base': 129 | bert = BertModel.from_pretrained('bert-base-uncased') 130 | elif args.text_model == 'large': 131 | bert = BertModel.from_pretrained('bert-large-uncased') 132 | else: 133 | print('error and tokenizer may have something wrong') 134 | exit(0) 135 | model = BERT_MODEL(bert, args.output_dim, args.text_model, embedding=embedding) 136 | else: 137 | model = LSTM_MODEL(hidden_size=256) 138 | 139 | return model 140 | 141 | 142 | def get_text_configuration(args, model): 143 | optimizer = optim.Adam(model.parameters(), lr=args.text_lr, weight_decay=args.text_weight_decay) 144 | num_training_steps = int(args.train_set_len / args.batch_size * args.epoch) 145 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) 146 | criterion = nn.BCEWithLogitsLoss() 147 | return optimizer, scheduler, criterion 148 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils import data 4 | import argparse 5 | import random 6 | import numpy as np 7 | import time 8 | import torch.nn.functional as F 9 | from sklearn import metrics 10 | from models.text_models import get_text_configuration, get_text_model, initialize_tokenizer 11 | from models.vision_models import initialize_transforms, get_vision_model, get_vision_configuration 12 | from models.vision_text import MultiModalDataset, get_multimodal_model, get_multimodal_configuration 13 | from utility.text_sentiment import get_word_level_sentiment, get_text_sentiment 14 | import torchtext.vocab as vocab 15 | 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | def binary_accuracy(preds, y): 25 | acc = metrics.accuracy_score(y, preds) 26 | return acc 27 | 28 | 29 | def get_sentiment(data_names, sentiment_labels): 30 | sentiment_res = [] 31 | for file_name in data_names: 32 | assert file_name in sentiment_labels.keys() 33 | sentiment_res.append(sentiment_labels[file_name]) 34 | return sentiment_res 35 | 36 | 37 | def train(args, text_tools, vision_model, text_model, multimodal_model, loader, optimizer, criterion, scheduler, device, vision_optimizer, vision_scheduler, text_optimizer, text_scheduler): 38 | epoch_loss = 0 39 | epoch_acc = 0 40 | vision_model.train() 41 | text_model.train() 42 | multimodal_model.train() 43 | for iter, data in enumerate(loader): 44 | data_names, vision_data, text_data, label = data[0], data[1], data[2], data[3] 45 | 46 | if args.text_backbone == 'bert': 47 | text_sentiment = get_word_level_sentiment(text_data, text_tools['tokenizer'], device) 48 | text_ids = text_tools['tokenizer'](text_data, padding='longest', truncation=True, return_tensors='pt')['input_ids'].to(device) 49 | 50 | text_ids = text_tools['tokenizer'](text_data, padding='longest', truncation=True, return_tensors='pt')['input_ids'].to(device) 51 | 52 | vision_data, text_ids, label = vision_data.to(device), text_ids.to(device), label.to(device) 53 | 54 | vision_embeddings = vision_model(vision_data) 55 | text_embeddings = text_model(text_ids) 56 | 57 | predictions, sentiment_contrast_loss, text_sentiment_loss = multimodal_model(vision_embeddings, text_embeddings, text_sentiment, label) 58 | 59 | loss = criterion(predictions, label.float()) + sentiment_contrast_loss + text_sentiment_loss 60 | acc = binary_accuracy(torch.round(torch.sigmoid(predictions)).cpu().detach().numpy().tolist(), label.cpu().detach().numpy().tolist()) 61 | vision_optimizer.zero_grad() 62 | text_optimizer.zero_grad() 63 | optimizer.zero_grad() 64 | loss.backward() 65 | vision_optimizer.step() 66 | text_optimizer.step() 67 | optimizer.step() 68 | vision_scheduler.step() 69 | text_scheduler.step() 70 | scheduler.step() 71 | 72 | epoch_loss += loss.item() 73 | epoch_acc += acc.item() 74 | return epoch_loss / len(loader), epoch_acc / len(loader) 75 | 76 | 77 | def evaluate(args, text_tools, epoch, all_epoch, best_acc, vision_model, text_model, multimodal_model, loader, criterion, device): 78 | epoch_loss = 0 79 | vision_model.eval() 80 | text_model.eval() 81 | multimodal_model.eval() 82 | preds = [] 83 | labels = [] 84 | with torch.no_grad(): 85 | for iter, data in enumerate(loader): 86 | data_name, vision_data, text_data, label = data[0], data[1], data[2], data[3] 87 | if args.text_backbone == 'bert': 88 | text_sentiment = get_word_level_sentiment(text_data, text_tools['tokenizer'], device) 89 | text_ids = text_tools['tokenizer'](text_data, padding='longest', truncation=True, return_tensors='pt')['input_ids'].to(device) 90 | 91 | text_ids = text_tools['tokenizer'](text_data, padding='longest', truncation=True, return_tensors='pt')['input_ids'].to(device) 92 | 93 | vision_data, text_ids, label = vision_data.to(device), text_ids.to(device), label.to(device) 94 | 95 | vision_embeddings = vision_model(vision_data) 96 | text_embeddings = text_model(text_ids) 97 | 98 | predictions, sentiment_contrast_loss, text_sentiment_loss = multimodal_model(vision_embeddings, text_embeddings, text_sentiment) 99 | 100 | loss = criterion(predictions, label.float()) 101 | preds.extend(torch.round(torch.sigmoid(predictions)).cpu().detach().numpy().tolist()) 102 | labels.extend(label.cpu().detach().numpy().tolist()) 103 | epoch_loss += loss.item() 104 | 105 | acc = metrics.accuracy_score(labels, preds) 106 | binary_f1 = metrics.f1_score(labels[:], preds[:]) 107 | binary_precision = metrics.precision_score(labels[:], preds[:]) 108 | binary_recall = metrics.recall_score(labels[:], preds[:]) 109 | macro_f1 = metrics.f1_score(labels[:], preds[:], average='macro') 110 | macro_precision = metrics.precision_score(labels[:], preds[:], average='macro') 111 | macro_recall = metrics.recall_score(labels[:], preds[:], average='macro') 112 | best_acc = max(best_acc, acc) 113 | print('Epoch: {}/{}: Macro F1: {} Macro Precision: {} Macro Recall: {} Binary F1: {} Binary Precision: {} Binary Recall: {} Acc: {} Best Acc: {}'.format( 114 | epoch, all_epoch, macro_f1, macro_precision, macro_recall, binary_f1, binary_precision, binary_recall, acc, best_acc 115 | )) 116 | return epoch_loss / len(loader), acc 117 | 118 | 119 | def epoch_time(start_time, end_time): 120 | elapsed_time = end_time - start_time 121 | elapsed_mins = int(elapsed_time / 60) 122 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 123 | return elapsed_mins, elapsed_secs 124 | 125 | 126 | def run(args): 127 | set_seed(args.seed) 128 | device = torch.device('cuda', args.gpu) 129 | 130 | if args.text_backbone == 'bert': tokenizer = initialize_tokenizer(args) 131 | else: tokenizer = vocab.GloVe(name='42B', dim=300).stoi 132 | text_tools = { 'tokenizer': tokenizer } 133 | vision_transforms = initialize_transforms(args) 134 | train_set = MultiModalDataset(text_tools, vision_transforms, args, 'train') 135 | test_set = MultiModalDataset(text_tools, vision_transforms, args, 'test') 136 | train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 137 | test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 138 | 139 | vision_model = get_vision_model(args) 140 | text_model = get_text_model(args, embedding=True) 141 | multimodal_model = get_multimodal_model(args) 142 | vision_optimizer, vision_scheduler, _ = get_vision_configuration(args, vision_model) 143 | text_optimizer, text_scheduler, _ = get_text_configuration(args, text_model) 144 | optimizer, scheduler, criterion = get_multimodal_configuration(args, multimodal_model) 145 | vision_model.to(device) 146 | text_model.to(device) 147 | multimodal_model.to(device) 148 | criterion.to(device) 149 | best_test_acc = -float('inf') 150 | 151 | for epoch in range(1, args.epoch+1): 152 | start_time = time.time() 153 | 154 | train_loss, train_acc = train(args, text_tools, vision_model, text_model, multimodal_model, train_loader, optimizer, criterion, scheduler, device, vision_optimizer, vision_scheduler, text_optimizer, text_scheduler) 155 | test_loss, test_acc = evaluate(args, text_tools, epoch, args.epoch, best_test_acc, vision_model, text_model, multimodal_model, test_loader, criterion, device) 156 | 157 | end_time = time.time() 158 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 159 | 160 | if test_acc > best_test_acc: 161 | best_test_acc = test_acc 162 | torch.save({ 163 | 'vision_model': vision_model.state_dict(), 164 | 'text_model': text_model.state_dict(), 165 | 'multimodal': multimodal_model.state_dict(), 166 | 'acc': test_acc 167 | }, os.path.join(args.save_dir, args.save_name)) 168 | 169 | print(f'Epoch: {epoch:02} | Epoch Time: {epoch_mins}m {epoch_secs}s') 170 | print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%') 171 | print(f'\tTest. Loss: {test_loss:.3f} | Test. Acc: {test_acc*100:.2f}%') 172 | 173 | 174 | def main(): 175 | parser = argparse.ArgumentParser(description='') 176 | 177 | # save information 178 | parser.add_argument('--save_dir', type=str, default='./saved_models') 179 | parser.add_argument('--save_name', type=str, default='best.pth') 180 | parser.add_argument('--seed', default=12345, type=int, help='seed for initializing training.') 181 | parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.') 182 | 183 | # train information 184 | parser.add_argument('--vision_backbone', type=str, default='vit') 185 | parser.add_argument('--vision_model', type=str, default='base') 186 | parser.add_argument('--text_backbone', type=str, default='bert') 187 | parser.add_argument('--text_model', type=str, default='base') 188 | parser.add_argument('--output_dim', type=int, default=1) 189 | parser.add_argument('--epoch', type=int, default=20) 190 | parser.add_argument('--batch_size', type=int, default=16) 191 | parser.add_argument('--vision_lr', type=float, default=2e-5) 192 | parser.add_argument('--vision_weight_decay', type=float, default=1e-5) 193 | parser.add_argument('--text_lr', type=float, default=2e-5) 194 | parser.add_argument('--text_weight_decay', type=float, default=1e-5) 195 | parser.add_argument('--multimodal_lr', type=float, default=5e-5) 196 | parser.add_argument('--multimodal_weight_decay', type=float, default=1e-5) 197 | parser.add_argument('--multimodal_fusion', type=str, default='product') 198 | parser.add_argument('--multilevel_fusion', type=str, default='concat') 199 | parser.add_argument('--lambda_sentiment', type=float, default=1) 200 | parser.add_argument('--lambda_semantic', type=float, default=1) 201 | parser.add_argument('--constant', type=float, default=0) 202 | parser.add_argument('--memory_length', type=int, default=256) 203 | parser.add_argument('--explicit_t', type=float, default=0.2) 204 | 205 | # dataset configuration 206 | parser.add_argument('--train_text_path', type=str, default='/home/ubuntu14/wcs/up_load/text_data/train.txt') 207 | parser.add_argument('--test_text_path', type=str, default='/home/ubuntu14/wcs/up_load/text_data/valid.txt') 208 | parser.add_argument('--train_image_path', type=str, default='/home/ubuntu14/wcs/imgs/train') 209 | parser.add_argument('--test_image_path', type=str, default='/home/ubuntu14/wcs/imgs/valid') 210 | parser.add_argument('--train_set_len', type=int, default=29040) 211 | parser.add_argument('--num_workers', type=int, default=20) 212 | 213 | args = parser.parse_args() 214 | 215 | run(args) 216 | 217 | 218 | main() 219 | -------------------------------------------------------------------------------- /models/vision_text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import os 4 | from PIL import Image 5 | from transformers import BertTokenizer 6 | from torch.utils.data import Dataset 7 | from torch import optim, nn 8 | from transformers import ViTForImageClassification, get_linear_schedule_with_warmup 9 | import torch.nn.functional as F 10 | from queue import Queue 11 | import scipy.stats as st 12 | import numpy as np 13 | import math 14 | from torch.optim.lr_scheduler import LambdaLR 15 | import cv2 16 | import json 17 | def read_json(path): 18 | with open(path,"r",encoding = 'utf-8') as f: 19 | data = json.load(f) 20 | return data 21 | def write_json(path,data): 22 | with open(path,"w",encoding = 'utf-8') as f: 23 | json.dump(data,f) 24 | def get_cosine_schedule_with_warmup(optimizer, 25 | num_training_steps, 26 | num_cycles=7. / 16., 27 | num_warmup_steps=0, 28 | last_epoch=-1): 29 | 30 | def _lr_lambda(current_step): 31 | 32 | if current_step < num_warmup_steps: 33 | _lr = float(current_step) / float(max(1, num_warmup_steps)) 34 | else: 35 | num_cos_steps = float(current_step - num_warmup_steps) 36 | num_cos_steps = num_cos_steps / float(max(1, num_training_steps - num_warmup_steps)) 37 | _lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps)) 38 | return _lr 39 | return LambdaLR(optimizer, _lr_lambda, last_epoch) 40 | 41 | 42 | class MultiModalDataset(Dataset): 43 | def __init__(self, text_tools, vision_transforms, args, mode): 44 | self.args = args 45 | self.vision_transform = vision_transforms[mode] 46 | self.mode = mode 47 | self.text_arr, self.img_path, self.label, self.idx2file = self.init_data() 48 | 49 | def init_data(self): 50 | if self.mode == 'train': 51 | text_path = self.args.train_text_path 52 | vision_path = self.args.train_image_path 53 | else: 54 | text_path = self.args.test_text_path 55 | vision_path = self.args.test_image_path 56 | 57 | text_arr, img_path, labels, idx2file = {}, {}, {}, [] 58 | skip_words = ['exgag', 'sarcasm', 'sarcastic', '', 'reposting', 'joke', 'humor', 'humour', 'jokes', 'irony', 'ironic'] 59 | for line in open(text_path, 'r').readlines(): 60 | content = eval(line) 61 | file_name, text, label = content[0], content[1], content[2] 62 | flag = False 63 | for skip_word in skip_words: 64 | if skip_word in content[1]: flag = True 65 | if flag: continue 66 | 67 | cur_img_path = os.path.join(vision_path, file_name+'.jpg') 68 | if not os.path.exists(cur_img_path): 69 | print(file_name) 70 | continue 71 | 72 | text_arr[file_name], labels[file_name] = text, label 73 | img_path[file_name] = os.path.join(vision_path, file_name+'.jpg') 74 | idx2file.append(file_name) 75 | return text_arr, img_path, labels, idx2file 76 | 77 | def __getitem__(self, idx): 78 | file_name = self.idx2file[idx] 79 | text = self.text_arr[file_name] 80 | img_path = self.img_path[file_name] 81 | label = self.label[file_name] 82 | 83 | img = Image.open(img_path).convert("RGB") 84 | img = self.vision_transform(img) 85 | return file_name, img, text, label 86 | 87 | def __len__(self): 88 | return len(self.label) 89 | 90 | 91 | class MSD_Net(nn.Module): 92 | def __init__(self, args): 93 | super().__init__() 94 | self.args = args 95 | self.sentiment_fc1 = nn.Linear(768, 768, bias=True) 96 | self.ReLu=nn.ReLU() 97 | self.dropout = nn.Dropout(p=0.1, inplace=False) 98 | self.sentiment_fc2 = nn.Linear(768, 1, bias=True) 99 | self.sentiment_criterion = nn.MSELoss() 100 | 101 | self.correlation_conv = nn.Sequential( 102 | nn.Conv2d(1, 64, 3, stride=1, padding=1), 103 | nn.Conv2d(64, 1, 3, stride=1, padding=1), 104 | nn.ReLU() 105 | ) 106 | 107 | self.multimodal_fusion = args.multimodal_fusion 108 | self.multilevel_fusion = args.multilevel_fusion 109 | if self.multilevel_fusion != 'concat' and self.multimodal_fusion != 'concat': self.final_fc = nn.Linear(768, 1, bias=True) 110 | elif self.multilevel_fusion == 'concat' and self.multimodal_fusion == 'concat': self.final_fc = nn.Linear(4*768, 1, bias=True) 111 | else: self.final_fc = nn.Linear(2*768, 1, bias=True) 112 | 113 | self.memory_length = args.memory_length 114 | self.sarcasm_bank = Queue(maxsize=self.memory_length) 115 | self.non_sarcasm_bank = Queue(maxsize=self.memory_length) 116 | 117 | def fusion(self, embeddings1, embeddings2, strategy): 118 | assert strategy in ['sum', 'product', 'concat'] 119 | if strategy == 'sum': return (embeddings1+embeddings2) / 2 120 | elif strategy == 'product': return embeddings1 * embeddings2 121 | else: return torch.cat([embeddings1, embeddings2], dim=1) 122 | 123 | def forward(self, vision_embeddings, text_embeddings, text_sentiment, label=None): 124 | vision_embeddings, text_embeddings = vision_embeddings['embeddings'], text_embeddings['embeddings'] 125 | 126 | batch_size = vision_embeddings.size()[0] 127 | 128 | text_embedd = text_embeddings.transpose(1, 2) 129 | vision_embedd = vision_embeddings 130 | attention_map = torch.bmm(vision_embedd, text_embedd) 131 | attention_map = self.correlation_conv(attention_map.unsqueeze(1)).squeeze() 132 | vision_c, text_c = attention_map.size(1), attention_map.size(2) 133 | 134 | vision_attention, text_attention = torch.sum(attention_map, dim=2)/text_c, torch.sum(attention_map, dim=1)/vision_c 135 | vision_attention, text_attention = torch.sigmoid(vision_attention), torch.sigmoid(text_attention) 136 | aligned_vision_embeddings = vision_attention.unsqueeze(-1) * vision_embedd 137 | aligned_text_embeddings = text_attention.unsqueeze(-1) * text_embedd.transpose(1,2) 138 | 139 | vision_embeddings = aligned_vision_embeddings 140 | text_embeddings = aligned_text_embeddings 141 | vision_nums, text_nums = vision_embeddings.size(1), text_embeddings.size(1) 142 | vision_CLS = torch.sum(vision_embeddings, dim=1) / vision_nums 143 | text_CLS = torch.sum(text_embeddings, dim=1) / text_nums 144 | 145 | 146 | # sentiment model 147 | text_sentiment_loss = 0 148 | for idx, cur_text_sentiment in enumerate(text_sentiment): 149 | cur_text_len = len(cur_text_sentiment) 150 | if self.args.text_backbone == 'bert': 151 | if cur_text_len > 510: cur_text_len, cur_text_sentiment = 510, cur_text_sentiment[:510] 152 | else: 153 | if cur_text_len > 512: cur_text_len, cur_text_sentiment = 512, cur_text_sentiment[:512] 154 | cur_text_embeddings = text_embeddings[idx, 0:cur_text_len, :] 155 | predicted_text_sentiment_embedding = self.sentiment_fc1(cur_text_embeddings) 156 | predicted_text_sentiment_embedding = self.dropout(self.ReLu(predicted_text_sentiment_embedding)) 157 | predicted_text_sentiment = self.sentiment_fc2(predicted_text_sentiment_embedding) 158 | 159 | mask = torch.ones_like(cur_text_sentiment) 160 | mask[cur_text_sentiment == 0] = 0 161 | predicted_text_sentiment = predicted_text_sentiment * mask.unsqueeze(1) 162 | text_sentiment_loss += self.sentiment_criterion(predicted_text_sentiment.squeeze(1), cur_text_sentiment) 163 | 164 | text_sentiment_loss /= len(text_sentiment) 165 | 166 | 167 | text_cls_sentiment_embedding = self.sentiment_fc1(text_CLS) 168 | vision_cls_sentiment_embedding = self.sentiment_fc1(vision_CLS) 169 | 170 | with torch.no_grad(): 171 | vision_cls_sentiment_embedd = self.dropout(self.ReLu(vision_cls_sentiment_embedding)) 172 | vision_cls_sentiment = self.sentiment_fc2(self.ReLu(vision_cls_sentiment_embedd)) 173 | text_cls_sentiment_embedd = self.dropout(text_cls_sentiment_embedding) 174 | text_cls_sentiment = self.sentiment_fc2(text_cls_sentiment_embedd) 175 | 176 | contrast_label = torch.abs(vision_cls_sentiment - text_cls_sentiment.t()) 177 | contrast_label = torch.exp(-contrast_label) 178 | contrast_label = contrast_label / contrast_label.sum(1, keepdim=True) 179 | 180 | sim = torch.exp(torch.mm(F.normalize(vision_cls_sentiment_embedding, dim=1), F.normalize(text_cls_sentiment_embedding, dim=1).t()) / 0.2) 181 | sim = sim / sim.sum(1, keepdim=True) 182 | sentiment_contrast_loss = F.kl_div(torch.log(sim), contrast_label, reduction='batchmean') 183 | 184 | lamda_sentiment = torch.abs(text_cls_sentiment.squeeze(1) - vision_cls_sentiment.squeeze(1)) 185 | 186 | # semantic model 187 | variance_vision = torch.nn.functional.normalize(torch.var(vision_CLS, dim=0), dim=-1) 188 | variance_text = torch.nn.functional.normalize(torch.var(text_CLS, dim=0), dim=-1) 189 | semantic_vision_embeddings = vision_CLS + vision_CLS*variance_vision.unsqueeze(0).repeat(batch_size,1) 190 | semantic_text_embeddings = text_CLS + text_CLS*variance_text.unsqueeze(0).repeat(batch_size,1) 191 | 192 | COS = nn.CosineSimilarity(dim=-1, eps=1e-6) 193 | 194 | sims = COS(semantic_vision_embeddings, semantic_text_embeddings) 195 | 196 | if label is not None: 197 | with torch.no_grad(): 198 | for id in range (batch_size): 199 | if label[id]==0: 200 | if self.non_sarcasm_bank.full() == True: self.non_sarcasm_bank.get() 201 | self.non_sarcasm_bank.put(sims[id]) 202 | elif label[id]==1: 203 | if self.sarcasm_bank.full() == True: self.sarcasm_bank.get() 204 | self.sarcasm_bank.put(sims[id]) 205 | 206 | if self.non_sarcasm_bank.full() == True and self.sarcasm_bank.full() == True: 207 | with torch.no_grad(): 208 | sarcasm_list = list(self.sarcasm_bank.queue) 209 | mu_sarcasm = sum(sarcasm_list) / self.args.memory_length 210 | sigma_sarcasm = torch.sqrt(sum([(tmp-mu_sarcasm)**2 for tmp in sarcasm_list])) 211 | 212 | non_sarcasm_list = list(self.non_sarcasm_bank.queue) 213 | mu_non_sarcasm = sum(non_sarcasm_list) / self.args.memory_length 214 | sigma_non_sarcasm = torch.sqrt(sum([(tmp-mu_non_sarcasm)**2 for tmp in non_sarcasm_list])) 215 | 216 | prob_sarcasm = (1/(sigma_sarcasm*np.sqrt(2*math.pi))) * torch.exp(-50*((sims - mu_sarcasm)/sigma_sarcasm)**2) 217 | prob_non_sarcasm = (1/(sigma_non_sarcasm*np.sqrt(2*math.pi))) * torch.exp(-50*((sims - mu_non_sarcasm)/sigma_non_sarcasm)**2) 218 | lamda_semantic = prob_sarcasm - prob_non_sarcasm 219 | else: 220 | lamda_semantic = torch.zeros_like(lamda_sentiment) 221 | prob_sarcasm=torch.zeros_like(lamda_sentiment) 222 | prob_non_sarcasm=torch.zeros_like(lamda_sentiment) 223 | 224 | 225 | # fusion 226 | semantic_cls = self.fusion(semantic_vision_embeddings, semantic_text_embeddings, self.multimodal_fusion) 227 | sentiment_cls = self.fusion(vision_cls_sentiment_embedding, text_cls_sentiment_embedding, self.multimodal_fusion) 228 | final_cls = self.fusion(semantic_cls, sentiment_cls, self.multilevel_fusion) 229 | final_cls = self.final_fc(final_cls).squeeze() 230 | fuse_final_cls = final_cls + self.args.lambda_sentiment*lamda_sentiment + self.args.lambda_semantic*lamda_semantic - self.args.constant 231 | 232 | return fuse_final_cls, sentiment_contrast_loss, text_sentiment_loss 233 | 234 | 235 | def get_multimodal_model(args): 236 | return MSD_Net(args) 237 | 238 | 239 | def get_multimodal_configuration(args, model): 240 | optimizer = optim.Adam(model.parameters(), lr=args.multimodal_lr, weight_decay=args.multimodal_weight_decay) 241 | num_training_steps = int(args.train_set_len / args.batch_size * args.epoch) 242 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps) 243 | # scheduler = get_cosine_schedule_with_warmup(optimizer=optimizer, num_training_steps=num_training_steps) 244 | criterion = nn.BCEWithLogitsLoss() 245 | return optimizer, scheduler, criterion 246 | --------------------------------------------------------------------------------