├── .DS_Store ├── .gitignore ├── .idea ├── .gitignore ├── Defeasible_Visual_Entailment.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── webServers.xml ├── Data ├── DVE_dev.csv ├── DVE_test.csv ├── DVE_train.csv ├── cleaned_dev.jsonl ├── cleaned_test.jsonl └── cleaned_train.jsonl ├── Evaluator ├── inference_demo.py ├── visual_text_inference.py └── visual_text_training.py ├── README.md └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skywalkerzhang/Defeasible_Visual_Entailment/ffb870e42a0538d9bd97c7ebca06e9c0a2bbc032/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all .pth files 2 | *.pth -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/Defeasible_Visual_Entailment.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /Evaluator/inference_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 4 | from transformers import BertTokenizer 5 | from visual_text_training import BERTModelModule, set_seed 6 | 7 | 8 | def demo_infer(hypothesis, image_path, updates, 9 | model_path='evaluator_weights.pth', 10 | gpu=7, use_classification_head=True): 11 | # Set seed for reproducibility 12 | set_seed(42) 13 | 14 | # Load tokenizer and model 15 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 16 | device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu" 17 | model_name = "bert-large-uncased" 18 | model = BERTModelModule(model_name=model_name, use_classification_head=use_classification_head).to(device) 19 | 20 | # Load model weights 21 | model.load_state_dict(torch.load(model_path)) 22 | model.eval() 23 | 24 | # Process the image 25 | image_transform = Compose([ 26 | Resize((224, 224)), 27 | CenterCrop(224), 28 | ToTensor(), 29 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 30 | ]) 31 | image = Image.open(image_path).convert('RGB') 32 | image = image_transform(image).unsqueeze(0).to(device) 33 | 34 | # Process the hypothesis and updates as tensors 35 | hypothesis_inputs = tokenizer(hypothesis, return_tensors="pt", padding=True, truncation=True).to(device) 36 | update_inputs = [tokenizer(update, return_tensors="pt", padding=True, truncation=True).to(device) for update in 37 | updates] 38 | 39 | # Run inference and collect scores 40 | scores = [] 41 | with torch.no_grad(): 42 | for update_input in update_inputs: 43 | # Forward pass with separate tensors 44 | _, score = model(image, update_input['input_ids'], update_input['attention_mask']) 45 | scores.append(score.item()) 46 | 47 | return scores 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | hypothesis = "A dog chases a rabbit." 53 | image_path = "/Data/flickr30k_images/3486831913.jpg" 54 | updates = [ 55 | "A rabbit could photobomb this chase, and the dog would not even look up — it’s got its eye on nothing else but its ball.", 56 | "With a ball tossed by its owner, the dog’s attention is fully absorbed in the game, showing zero interest in rabbits.", 57 | "The dog is too absorbed in chasing the ball to even notice a rabbit.", 58 | "The dog looks like it’s going to chase something any second now.", 59 | "Every muscle in the dog’s body is alert, signaling it’s primed for a chase.", 60 | "With that intense look, could the dog be any more ready to chase?" 61 | ] 62 | 63 | # Call the demo_infer function 64 | scores = demo_infer(hypothesis, image_path, updates) 65 | 66 | # Print the output scores for each update 67 | for i, score in enumerate(scores): 68 | print(f"Update {i + 1} score: {score}") 69 | -------------------------------------------------------------------------------- /Evaluator/visual_text_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from PIL import Image 4 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 5 | import pandas as pd 6 | from torch.utils.data import DataLoader 7 | from tqdm.auto import tqdm 8 | from visual_text_training import VisualTextDataset, BERTModelModule, set_seed, process_single_df, generate_filename, \ 9 | test 10 | from transformers import BertTokenizer, BertModel, BertConfig, AdamW 11 | 12 | 13 | def parse_infer_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--test_csv_file', type=str, default='Data/DVE_test.csv', 16 | help='CSV file containing image paths, captions, and targets for testing.') 17 | parser.add_argument('--image_dir', type=str, default='Data/flickr30k_images', 18 | help='Directory containing images.') 19 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size for inference.') 20 | parser.add_argument('--model_path', type=str, 21 | default='evaluator_weights.pth', 22 | help='Path to the fine-tuned model checkpoint.') 23 | parser.add_argument('--output_file', type=str, default='visual_text_bs64.csv', 24 | help='File to save the inference results.') 25 | parser.add_argument('--gpu', type=int, default=7, help='GPU id to use.') 26 | parser.add_argument('--use_classification_head', action='store_true', help='Whether to use classification head.') 27 | parser.add_argument('--calculate_accuracy', action='store_true', help='Whether to calculate the acc') 28 | return parser.parse_args() 29 | 30 | 31 | def get_score(df, batch_size=64, model_path='evaluator_weights.pth', 32 | gpu=7, use_classification_head=True, image_dir='Data/flickr30k_images'): 33 | set_seed(42) 34 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 35 | 36 | test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types = process_single_df(df, image_dir) 37 | 38 | image_transform = Compose([ 39 | Resize((224, 224)), 40 | CenterCrop(224), 41 | ToTensor(), 42 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 43 | ]) 44 | 45 | device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu" 46 | model_name = "bert-large-uncased" 47 | model = BERTModelModule(model_name=model_name, use_classification_head=use_classification_head).to(device) 48 | 49 | model.load_state_dict(torch.load(model_path)) 50 | 51 | test_dataset = VisualTextDataset(test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types, 52 | tokenizer, image_transform) 53 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 54 | results_df = test(model, test_loader, device, args=None, return_df=True, calculate_accuracy=False) 55 | return results_df 56 | 57 | 58 | def main(): 59 | args = parse_infer_args() 60 | set_seed(42) 61 | 62 | test_df = pd.read_csv(args.test_csv_file) 63 | 64 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 65 | 66 | test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types = process_single_df(test_df, 67 | args.image_dir) 68 | 69 | image_transform = Compose([ 70 | Resize((224, 224)), 71 | CenterCrop(224), 72 | ToTensor(), 73 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 74 | ]) 75 | 76 | device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" 77 | model_name = "bert-large-uncased" 78 | model = BERTModelModule(model_name=model_name, use_classification_head=args.use_classification_head).to(device) 79 | 80 | model.load_state_dict(torch.load(args.model_path)) 81 | 82 | test_dataset = VisualTextDataset(test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types, 83 | tokenizer, image_transform) 84 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) 85 | 86 | if args.calculate_accuracy: 87 | results_df = test(model, test_loader, device, args, return_df=True, calculate_accuracy=True) 88 | else: 89 | results_df = test(model, test_loader, device, args, return_df=True, calculate_accuracy=False) 90 | final_df = pd.concat([test_df, results_df], axis=1) 91 | final_df.to_csv(args.output_file, index=False) 92 | print(f"Test results saved to {args.output_file}") 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /Evaluator/visual_text_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from PIL import Image 4 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 5 | import numpy as np 6 | import os 7 | from torch.utils.data import Dataset, DataLoader 8 | import wandb 9 | import pandas as pd 10 | from tqdm.auto import tqdm 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import LambdaLR 14 | import random 15 | from transformers import BertTokenizer, BertModel, BertConfig, AdamW 16 | import math 17 | 18 | 19 | class ClassificationHead(nn.Module): 20 | """Head for classification tasks.""" 21 | 22 | def __init__(self, hidden_size, hidden_dropout_prob=0.1, num_labels=1): 23 | super().__init__() 24 | self.dense = nn.Linear(hidden_size, hidden_size) 25 | self.dropout = nn.Dropout(hidden_dropout_prob) 26 | self.out_proj = nn.Linear(hidden_size, num_labels) 27 | 28 | def forward(self, hidden_states): 29 | hidden_states = self.dropout(hidden_states) 30 | hidden_states = self.dense(hidden_states) 31 | hidden_states = torch.tanh(hidden_states) 32 | hidden_states = self.dropout(hidden_states) 33 | output = self.out_proj(hidden_states) 34 | return output 35 | 36 | 37 | class AttentionHead(nn.Module): 38 | def __init__(self, hidden_size, num_labels=1): 39 | super(AttentionHead, self).__init__() 40 | self.query = nn.Linear(hidden_size, hidden_size) 41 | self.key = nn.Linear(hidden_size, hidden_size) 42 | self.value = nn.Linear(hidden_size, hidden_size) 43 | self.out_proj = nn.Linear(hidden_size, num_labels) 44 | self.dropout = nn.Dropout(0.1) 45 | 46 | def forward(self, hidden_states): 47 | query = self.query(hidden_states) 48 | key = self.key(hidden_states) 49 | value = self.value(hidden_states) 50 | 51 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(hidden_states.size(-1)) 52 | attention_probs = F.softmax(attention_scores, dim=-1) 53 | 54 | attention_output = torch.matmul(attention_probs, value) 55 | attention_output = self.dropout(attention_output) 56 | output = self.out_proj(attention_output) 57 | return output 58 | 59 | 60 | # 设置随机种子 61 | def set_seed(seed): 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed(seed) 66 | torch.cuda.manual_seed_all(seed) 67 | torch.backends.cudnn.benchmark = False 68 | torch.backends.cudnn.deterministic = True 69 | 70 | 71 | def process_single_df(df, image_dir): 72 | image_paths = df['SNLIPairId'].apply(lambda x: os.path.join(image_dir, x.split('#')[0])).tolist() 73 | hypotheses = df['Hypothesis'].astype(str).tolist() 74 | if 'Premise' in df.columns: 75 | captions = df['Premise'].astype(str).tolist() 76 | else: 77 | captions = None 78 | updates = df['Update'].astype(str).tolist() 79 | update_types = df['UpdateType'].apply(lambda x: 1 if x == 'strengthener' else -1).tolist() # 增强为 1,减弱为 -1 80 | return image_paths, hypotheses, captions, updates, update_types 81 | 82 | 83 | class VisualTextDataset(Dataset): 84 | def __init__(self, image_paths, hypotheses, premises, updates, update_types, tokenizer, image_transform, max_length=128): 85 | self.image_paths = image_paths 86 | self.hypotheses = hypotheses 87 | self.premises = premises 88 | self.updates = updates 89 | self.update_types = update_types 90 | self.tokenizer = tokenizer 91 | self.image_transform = image_transform 92 | self.max_length = max_length 93 | 94 | def __getitem__(self, idx): 95 | try: 96 | image_path = self.image_paths[idx] 97 | image = Image.open(image_path).convert("RGB") 98 | image = self.image_transform(image) 99 | 100 | hypothesis = self.hypotheses[idx] 101 | update = self.updates[idx] 102 | update_type = self.update_types[idx] 103 | 104 | # 将 hypothesis 和 update 组合在一起进行编码 105 | inputs_hypothesis_update = self.tokenizer( 106 | hypothesis, 107 | update, 108 | return_tensors="pt", 109 | padding="max_length", 110 | truncation=True, 111 | max_length=self.max_length 112 | ) 113 | 114 | # 如果 premises 存在,将 hypothesis 和 premise 组合在一起进行编码 115 | if self.premises is not None: 116 | premise = self.premises[idx] 117 | inputs_hypothesis_premise = self.tokenizer( 118 | hypothesis, 119 | premise, 120 | return_tensors="pt", 121 | padding="max_length", 122 | truncation=True, 123 | max_length=self.max_length 124 | ) 125 | inputs_hypothesis_premise = {key: val.squeeze(0) for key, val in inputs_hypothesis_premise.items()} 126 | else: 127 | # 直接赋值为空字符串 128 | inputs_hypothesis_premise = {"input_ids": torch.tensor([], dtype=torch.long), 129 | "attention_mask": torch.tensor([], dtype=torch.long)} 130 | 131 | inputs_hypothesis_update = {key: val.squeeze(0) for key, val in inputs_hypothesis_update.items()} 132 | 133 | return image, inputs_hypothesis_premise, inputs_hypothesis_update, update_type 134 | except Exception as e: 135 | print(f"Error processing index {idx}: {e}") 136 | return None 137 | 138 | def __len__(self): 139 | return len(self.hypotheses) 140 | 141 | 142 | class CustomLoss(nn.Module): 143 | def __init__(self, reduction='mean'): 144 | super(CustomLoss, self).__init__() 145 | self.reduction = reduction 146 | 147 | def forward(self, score_hypo_premise, score_hypo_update, labels): 148 | outputs = (score_hypo_update - score_hypo_premise).view(-1) 149 | loss = - torch.mean(torch.log(torch.sigmoid(outputs * labels.view(-1)))) 150 | return loss 151 | 152 | 153 | class VisualEncoder(nn.Module): 154 | def __init__(self, model_name='resnet50', pretrained=True): 155 | super(VisualEncoder, self).__init__() 156 | from torchvision.models import resnet50, ResNet50_Weights 157 | self.model = resnet50(weights=ResNet50_Weights.DEFAULT) 158 | self.model = nn.Sequential(*list(self.model.children())[:-1]) # 去掉最后的分类层 159 | 160 | def forward(self, images): 161 | features = self.model(images) 162 | features = features.view(features.size(0), -1) # 展平为二维 163 | return features 164 | 165 | 166 | class BERTModelModule(nn.Module): 167 | def __init__(self, model_name: str, use_classification_head=False, classification_weight=0.9): 168 | super(BERTModelModule, self).__init__() 169 | self.bert_model = BertModel.from_pretrained(model_name) 170 | self.tokenizer = BertTokenizer.from_pretrained(model_name) 171 | self.use_classification_head = use_classification_head 172 | self.classification_weight = classification_weight 173 | self.visual_encoder = VisualEncoder() 174 | self.custom_loss_fn = CustomLoss() 175 | 176 | if self.use_classification_head: 177 | self.classifier = nn.Linear(self.bert_model.config.hidden_size + 2048, 2) 178 | self.ce_loss_fn = nn.CrossEntropyLoss() 179 | 180 | self.regressor = nn.Linear(self.bert_model.config.hidden_size + 2048, 1) 181 | 182 | def forward(self, images, input_ids, attention_mask): 183 | visual_features = self.visual_encoder(images) 184 | text_outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask) 185 | pooled_output = text_outputs.pooler_output # 获取 [CLS] token 的输出 186 | combined_features = torch.cat((pooled_output, visual_features), dim=1) 187 | if self.use_classification_head: 188 | logits = self.classifier(combined_features) 189 | else: 190 | logits = None 191 | score = self.regressor(combined_features) 192 | return logits, score 193 | 194 | def compute_loss_and_scores(self, batch): 195 | weight = self.classification_weight 196 | images, hypo_premise_inputs, hypo_update_inputs, update_types = batch 197 | 198 | if hypo_premise_inputs and hypo_premise_inputs['input_ids'].numel() > 0: 199 | # 获取文本和图像的特征并计算分类和回归分数 200 | logits_hypo_premise, score_hypo_premise = self.forward(images, hypo_premise_inputs['input_ids'], 201 | hypo_premise_inputs['attention_mask']) 202 | else: 203 | logits_hypo_premise = None 204 | score_hypo_premise = None 205 | 206 | # 获取文本和图像的特征并计算分类和回归分数 207 | logits_hypo_update, score_hypo_update = self.forward(images, hypo_update_inputs['input_ids'], 208 | hypo_update_inputs['attention_mask']) 209 | 210 | # 计算自定义损失 211 | if hypo_premise_inputs and hypo_premise_inputs['input_ids'].numel() > 0: 212 | custom_loss = self.custom_loss_fn(score_hypo_premise, score_hypo_update, update_types) 213 | else: 214 | custom_loss = None 215 | 216 | if self.use_classification_head: 217 | if logits_hypo_update is not None: 218 | # 计算分类损失 219 | update_types_classification = (update_types + 1) // 2 # 将 -1, 1 转换为 0, 1 220 | ce_loss = self.ce_loss_fn(logits_hypo_update, update_types_classification.long()) 221 | else: 222 | ce_loss = None 223 | # 总损失 224 | if custom_loss is not None and ce_loss is not None: 225 | loss = weight * ce_loss + (1 - weight) * custom_loss 226 | elif ce_loss is not None: 227 | loss = ce_loss 228 | else: 229 | loss = custom_loss 230 | else: 231 | loss = custom_loss 232 | 233 | return loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update 234 | 235 | 236 | def parse_args(): 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--train_csv_file', type=str, default='../Data/DVE_train.csv ', 239 | help='CSV file containing image paths, captions, and targets for training.') 240 | parser.add_argument('--val_csv_file', type=str, default='../Data/DVE_dev.csv', 241 | help='CSV file containing image paths, captions, and targets for validation.') 242 | parser.add_argument('--test_csv_file', type=str, default='../Data/DVE_test.csv', 243 | help='CSV file containing image paths, captions, and targets for testing.') 244 | parser.add_argument('--image_dir', type=str, default='../Data/flickr30k_images', 245 | help='Directory containing images.') 246 | parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for fine-tuning.') 247 | parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate for fine-tuning.') 248 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size for fine-tuning.') 249 | parser.add_argument('--wandb_project', type=str, default="DI_visual_text", help='wandb project name.') 250 | parser.add_argument('--output_model', type=str, default="DI_visual_text", 251 | help='Base name for the output model file.') 252 | parser.add_argument('--dataset_type', type=str, choices=['VE', 'DI'], default='DI', 253 | help='Dataset type: VE or DI.') 254 | parser.add_argument('--test_batch_size', type=int, default=64, help='Batch size for inference.') 255 | parser.add_argument('--model_path', type=str, help='Path to the fine-tuned model checkpoint.') 256 | parser.add_argument('--output_file', type=str, default='test_results.csv', 257 | help='File to save the test results.') 258 | parser.add_argument('--gpu', type=int, default=7) 259 | parser.add_argument('--classification_weight', type=float, help='The weight for classification loss.') 260 | parser.add_argument('--use_classification_head', action='store_true', help='Whether to use classification head.') 261 | return parser.parse_args() 262 | 263 | 264 | def generate_filename(base_name, lr, accuracy, batch_size, ext): 265 | return f"{base_name}_lr{lr}_acc{accuracy:.4f}_bs{batch_size}.{ext}" 266 | 267 | 268 | def train(model, train_loader, val_loader, optimizer, scheduler, device, epochs, args): 269 | best_val_accuracy = 0.0 270 | 271 | for epoch in range(epochs): 272 | model.train() 273 | train_losses = [] 274 | 275 | for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"): 276 | images, hypo_premise_inputs, hypo_update_inputs, update_types = batch 277 | images = images.to(device) 278 | hypo_premise_inputs = {k: v.to(device) for k, v in hypo_premise_inputs.items()} 279 | hypo_update_inputs = {k: v.to(device) for k, v in hypo_update_inputs.items()} 280 | update_types = torch.tensor(update_types, dtype=torch.float32).clone().detach().to(device) 281 | optimizer.zero_grad() 282 | 283 | loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update = model.compute_loss_and_scores( 284 | (images, hypo_premise_inputs, hypo_update_inputs, update_types)) 285 | 286 | loss.backward() 287 | optimizer.step() 288 | 289 | train_losses.append(loss.item()) 290 | wandb.log({"step_loss": loss.item()}) 291 | 292 | avg_train_loss = np.mean(train_losses) 293 | wandb.log({"epoch_train_loss": avg_train_loss, "epoch": epoch + 1}) 294 | print(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss}") 295 | 296 | model.eval() 297 | val_losses = [] 298 | val_correct_predictions = 0 299 | val_total_predictions = 0 300 | val_classification_correct_predictions = 0 301 | 302 | with torch.no_grad(): 303 | for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}"): 304 | images, hypo_premise_inputs, hypo_update_inputs, update_types = batch 305 | images = images.to(device) 306 | hypo_premise_inputs = {k: v.to(device) for k, v in hypo_premise_inputs.items()} 307 | hypo_update_inputs = {k: v.to(device) for k, v in hypo_update_inputs.items()} 308 | update_types = torch.tensor(update_types, dtype=torch.float32).clone().detach().to(device) 309 | 310 | val_loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update = model.compute_loss_and_scores( 311 | (images, hypo_premise_inputs, hypo_update_inputs, update_types)) 312 | 313 | val_losses.append(val_loss.item()) 314 | 315 | # 计算自定义的准确率 316 | for score_hp, score_hu, update_type in zip(score_hypo_premise, score_hypo_update, update_types): 317 | if (update_type == 1 and score_hu > score_hp) or (update_type == -1 and score_hu < score_hp): 318 | val_correct_predictions += 1 319 | val_total_predictions += 1 320 | 321 | # 计算分类准确率 322 | if args.use_classification_head: 323 | preds = torch.argmax(logits_hypo_update, dim=1) 324 | update_types_classification = (update_types + 1) // 2 # 将 -1, 1 转换为 0, 1 325 | val_classification_correct_predictions += (preds == update_types_classification.long()).sum().item() 326 | 327 | avg_val_loss = np.mean(val_losses) 328 | val_accuracy = val_correct_predictions / val_total_predictions if val_total_predictions > 0 else 0 329 | val_classification_accuracy = val_classification_correct_predictions / val_total_predictions if val_total_predictions > 0 else 0 330 | 331 | wandb.log({ 332 | "epoch_val_loss": avg_val_loss, 333 | "val_accuracy": val_accuracy, 334 | "val_classification_accuracy": val_classification_accuracy, 335 | "epoch": epoch + 1 336 | }) 337 | 338 | print( 339 | f"Epoch {epoch + 1}, Val Loss: {avg_val_loss}, Val Accuracy: {val_accuracy}, Val Classification Accuracy: {val_classification_accuracy}") 340 | 341 | if val_accuracy > best_val_accuracy: 342 | best_val_accuracy = val_accuracy 343 | model_filename = generate_filename(args.output_model, args.lr, val_accuracy, args.batch_size, "pth") 344 | torch.save(model.state_dict(), model_filename) 345 | print(f"Best model saved with accuracy: {val_accuracy}") 346 | 347 | scheduler.step() 348 | 349 | 350 | def test(model, test_loader, device, args=None, calculate_accuracy=True, return_df=False): 351 | model.eval() 352 | results = [] 353 | correct_predictions = 0 354 | total_predictions = 0 355 | classification_correct_predictions = 0 356 | premise_valid = False 357 | 358 | with torch.no_grad(): 359 | for batch in tqdm(test_loader, desc="Testing"): 360 | images, hypo_premise_inputs, hypo_update_inputs, targets = batch 361 | images = images.to(device) 362 | hypo_update_inputs = {k: v.to(device) for k, v in hypo_update_inputs.items()} 363 | targets = targets.to(device) 364 | 365 | if hypo_premise_inputs['input_ids'].size(0) > 0 and hypo_premise_inputs['attention_mask'].size(0) > 0: 366 | hypo_premise_inputs = {k: v.to(device) for k, v in hypo_premise_inputs.items()} 367 | premise_valid = True 368 | 369 | loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update = model.compute_loss_and_scores( 370 | (images, hypo_premise_inputs if premise_valid else None, hypo_update_inputs, targets)) 371 | 372 | # 计算自定义的准确率 373 | if calculate_accuracy: 374 | if premise_valid: 375 | for score_hp, score_hu, target in zip(score_hypo_premise, score_hypo_update, targets): 376 | if (target == 1 and score_hu > score_hp) or (target == -1 and score_hu < score_hp): 377 | correct_predictions += 1 378 | total_predictions += 1 379 | results.append((score_hu.item(), score_hp.item(), target.item())) 380 | else: 381 | for score_hu, target in zip(score_hypo_update, targets): 382 | total_predictions += 1 383 | results.append((score_hu.item(), None, target.item())) 384 | else: 385 | for score_hu, target in zip(score_hypo_update, targets): 386 | results.append((score_hu.item(), None, target.item())) 387 | 388 | # 计算分类准确率 389 | if calculate_accuracy and args.use_classification_head: 390 | preds = torch.argmax(logits_hypo_update, dim=1) 391 | update_types_classification = (targets + 1) // 2 # 将 -1, 1 转换为 0, 1 392 | classification_correct_predictions += (preds == update_types_classification.long()).sum().item() 393 | 394 | if calculate_accuracy: 395 | accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0 396 | classification_accuracy = classification_correct_predictions / total_predictions if total_predictions > 0 else 0 397 | print(f"Accuracy: {accuracy * 100:.2f}%, Classification Accuracy: {classification_accuracy * 100:.2f}%") 398 | 399 | results_df = pd.DataFrame(results, columns=['Score_Update', 'Score_Premise', 'UpdateType']) 400 | if not premise_valid or not calculate_accuracy: 401 | results_df = results_df.drop(columns=['Score_Premise']) 402 | 403 | if return_df: 404 | return results_df 405 | else: 406 | if calculate_accuracy: 407 | csv_filename = generate_filename(args.output_file.replace(".csv", ""), args.lr, accuracy, 408 | args.test_batch_size, "csv") 409 | else: 410 | csv_filename = generate_filename(args.output_file.replace(".csv", ""), args.lr, 0, args.test_batch_size, 411 | "csv") 412 | results_df.to_csv(csv_filename, index=False) 413 | print(f"Test results saved to {csv_filename}") 414 | return None 415 | 416 | 417 | def main(): 418 | args = parse_args() 419 | 420 | wandb.init(project=args.wandb_project) 421 | 422 | train_df = pd.read_csv(args.train_csv_file) 423 | val_df = pd.read_csv(args.val_csv_file) 424 | test_df = pd.read_csv(args.test_csv_file) 425 | 426 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 427 | 428 | train_image_paths, train_hypotheses, train_premises, train_updates, train_update_types = process_single_df(train_df, args.image_dir) 429 | val_image_paths, val_hypotheses, val_premises, val_updates, val_update_types = process_single_df(val_df, args.image_dir) 430 | test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types = process_single_df(test_df, args.image_dir) 431 | 432 | image_transform = Compose([ 433 | Resize((224, 224)), 434 | CenterCrop(224), 435 | ToTensor(), 436 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 437 | ]) 438 | 439 | device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu" 440 | model_name = "bert-large-uncased" 441 | model = BERTModelModule(model_name=model_name, use_classification_head=args.use_classification_head).to(device) 442 | 443 | train_dataset = VisualTextDataset(train_image_paths, train_hypotheses, train_premises, train_updates, train_update_types, tokenizer, image_transform) 444 | val_dataset = VisualTextDataset(val_image_paths, val_hypotheses, val_premises, val_updates, val_update_types, tokenizer, image_transform) 445 | test_dataset = VisualTextDataset(test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types, tokenizer, image_transform) 446 | 447 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) 448 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) 449 | test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=4) 450 | 451 | optimizer = AdamW(model.parameters(), lr=args.lr) 452 | scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1 - epoch / args.epochs) 453 | 454 | # 训练模型 455 | train(model, train_loader, val_loader, optimizer, scheduler, device, args.epochs, args) 456 | 457 | # 测试模型 458 | test_accuracy = test(model, test_loader, device, args) 459 | wandb.log({"test_accuracy": test_accuracy}) 460 | 461 | wandb.finish() 462 | 463 | 464 | if __name__ == "__main__": 465 | set_seed(42) 466 | main() 467 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Defeasible Visual Entailment 2 | 3 | This is the official code implementation for the AAAI 2025 paper *"Defeasible Visual Entailment: Benchmark, Evaluator, and Reward-Driven Optimization"*. 4 | 5 | ## Dataset 6 | 7 | - **Image Data**: The image dataset can be downloaded by filling out the form at [this link](https://forms.illinois.edu/sec/229675). 8 | - **Text Data**: The text data has already been uploaded. You can access it from the repository. 9 | 10 | We would like to thank the creators of the following datasets for their contributions: 11 | 12 | - **Flickr30k**: A large-scale image dataset with natural language descriptions. 13 | - **SNLI (Stanford Natural Language Inference)**: A dataset for developing and evaluating models for natural language inference. 14 | 15 | ## Installation 16 | 17 | 1. **Clone the repository:** 18 | 19 | ```bash 20 | git clone https://github.com/yourusername/Defeasible_Visual_Entailment.git 21 | cd Defeasible_Visual_Entailment 22 | ``` 23 | 24 | 2. **Install the necessary dependencies:** 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Dataset Details 31 | 32 | The dataset consists of images paired with text captions. These pairs are annotated for visual entailment tasks, where the model determines whether the image **entails, contradicts, or is neutral** to the given text. 33 | 34 | | Split | Number of Samples | Weakener Count | Strengthener Count | Unique Images | 35 | |----------------|------------------|---------------|----------------|---------------| 36 | | **Training** | 93,082 | 46,541 | 46,541 | 9,507 | 37 | | **Validation** | 1,888 | 944 | 944 | 195 | 38 | | **Test** | 1,972 | 986 | 986 | 203 | 39 | 40 | Each sample contains: 41 | - An **image premise** 42 | - A **text hypothesis** 43 | - A textual **update** that either strengthens or weakens the hypothesis 44 | ## Usage 45 | 46 | To train the model, run the following command: 47 | 48 | ```bash 49 | python visual_text_training.py \ 50 | --train_csv_file ../Data/DVE_train.csv \ 51 | --val_csv_file ../Data/DVE_dev.csv \ 52 | --test_csv_file ../Data/DVE_test.csv \ 53 | --image_dir ../Data/flickr30k_images \ 54 | --epochs 20 \ 55 | --lr 5e-6 \ 56 | --batch_size 32 \ 57 | --wandb_project "Defeasible_Visual_Entailment" \ 58 | --output_model "DVE_model.pth" \ 59 | --gpu 0 \ 60 | --classification_weight 0.9 \ 61 | --use_classification_head 62 | ``` 63 | 64 | ## Model 65 | 66 | The model is designed to handle visual and textual inputs, combining them to predict the relationship (entailment, contradiction, or neutral) between the image and the caption. 67 | 68 | The model integrates a reasoning evaluator, which assesses the strength of updates and their impact on visual entailment tasks. This allows for reward-driven optimization, improving model performance over time. 69 | 70 | ### Model Link 71 | 72 | The pre-trained evaluator model can be downloaded from Hugging Face: 73 | 74 | ➡ **[DVE Evaluator Model](https://huggingface.co/skywalkerzhang19/DVE_evaluator/resolve/main/evaluator_weights.pth?download=true)** 75 | 76 | You can also download it via `wget` or `curl`: 77 | 78 | ```bash 79 | wget "https://huggingface.co/skywalkerzhang19/DVE_evaluator/resolve/main/evaluator_weights.pth?download=true" -O evaluator_weights.pth 80 | ``` 81 | or 82 | 83 | ```bash 84 | curl -L "https://huggingface.co/skywalkerzhang19/DVE_evaluator/resolve/main/evaluator_weights.pth?download=true" -o evaluator_weights.pth 85 | ``` 86 | 87 | ## Evaluation 88 | To evaluate the model, use the following command: 89 | ```bash 90 | python visual_text_inference.py \ 91 | --test_csv_file ../Data/DVE_test.csv \ 92 | --image_dir ../Data/flickr30k_images \ 93 | --model_path evaluator_weights.pth \ 94 | --output_file "test_results.csv" \ 95 | --gpu 0 \ 96 | --test_batch_size 64 97 | ``` 98 | 99 | For running inference on specific data using inference_demo.py, execute the following command: 100 | ```bash 101 | python inference_demo.py \ 102 | --image_path path/to/your/image.jpg \ 103 | --text "Your hypothesis text" \ 104 | --update "Your update text" \ 105 | --model_path Evaluator/evaluator_weights.pth \ 106 | --output_file "inference_results.txt" \ 107 | --gpu 0 108 | ``` 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | Pillow==10.3.0 4 | transformers==4.41.2 5 | numpy==1.26.4 6 | pandas==2.2.1 7 | tqdm==4.66.4 8 | scikit-learn==1.5.1 9 | wandb==0.17.0 10 | --------------------------------------------------------------------------------