├── .DS_Store
├── .gitignore
├── .idea
├── .gitignore
├── Defeasible_Visual_Entailment.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── webServers.xml
├── Data
├── DVE_dev.csv
├── DVE_test.csv
├── DVE_train.csv
├── cleaned_dev.jsonl
├── cleaned_test.jsonl
└── cleaned_train.jsonl
├── Evaluator
├── inference_demo.py
├── visual_text_inference.py
└── visual_text_training.py
├── README.md
└── requirements.txt
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/skywalkerzhang/Defeasible_Visual_Entailment/ffb870e42a0538d9bd97c7ebca06e9c0a2bbc032/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore all .pth files
2 | *.pth
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/Defeasible_Visual_Entailment.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
--------------------------------------------------------------------------------
/Evaluator/inference_demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
4 | from transformers import BertTokenizer
5 | from visual_text_training import BERTModelModule, set_seed
6 |
7 |
8 | def demo_infer(hypothesis, image_path, updates,
9 | model_path='evaluator_weights.pth',
10 | gpu=7, use_classification_head=True):
11 | # Set seed for reproducibility
12 | set_seed(42)
13 |
14 | # Load tokenizer and model
15 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
16 | device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"
17 | model_name = "bert-large-uncased"
18 | model = BERTModelModule(model_name=model_name, use_classification_head=use_classification_head).to(device)
19 |
20 | # Load model weights
21 | model.load_state_dict(torch.load(model_path))
22 | model.eval()
23 |
24 | # Process the image
25 | image_transform = Compose([
26 | Resize((224, 224)),
27 | CenterCrop(224),
28 | ToTensor(),
29 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30 | ])
31 | image = Image.open(image_path).convert('RGB')
32 | image = image_transform(image).unsqueeze(0).to(device)
33 |
34 | # Process the hypothesis and updates as tensors
35 | hypothesis_inputs = tokenizer(hypothesis, return_tensors="pt", padding=True, truncation=True).to(device)
36 | update_inputs = [tokenizer(update, return_tensors="pt", padding=True, truncation=True).to(device) for update in
37 | updates]
38 |
39 | # Run inference and collect scores
40 | scores = []
41 | with torch.no_grad():
42 | for update_input in update_inputs:
43 | # Forward pass with separate tensors
44 | _, score = model(image, update_input['input_ids'], update_input['attention_mask'])
45 | scores.append(score.item())
46 |
47 | return scores
48 |
49 |
50 | if __name__ == "__main__":
51 |
52 | hypothesis = "A dog chases a rabbit."
53 | image_path = "/Data/flickr30k_images/3486831913.jpg"
54 | updates = [
55 | "A rabbit could photobomb this chase, and the dog would not even look up — it’s got its eye on nothing else but its ball.",
56 | "With a ball tossed by its owner, the dog’s attention is fully absorbed in the game, showing zero interest in rabbits.",
57 | "The dog is too absorbed in chasing the ball to even notice a rabbit.",
58 | "The dog looks like it’s going to chase something any second now.",
59 | "Every muscle in the dog’s body is alert, signaling it’s primed for a chase.",
60 | "With that intense look, could the dog be any more ready to chase?"
61 | ]
62 |
63 | # Call the demo_infer function
64 | scores = demo_infer(hypothesis, image_path, updates)
65 |
66 | # Print the output scores for each update
67 | for i, score in enumerate(scores):
68 | print(f"Update {i + 1} score: {score}")
69 |
--------------------------------------------------------------------------------
/Evaluator/visual_text_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from PIL import Image
4 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
5 | import pandas as pd
6 | from torch.utils.data import DataLoader
7 | from tqdm.auto import tqdm
8 | from visual_text_training import VisualTextDataset, BERTModelModule, set_seed, process_single_df, generate_filename, \
9 | test
10 | from transformers import BertTokenizer, BertModel, BertConfig, AdamW
11 |
12 |
13 | def parse_infer_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--test_csv_file', type=str, default='Data/DVE_test.csv',
16 | help='CSV file containing image paths, captions, and targets for testing.')
17 | parser.add_argument('--image_dir', type=str, default='Data/flickr30k_images',
18 | help='Directory containing images.')
19 | parser.add_argument('--batch_size', type=int, default=64, help='Batch size for inference.')
20 | parser.add_argument('--model_path', type=str,
21 | default='evaluator_weights.pth',
22 | help='Path to the fine-tuned model checkpoint.')
23 | parser.add_argument('--output_file', type=str, default='visual_text_bs64.csv',
24 | help='File to save the inference results.')
25 | parser.add_argument('--gpu', type=int, default=7, help='GPU id to use.')
26 | parser.add_argument('--use_classification_head', action='store_true', help='Whether to use classification head.')
27 | parser.add_argument('--calculate_accuracy', action='store_true', help='Whether to calculate the acc')
28 | return parser.parse_args()
29 |
30 |
31 | def get_score(df, batch_size=64, model_path='evaluator_weights.pth',
32 | gpu=7, use_classification_head=True, image_dir='Data/flickr30k_images'):
33 | set_seed(42)
34 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
35 |
36 | test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types = process_single_df(df, image_dir)
37 |
38 | image_transform = Compose([
39 | Resize((224, 224)),
40 | CenterCrop(224),
41 | ToTensor(),
42 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
43 | ])
44 |
45 | device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"
46 | model_name = "bert-large-uncased"
47 | model = BERTModelModule(model_name=model_name, use_classification_head=use_classification_head).to(device)
48 |
49 | model.load_state_dict(torch.load(model_path))
50 |
51 | test_dataset = VisualTextDataset(test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types,
52 | tokenizer, image_transform)
53 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
54 | results_df = test(model, test_loader, device, args=None, return_df=True, calculate_accuracy=False)
55 | return results_df
56 |
57 |
58 | def main():
59 | args = parse_infer_args()
60 | set_seed(42)
61 |
62 | test_df = pd.read_csv(args.test_csv_file)
63 |
64 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
65 |
66 | test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types = process_single_df(test_df,
67 | args.image_dir)
68 |
69 | image_transform = Compose([
70 | Resize((224, 224)),
71 | CenterCrop(224),
72 | ToTensor(),
73 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
74 | ])
75 |
76 | device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
77 | model_name = "bert-large-uncased"
78 | model = BERTModelModule(model_name=model_name, use_classification_head=args.use_classification_head).to(device)
79 |
80 | model.load_state_dict(torch.load(args.model_path))
81 |
82 | test_dataset = VisualTextDataset(test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types,
83 | tokenizer, image_transform)
84 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
85 |
86 | if args.calculate_accuracy:
87 | results_df = test(model, test_loader, device, args, return_df=True, calculate_accuracy=True)
88 | else:
89 | results_df = test(model, test_loader, device, args, return_df=True, calculate_accuracy=False)
90 | final_df = pd.concat([test_df, results_df], axis=1)
91 | final_df.to_csv(args.output_file, index=False)
92 | print(f"Test results saved to {args.output_file}")
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/Evaluator/visual_text_training.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from PIL import Image
4 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
5 | import numpy as np
6 | import os
7 | from torch.utils.data import Dataset, DataLoader
8 | import wandb
9 | import pandas as pd
10 | from tqdm.auto import tqdm
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.optim.lr_scheduler import LambdaLR
14 | import random
15 | from transformers import BertTokenizer, BertModel, BertConfig, AdamW
16 | import math
17 |
18 |
19 | class ClassificationHead(nn.Module):
20 | """Head for classification tasks."""
21 |
22 | def __init__(self, hidden_size, hidden_dropout_prob=0.1, num_labels=1):
23 | super().__init__()
24 | self.dense = nn.Linear(hidden_size, hidden_size)
25 | self.dropout = nn.Dropout(hidden_dropout_prob)
26 | self.out_proj = nn.Linear(hidden_size, num_labels)
27 |
28 | def forward(self, hidden_states):
29 | hidden_states = self.dropout(hidden_states)
30 | hidden_states = self.dense(hidden_states)
31 | hidden_states = torch.tanh(hidden_states)
32 | hidden_states = self.dropout(hidden_states)
33 | output = self.out_proj(hidden_states)
34 | return output
35 |
36 |
37 | class AttentionHead(nn.Module):
38 | def __init__(self, hidden_size, num_labels=1):
39 | super(AttentionHead, self).__init__()
40 | self.query = nn.Linear(hidden_size, hidden_size)
41 | self.key = nn.Linear(hidden_size, hidden_size)
42 | self.value = nn.Linear(hidden_size, hidden_size)
43 | self.out_proj = nn.Linear(hidden_size, num_labels)
44 | self.dropout = nn.Dropout(0.1)
45 |
46 | def forward(self, hidden_states):
47 | query = self.query(hidden_states)
48 | key = self.key(hidden_states)
49 | value = self.value(hidden_states)
50 |
51 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(hidden_states.size(-1))
52 | attention_probs = F.softmax(attention_scores, dim=-1)
53 |
54 | attention_output = torch.matmul(attention_probs, value)
55 | attention_output = self.dropout(attention_output)
56 | output = self.out_proj(attention_output)
57 | return output
58 |
59 |
60 | # 设置随机种子
61 | def set_seed(seed):
62 | random.seed(seed)
63 | np.random.seed(seed)
64 | torch.manual_seed(seed)
65 | torch.cuda.manual_seed(seed)
66 | torch.cuda.manual_seed_all(seed)
67 | torch.backends.cudnn.benchmark = False
68 | torch.backends.cudnn.deterministic = True
69 |
70 |
71 | def process_single_df(df, image_dir):
72 | image_paths = df['SNLIPairId'].apply(lambda x: os.path.join(image_dir, x.split('#')[0])).tolist()
73 | hypotheses = df['Hypothesis'].astype(str).tolist()
74 | if 'Premise' in df.columns:
75 | captions = df['Premise'].astype(str).tolist()
76 | else:
77 | captions = None
78 | updates = df['Update'].astype(str).tolist()
79 | update_types = df['UpdateType'].apply(lambda x: 1 if x == 'strengthener' else -1).tolist() # 增强为 1,减弱为 -1
80 | return image_paths, hypotheses, captions, updates, update_types
81 |
82 |
83 | class VisualTextDataset(Dataset):
84 | def __init__(self, image_paths, hypotheses, premises, updates, update_types, tokenizer, image_transform, max_length=128):
85 | self.image_paths = image_paths
86 | self.hypotheses = hypotheses
87 | self.premises = premises
88 | self.updates = updates
89 | self.update_types = update_types
90 | self.tokenizer = tokenizer
91 | self.image_transform = image_transform
92 | self.max_length = max_length
93 |
94 | def __getitem__(self, idx):
95 | try:
96 | image_path = self.image_paths[idx]
97 | image = Image.open(image_path).convert("RGB")
98 | image = self.image_transform(image)
99 |
100 | hypothesis = self.hypotheses[idx]
101 | update = self.updates[idx]
102 | update_type = self.update_types[idx]
103 |
104 | # 将 hypothesis 和 update 组合在一起进行编码
105 | inputs_hypothesis_update = self.tokenizer(
106 | hypothesis,
107 | update,
108 | return_tensors="pt",
109 | padding="max_length",
110 | truncation=True,
111 | max_length=self.max_length
112 | )
113 |
114 | # 如果 premises 存在,将 hypothesis 和 premise 组合在一起进行编码
115 | if self.premises is not None:
116 | premise = self.premises[idx]
117 | inputs_hypothesis_premise = self.tokenizer(
118 | hypothesis,
119 | premise,
120 | return_tensors="pt",
121 | padding="max_length",
122 | truncation=True,
123 | max_length=self.max_length
124 | )
125 | inputs_hypothesis_premise = {key: val.squeeze(0) for key, val in inputs_hypothesis_premise.items()}
126 | else:
127 | # 直接赋值为空字符串
128 | inputs_hypothesis_premise = {"input_ids": torch.tensor([], dtype=torch.long),
129 | "attention_mask": torch.tensor([], dtype=torch.long)}
130 |
131 | inputs_hypothesis_update = {key: val.squeeze(0) for key, val in inputs_hypothesis_update.items()}
132 |
133 | return image, inputs_hypothesis_premise, inputs_hypothesis_update, update_type
134 | except Exception as e:
135 | print(f"Error processing index {idx}: {e}")
136 | return None
137 |
138 | def __len__(self):
139 | return len(self.hypotheses)
140 |
141 |
142 | class CustomLoss(nn.Module):
143 | def __init__(self, reduction='mean'):
144 | super(CustomLoss, self).__init__()
145 | self.reduction = reduction
146 |
147 | def forward(self, score_hypo_premise, score_hypo_update, labels):
148 | outputs = (score_hypo_update - score_hypo_premise).view(-1)
149 | loss = - torch.mean(torch.log(torch.sigmoid(outputs * labels.view(-1))))
150 | return loss
151 |
152 |
153 | class VisualEncoder(nn.Module):
154 | def __init__(self, model_name='resnet50', pretrained=True):
155 | super(VisualEncoder, self).__init__()
156 | from torchvision.models import resnet50, ResNet50_Weights
157 | self.model = resnet50(weights=ResNet50_Weights.DEFAULT)
158 | self.model = nn.Sequential(*list(self.model.children())[:-1]) # 去掉最后的分类层
159 |
160 | def forward(self, images):
161 | features = self.model(images)
162 | features = features.view(features.size(0), -1) # 展平为二维
163 | return features
164 |
165 |
166 | class BERTModelModule(nn.Module):
167 | def __init__(self, model_name: str, use_classification_head=False, classification_weight=0.9):
168 | super(BERTModelModule, self).__init__()
169 | self.bert_model = BertModel.from_pretrained(model_name)
170 | self.tokenizer = BertTokenizer.from_pretrained(model_name)
171 | self.use_classification_head = use_classification_head
172 | self.classification_weight = classification_weight
173 | self.visual_encoder = VisualEncoder()
174 | self.custom_loss_fn = CustomLoss()
175 |
176 | if self.use_classification_head:
177 | self.classifier = nn.Linear(self.bert_model.config.hidden_size + 2048, 2)
178 | self.ce_loss_fn = nn.CrossEntropyLoss()
179 |
180 | self.regressor = nn.Linear(self.bert_model.config.hidden_size + 2048, 1)
181 |
182 | def forward(self, images, input_ids, attention_mask):
183 | visual_features = self.visual_encoder(images)
184 | text_outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
185 | pooled_output = text_outputs.pooler_output # 获取 [CLS] token 的输出
186 | combined_features = torch.cat((pooled_output, visual_features), dim=1)
187 | if self.use_classification_head:
188 | logits = self.classifier(combined_features)
189 | else:
190 | logits = None
191 | score = self.regressor(combined_features)
192 | return logits, score
193 |
194 | def compute_loss_and_scores(self, batch):
195 | weight = self.classification_weight
196 | images, hypo_premise_inputs, hypo_update_inputs, update_types = batch
197 |
198 | if hypo_premise_inputs and hypo_premise_inputs['input_ids'].numel() > 0:
199 | # 获取文本和图像的特征并计算分类和回归分数
200 | logits_hypo_premise, score_hypo_premise = self.forward(images, hypo_premise_inputs['input_ids'],
201 | hypo_premise_inputs['attention_mask'])
202 | else:
203 | logits_hypo_premise = None
204 | score_hypo_premise = None
205 |
206 | # 获取文本和图像的特征并计算分类和回归分数
207 | logits_hypo_update, score_hypo_update = self.forward(images, hypo_update_inputs['input_ids'],
208 | hypo_update_inputs['attention_mask'])
209 |
210 | # 计算自定义损失
211 | if hypo_premise_inputs and hypo_premise_inputs['input_ids'].numel() > 0:
212 | custom_loss = self.custom_loss_fn(score_hypo_premise, score_hypo_update, update_types)
213 | else:
214 | custom_loss = None
215 |
216 | if self.use_classification_head:
217 | if logits_hypo_update is not None:
218 | # 计算分类损失
219 | update_types_classification = (update_types + 1) // 2 # 将 -1, 1 转换为 0, 1
220 | ce_loss = self.ce_loss_fn(logits_hypo_update, update_types_classification.long())
221 | else:
222 | ce_loss = None
223 | # 总损失
224 | if custom_loss is not None and ce_loss is not None:
225 | loss = weight * ce_loss + (1 - weight) * custom_loss
226 | elif ce_loss is not None:
227 | loss = ce_loss
228 | else:
229 | loss = custom_loss
230 | else:
231 | loss = custom_loss
232 |
233 | return loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update
234 |
235 |
236 | def parse_args():
237 | parser = argparse.ArgumentParser()
238 | parser.add_argument('--train_csv_file', type=str, default='../Data/DVE_train.csv ',
239 | help='CSV file containing image paths, captions, and targets for training.')
240 | parser.add_argument('--val_csv_file', type=str, default='../Data/DVE_dev.csv',
241 | help='CSV file containing image paths, captions, and targets for validation.')
242 | parser.add_argument('--test_csv_file', type=str, default='../Data/DVE_test.csv',
243 | help='CSV file containing image paths, captions, and targets for testing.')
244 | parser.add_argument('--image_dir', type=str, default='../Data/flickr30k_images',
245 | help='Directory containing images.')
246 | parser.add_argument('--epochs', type=int, default=20, help='Number of epochs for fine-tuning.')
247 | parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate for fine-tuning.')
248 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size for fine-tuning.')
249 | parser.add_argument('--wandb_project', type=str, default="DI_visual_text", help='wandb project name.')
250 | parser.add_argument('--output_model', type=str, default="DI_visual_text",
251 | help='Base name for the output model file.')
252 | parser.add_argument('--dataset_type', type=str, choices=['VE', 'DI'], default='DI',
253 | help='Dataset type: VE or DI.')
254 | parser.add_argument('--test_batch_size', type=int, default=64, help='Batch size for inference.')
255 | parser.add_argument('--model_path', type=str, help='Path to the fine-tuned model checkpoint.')
256 | parser.add_argument('--output_file', type=str, default='test_results.csv',
257 | help='File to save the test results.')
258 | parser.add_argument('--gpu', type=int, default=7)
259 | parser.add_argument('--classification_weight', type=float, help='The weight for classification loss.')
260 | parser.add_argument('--use_classification_head', action='store_true', help='Whether to use classification head.')
261 | return parser.parse_args()
262 |
263 |
264 | def generate_filename(base_name, lr, accuracy, batch_size, ext):
265 | return f"{base_name}_lr{lr}_acc{accuracy:.4f}_bs{batch_size}.{ext}"
266 |
267 |
268 | def train(model, train_loader, val_loader, optimizer, scheduler, device, epochs, args):
269 | best_val_accuracy = 0.0
270 |
271 | for epoch in range(epochs):
272 | model.train()
273 | train_losses = []
274 |
275 | for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}"):
276 | images, hypo_premise_inputs, hypo_update_inputs, update_types = batch
277 | images = images.to(device)
278 | hypo_premise_inputs = {k: v.to(device) for k, v in hypo_premise_inputs.items()}
279 | hypo_update_inputs = {k: v.to(device) for k, v in hypo_update_inputs.items()}
280 | update_types = torch.tensor(update_types, dtype=torch.float32).clone().detach().to(device)
281 | optimizer.zero_grad()
282 |
283 | loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update = model.compute_loss_and_scores(
284 | (images, hypo_premise_inputs, hypo_update_inputs, update_types))
285 |
286 | loss.backward()
287 | optimizer.step()
288 |
289 | train_losses.append(loss.item())
290 | wandb.log({"step_loss": loss.item()})
291 |
292 | avg_train_loss = np.mean(train_losses)
293 | wandb.log({"epoch_train_loss": avg_train_loss, "epoch": epoch + 1})
294 | print(f"Epoch {epoch + 1}, Train Loss: {avg_train_loss}")
295 |
296 | model.eval()
297 | val_losses = []
298 | val_correct_predictions = 0
299 | val_total_predictions = 0
300 | val_classification_correct_predictions = 0
301 |
302 | with torch.no_grad():
303 | for batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}"):
304 | images, hypo_premise_inputs, hypo_update_inputs, update_types = batch
305 | images = images.to(device)
306 | hypo_premise_inputs = {k: v.to(device) for k, v in hypo_premise_inputs.items()}
307 | hypo_update_inputs = {k: v.to(device) for k, v in hypo_update_inputs.items()}
308 | update_types = torch.tensor(update_types, dtype=torch.float32).clone().detach().to(device)
309 |
310 | val_loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update = model.compute_loss_and_scores(
311 | (images, hypo_premise_inputs, hypo_update_inputs, update_types))
312 |
313 | val_losses.append(val_loss.item())
314 |
315 | # 计算自定义的准确率
316 | for score_hp, score_hu, update_type in zip(score_hypo_premise, score_hypo_update, update_types):
317 | if (update_type == 1 and score_hu > score_hp) or (update_type == -1 and score_hu < score_hp):
318 | val_correct_predictions += 1
319 | val_total_predictions += 1
320 |
321 | # 计算分类准确率
322 | if args.use_classification_head:
323 | preds = torch.argmax(logits_hypo_update, dim=1)
324 | update_types_classification = (update_types + 1) // 2 # 将 -1, 1 转换为 0, 1
325 | val_classification_correct_predictions += (preds == update_types_classification.long()).sum().item()
326 |
327 | avg_val_loss = np.mean(val_losses)
328 | val_accuracy = val_correct_predictions / val_total_predictions if val_total_predictions > 0 else 0
329 | val_classification_accuracy = val_classification_correct_predictions / val_total_predictions if val_total_predictions > 0 else 0
330 |
331 | wandb.log({
332 | "epoch_val_loss": avg_val_loss,
333 | "val_accuracy": val_accuracy,
334 | "val_classification_accuracy": val_classification_accuracy,
335 | "epoch": epoch + 1
336 | })
337 |
338 | print(
339 | f"Epoch {epoch + 1}, Val Loss: {avg_val_loss}, Val Accuracy: {val_accuracy}, Val Classification Accuracy: {val_classification_accuracy}")
340 |
341 | if val_accuracy > best_val_accuracy:
342 | best_val_accuracy = val_accuracy
343 | model_filename = generate_filename(args.output_model, args.lr, val_accuracy, args.batch_size, "pth")
344 | torch.save(model.state_dict(), model_filename)
345 | print(f"Best model saved with accuracy: {val_accuracy}")
346 |
347 | scheduler.step()
348 |
349 |
350 | def test(model, test_loader, device, args=None, calculate_accuracy=True, return_df=False):
351 | model.eval()
352 | results = []
353 | correct_predictions = 0
354 | total_predictions = 0
355 | classification_correct_predictions = 0
356 | premise_valid = False
357 |
358 | with torch.no_grad():
359 | for batch in tqdm(test_loader, desc="Testing"):
360 | images, hypo_premise_inputs, hypo_update_inputs, targets = batch
361 | images = images.to(device)
362 | hypo_update_inputs = {k: v.to(device) for k, v in hypo_update_inputs.items()}
363 | targets = targets.to(device)
364 |
365 | if hypo_premise_inputs['input_ids'].size(0) > 0 and hypo_premise_inputs['attention_mask'].size(0) > 0:
366 | hypo_premise_inputs = {k: v.to(device) for k, v in hypo_premise_inputs.items()}
367 | premise_valid = True
368 |
369 | loss, logits_hypo_premise, logits_hypo_update, score_hypo_premise, score_hypo_update = model.compute_loss_and_scores(
370 | (images, hypo_premise_inputs if premise_valid else None, hypo_update_inputs, targets))
371 |
372 | # 计算自定义的准确率
373 | if calculate_accuracy:
374 | if premise_valid:
375 | for score_hp, score_hu, target in zip(score_hypo_premise, score_hypo_update, targets):
376 | if (target == 1 and score_hu > score_hp) or (target == -1 and score_hu < score_hp):
377 | correct_predictions += 1
378 | total_predictions += 1
379 | results.append((score_hu.item(), score_hp.item(), target.item()))
380 | else:
381 | for score_hu, target in zip(score_hypo_update, targets):
382 | total_predictions += 1
383 | results.append((score_hu.item(), None, target.item()))
384 | else:
385 | for score_hu, target in zip(score_hypo_update, targets):
386 | results.append((score_hu.item(), None, target.item()))
387 |
388 | # 计算分类准确率
389 | if calculate_accuracy and args.use_classification_head:
390 | preds = torch.argmax(logits_hypo_update, dim=1)
391 | update_types_classification = (targets + 1) // 2 # 将 -1, 1 转换为 0, 1
392 | classification_correct_predictions += (preds == update_types_classification.long()).sum().item()
393 |
394 | if calculate_accuracy:
395 | accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
396 | classification_accuracy = classification_correct_predictions / total_predictions if total_predictions > 0 else 0
397 | print(f"Accuracy: {accuracy * 100:.2f}%, Classification Accuracy: {classification_accuracy * 100:.2f}%")
398 |
399 | results_df = pd.DataFrame(results, columns=['Score_Update', 'Score_Premise', 'UpdateType'])
400 | if not premise_valid or not calculate_accuracy:
401 | results_df = results_df.drop(columns=['Score_Premise'])
402 |
403 | if return_df:
404 | return results_df
405 | else:
406 | if calculate_accuracy:
407 | csv_filename = generate_filename(args.output_file.replace(".csv", ""), args.lr, accuracy,
408 | args.test_batch_size, "csv")
409 | else:
410 | csv_filename = generate_filename(args.output_file.replace(".csv", ""), args.lr, 0, args.test_batch_size,
411 | "csv")
412 | results_df.to_csv(csv_filename, index=False)
413 | print(f"Test results saved to {csv_filename}")
414 | return None
415 |
416 |
417 | def main():
418 | args = parse_args()
419 |
420 | wandb.init(project=args.wandb_project)
421 |
422 | train_df = pd.read_csv(args.train_csv_file)
423 | val_df = pd.read_csv(args.val_csv_file)
424 | test_df = pd.read_csv(args.test_csv_file)
425 |
426 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
427 |
428 | train_image_paths, train_hypotheses, train_premises, train_updates, train_update_types = process_single_df(train_df, args.image_dir)
429 | val_image_paths, val_hypotheses, val_premises, val_updates, val_update_types = process_single_df(val_df, args.image_dir)
430 | test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types = process_single_df(test_df, args.image_dir)
431 |
432 | image_transform = Compose([
433 | Resize((224, 224)),
434 | CenterCrop(224),
435 | ToTensor(),
436 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
437 | ])
438 |
439 | device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
440 | model_name = "bert-large-uncased"
441 | model = BERTModelModule(model_name=model_name, use_classification_head=args.use_classification_head).to(device)
442 |
443 | train_dataset = VisualTextDataset(train_image_paths, train_hypotheses, train_premises, train_updates, train_update_types, tokenizer, image_transform)
444 | val_dataset = VisualTextDataset(val_image_paths, val_hypotheses, val_premises, val_updates, val_update_types, tokenizer, image_transform)
445 | test_dataset = VisualTextDataset(test_image_paths, test_hypotheses, test_premises, test_updates, test_update_types, tokenizer, image_transform)
446 |
447 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)
448 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
449 | test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=4)
450 |
451 | optimizer = AdamW(model.parameters(), lr=args.lr)
452 | scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1 - epoch / args.epochs)
453 |
454 | # 训练模型
455 | train(model, train_loader, val_loader, optimizer, scheduler, device, args.epochs, args)
456 |
457 | # 测试模型
458 | test_accuracy = test(model, test_loader, device, args)
459 | wandb.log({"test_accuracy": test_accuracy})
460 |
461 | wandb.finish()
462 |
463 |
464 | if __name__ == "__main__":
465 | set_seed(42)
466 | main()
467 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Defeasible Visual Entailment
2 |
3 | This is the official code implementation for the AAAI 2025 paper *"Defeasible Visual Entailment: Benchmark, Evaluator, and Reward-Driven Optimization"*.
4 |
5 | ## Dataset
6 |
7 | - **Image Data**: The image dataset can be downloaded by filling out the form at [this link](https://forms.illinois.edu/sec/229675).
8 | - **Text Data**: The text data has already been uploaded. You can access it from the repository.
9 |
10 | We would like to thank the creators of the following datasets for their contributions:
11 |
12 | - **Flickr30k**: A large-scale image dataset with natural language descriptions.
13 | - **SNLI (Stanford Natural Language Inference)**: A dataset for developing and evaluating models for natural language inference.
14 |
15 | ## Installation
16 |
17 | 1. **Clone the repository:**
18 |
19 | ```bash
20 | git clone https://github.com/yourusername/Defeasible_Visual_Entailment.git
21 | cd Defeasible_Visual_Entailment
22 | ```
23 |
24 | 2. **Install the necessary dependencies:**
25 |
26 | ```bash
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ## Dataset Details
31 |
32 | The dataset consists of images paired with text captions. These pairs are annotated for visual entailment tasks, where the model determines whether the image **entails, contradicts, or is neutral** to the given text.
33 |
34 | | Split | Number of Samples | Weakener Count | Strengthener Count | Unique Images |
35 | |----------------|------------------|---------------|----------------|---------------|
36 | | **Training** | 93,082 | 46,541 | 46,541 | 9,507 |
37 | | **Validation** | 1,888 | 944 | 944 | 195 |
38 | | **Test** | 1,972 | 986 | 986 | 203 |
39 |
40 | Each sample contains:
41 | - An **image premise**
42 | - A **text hypothesis**
43 | - A textual **update** that either strengthens or weakens the hypothesis
44 | ## Usage
45 |
46 | To train the model, run the following command:
47 |
48 | ```bash
49 | python visual_text_training.py \
50 | --train_csv_file ../Data/DVE_train.csv \
51 | --val_csv_file ../Data/DVE_dev.csv \
52 | --test_csv_file ../Data/DVE_test.csv \
53 | --image_dir ../Data/flickr30k_images \
54 | --epochs 20 \
55 | --lr 5e-6 \
56 | --batch_size 32 \
57 | --wandb_project "Defeasible_Visual_Entailment" \
58 | --output_model "DVE_model.pth" \
59 | --gpu 0 \
60 | --classification_weight 0.9 \
61 | --use_classification_head
62 | ```
63 |
64 | ## Model
65 |
66 | The model is designed to handle visual and textual inputs, combining them to predict the relationship (entailment, contradiction, or neutral) between the image and the caption.
67 |
68 | The model integrates a reasoning evaluator, which assesses the strength of updates and their impact on visual entailment tasks. This allows for reward-driven optimization, improving model performance over time.
69 |
70 | ### Model Link
71 |
72 | The pre-trained evaluator model can be downloaded from Hugging Face:
73 |
74 | ➡ **[DVE Evaluator Model](https://huggingface.co/skywalkerzhang19/DVE_evaluator/resolve/main/evaluator_weights.pth?download=true)**
75 |
76 | You can also download it via `wget` or `curl`:
77 |
78 | ```bash
79 | wget "https://huggingface.co/skywalkerzhang19/DVE_evaluator/resolve/main/evaluator_weights.pth?download=true" -O evaluator_weights.pth
80 | ```
81 | or
82 |
83 | ```bash
84 | curl -L "https://huggingface.co/skywalkerzhang19/DVE_evaluator/resolve/main/evaluator_weights.pth?download=true" -o evaluator_weights.pth
85 | ```
86 |
87 | ## Evaluation
88 | To evaluate the model, use the following command:
89 | ```bash
90 | python visual_text_inference.py \
91 | --test_csv_file ../Data/DVE_test.csv \
92 | --image_dir ../Data/flickr30k_images \
93 | --model_path evaluator_weights.pth \
94 | --output_file "test_results.csv" \
95 | --gpu 0 \
96 | --test_batch_size 64
97 | ```
98 |
99 | For running inference on specific data using inference_demo.py, execute the following command:
100 | ```bash
101 | python inference_demo.py \
102 | --image_path path/to/your/image.jpg \
103 | --text "Your hypothesis text" \
104 | --update "Your update text" \
105 | --model_path Evaluator/evaluator_weights.pth \
106 | --output_file "inference_results.txt" \
107 | --gpu 0
108 | ```
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | Pillow==10.3.0
4 | transformers==4.41.2
5 | numpy==1.26.4
6 | pandas==2.2.1
7 | tqdm==4.66.4
8 | scikit-learn==1.5.1
9 | wandb==0.17.0
10 |
--------------------------------------------------------------------------------