├── assets
├── teaser.png
└── commentary.png
├── soccer_words_llama3.pkl
├── inference_result
└── inference_result_scores.csv
├── evaluation
├── score_group.py
├── score_gpt.py
└── score_single.py
├── features
└── preprocess.py
├── environment.yaml
├── alignment
├── matchtime_model.py
├── soccer_whisperx.py
├── soccer_asr2events.py
├── soccer_align_from_event.py
└── do_alignment.py
├── inference.py
├── inference_single_video_CLIP.py
├── train.py
├── matchvoice_dataset.py
├── README.md
└── models
├── matchvoice_model.py
└── Qformer.py
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jyrao/MatchTime/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/assets/commentary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jyrao/MatchTime/HEAD/assets/commentary.png
--------------------------------------------------------------------------------
/soccer_words_llama3.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jyrao/MatchTime/HEAD/soccer_words_llama3.pkl
--------------------------------------------------------------------------------
/inference_result/inference_result_scores.csv:
--------------------------------------------------------------------------------
1 | filename,BLEU-1,BLEU-4,METEOR,ROUGE-L,CIDER,sBERT
2 | sample.csv,30.462,8.638,26.374,23.724,36.012,69.208
3 |
--------------------------------------------------------------------------------
/evaluation/score_group.py:
--------------------------------------------------------------------------------
1 | from score_single import calculate_metrics
2 |
3 | import os
4 | import glob
5 | import pandas as pd
6 | from tqdm import tqdm
7 |
8 | def process_all_csv(folder_path, output_folder):
9 |
10 | csv_files = glob.glob(os.path.join(folder_path, '*.csv'))
11 | results = []
12 |
13 | for csv_file in tqdm(csv_files, desc="Processing CSV files"):
14 | metrics = calculate_metrics(csv_file)
15 | row = [os.path.basename(csv_file)]
16 | row.extend(metrics.values())
17 | results.append(row)
18 |
19 | column_names = ['filename']
20 | if results:
21 | column_names.extend(metrics.keys())
22 |
23 | df = pd.DataFrame(results, columns=column_names)
24 | basename = os.path.basename(folder_path) + '_scores.csv'
25 | output_file = os.path.join(output_folder, basename)
26 |
27 | df.to_csv(output_file, index=False)
28 | print(f"Results saved to {output_file}")
29 |
30 | import argparse
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser(description='Process CSV files for football match commentary.')
34 | parser.add_argument('--folder_path', type=str, default='./inference_result', help='Path to the folder containing CSV files')
35 | parser.add_argument('--output_folder', type=str, default='./inference_result', help='Output folder for processed results')
36 |
37 | args = parser.parse_args()
38 |
39 | process_all_csv(args.folder_path, args.output_folder)
--------------------------------------------------------------------------------
/features/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 |
5 | def rename_and_relocate_grandchild_folders(main_dir):
6 | for subdir, dirs, files in os.walk(main_dir):
7 | for dir in dirs:
8 | grandchild_path = os.path.join(subdir, dir)
9 | for grandchild in os.listdir(grandchild_path):
10 | grandchild_full_path = os.path.join(grandchild_path, grandchild)
11 | if os.path.isdir(grandchild_full_path):
12 | new_name = f"{dir}_{grandchild}"
13 | new_full_path = os.path.join(main_dir, new_name)
14 | counter = 1
15 | while os.path.exists(new_full_path):
16 | new_full_path = os.path.join(main_dir, f"{new_name}_{counter}")
17 | counter += 1
18 | shutil.move(grandchild_full_path, new_full_path)
19 | print(f"Moved {grandchild_full_path} to {new_full_path}")
20 |
21 | for dir in dirs:
22 | try:
23 | os.rmdir(os.path.join(subdir, dir))
24 | print(f"Removed empty directory {os.path.join(subdir, dir)}")
25 | except OSError as e:
26 | print(f"Error: {e.strerror} - {os.path.join(subdir, dir)}")
27 |
28 | if __name__ == "__main__":
29 | # Default could be './features/baidu_soccer_embeddings'
30 | parser = argparse.ArgumentParser(description="Rename and relocate grandchild folders.")
31 | parser.add_argument('--main_dir', type=str, required=True, help='The main directory containing the folders to be processed.')
32 | args = parser.parse_args()
33 | rename_and_relocate_grandchild_folders(args.main_dir)
--------------------------------------------------------------------------------
/evaluation/score_gpt.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | import re, sys
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 |
7 | client = OpenAI(
8 | api_key=YOUR_API_KEY_HERE
9 | )
10 |
11 |
12 | def generate_prompt(gt, candidate):
13 | prompt = f"You are a grader of soccer game commentaries. There is a predicted commentary by AI model about a soccer game video clip and you need to score it comparing with ground truth. \n\nYou should rate an integer score from 0 to 10 about the degree of similarity with ground truth commentary (The higher the score, the more correct the candidate is). You must first consider the accuracy of the soccer events, then to consider about the semantic information in expressions and the professional soccer terminologies. The names of players and teams are masked by \"[PLAYER]\" and \"[TEAM]\". \n\nThe ground truth commentary of this soccer game video clip is:\n\n\"{gt}\"\n\n I need you to rate the following predicted commentary from 0 to 10:\n\n\"{candidate}\"\n\nThe score you give is (Just return one number, no other word or sentences):"
14 | return prompt
15 |
16 | def score(client, prompt):
17 | completion = client.chat.completions.create(
18 | model="gpt-3.5-turbo",
19 | messages=[
20 | {"role": "system", "content": "You are an expert in professional soccer commentary."},
21 | {"role": "user", "content": prompt},
22 | ],
23 | stop=["\n"],
24 | temperature=0,
25 | max_tokens=5,
26 | top_p=1,
27 | frequency_penalty=0.0,
28 | presence_penalty=0.0
29 |
30 | )
31 |
32 | res = completion.choices[0].message.content
33 | result = re.search(r'\b([0-9]|10)\b', res).group(0)
34 | return int(result)
35 |
36 | file_path = sys.argv[1]
37 | data = pd.read_csv(file_path)
38 | if 'llm_score' not in data.columns:
39 | data['llm_score'] = None
40 | data.to_csv(file_path, index=False)
41 |
42 | for start in tqdm(range(0, pd.read_csv(file_path).shape[0], 1)):
43 | data = pd.read_csv(file_path)
44 |
45 | end = start + 10
46 | group_data = data.iloc[start:end]
47 |
48 | if 'llm_score' not in data.columns:
49 | data['llm_score'] = None
50 |
51 | for index, row in group_data.iterrows():
52 | if pd.isna(row['llm_score']):
53 | gt = row['anoymized']
54 | candidate = row['predicted_res']
55 | prompt = generate_prompt(gt, candidate)
56 | result = score(client, prompt)
57 | data.at[index, 'llm_score'] = result
58 |
59 | data.to_csv(file_path, index=False)
60 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: matchtime
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - bzip2=1.0.8=h5eee18b_6
8 | - ca-certificates=2024.7.2=h06a4308_0
9 | - expat=2.6.2=h6a678d5_0
10 | - ld_impl_linux-64=2.38=h1181459_1
11 | - libffi=3.4.4=h6a678d5_1
12 | - libgcc-ng=11.2.0=h1234567_1
13 | - libgomp=11.2.0=h1234567_1
14 | - libstdcxx-ng=11.2.0=h1234567_1
15 | - libuuid=1.41.5=h5eee18b_0
16 | - ncurses=6.4=h6a678d5_0
17 | - openssl=3.0.14=h5eee18b_0
18 | - pip=24.2=py312h06a4308_0
19 | - python=3.12.4=h5148396_1
20 | - readline=8.2=h5eee18b_0
21 | - setuptools=72.1.0=py312h06a4308_0
22 | - sqlite=3.45.3=h5eee18b_0
23 | - tk=8.6.14=h39e8969_0
24 | - tzdata=2024a=h04d1e81_0
25 | - wheel=0.43.0=py312h06a4308_0
26 | - xz=5.4.6=h5eee18b_1
27 | - zlib=1.2.13=h5eee18b_1
28 | - pip:
29 | - argparse==1.4.0
30 | - certifi==2024.7.4
31 | - charset-normalizer==3.3.2
32 | - contourpy==1.3.0
33 | - cycler==0.12.1
34 | - einops==0.8.0
35 | - filelock==3.13.1
36 | - fonttools==4.53.1
37 | - fsspec==2024.2.0
38 | - ftfy==6.2.3
39 | - huggingface-hub==0.24.6
40 | - idna==3.8
41 | - jinja2==3.1.3
42 | - kiwisolver==1.4.5
43 | - markupsafe==2.1.5
44 | - matplotlib==3.9.2
45 | - mpmath==1.3.0
46 | - networkx==3.2.1
47 | - numpy==1.26.3
48 | - nvidia-cublas-cu11==11.11.3.6
49 | - nvidia-cuda-cupti-cu11==11.8.87
50 | - nvidia-cuda-nvrtc-cu11==11.8.89
51 | - nvidia-cuda-runtime-cu11==11.8.89
52 | - nvidia-cudnn-cu11==8.7.0.84
53 | - nvidia-cufft-cu11==10.9.0.58
54 | - nvidia-curand-cu11==10.3.0.86
55 | - nvidia-cusolver-cu11==11.4.1.48
56 | - nvidia-cusparse-cu11==11.7.5.86
57 | - nvidia-nccl-cu11==2.20.5
58 | - nvidia-nvtx-cu11==11.8.86
59 | - opencv-python==4.10.0.84
60 | - packaging==24.1
61 | - pillow==10.4.0
62 | - pycocoevalcap==1.2
63 | - pycocotools==2.0.8
64 | - pyparsing==3.1.4
65 | - python-dateutil==2.9.0.post0
66 | - pyyaml==6.0.2
67 | - regex==2024.7.24
68 | - requests==2.32.3
69 | - safetensors==0.4.4
70 | - six==1.16.0
71 | - sympy==1.12
72 | - tokenizers==0.19.1
73 | - torch==2.3.1+cu118
74 | - torchaudio==2.3.1+cu118
75 | - torchvision==0.18.1+cu118
76 | - tqdm==4.66.4
77 | - transformers==4.42.3
78 | - typing-extensions==4.9.0
79 | - urllib3==2.2.2
80 | - wcwidth==0.2.13
81 | - clip==1.0
82 |
--------------------------------------------------------------------------------
/alignment/matchtime_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 | import clip
6 |
7 |
8 | class VideoEncoder(nn.Module):
9 | def __init__(self, input_dim=512, output_dim=128):
10 | super(VideoEncoder, self).__init__()
11 | self.fc1 = nn.Linear(input_dim, 384)
12 | self.bn1 = nn.BatchNorm1d(384)
13 | self.fc2 = nn.Linear(384, 256)
14 | self.bn2 = nn.BatchNorm1d(256)
15 | self.fc3 = nn.Linear(256, output_dim)
16 | self.bn3 = nn.BatchNorm1d(output_dim)
17 |
18 | def forward(self, x):
19 | bs = x.shape[0]
20 | x = rearrange(x, 'b l a f -> (b l) (a f)')
21 | x = F.leaky_relu(self.bn1(self.fc1(x)))
22 | x = F.leaky_relu(self.bn2(self.fc2(x)))
23 | x = self.bn3(self.fc3(x))
24 | x = rearrange(x, '(b l) f -> b l f', b=bs)
25 | return x
26 |
27 | class TextEncoder(nn.Module):
28 | def __init__(self):
29 | super(TextEncoder, self).__init__()
30 | self.fc1 = nn.Linear(512, 384)
31 | self.bn1 = nn.BatchNorm1d(384)
32 | self.fc2 = nn.Linear(384, 256)
33 | self.bn2 = nn.BatchNorm1d(256)
34 | self.fc3 = nn.Linear(256, 128)
35 | self.bn3 = nn.BatchNorm1d(128)
36 |
37 | def forward(self, x):
38 | x = F.leaky_relu(self.bn1(self.fc1(x)))
39 | x = F.leaky_relu(self.bn2(self.fc2(x)))
40 | x = self.bn3(self.fc3(x))
41 | return x
42 |
43 | class ContrastiveLearningModel(nn.Module):
44 | def __init__(self, feature_dim=512, embedding_dim=128, device="cuda:7"):
45 | super(ContrastiveLearningModel, self).__init__()
46 | self.video_encoder = VideoEncoder(input_dim=feature_dim, output_dim=embedding_dim).to(device=device, dtype=torch.bfloat16)
47 | self.text_encoder = TextEncoder().to(device=device, dtype=torch.bfloat16)
48 | self.model, _ = clip.load("ViT-B/32", device=device)
49 | for name, param in self.model.named_parameters():
50 | param.requires_grad = False
51 |
52 | def forward(self, anchor_caption, concat_feature):
53 | anchor_embeddings = self.model.encode_text(anchor_caption).to(torch.bfloat16)
54 | anchor_encoded = self.text_encoder(anchor_embeddings).unsqueeze(2)
55 | concat_encoded = self.video_encoder(concat_feature)
56 |
57 | logits = torch.bmm(concat_encoded, anchor_encoded).squeeze(2)
58 | labels = torch.zeros(anchor_embeddings.shape[0], dtype=torch.long).to(device=anchor_embeddings.device)
59 | loss = F.cross_entropy(logits, labels)
60 |
61 | return loss, logits
--------------------------------------------------------------------------------
/alignment/soccer_whisperx.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import whisperx
4 | import json
5 |
6 | def convert_all_mkv_files(folder_path, output_directory, device):
7 | batch_size = 4
8 | compute_type = "float16"
9 | model = whisperx.load_model("large-v3", "cuda", device_index=device, language="en", compute_type=compute_type)
10 | device = "cuda"
11 | model_a, metadata = whisperx.load_align_model(language_code='en', device=device)
12 | print("Align loaded!")
13 |
14 | subdirectories = [subdir for subdir in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, subdir))]
15 | progress_bar = tqdm(total=len(subdirectories), desc="Processing")
16 | for subdir in subdirectories:
17 | subfolder_path = os.path.join(folder_path, subdir)
18 | mkv_files = [file for file in os.listdir(subfolder_path) if file.endswith(".mkv")]
19 | for mkv_file in mkv_files:
20 |
21 | new_game_folder = os.path.join(output_directory, subdir)
22 | os.makedirs(new_game_folder, exist_ok=True)
23 |
24 | audio_file_path = os.path.join(folder_path, subdir, mkv_file)
25 | output_file = os.path.join(new_game_folder, mkv_file[:-4] + ".json")
26 | if os.path.exists(output_file) and os.path.getsize(output_file) == 0:
27 | os.remove(output_file)
28 | if os.path.exists(output_file):
29 | continue
30 |
31 | try:
32 | audio = whisperx.load_audio(audio_file_path)
33 | print("Audio loaded!")
34 | result = model.transcribe(audio, batch_size=batch_size)
35 | print("Finished transcribe!")
36 | result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
37 | print("Aligned!")
38 | res = [{"start": item["start"], "end": item["end"], "text": item["text"]} for item in result["segments"]]
39 | with open(output_file, 'w') as file:
40 | json.dump(res, file, indent=4)
41 | print("Saved:",output_file)
42 | except:
43 | print("Failed:", output_file)
44 | progress_bar.update(1)
45 | progress_bar.close()
46 |
47 | import argparse
48 | import os
49 |
50 | # 设置 argparse 解析器
51 | parser = argparse.ArgumentParser(description="terminal instructions")
52 | parser.add_argument('--process_directory', type=str, required=True, help='input directory of league+year')
53 | parser.add_argument('--output_directory', type=str, required=True, help='output directory of league+year')
54 | parser.add_argument('--device', type=int, required=True, help='id of your using CUDA')
55 |
56 | # 解析命令行参数
57 | args = parser.parse_args()
58 | folder_path = args.process_directory
59 | output_directory = args.output_directory
60 | device = args.device
61 | print("Start!")
62 | convert_all_mkv_files(folder_path, output_directory, device)
--------------------------------------------------------------------------------
/evaluation/score_single.py:
--------------------------------------------------------------------------------
1 | from pycocoevalcap.bleu.bleu import Bleu as Bleuold
2 | from pycocoevalcap.bleu.bleu_scorer import BleuScorer
3 | from pycocoevalcap.meteor.meteor import Meteor
4 | from pycocoevalcap.rouge.rouge import Rouge
5 | from pycocoevalcap.cider.cider import Cider
6 | from sentence_transformers import SentenceTransformer
7 | import csv, argparse
8 | import numpy as np
9 |
10 | class Bleu(Bleuold):
11 | # Same as SoccerNet Evaluation
12 | def compute_score(self, gts, res):
13 |
14 | assert(gts.keys() == res.keys())
15 | imgIds = gts.keys()
16 |
17 | bleu_scorer = BleuScorer(n=self._n)
18 | for id in imgIds:
19 | hypo = res[id]
20 | ref = gts[id]
21 | assert(type(hypo) is list)
22 | assert(len(hypo) == 1)
23 | assert(type(ref) is list)
24 | assert(len(ref) >= 1)
25 | bleu_scorer += (hypo[0], ref)
26 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
27 |
28 | return score, scores
29 |
30 | def cosine_similarity(vec1, vec2):
31 | vec1_np = vec1.cpu().numpy()
32 | vec2_np = vec2.cpu().numpy()
33 |
34 | dot_product = np.dot(vec1_np, vec2_np.T)
35 | norm_vec1 = np.linalg.norm(vec1_np, axis=1, keepdims=True)
36 | norm_vec2 = np.linalg.norm(vec2_np, axis=1, keepdims=True)
37 | cosine_sim = dot_product / np.dot(norm_vec1, norm_vec2.T)
38 |
39 | return cosine_sim
40 |
41 | def calculate_metrics(csv_file_path):
42 | # Initialize scorers
43 | bleu4_scorer = Bleu(4)
44 | meteor_scorer = Meteor()
45 | rouge_scorer = Rouge()
46 | cider_scorer = Cider()
47 | sbert_model = SentenceTransformer('all-MiniLM-L6-v2') # Load the sBERT model
48 |
49 |
50 | references = {}
51 | hypotheses = {}
52 | with open(csv_file_path, newline='', encoding='utf-8') as csvfile:
53 | reader = csv.reader(csvfile)
54 | next(reader)
55 | for i, row in enumerate(reader):
56 | references[i] = [row[5]] # Ground truth in the 6th column (index 5)
57 | hypotheses[i] = [row[6]] # Predicted caption in the 7th column (index 6)
58 |
59 | # Calculate BLEU scores
60 | bleu4_score, _ = bleu4_scorer.compute_score(references, hypotheses)
61 |
62 | # Calculate METEOR scores
63 | meteor_score, _ = meteor_scorer.compute_score(references, hypotheses)
64 |
65 | # Calculate ROUGE scores, focusing on ROUGE-L
66 | _, rouge_scores = rouge_scorer.compute_score(references, hypotheses)
67 | rouge_l_score = rouge_scores.mean()
68 |
69 | # Calculate CIDER scores
70 | cider_score, _ = cider_scorer.compute_score(references, hypotheses)
71 |
72 | # Calculate sBERT scores
73 | ref_sentences = [refs[0] for refs in references.values()]
74 | hyp_sentences = [hyps[0] for hyps in hypotheses.values()]
75 | ref_embeddings = sbert_model.encode(ref_sentences, convert_to_tensor=True)
76 | hyp_embeddings = sbert_model.encode(hyp_sentences, convert_to_tensor=True)
77 | cosine_scores = np.diag(cosine_similarity(ref_embeddings, hyp_embeddings))
78 | sbert_score = np.mean(cosine_scores)
79 |
80 | return {
81 | "BLEU-1": f"{bleu4_score[0]*100:.3f}",
82 | "BLEU-4": f"{bleu4_score[3]*100:.3f}",
83 | "METEOR": f"{meteor_score*100:.3f}",
84 | "ROUGE-L": f"{rouge_l_score*100:.3f}",
85 | "CIDER": f"{cider_score*100:.3f}",
86 | "sBERT": f"{sbert_score*100:.3f}"
87 | }
88 |
89 | def main():
90 | parser = argparse.ArgumentParser(description="Calculate metrics from a CSV file.")
91 | parser.add_argument("--csv_path", type=str, default="./inference_result/sample.csv", help="Path to the CSV file containing the data.")
92 | args = parser.parse_args()
93 | results = calculate_metrics(args.csv_path)
94 | print(results)
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from matchvoice_dataset import MatchVoice_Dataset
2 | from torch.utils.data import DataLoader
3 | from models.matchvoice_model import matchvoice_model
4 | import torch
5 | import argparse
6 | import os
7 | import csv
8 | from tqdm import tqdm
9 |
10 | def predict(args):
11 | '''
12 | the outputs will be filled in a csv file with the colomns:
13 | - league: the league and season of soccer game
14 | - game: the name of this soccer game
15 | - half: 1st/2nd half of this game
16 | - timestamp: in which second of this half
17 | - type: the type of this soccer event
18 | - anonymized: the ground truth of this video clip
19 | - predicted_res_{i}: the predicted results of this video clip
20 | '''
21 | os.makedirs(os.path.dirname(args.csv_output_path), exist_ok=True)
22 | print(args.ann_root)
23 | test_dataset = MatchVoice_Dataset(
24 | feature_root=args.feature_root,
25 | ann_root=args.ann_root,
26 | fps=args.fps,
27 | timestamp_key="gt_gameTime",
28 | tokenizer_name=args.tokenizer_name,
29 | window=args.window
30 | )
31 | test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=False, shuffle=False, pin_memory=True, collate_fn=test_dataset.collater)
32 | print("===== Video features data loaded! =====")
33 | predict_model = matchvoice_model(llm_ckpt=args.tokenizer_name,tokenizer_ckpt=args.tokenizer_name,num_video_query_token=args.num_video_query_token, num_features=args.num_features, device=args.device, inference=True)
34 |
35 | # Load checkpoints
36 | other_parts_state_dict = torch.load(args.model_ckpt)
37 | new_model_state_dict = predict_model.state_dict()
38 | for key, value in other_parts_state_dict.items():
39 | if key in new_model_state_dict:
40 | new_model_state_dict[key] = value
41 | predict_model.load_state_dict(new_model_state_dict)
42 |
43 | predict_model.eval()
44 | print("===== Model and Checkpoints loaded! =====")
45 | headers = ['league', 'game', 'half', 'timestamp', 'type', 'anonymized']
46 | headers += [f'predicted_res_{i}' for i in range(args.generate_num)]
47 | with open(args.csv_output_path, 'w', newline='') as file:
48 | writer = csv.writer(file)
49 | writer.writerow(headers)
50 |
51 | # predict process
52 | with torch.no_grad():
53 | for samples in tqdm(test_data_loader):
54 | all_predictions = []
55 | for _ in range(args.generate_num):
56 | predicted_res = predict_model(samples)
57 | all_predictions.append(predicted_res)
58 |
59 | caption_info = samples["caption_info"]
60 | with open(args.csv_output_path, 'a', newline='') as file:
61 | writer = csv.writer(file)
62 | for info in zip(*all_predictions, caption_info):
63 | row = [info[-1][4], info[-1][5], info[-1][0], info[-1][1], info[-1][2], info[-1][3]] + list(info[:-1])
64 | writer.writerow(row)
65 |
66 | if __name__ == "__main__":
67 | parser = argparse.ArgumentParser(description="Train a model with FRANZ dataset.")
68 | parser.add_argument("--feature_root", type=str, default="./features/baidu_soccer_embeddings")
69 | parser.add_argument("--ann_root", type=str, default="./dataset/SN-Caption-test-align")
70 | parser.add_argument("--model_ckpt", type=str, default="./ckpt/models_ckpt/baidu/model_save_best_CIDEr.pth")
71 | parser.add_argument("--window", type=float, default=15)
72 | parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="LLM checkpoints, use path in your computer is fine as well")
73 | parser.add_argument("--batch_size", type=int, default=60)
74 | parser.add_argument("--num_workers", type=int, default=32)
75 | parser.add_argument("--num_query_tokens", type=int, default=32)
76 | parser.add_argument("--num_video_query_token", type=int, default=32)
77 | parser.add_argument("--num_features", type=int, default=512)
78 | parser.add_argument("--generate_num", type=int, default=1, help="You can determine how many sentences you want to comment (on the same video clip) here.")
79 | parser.add_argument("--csv_output_path", type=str, default="./inference_result/predict_baidu_window_15.csv", help="the path to the output predictions")
80 | parser.add_argument("--device", type=str, default="cuda:0")
81 | parser.add_argument("--fps", type=int, default=2, help="the FPS of your feature")
82 |
83 | args = parser.parse_args()
84 | predict(args)
85 |
--------------------------------------------------------------------------------
/inference_single_video_CLIP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import Dataset, DataLoader
3 | from torchvision import transforms
4 | import clip
5 | from PIL import Image
6 | import torch, os, cv2, argparse
7 | from models.matchvoice_model import matchvoice_model
8 |
9 | class VideoDataset(Dataset):
10 | def __init__(self, video_path, size=224, fps=2):
11 | self.video_path = video_path
12 | self.size = size
13 | self.fps = fps
14 | self.transforms = transforms.Compose([
15 | transforms.Resize((self.size, self.size)),
16 | transforms.ToTensor(),
17 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
18 | ])
19 | # Load video using OpenCV
20 | self.cap = cv2.VideoCapture(self.video_path)
21 | self.length = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
22 | # Calculate frames to capture based on FPS
23 | self.frame_indices = [int(x * self.cap.get(cv2.CAP_PROP_FPS) / self.fps) for x in range(int(self.length / self.cap.get(cv2.CAP_PROP_FPS) * self.fps))]
24 |
25 | def __len__(self):
26 | return len(self.frame_indices)
27 |
28 | def __getitem__(self, idx):
29 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_indices[idx])
30 | ret, frame = self.cap.read()
31 | if not ret:
32 | print("Error in reading frame")
33 | return None
34 | # Convert color from BGR to RGB
35 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
36 | # Apply transformations
37 | frame = self.transforms(Image.fromarray(frame))
38 | return frame.to(torch.float16)
39 |
40 | def close(self):
41 | self.cap.release()
42 |
43 | def encode_features(data_loader, encoder, device):
44 | all_features = None # 初始化为None,用于第一次赋值
45 | for frames in data_loader:
46 | features = encoder(frames.to(device))
47 | if all_features is None:
48 | all_features = features # 第一次迭代,直接赋值
49 | else:
50 | all_features = torch.cat((all_features, features), dim=0) # 后续迭代,在第0维(行)上连接
51 | return all_features
52 |
53 | def predict_single_video_CLIP(video_path, predict_model, visual_encoder, size, fps, device):
54 | # Loading features
55 | try:
56 | dataset = VideoDataset(video_path, size=size, fps=fps)
57 | data_loader = DataLoader(dataset, batch_size=40, shuffle=False, pin_memory=True, num_workers=0)
58 | # print("Start encoding!")
59 | features = encode_features(data_loader, visual_encoder, device)
60 | dataset.close()
61 | print("Features of this video loaded with shape of:", features.shape)
62 | except:
63 | print("Error with loading:", video_path)
64 |
65 | sample = {
66 | "features": features.unsqueeze(dim=0),
67 | "labels": None,
68 | "attention_mask": None,
69 | "input_ids": None
70 | }
71 |
72 | # Doing prediction:
73 | comment = predict_model(sample)
74 | print("The commentary is:", comment)
75 |
76 |
77 |
78 | if __name__ == '__main__':
79 | parser = argparse.ArgumentParser(description='Process video files for feature extraction.')
80 | parser.add_argument('--video_path', type=str, default="./examples/eng.mkv", help='Path to the soccer game video clip.')
81 | parser.add_argument('--device', type=str, default="cuda:0", help='Device to extract.')
82 | parser.add_argument('--size', type=int, default=224, help='Size to which each video frame is resized.')
83 | parser.add_argument('--fps', type=int, default=2, help='Frames per second to sample from the video.')
84 | parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3-8B", help="LLM checkpoints, use path in your computer is fine as well")
85 | parser.add_argument("--model_ckpt", type=str, default="./ckpt/CLIP_matchvoice.pth")
86 | parser.add_argument("--num_query_tokens", type=int, default=32)
87 | parser.add_argument("--num_video_query_token", type=int, default=32)
88 | parser.add_argument("--num_features", type=int, default=512)
89 |
90 | args = parser.parse_args()
91 |
92 | # 创建并配置模型
93 | model, preprocess = clip.load("ViT-B/32", device=args.device)
94 | model.eval()
95 | # print(model.dtype)
96 | clip_image_encoder = model.encode_image
97 | predict_model = matchvoice_model(llm_ckpt=args.tokenizer_name,tokenizer_ckpt=args.tokenizer_name,num_video_query_token=args.num_video_query_token, num_features=args.num_features, device=args.device, inference=True)
98 | # Load checkpoints
99 | other_parts_state_dict = torch.load(args.model_ckpt)
100 | new_model_state_dict = predict_model.state_dict()
101 | for key, value in other_parts_state_dict.items():
102 | if key in new_model_state_dict:
103 | new_model_state_dict[key] = value
104 | predict_model.load_state_dict(new_model_state_dict)
105 | predict_model.eval()
106 |
107 | predict_single_video_CLIP(video_path=args.video_path, predict_model=predict_model, visual_encoder=clip_image_encoder, device=args.device, size=args.size, fps=args.fps)
108 |
--------------------------------------------------------------------------------
/alignment/soccer_asr2events.py:
--------------------------------------------------------------------------------
1 | import json, os
2 | import numpy as np
3 | import transformers
4 | from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5 | import torch
6 | import argparse
7 | from tqdm import tqdm
8 | import re
9 |
10 | # Load JSON data from the file
11 | def load_data(filename):
12 | with open(filename, 'r') as file:
13 | data = json.load(file)
14 | return data
15 |
16 | # Organize data into one-minute intervals
17 | def organize_data(data):
18 | grouped_data = {}
19 | for entry in data:
20 | start_minute = int(entry['start'] // 60)
21 | end_minute = int(entry['end'] // 60)
22 | if start_minute not in grouped_data:
23 | grouped_data[start_minute] = []
24 | if end_minute not in grouped_data:
25 | grouped_data[end_minute] = []
26 | if entry not in grouped_data[start_minute]:
27 | grouped_data[start_minute].append(entry)
28 | if entry not in grouped_data[end_minute]:
29 | grouped_data[end_minute].append(entry)
30 | return grouped_data
31 |
32 |
33 | # Generate a comprehensive prompt for each minute and segment summaries for every 10 seconds
34 | def generate_prompt(grouped_data):
35 | prompt_all = {}
36 | for minute, entries in grouped_data.items():
37 | # Create a single text with timestamps for the whole minute
38 | full_minute_text = "\n".join(f"{entry['start']-minute*60:.2f}-{entry['end']-minute*60:.2f}s: {entry['text']}" for entry in entries)
39 | prompt = f"I will give you an automatically recognized speech with timestamps from a soccer game video. The narrator in the video is commenting on the soccer game. Your task is to summarize the key events for every 10 seconds, each commentary should be clear about the person name and soccer terminology. Here is this automatically recognized speech: \n\n{full_minute_text}\n\n You need to summarize 6 sentence commentaries for 0-10s, 10-20s, 20-30s, 30-40s, 40-50s, 50-60s according to the timestamps in automatically recognized speech results, every single sentence commentary should be clear and consise about the incidents happened within that 10 seconds for around 20-30 words. Now please write these 6 commentaries.\nAnswer:"
40 | prompt_all[minute] = prompt
41 | return prompt_all
42 |
43 | def asr2events(asr_json_path, output_json_path, model, tokenizer, device):
44 | asr_data = load_data(asr_json_path)
45 | grouped_data = organize_data(asr_data)
46 | prompts = generate_prompt(grouped_data)
47 |
48 | commentary_dict = {}
49 | for min in sorted(prompts.keys()):
50 | prompt = prompts[min]
51 | # print(f"Key: {min}, Value: {prompt}")
52 | input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
53 | generated_ids = model.generate(input_ids, max_length = 2000)
54 | generated_res = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
55 | try:
56 | extracted_commentaries = re.findall(r'\d+-\d+s: (.*?)(?=\n|$)', generated_res.split('Answer:')[1])
57 | extracted_commentaries = extracted_commentaries[:6]
58 | extracted_commentaries = [item.replace("assistant", "") if item.endswith("assistant") else item for item in extracted_commentaries]
59 | extracted_commentaries = [item.replace("assistant.", "") if item.endswith("assistant.") else item for item in extracted_commentaries]
60 | # Calculate keys and store the results
61 | for idx, commentary in enumerate(extracted_commentaries):
62 | key = min * 60 + (5 + idx * 10) # Calculate the key as specified
63 | commentary_dict[key] = commentary
64 | except:
65 | print("Error with", min, "minute in", asr_json_path)
66 |
67 | with open(output_json_path, 'w') as json_file:
68 | json.dump(commentary_dict, json_file, indent=4)
69 |
70 |
71 | parser = argparse.ArgumentParser(description='Process ASR data for football matches.')
72 | parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on.')
73 | parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3-8B", help='Path to the pretrained model.')
74 | parser.add_argument('--base_path', type=str, required=True, help='The folder of ASR results')
75 | parser.add_argument('--output_dir', type=str, required=True, help='Output directory for processed JSON files.')
76 |
77 | args = parser.parse_args()
78 |
79 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
80 | model = AutoModelForCausalLM.from_pretrained(args.model_name)
81 | model.config.pad_token_id = model.config.eos_token_id
82 | tokenizer.pad_token = tokenizer.eos_token
83 | model.to(args.device)
84 |
85 | tasks = []
86 | for match in os.listdir(args.base_path):
87 | match_path = os.path.join(args.base_path, match)
88 | if os.path.isdir(match_path):
89 | for file in os.listdir(match_path):
90 | file_path = os.path.join(match_path, file)
91 |
92 | if file_path.endswith("224p.json"):
93 | asr_json_path = os.path.join(match_path, file_path)
94 | output_file_name = os.path.basename(file_path.replace("224p.json", "narrator_event.json"))
95 | output_json_path = os.path.join(args.output_dir, os.path.basename(args.base_path), match, output_file_name)
96 | tasks.append((asr_json_path, output_json_path))
97 |
98 | league_year_name = os.path.basename(os.path.dirname(os.path.dirname(tasks[0][0])))
99 | print("Processing:",league_year_name)
100 | # 使用tqdm遍历所有任务并处理
101 | for asr_json_path, output_json_path in tqdm(tasks, desc="Processing files"):
102 | # print(asr_json_path, output_json_path)
103 | if os.path.exists(output_json_path):
104 | continue
105 | os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
106 | try:
107 | asr2events(asr_json_path, output_json_path, model, tokenizer, args.device)
108 | except:
109 | print("Error with", asr_json_path)
110 |
111 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from matchvoice_dataset import MatchVoice_Dataset
3 | from models.matchvoice_model import matchvoice_model
4 | from torch.utils.data import DataLoader
5 | from tqdm import tqdm
6 | from transformers import AdamW
7 | import torch
8 | import numpy as np
9 | import random
10 | import os
11 | from pycocoevalcap.cider.cider import Cider
12 |
13 | # Use CIDEr score to do validation
14 | def eval_cider(predicted_captions, gt_captions):
15 | cider_evaluator = Cider()
16 | predicted_captions_dict =dict()
17 | gt_captions_dict = dict()
18 | for i, caption in enumerate(predicted_captions):
19 | predicted_captions_dict[i] = [caption]
20 | for i, caption in enumerate(gt_captions):
21 | gt_captions_dict[i] = [caption]
22 | _, cider_scores = cider_evaluator.compute_score(predicted_captions_dict, gt_captions_dict)
23 | return cider_scores.tolist()
24 |
25 | def train(args):
26 | train_dataset = MatchVoice_Dataset(feature_root=args.feature_root, ann_root=args.train_ann_root,
27 | window=args.window, fps=args.fps, tokenizer_name=args.tokenizer_name, timestamp_key=args.train_timestamp_key)
28 | val_dataset = MatchVoice_Dataset(feature_root=args.feature_root, ann_root=args.val_ann_root,
29 | window=args.window, fps=args.fps, tokenizer_name=args.tokenizer_name, timestamp_key=args.val_timestamp_key)
30 |
31 | train_data_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, num_workers=args.train_num_workers, drop_last=False, shuffle=True, pin_memory=True, collate_fn=train_dataset.collater)
32 | val_data_loader = DataLoader(val_dataset, batch_size=args.val_batch_size, num_workers=args.val_num_workers, drop_last=True, shuffle=True, pin_memory=True, collate_fn=train_dataset.collater)
33 | print("===== Video features data loaded! =====")
34 | model = matchvoice_model(llm_ckpt=args.tokenizer_name, tokenizer_ckpt=args.tokenizer_name ,window=args.window, num_query_tokens=args.num_query_tokens, num_video_query_token=args.num_video_query_token, num_features=args.num_features, device=args.device).to(args.device)
35 | if args.continue_train:
36 | model.load_state_dict(torch.load(args.load_ckpt))
37 | optimizer = AdamW(model.parameters(), lr=args.lr)
38 | os.makedirs(args.model_output_dir, exist_ok=True)
39 | print("===== Model and Checkpoints loaded! =====")
40 |
41 | max_val_CIDEr = max(float(0), args.pre_max_CIDEr)
42 | for epoch in range(args.pre_epoch, args.num_epoch):
43 | model.train()
44 | train_loss_accum = 0.0
45 | train_pbar = tqdm(train_data_loader, desc=f'Epoch {epoch+1}/{args.num_epoch} Training')
46 | for samples in train_pbar:
47 |
48 | optimizer.zero_grad()
49 | try:
50 | loss = model(samples)
51 | loss.backward()
52 | optimizer.step()
53 | train_loss_accum += loss.item()
54 | train_pbar.set_postfix({"Loss": f"{loss.item():.4f}"})
55 | avg_train_loss = train_loss_accum / len(train_data_loader)
56 | except:
57 | pass
58 |
59 | model.eval()
60 | val_CIDEr = 0.0
61 | val_pbar = tqdm(val_data_loader, desc=f'Epoch {epoch+1}/{args.num_epoch} Validation')
62 | with torch.no_grad():
63 | for samples in val_pbar:
64 | temp_res_text, anonymized = model(samples, True)
65 | cur_CIDEr_score = eval_cider(temp_res_text,anonymized)
66 | val_CIDEr += sum(cur_CIDEr_score)/len(cur_CIDEr_score)
67 | val_pbar.set_postfix({"Scores": f"|C:{sum(cur_CIDEr_score)/len(cur_CIDEr_score):.4f}"})
68 |
69 | avg_val_CIDEr = val_CIDEr / len(val_data_loader)
70 | print(f"Epoch {epoch+1} Summary: Average Training Loss: {avg_train_loss:.3f}, Average Validation scores: C:{avg_val_CIDEr*100:.3f}")
71 |
72 | if epoch % 5 == 0:
73 | file_path = f"{args.model_output_dir}/model_save_{epoch+1}.pth"
74 | save_matchvoice_model(model, file_path)
75 |
76 | if avg_val_CIDEr > max_val_CIDEr:
77 | max_val_CIDEr = avg_val_CIDEr
78 | file_path = f"{args.model_output_dir}/model_save_best_val_CIDEr.pth"
79 | save_matchvoice_model(model, file_path)
80 |
81 | def save_matchvoice_model(model, file_path):
82 | state_dict = model.cpu().state_dict()
83 | state_dict_without_llama = {}
84 | # 遍历原始模型的 state_dict,并排除 llama_model 相关的权重
85 | for key, value in state_dict.items():
86 | if "llama_model.model.layers" not in key:
87 | state_dict_without_llama[key] = value
88 | torch.save(state_dict_without_llama, file_path)
89 | model.to(model.device)
90 |
91 | if __name__ == "__main__":
92 |
93 |
94 | torch.manual_seed(42)
95 | np.random.seed(42)
96 | random.seed(42)
97 | if torch.cuda.is_available():
98 | torch.cuda.manual_seed_all(42)
99 |
100 | parser = argparse.ArgumentParser(description="Train a model with FRANZ dataset.")
101 | parser.add_argument("--feature_root", type=str, default="./features/baidu_soccer_embeddings")
102 | parser.add_argument("--window", type=float, default=15)
103 | parser.add_argument("--tokenizer_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct")
104 |
105 | parser.add_argument("--train_ann_root", type=str, default="./dataset/MatchTime/train")
106 | parser.add_argument("--train_batch_size", type=int, default=32)
107 | parser.add_argument("--train_num_workers", type=int, default=32)
108 | parser.add_argument("--train_timestamp_key", type=str, default="gameTime")
109 |
110 | parser.add_argument("--val_ann_root", type=str, default="./dataset/MatchTime/valid")
111 | parser.add_argument("--val_batch_size", type=int, default=20)
112 | parser.add_argument("--val_num_workers", type=int, default=32)
113 | parser.add_argument("--val_timestamp_key", type=str, default="gameTime")
114 |
115 | parser.add_argument("--lr", type=float, default=1e-4)
116 | parser.add_argument("--num_epoch", type=int, default=80)
117 | parser.add_argument("--num_query_tokens", type=int, default=32)
118 | parser.add_argument("--num_video_query_token", type=int, default=32)
119 | parser.add_argument("--num_features", type=int, default=512)
120 | parser.add_argument("--fps", type=int, default=2)
121 | parser.add_argument("--model_output_dir", type=str, default="./ckpt")
122 | parser.add_argument("--device", type=str, default="cuda:0")
123 |
124 | # If continue training from any epoch
125 | parser.add_argument("--continue_train", type=bool, default=False)
126 | parser.add_argument("--pre_max_CIDEr", type=float, default=0.0)
127 | parser.add_argument("--pre_epoch", type=int, default=0)
128 | parser.add_argument("--load_ckpt", type=str, default="./ckpt/model_save_best_val_CIDEr.pth")
129 |
130 |
131 |
132 | args = parser.parse_args()
133 | train(args)
134 |
--------------------------------------------------------------------------------
/matchvoice_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from torch.utils.data import Dataset
4 | import torch
5 | import numpy as np
6 | import json
7 | from transformers import AutoTokenizer
8 | import copy
9 |
10 | IGNORE_INDEX = -100
11 |
12 | class MatchVoice_Dataset(Dataset):
13 | def __init__(self, feature_root, ann_root, window = 15, fps = 2, timestamp_key="gameTime",
14 | tokenizer_name = 'meta-llama/Meta-Llama-3-8B', max_token_length =128):
15 |
16 | self.caption = traverse_and_parse(ann_root, timestamp_key)
17 | self.feature_root = feature_root
18 | self.window = window
19 | self.fps = fps
20 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
21 | self.tokenizer.pad_token_id = 128001
22 | self.tokenizer.add_tokens(["[PLAYER]","[TEAM]","[COACH]","[REFEREE]","([TEAM])"], special_tokens=True)
23 | self.max_token_length = max_token_length
24 |
25 | def __getitem__(self, index):
26 | num_retries = 50
27 | fetched_features = None
28 | for _ in range(num_retries):
29 | try:
30 | half, timestamp, type, anonymized, league, game = self.caption[index]
31 | feature_folder = os.path.join(self.feature_root, league, game)
32 | file_paths = [os.path.join(feature_folder, file) for file in os.listdir(feature_folder) if file.startswith(str(half)) and file.endswith(".npy")]
33 | fetched_features = torch.from_numpy(load_adjusted_features(file_paths[0], timestamp, self.window, self.fps))
34 |
35 | anonymized_tokens = self.tokenizer(
36 | anonymized,
37 | return_tensors = "pt",
38 | max_length=self.max_token_length,
39 | truncation=True
40 | ).input_ids[0]
41 |
42 | except:
43 | index = random.randint(0, len(self) - 1)
44 | continue
45 | break
46 | else:
47 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
48 | return {
49 | "features": fetched_features,
50 | "tokens_input_ids": anonymized_tokens,
51 | "caption_info": self.caption[index]
52 | }
53 |
54 | def __len__(self):
55 | return len(self.caption)
56 |
57 | def collater(self, instances):
58 | input_ids = [
59 | torch.cat((torch.tensor([self.tokenizer.convert_tokens_to_ids("<|begin_of_text|>")]),
60 | instance["tokens_input_ids"],
61 | torch.tensor([self.tokenizer.convert_tokens_to_ids("<|end_of_text|>")]))) for instance in instances] # add end token
62 | labels = copy.deepcopy(input_ids)
63 | caption_info = [instance["caption_info"] for instance in instances]
64 | input_ids = torch.nn.utils.rnn.pad_sequence(
65 | input_ids,
66 | batch_first=True,
67 | padding_value=self.tokenizer.convert_tokens_to_ids("<|end_of_text|>"))
68 | labels = torch.nn.utils.rnn.pad_sequence(
69 | labels,
70 | batch_first=True,
71 | padding_value=IGNORE_INDEX)
72 |
73 | batch = dict(
74 | input_ids=input_ids,
75 | attention_mask=input_ids.ne(self.tokenizer.convert_tokens_to_ids("<|end_of_text|>")),
76 | labels=labels,
77 | caption_info=caption_info
78 | )
79 |
80 | if 'features' in instances[0]:
81 | features = [instance['features'] for instance in instances]
82 | if all(x is not None and x.shape == features[0].shape for x in features):
83 | batch['features'] = torch.stack(features)
84 | else:
85 | batch['features'] = features
86 | return batch
87 |
88 |
89 | def load_adjusted_features(feature_path, timestamp, window, fps=2):
90 | """
91 | Load and adjust video features based on the given timestamp and window.
92 |
93 | Args:
94 | - feature_path (str): The path to the .npy file containing video features.
95 | - timestamp (int): The target timestamp in seconds.
96 | - window (float): The window size in seconds.
97 |
98 | Returns:
99 | - np.array: The adjusted array of video features.
100 | """
101 | features = np.load(feature_path)
102 | total_frames = int(window * 2 * fps) # Total frames to extract
103 | if timestamp * fps > len(features):
104 | return None
105 |
106 | start_frame = int(max(0, timestamp - window) * fps + 1)
107 | end_frame = int((timestamp + window) * fps + 1)
108 | if end_frame > len(features):
109 | start_frame = int(max(0, len(features) - total_frames)) # Adjust to get the last total_frames
110 | ad = features[start_frame:start_frame+total_frames]
111 | return ad
112 |
113 | def parse_labels_caption(file_path, league, game, timestamp_key):
114 | """
115 | Parses a Labels-caption.json file and extracts the required data.
116 | Parameters:
117 | file_path (str): The path to the Labels-caption.json file.
118 | league (str): The league name.
119 | game (str): The game name.
120 | Returns:
121 | list: A list of tuples containing (half, timestamp, type, anonymized, league, game).
122 | """
123 | with open(file_path, 'r') as file:
124 | data = json.load(file)
125 |
126 | result = []
127 | for annotation in data.get('annotations', []):
128 | try:
129 | gameTime, _ = annotation.get(timestamp_key, ' - ').split(' - ')
130 | half = int(gameTime.split(' ')[0])
131 | if half not in [1, 2]:
132 | continue
133 | minutes, seconds = map(int, _.split(':'))
134 | timestamp = minutes * 60 + seconds
135 | label = annotation.get('label', '')
136 | anonymized = annotation.get('anonymized', '')
137 | result.append((half, timestamp, label, anonymized, league, game))
138 | except ValueError:
139 | continue
140 | # print(len(result))
141 | return result
142 |
143 | def traverse_and_parse(root_dir, timestamp_key):
144 | """
145 | Traverses a directory and its subdirectories to find and parse all Labels-caption.json files.
146 | Parameters:
147 | root_dir (str): The root directory to start traversal.
148 | Returns:
149 | list: A combined list of tuples from all Labels-caption.json files found.
150 | """
151 | all_data = []
152 | for subdir, dirs, files in os.walk(root_dir):
153 | for file in files:
154 | if file == 'Labels-caption.json' or file == "Labels-caption_with_gt.json" or file == "Labels-caption_event_aligned_with_contrastive.json":
155 | league = os.path.basename(os.path.dirname(subdir))
156 | game = os.path.basename(subdir)
157 | file_path = os.path.join(subdir, file)
158 | all_data.extend(parse_labels_caption(file_path, league, game, timestamp_key))
159 | return all_data
160 |
161 |
--------------------------------------------------------------------------------
/alignment/soccer_align_from_event.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModelForCausalLM
2 | import json, os
3 | import re, argparse
4 | from tqdm import tqdm
5 |
6 | def find_closest_keys(event, total_seconds):
7 | below = sorted((k for k in event.keys() if int(k) < total_seconds and int(k) >= max(0, total_seconds - 70)), key=lambda x: total_seconds - int(x))
8 | above = sorted((k for k in event.keys() if int(k) > total_seconds and int(k) <= total_seconds + 70), key=lambda x: int(x) - total_seconds)
9 |
10 | closest_below = below[:min(6, len(below))]
11 | closest_above = above[:min(3, len(above))]
12 |
13 | closest_keys = closest_below + closest_above
14 | return closest_keys
15 |
16 | def generate_prompt(caption_with_gt_path, event_folder):
17 | event_1 = None
18 | event_2 = None
19 | data = None
20 | try:
21 | with open(os.path.join(event_folder, "1_narrator_event.json"), 'r') as file:
22 | event_1 = json.load(file)
23 | with open(os.path.join(event_folder, "2_narrator_event.json"), 'r') as file:
24 | event_2 = json.load(file)
25 | with open(caption_with_gt_path, 'r') as file:
26 | data = json.load(file)
27 | except:
28 | print("Load error")
29 | pass
30 |
31 | all_prompts = []
32 | for annotation in data['annotations']:
33 | if annotation['gt_gameTime'] != "":
34 | half, time_str = annotation['gameTime'].split(' - ')
35 | minutes, seconds = map(int, time_str.split(':'))
36 | total_seconds = minutes * 60 + seconds
37 | half, time_str = annotation['gt_gameTime'].split(' - ')
38 | minutes, seconds = map(int, time_str.split(':'))
39 | total_gt_seconds = minutes * 60 + seconds
40 | description = annotation['description']
41 |
42 | if half == "1":
43 | current_event = event_1
44 | elif half == "2":
45 | current_event = event_2
46 | time_stamp_candidates = find_closest_keys(current_event, total_seconds)
47 |
48 | prompt = f"I have a text commentary of a soccer game event at the original time stamp:\n\n{total_seconds}: {description}\n\nand I want to locate the time of this commentary among the following events with timestamp:\n"
49 |
50 | for time_stamp in time_stamp_candidates:
51 | prompt = prompt + f"{int(time_stamp)-5}-{int(time_stamp)+5}: {current_event[time_stamp]}\n"
52 |
53 | prompt = prompt + "These are the words said by narrator and I want you to temporally align the first text commentary according to these words by narrators since there is a fair chance that the original timestamp is somehow inaccurate in time. So please return me with a number of time stamp that event is most likely to happen. I hope that you can choose a number of time stamp from the ranges of candidates. But if really none of the candidates is suitable, you can just return me with the original time stamp. Your answer is:"
54 |
55 | all_prompts.append((half, total_seconds, description, prompt, total_gt_seconds))
56 | return all_prompts
57 |
58 | parser = argparse.ArgumentParser(description='Process ASR data for football matches.')
59 | parser.add_argument('--device', type=str, default="cuda:5", help='Device to run the model on.')
60 | parser.add_argument('--model_name', type=str, default="meta-llama/Meta-Llama-3-8B", help='Path to the pretrained model.')
61 | parser.add_argument('--event_path', type=str, help='Base path to the seasons directory.')
62 | parser.add_argument('--caption_with_gt_dir', type=str, default="./dataset/SN-Caption-test-align", help='Base path to the seasons directory.')
63 | parser.add_argument('--output_dir', type=str, help='Output directory for processed JSON files.')
64 | args = parser.parse_args()
65 |
66 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
67 | model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map=args.device)
68 | model.config.pad_token_id = model.config.eos_token_id
69 | tokenizer.pad_token = tokenizer.eos_token
70 |
71 |
72 | def align_from_event(event_folder, caption_with_gt_path, output_json_path, model, tokenizer, device):
73 |
74 | all_gt = []
75 | all_aligned = []
76 | all_description = []
77 | all_prompts = None
78 | json_contents = None
79 | try:
80 | with open(caption_with_gt_path, 'r') as file:
81 | json_contents = json.load(file)
82 | all_prompts = generate_prompt(caption_with_gt_path, event_folder)
83 | except:
84 | print("Erroe with loading1:", event_folder)
85 | return
86 |
87 | for half, original_time, description, prompt, gt_time in all_prompts:
88 | try:
89 | input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
90 | generated_ids = model.generate(input_ids, max_length = 800)
91 | generated_res = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
92 | answer = re.search(r'Your answer is: (\d+)', generated_res).group(1)
93 | original_time_2d = "{:02d}:{:02d}".format(*divmod(int(original_time), 60))
94 | answer_2d = "{:02d}:{:02d}".format(*divmod(int(answer), 60))
95 | gt_time_2d = "{:02d}:{:02d}".format(*divmod(int(gt_time), 60))
96 | print(f"{half} - {original_time_2d}", f"{half} - {answer_2d}", f"{half} - {gt_time_2d}")
97 |
98 | all_aligned.append(f"{half} - {answer_2d}")
99 | all_gt.append(f"{half} - {gt_time_2d}")
100 | all_description.append(description)
101 | except:
102 | print("Erroe with loading:", event_folder)
103 |
104 | for annotation in json_contents['annotations']:
105 | annotation['event_aligned_gameTime'] = ""
106 | for gt, aligned, description in zip(all_gt, all_aligned, all_description):
107 | for annotation in json_contents['annotations']:
108 | if annotation['gt_gameTime'] == gt and annotation['description'] == description:
109 | annotation['event_aligned_gameTime'] = aligned
110 | with open(output_json_path, 'w') as outfile:
111 | json.dump(json_contents, outfile, indent=4)
112 |
113 | print(f'Updated JSON file has been saved to {output_json_path}')
114 |
115 |
116 | tasks = []
117 | league_year_name = os.path.basename(args.event_path)
118 | print("Processing:",league_year_name)
119 |
120 | for match in os.listdir(args.event_path):
121 | match_path = os.path.join(args.event_path, match)
122 | if os.path.isdir(match_path):
123 | gt_json_path = os.path.join(args.caption_with_gt_dir, league_year_name, match, "Labels-caption_with_gt.json")
124 | output_json_path = os.path.join(args.output_dir, league_year_name, match, "Labels-caption_event_aligned_with_gt.json")
125 | tasks.append((match_path, gt_json_path, output_json_path))
126 |
127 | for match_path, gt_json_path, output_json_path in tqdm(tasks, desc="Processing files"):
128 | if os.path.exists(output_json_path):
129 | continue
130 | os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
131 | try:
132 | print(match_path, gt_json_path, output_json_path)
133 | align_from_event(match_path, gt_json_path, output_json_path, model, tokenizer, args.device)
134 | except:
135 | print("Error with", match_path)
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MatchTime: Towards Automatic Soccer Game Commentary Generation (EMNLP 2024 Oral)
2 | This repository contains the official PyTorch implementation of MatchTime: https://arxiv.org/abs/2406.18530/
3 |
4 |
5 |

6 |
7 |
8 |
9 |

10 |
11 |
12 | ## Some Information
13 | [Project Page](https://haoningwu3639.github.io/MatchTime/) $\cdot$ [Paper](https://arxiv.org/abs/2406.18530/) $\cdot$ [Dataset](https://drive.google.com/drive/folders/14tb6lV2nlTxn3VygwAPdmtKm7v0Ss8wG) $\cdot$ [Checkpoint](https://huggingface.co/Homie0609/MatchVoice) $\cdot$ [Demo Video (YouTube)](https://www.youtube.com/watch?v=E3RxHR-M6y0) $\cdot$ [Demo Video (bilibili)](https://www.bilibili.com/video/BV1L4421U76m)
14 |
15 | ## Requirements
16 | - Python >= 3.8 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
17 | - [PyTorch >= 2.0.0](https://pytorch.org/) (If use A100)
18 | - transformers >= 4.42.3
19 | - pycocoevalcap >= 1.2
20 |
21 | A suitable [conda](https://conda.io/) environment named `matchtime` can be created and activated with:
22 | ```
23 | cd MatchTime
24 | conda env create -f environment.yaml
25 | conda activate matchtime
26 | ```
27 |
28 | ## Training
29 | Before training, make sure you have prepared [features](https://pypi.org/project/SoccerNet/) and caption [data]((https://drive.google.com/drive/folders/14tb6lV2nlTxn3VygwAPdmtKm7v0Ss8wG)), and put them into according folders. The structure after collating should be like:
30 | ``````
31 | └─ MatchTime
32 | ├─ dataset
33 | │ ├─ MatchTime
34 | │ │ ├─ valid
35 | │ │ └─ train
36 | │ │ ├─ england_epl_2014-2015
37 | │ │ ... ├─ 2015-02-21 - 18-00 Chelsea 1 - 1 Burnley
38 | │ │ ... └─ Labels-caption.json
39 | │ │
40 | │ ├─ SN-Caption
41 | │ └─ SN-Caption-test-align
42 | │ ├─ england_epl_2015-2016
43 | │ ... ├─ 2015-08-16 - 18-00 Manchester City 3 - 0 Chelsea
44 | │ ... └─ Labels-caption_with_gt.json
45 | │
46 | ├─ features
47 | │ ├─ baidu_soccer_embeddings
48 | │ │ ├─ england_epl_2014-2015
49 | ... │ ... ├─ 2015-02-21 - 18-00 Chelsea 1 - 1 Burnley
50 | │ ... ├─ 1_baidu_soccer_embeddings.npy
51 | │ └─ 2_baidu_soccer_embeddings.npy
52 | ├─ C3D_PCA512
53 | ...
54 | ``````
55 | with the format of features is adjusted by
56 | ```
57 | python ./features/preprocess.py directory_path_of_feature
58 | ```
59 | Above example gives the format of Baidu feature, in our experiments we also used ResNET_PCA_512, C3D_PCA_512 from official website. If you want to use [CLIP](https://github.com/openai/CLIP)(2 FPS) or [InternVideo](https://github.com/OpenGVLab/InternVideo/tree/main/InternVideo1)(1FPS) feature. You can follow their official website to extract feature or contact us for features.
60 |
61 | After preparing the data and features, you can pre-train (or finetune) with the following terminal command (Check hyper-parameters at the bottom of *train.py*):
62 | ```
63 | python train.py
64 | ```
65 | ## Inference
66 |
67 | We provide two types of inference:
68 |
69 | #### For all test set
70 |
71 | You can generate a *.csv* file with the following code to test the ***MatchVoice*** model with the following code (Check hyper-parameters at the bottom of *inference.py*)
72 |
73 | ```
74 | python inference.py
75 | ```
76 |
77 | There is a sample of this type of inference in *./inference_result/sample.csv*.
78 |
79 | #### For Single Video
80 |
81 | We also provide a version for predict the commentary single video (for our checkpoints, use 30s video)
82 | ```
83 | python inference_single_video_CLIP.py single_video_path
84 | ```
85 | Here we only provide the version of CLIP feature (using VIT/B-32), for crop the CLIP feature, please check [here](https://github.com/openai/CLIP). CLIP features are not the one with best performance but are the most friendly for new new videos.
86 |
87 | ## Alignment
88 |
89 | Before doing alignment, you should download videos from [here](https://www.soccer-net.org/data) (224p is enough) and make it in the following format:
90 |
91 | ``````
92 | └─ MatchTime
93 | ├─ videos_224p
94 | ... ├─ england_epl_2014-2015
95 | ... ├─ 2015-02-21 - 18-00 Chelsea 1 - 1 Burnley
96 | ... ├─ 1_224.mkv
97 | └─ 2_224p.mkv
98 | ``````
99 |
100 | ### Pre-process (Coarse Align)
101 |
102 | We need to use [WhisperX](https://github.com/m-bain/whisperX) and [LLaMA3](https://huggingface.co/docs/transformers/model_doc/llama3) (as agent) to finish coarse alignment with following steps:
103 |
104 | *WhisperX ASR:*
105 | ```
106 | python ./alignment/soccer_whisperx.py --process_directory video_folder(eg. ./videos_224p/england_epl_2014-2015) --output_directory output_folder(eg. ./ASR_results/england_epl_2014-2015)
107 | ```
108 | *Transform to Events:*
109 | ```
110 | python ./alignment/soccer_asr2events.py --base_path ASR_results_folder(eg. ./ASR_results/england_epl_2014-2015) --output_dir envent_results_folder(eg. ./event_results/england_epl_2014-2015)
111 | ```
112 |
113 | *Align from Events:*
114 | ```
115 | python ./alignment/soccer_align_from_event.py --event_path envent_results_folder(eg. ./event_results/england_epl_2014-2015) --output_dir output_directory(eg. ./pre-processed/england_epl_2014-2015)
116 | ```
117 |
118 | More details could be checked in paper.
119 |
120 | ### Contrastive Learning (Fine-grained Align)
121 |
122 | After downloading checkpoints from [here](https://huggingface.co/Homie0609/MatchTime/tree/main). Use the following code to finish alignment with contrastive learning:
123 | ```
124 | python ./alignment/do_alignment.py
125 | ```
126 | By changing the hyper-parameter ***finding_words***, you can freely align from ASR, enent, or original SN-Caption.
127 |
128 | Also, you can directly use alignment model by
129 | ```
130 | from alignment.matchtime_model import ContrastiveLearningModel
131 | ```
132 |
133 | ## Evaluation
134 | We provide codes for evaluate the prediction results:
135 | ```
136 | # for single csv file
137 | python ./evaluation/scoer_single.py --csv_path ./inference_result/sample.csv
138 | # for many csv files to record scores in a new csv file
139 | python ./evaluation/scoer_group.py
140 | # for gpt score (need OpenAI API Key)
141 | python ./evaluation/scoer_gpt.py ./inference_result/sample.csv
142 | ```
143 |
144 | ## TODO
145 | - [x] Commentary Model & Training & Inference Code
146 | - [x] Release Checkpoints
147 | - [x] Release Meta Data
148 | - [x] Alignment Model & Training & Inference Code
149 | - [x] Evaluation Code
150 | - [x] Release Demo
151 |
152 | ## Citation
153 | If you use this code for your research or project, please cite:
154 |
155 | @inproceedings{rao2024matchtimeautomaticsoccergame,
156 | title = {MatchTime: Towards Automatic Soccer Game Commentary Generation},
157 | author = {Rao, Jiayuan and Wu, Haoning and Liu, Chang and Wang, Yanfeng and Xie, Weidi},
158 | booktitle = {Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing},
159 | year = {2024}
160 | }
161 |
162 | ## Acknowledgements
163 | Many thanks to the code bases from [Video-LLaMA](https://github.com/DAMO-NLP-SG/Video-LLaMA) and source data from [SoccerNet-Caption](https://arxiv.org/abs/2304.04565).
164 |
165 | ## Contact
166 | If you have any questions, please feel free to contact jy_rao@sjtu.edu.cn or haoningwu3639@gmail.com.
167 |
--------------------------------------------------------------------------------
/models/matchvoice_model.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModelForCausalLM
2 | import torch
3 | from torch import nn
4 | import einops
5 | import contextlib
6 | from models.Qformer import BertConfig, BertLMHeadModel
7 | from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList
8 | from typing import List
9 | import pickle as pkl
10 | import sys
11 | import io
12 |
13 | def process_output_tokens(predict_model, tokens):
14 | output_texts = []
15 | for output_token in tokens:
16 | output_text = predict_model.tokenizer.decode(output_token)
17 | end_token_index = output_text.find('<|end_of_text|>')
18 | if end_token_index != -1:
19 | output_text = output_text[:end_token_index]
20 | output_texts.append(output_text)
21 | return output_texts
22 |
23 | class RestrictTokenGenerationLogitsProcessor(LogitsProcessor):
24 | def __init__(self, allowed_token_id_list: List[int]):
25 | super().__init__()
26 | self.allowed_token_id_list = allowed_token_id_list
27 |
28 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
29 | mask = torch.full_like(scores, -float('inf'))
30 | for allowed_id in self.allowed_token_id_list:
31 | mask[:, allowed_id] = scores[:, allowed_id]
32 | return mask
33 |
34 | class matchvoice_model(nn.Module):
35 | def __init__(self,
36 | # LLM part
37 | llm_ckpt = "meta-llama/Meta-Llama-3-8B",
38 | tokenizer_ckpt = "meta-llama/Meta-Llama-3-8B",
39 | # Q-former part
40 | max_frame_pos = 128,
41 | window = 30,
42 | num_query_tokens = 32,
43 | num_video_query_token = 32,
44 | num_features = 512,
45 | device = "cuda:0",
46 | inference = False,
47 | **kwargs,
48 | ):
49 | super().__init__()
50 | if len(kwargs):
51 | print(f'kwargs not used: {kwargs}')
52 | self.device = device
53 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_ckpt)
54 | self.tokenizer.add_tokens(["[PLAYER]","[TEAM]","[COACH]","[REFEREE]","([TEAM])"], special_tokens=True)
55 | self.llama_model = AutoModelForCausalLM.from_pretrained(llm_ckpt, torch_dtype=torch.bfloat16)
56 | self.llama_model.resize_token_embeddings(len(self.tokenizer))
57 | self.ln_vision = LayerNorm(num_features)
58 | self.num_query_tokens = num_query_tokens,
59 | self.num_video_query_token = num_video_query_token
60 | self.inference = inference
61 |
62 | # Initialize video Q-former
63 | self.video_Qformer,self.video_query_tokens = self.init_video_Qformer(num_query_token = num_video_query_token,
64 | vision_width=num_features,
65 | num_hidden_layers =2)
66 | self.video_Qformer.cls = None
67 | self.video_Qformer.bert.embeddings.word_embeddings = None
68 | self.video_Qformer.bert.embeddings.position_embeddings = None
69 | for layer in self.video_Qformer.bert.encoder.layer:
70 | layer.output = None
71 | layer.intermediate = None
72 |
73 | # llama projection
74 | self.llama_proj = nn.Linear(
75 | self.video_Qformer.config.hidden_size, self.llama_model.config.hidden_size
76 | )
77 | # video frame positional embedding
78 | self.video_frame_position_embedding = nn.Embedding(max_frame_pos, num_features)
79 | self.window = window
80 |
81 | # move to device
82 | self.llama_model = self.llama_model.to(self.device)
83 | for name, param in self.llama_model.named_parameters():
84 | param.requires_grad = False
85 | self.video_Qformer = self.video_Qformer.to(self.device)
86 | self.llama_proj = self.llama_proj.to(self.device)
87 | self.ln_vision = self.ln_vision.to(self.device)
88 | for name, param in self.ln_vision.named_parameters():
89 | param.requires_grad = False
90 | self.ln_vision = self.ln_vision.eval()
91 | self.video_frame_position_embedding = self.video_frame_position_embedding.to(self.device)
92 |
93 | # Here is a trick for inference that generates soccer relevant, you can delete this LogitsProcessorList part (including in generation function)
94 | file_path = './soccer_words_llama3.pkl'
95 | with open(file_path, 'rb') as file:
96 | self.token_ids_list = pkl.load(file)
97 | self.token_ids_list.append(128000)
98 | self.token_ids_list.append(128001)
99 | self.processor = RestrictTokenGenerationLogitsProcessor(allowed_token_id_list=self.token_ids_list)
100 | self.logits_prosessors = LogitsProcessorList()
101 | self.logits_prosessors.append(self.processor)
102 |
103 | @classmethod
104 | def init_video_Qformer(cls, num_query_token, vision_width, num_hidden_layers =2):
105 | encoder_config = BertConfig.from_pretrained("bert-base-uncased")
106 | encoder_config.num_hidden_layers = num_hidden_layers
107 | encoder_config.encoder_width = vision_width
108 | # insert cross-attention layer every other block
109 | encoder_config.add_cross_attention = True
110 | encoder_config.cross_attention_freq = 1
111 | encoder_config.query_length = num_query_token
112 | Qformer = BertLMHeadModel(config=encoder_config)
113 | query_tokens = nn.Parameter(
114 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
115 | )
116 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
117 | return Qformer, query_tokens
118 |
119 |
120 | def maybe_autocast(self, dtype=torch.float16):
121 | enable_autocast = self.device != torch.device("cpu")
122 | if enable_autocast:
123 | return torch.cuda.amp.autocast(dtype=dtype)
124 | else:
125 | return contextlib.nullcontext()
126 |
127 | def forward(self, samples, validating=False):
128 | video_features = samples['features'].to(self.device)
129 | targets = samples['labels']
130 | atts_llama = samples['attention_mask']
131 | inputs_ids = samples['input_ids']
132 | # print(samples["caption_info"])
133 | batch_size = None
134 | time_length = None
135 | try:
136 | batch_size, time_length, _ = video_features.size()
137 | except:
138 | batch_size, time_length, _, _ = video_features.size()
139 |
140 | if len(video_features.size()) != 4:
141 | video_features = video_features.unsqueeze(-2)
142 | video_features = self.ln_vision(video_features)
143 | video_features = einops.rearrange(video_features, 'b t n f -> (b t) n f', b=batch_size, t=time_length)
144 |
145 | position_ids = torch.arange(time_length, dtype=torch.long, device=video_features.device)
146 | position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
147 | frame_position_embeddings = self.video_frame_position_embedding(position_ids)
148 | frame_position_embeddings = frame_position_embeddings.unsqueeze(-2)
149 | frame_hidden_state = einops.rearrange(video_features, '(b t) n f -> b t n f',b=batch_size,t=time_length)
150 | frame_hidden_state = frame_position_embeddings + frame_hidden_state
151 |
152 | frame_hidden_state = einops.rearrange(frame_hidden_state, 'b t q h -> b (t q) h',b=batch_size,t=time_length)
153 | frame_atts = torch.ones(frame_hidden_state.size()[:-1], dtype=torch.long).to(frame_hidden_state)
154 | video_query_tokens = self.video_query_tokens.expand(frame_hidden_state.shape[0], -1, -1).to(frame_hidden_state.device)
155 |
156 | video_query_output = self.video_Qformer.bert(
157 | query_embeds=video_query_tokens,
158 | encoder_hidden_states=frame_hidden_state,
159 | encoder_attention_mask=frame_atts,
160 | return_dict=True,
161 | )
162 | video_hidden = video_query_output.last_hidden_state
163 |
164 | inputs_llama = self.llama_proj(video_hidden)
165 | if self.inference:
166 | return self.generate_text(inputs_llama)
167 |
168 | if validating:
169 | temp_res_text = self.generate_text(inputs_llama)
170 | anonymized = [sublist[3] for sublist in samples["caption_info"]]
171 | return temp_res_text, anonymized
172 |
173 | visual_label = torch.full((batch_size, self.num_video_query_token), -100, dtype=targets.dtype)
174 | concat_targets = torch.cat((visual_label, targets), dim=1).to(self.device)
175 | temp_input_ids = inputs_ids.clone().to(self.device)
176 | targets_embeds = self.llama_model.model.embed_tokens(temp_input_ids)
177 | embedding_cat = torch.cat((inputs_llama, targets_embeds), dim=1)
178 | mask_prefix = torch.ones(batch_size, self.num_video_query_token, dtype=atts_llama.dtype)
179 | mask = torch.concat((mask_prefix, atts_llama), dim=1).to(self.device)
180 |
181 | original_stdout = sys.stdout
182 | sys.stdout = io.StringIO()
183 | with self.maybe_autocast():
184 | outputs = self.llama_model(
185 | inputs_embeds=embedding_cat,
186 | attention_mask=mask,
187 | return_dict=True,
188 | labels=concat_targets,
189 | )
190 | sys.stdout = original_stdout
191 | loss = outputs.loss
192 | return loss
193 |
194 | def generate_text(self, inputs_llama):
195 | start_embeds = self.llama_model.model.embed_tokens(torch.tensor([128000]).to(self.device))
196 | inputs_llama_with_s = torch.cat([inputs_llama, start_embeds.expand(inputs_llama.size(0), -1, -1)], dim=1).to(dtype=torch.bfloat16)
197 | temp_res_tokens = self.llama_model.generate(
198 | logits_processor=self.logits_prosessors,
199 | renormalize_logits=True,
200 | inputs_embeds=inputs_llama_with_s,
201 | max_new_tokens=128,
202 | num_beams=5,
203 | do_sample=True,
204 | min_length=5,
205 | top_p=0.9,
206 | repetition_penalty=1.0,
207 | length_penalty=1,
208 | temperature=1.0,
209 | )
210 | res_text = process_output_tokens(self, temp_res_tokens)
211 | return res_text
212 |
213 | class LayerNorm(nn.LayerNorm):
214 | def forward(self, x: torch.Tensor):
215 | orig_type = x.dtype
216 | ret = super().forward(x.type(torch.float32))
217 | return ret.type(orig_type)
--------------------------------------------------------------------------------
/alignment/do_alignment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | import os
5 | import random
6 | from tqdm import tqdm
7 | import argparse
8 | import json
9 | import clip
10 | import glob
11 | from matchtime_model import ContrastiveLearningModel
12 |
13 | def parse_labels_caption_without_gt(file_path, league, game, finding_words):
14 | """
15 | Parses a Labels-caption.json file and extracts the required data.
16 | Parameters:
17 | file_path (str): The path to the Labels-caption.json file.
18 | league (str): The league name.
19 | game (str): The game name.
20 | Returns:
21 | list: A list of tuples containing (half, timestamp, type, anonymized, league, game).
22 | """
23 | with open(file_path, 'r') as file:
24 | data = json.load(file)
25 |
26 | result = []
27 | for annotation in data.get('annotations', []):
28 | try:
29 | gameTime, _ = annotation.get(finding_words, ' - ').split(' - ')
30 | half = int(gameTime.split(' ')[0])
31 | if half not in [1, 2]:
32 | continue
33 | minutes, seconds = map(int, _.split(':'))
34 | timestamp = minutes * 60 + seconds
35 | label = annotation.get('label', '')
36 | anonymized = annotation.get('anonymized', '')
37 | result.append((half, timestamp, label, anonymized, league, game))
38 | except ValueError:
39 | continue
40 | return result
41 |
42 | class TimeStampDataset_Align(Dataset):
43 | # This needs CLIP feature for all videos at 2FPS, check https://github.com/openai/CLIP for details to get faetures. Videos source could be get from https://www.soccer-net.org/data
44 | def __init__(self,
45 | feature_root = "./features/CLIP",
46 | ann_path = "./dataset/SN-caption/train/england_epl_2014-2015/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley/Labels-caption.json",
47 | fps = 2,
48 | window = 45,
49 | finding_words = "gameTime"
50 | ):
51 | league = os.path.basename(os.path.dirname(os.path.dirname(ann_path)))
52 | game = os.path.basename(os.path.dirname(ann_path))
53 | self.caption = parse_labels_caption_without_gt(ann_path, league, game, finding_words)
54 | self.feature_root = feature_root
55 | feature_folder = os.path.join(self.feature_root, league, game)
56 | file_path_1 = [os.path.join(feature_folder, file) for file in os.listdir(feature_folder) if file.startswith("1") and file.endswith(".npy")][0]
57 | file_path_2 = [os.path.join(feature_folder, file) for file in os.listdir(feature_folder) if file.startswith("2") and file.endswith(".npy")][0]
58 | self.features_1 = np.load(file_path_1)
59 | self.features_2 = np.load(file_path_2)
60 |
61 | self.fps = fps
62 | self.window = window
63 | self.candidate_intervals = list(range(-self.window, self.window + 1))
64 |
65 |
66 | def __len__(self):
67 | return len(self.caption)
68 |
69 | def __getitem__(self, index):
70 | num_retries = 10 # skip error videos
71 | for _ in range(num_retries):
72 | try:
73 | half, timestamp, type, anonymized, league, game = self.caption[index]
74 |
75 | candidate_features = []
76 | anchor_caption = None
77 | feature_timestamp = timestamp * self.fps
78 | if half == 1:
79 | for t in self.candidate_intervals:
80 | feature_offset = t*self.fps
81 | candidate_feature = torch.from_numpy(self.features_1[feature_timestamp+feature_offset-1:feature_timestamp+feature_offset, :])
82 | assert candidate_feature.shape[0] == 1
83 | candidate_features.append(candidate_feature)
84 | elif half == 2:
85 | for t in self.candidate_intervals:
86 | feature_offset = t*self.fps
87 | candidate_feature = torch.from_numpy(self.features_2[feature_timestamp+feature_offset-1:feature_timestamp+feature_offset, :])
88 | assert candidate_feature.shape[0] == 1
89 | candidate_features.append(candidate_feature)
90 |
91 | anchor_caption = anonymized
92 | anchor_timestamp = timestamp
93 | assert len(self.candidate_intervals) == len(candidate_features)
94 |
95 | except:
96 | index = random.randint(0, len(self) - 1)
97 | continue
98 | break
99 | else:
100 | raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
101 | batch = dict(
102 | candidate_features = torch.stack(candidate_features),
103 | anchor_caption = torch.tensor(clip.tokenize(anchor_caption, context_length=128)),
104 | anchor_timestamp = anchor_timestamp,
105 |
106 | anchor_half = half,
107 | anchor_gameTime = f"{half} - {timestamp // 60:02}:{timestamp % 60:02}",
108 | anchor_caption_text = anonymized
109 | )
110 | return batch
111 |
112 | def collator(self, instances):
113 | anchor_captions = torch.stack([instance["anchor_caption"][0][:77] for instance in instances])
114 | candidate_features = torch.stack([instance['candidate_features'] for instance in instances])
115 | anchor_timestamp = [instance['anchor_timestamp'] for instance in instances]
116 |
117 | anchor_gameTime = [instance['anchor_gameTime'] for instance in instances]
118 | anchor_half = [instance['anchor_half'] for instance in instances]
119 | anchor_caption_text = [instance['anchor_caption_text'] for instance in instances]
120 |
121 | batch = dict(
122 | candidate_features = candidate_features.to(torch.bfloat16),
123 | anchor_caption = anchor_captions,
124 | anchor_timestamp = anchor_timestamp,
125 |
126 | anchor_gameTime = anchor_gameTime,
127 | anchor_half = anchor_half,
128 | anchor_caption_text = anchor_caption_text
129 | )
130 | return batch
131 |
132 | def give_sec(gameTime_input):
133 | try:
134 | gameTime, _ = gameTime_input.split(' - ')
135 | half = int(gameTime.split(' ')[0])
136 | if half not in [1, 2]:
137 | return None
138 | minutes, seconds = map(int, _.split(':'))
139 | timestamp = minutes * 60 + seconds
140 | return timestamp
141 | except:
142 | return None
143 |
144 |
145 | def contrastive_align(model, ann_path, device, window, output_json_path, finding_words):
146 | align_dataset = TimeStampDataset_Align(ann_path=ann_path, window=window, finding_words=finding_words)
147 | align_dataloader = DataLoader(align_dataset, batch_size=100, shuffle=False, collate_fn=align_dataset.collator, pin_memory=True)
148 | model.eval()
149 | all_results = []
150 | for batch in align_dataloader:
151 | candidate_features = batch['candidate_features'].to(device=device, dtype=torch.bfloat16)
152 | anchor_caption = batch['anchor_caption'].to(device=device)
153 | anchor_timestamp = batch['anchor_timestamp']
154 | anchor_gameTime = batch['anchor_gameTime']
155 | anchor_caption_text = batch['anchor_caption_text']
156 | anchor_half = batch['anchor_half']
157 |
158 | _, logits = model(anchor_caption, candidate_features)
159 | _, max_indices = torch.max(logits, dim=1)
160 | aligned_res = [time + pos_in_list.item()-window for pos_in_list, time in zip(max_indices, anchor_timestamp)]
161 | modified_timelist = [(gameTime, caption_text, f"{anchor_half} - {aligned_result // 60:02}:{aligned_result % 60:02}") for gameTime, caption_text, aligned_result, anchor_half in zip(anchor_gameTime, anchor_caption_text, aligned_res, anchor_half)]
162 | print("Should modify",len(modified_timelist))
163 | all_results.extend(modified_timelist)
164 |
165 | os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
166 | os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
167 |
168 | with open(ann_path, 'r') as file:
169 | data = json.load(file)
170 | for annotation in data['annotations']:
171 | annotation['contrastive_aligned_gameTime'] = ""
172 |
173 | for annotation in data['annotations']:
174 | for gameTime, caption_text, aligned_gameTime in all_results:
175 | if annotation[finding_words] == gameTime and annotation['anonymized'] == caption_text:
176 | annotation['contrastive_aligned_gameTime'] = aligned_gameTime
177 |
178 | if finding_words == "event_aligned_gameTime" and give_sec(annotation['event_aligned_gameTime']) and give_sec(annotation['contrastive_aligned_gameTime']) and give_sec(annotation['gameTime']):
179 | contrastive_aligned_sec = give_sec(annotation['contrastive_aligned_gameTime'])
180 | original_sec = give_sec(annotation['gameTime'])
181 | event_aligned_sec = give_sec(annotation['event_aligned_gameTime'])
182 | print(event_aligned_sec)
183 | if abs(contrastive_aligned_sec - original_sec) > 45 and abs(event_aligned_sec - original_sec) <= 15:
184 | annotation['contrastive_aligned_gameTime'] = annotation['event_aligned_gameTime']
185 |
186 | with open(output_json_path, 'w') as file:
187 | json.dump(data, file, indent=4)
188 |
189 |
190 | def main(args):
191 | model = ContrastiveLearningModel(device=args.device)
192 | model.load_state_dict(torch.load(args.ckpt_path))
193 | json_files = glob.glob(os.path.join(args.ann_root, '**/*.json'), recursive=True)
194 | absolute_paths = [os.path.abspath(path) for path in json_files]
195 | replaced_paths = [os.path.join(os.path.dirname(path.replace(os.path.abspath(args.ann_root), os.path.abspath(args.json_out_dir))), "Labels-caption.json") for path in absolute_paths]
196 | for original_path, output_path in tqdm(zip(absolute_paths, replaced_paths)):
197 | try:
198 | contrastive_align(model, original_path, args.device, args.window, output_path, args.finding_words)
199 | except:
200 | pass
201 |
202 |
203 | if __name__ == "__main__":
204 | parser = argparse.ArgumentParser(description="Train a Contrastive Learning Model")
205 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for training')
206 | parser.add_argument('--feature_dim', type=int, default=512, help='Feature dimension size for input')
207 | parser.add_argument('--embedding_dim', type=int, default=128, help='Embedding dimension size for output')
208 | parser.add_argument('--ckpt_path', type=str, default='./ckpt/matchtime.pth')
209 | parser.add_argument('--window', type=int, default=45)
210 | parser.add_argument('--json_out_dir', type=str, default="./dataset/matchtime_aligned/train")
211 | parser.add_argument('--ann_root', type=str, default="./dataset/SN-Caption/train")
212 | parser.add_argument('--finding_words', type=str, default="gameTime")
213 |
214 |
215 | args = parser.parse_args()
216 | main(args)
217 |
--------------------------------------------------------------------------------
/models/Qformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from salesforce@LAVIS. Below is the original copyright:
3 | * Copyright (c) 2023, salesforce.com, inc.
4 | * All rights reserved.
5 | * SPDX-License-Identifier: BSD-3-Clause
6 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7 | * By Junnan Li
8 | * Based on huggingface code base
9 | * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
10 | """
11 |
12 | import math
13 | import os
14 | import warnings
15 | from dataclasses import dataclass
16 | from typing import Optional, Tuple, Dict, Any
17 |
18 | import torch
19 | from torch import Tensor, device, dtype, nn
20 | import torch.utils.checkpoint
21 | from torch import nn
22 | from torch.nn import CrossEntropyLoss
23 | import torch.nn.functional as F
24 |
25 | from transformers.activations import ACT2FN
26 | from transformers.file_utils import (
27 | ModelOutput,
28 | )
29 | from transformers.modeling_outputs import (
30 | BaseModelOutputWithPastAndCrossAttentions,
31 | BaseModelOutputWithPoolingAndCrossAttentions,
32 | CausalLMOutputWithCrossAttentions,
33 | MaskedLMOutput,
34 | MultipleChoiceModelOutput,
35 | NextSentencePredictorOutput,
36 | QuestionAnsweringModelOutput,
37 | SequenceClassifierOutput,
38 | TokenClassifierOutput,
39 | )
40 | from transformers.modeling_utils import (
41 | PreTrainedModel,
42 | apply_chunking_to_forward,
43 | find_pruneable_heads_and_indices,
44 | prune_linear_layer,
45 | )
46 | from transformers.utils import logging
47 | from transformers.models.bert.configuration_bert import BertConfig
48 |
49 | logger = logging.get_logger(__name__)
50 |
51 |
52 | class BertEmbeddings(nn.Module):
53 | """Construct the embeddings from word and position embeddings."""
54 |
55 | def __init__(self, config):
56 | super().__init__()
57 | self.word_embeddings = nn.Embedding(
58 | config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59 | )
60 | self.position_embeddings = nn.Embedding(
61 | config.max_position_embeddings, config.hidden_size
62 | )
63 |
64 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65 | # any TensorFlow checkpoint file
66 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
68 |
69 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70 | self.register_buffer(
71 | "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72 | )
73 | self.position_embedding_type = getattr(
74 | config, "position_embedding_type", "absolute"
75 | )
76 |
77 | self.config = config
78 |
79 | def forward(
80 | self,
81 | input_ids=None,
82 | position_ids=None,
83 | query_embeds=None,
84 | past_key_values_length=0,
85 | ):
86 | if input_ids is not None:
87 | seq_length = input_ids.size()[1]
88 | else:
89 | seq_length = 0
90 |
91 | if position_ids is None:
92 | position_ids = self.position_ids[
93 | :, past_key_values_length : seq_length + past_key_values_length
94 | ].clone()
95 |
96 | if input_ids is not None:
97 | embeddings = self.word_embeddings(input_ids)
98 | if self.position_embedding_type == "absolute":
99 | position_embeddings = self.position_embeddings(position_ids)
100 | embeddings = embeddings + position_embeddings
101 |
102 | if query_embeds is not None:
103 | embeddings = torch.cat((query_embeds, embeddings), dim=1)
104 | else:
105 | embeddings = query_embeds
106 | # print(embeddings.device)
107 | embeddings = self.LayerNorm(embeddings)
108 | embeddings = self.dropout(embeddings)
109 | return embeddings
110 |
111 |
112 | class BertSelfAttention(nn.Module):
113 | def __init__(self, config, is_cross_attention):
114 | super().__init__()
115 | self.config = config
116 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117 | config, "embedding_size"
118 | ):
119 | raise ValueError(
120 | "The hidden size (%d) is not a multiple of the number of attention "
121 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122 | )
123 |
124 | self.num_attention_heads = config.num_attention_heads
125 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126 | self.all_head_size = self.num_attention_heads * self.attention_head_size
127 |
128 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
129 | if is_cross_attention:
130 | self.key = nn.Linear(config.encoder_width, self.all_head_size)
131 | self.value = nn.Linear(config.encoder_width, self.all_head_size)
132 | else:
133 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
134 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
135 |
136 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137 | self.position_embedding_type = getattr(
138 | config, "position_embedding_type", "absolute"
139 | )
140 | if (
141 | self.position_embedding_type == "relative_key"
142 | or self.position_embedding_type == "relative_key_query"
143 | ):
144 | self.max_position_embeddings = config.max_position_embeddings
145 | self.distance_embedding = nn.Embedding(
146 | 2 * config.max_position_embeddings - 1, self.attention_head_size
147 | )
148 | self.save_attention = False
149 |
150 | def save_attn_gradients(self, attn_gradients):
151 | self.attn_gradients = attn_gradients
152 |
153 | def get_attn_gradients(self):
154 | return self.attn_gradients
155 |
156 | def save_attention_map(self, attention_map):
157 | self.attention_map = attention_map
158 |
159 | def get_attention_map(self):
160 | return self.attention_map
161 |
162 | def transpose_for_scores(self, x):
163 | new_x_shape = x.size()[:-1] + (
164 | self.num_attention_heads,
165 | self.attention_head_size,
166 | )
167 | x = x.view(*new_x_shape)
168 | return x.permute(0, 2, 1, 3)
169 |
170 | def forward(
171 | self,
172 | hidden_states,
173 | attention_mask=None,
174 | head_mask=None,
175 | encoder_hidden_states=None,
176 | encoder_attention_mask=None,
177 | past_key_value=None,
178 | output_attentions=False,
179 | ):
180 |
181 | # If this is instantiated as a cross-attention module, the keys
182 | # and values come from an encoder; the attention mask needs to be
183 | # such that the encoder's padding tokens are not attended to.
184 | is_cross_attention = encoder_hidden_states is not None
185 |
186 | if is_cross_attention:
187 | key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188 | value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189 | attention_mask = encoder_attention_mask
190 | elif past_key_value is not None:
191 | key_layer = self.transpose_for_scores(self.key(hidden_states))
192 | value_layer = self.transpose_for_scores(self.value(hidden_states))
193 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195 | else:
196 | key_layer = self.transpose_for_scores(self.key(hidden_states))
197 | value_layer = self.transpose_for_scores(self.value(hidden_states))
198 |
199 | mixed_query_layer = self.query(hidden_states)
200 |
201 | query_layer = self.transpose_for_scores(mixed_query_layer)
202 |
203 | past_key_value = (key_layer, value_layer)
204 |
205 | # Take the dot product between "query" and "key" to get the raw attention scores.
206 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207 |
208 | if (
209 | self.position_embedding_type == "relative_key"
210 | or self.position_embedding_type == "relative_key_query"
211 | ):
212 | seq_length = hidden_states.size()[1]
213 | position_ids_l = torch.arange(
214 | seq_length, dtype=torch.long, device=hidden_states.device
215 | ).view(-1, 1)
216 | position_ids_r = torch.arange(
217 | seq_length, dtype=torch.long, device=hidden_states.device
218 | ).view(1, -1)
219 | distance = position_ids_l - position_ids_r
220 | positional_embedding = self.distance_embedding(
221 | distance + self.max_position_embeddings - 1
222 | )
223 | positional_embedding = positional_embedding.to(
224 | dtype=query_layer.dtype
225 | ) # fp16 compatibility
226 |
227 | if self.position_embedding_type == "relative_key":
228 | relative_position_scores = torch.einsum(
229 | "bhld,lrd->bhlr", query_layer, positional_embedding
230 | )
231 | attention_scores = attention_scores + relative_position_scores
232 | elif self.position_embedding_type == "relative_key_query":
233 | relative_position_scores_query = torch.einsum(
234 | "bhld,lrd->bhlr", query_layer, positional_embedding
235 | )
236 | relative_position_scores_key = torch.einsum(
237 | "bhrd,lrd->bhlr", key_layer, positional_embedding
238 | )
239 | attention_scores = (
240 | attention_scores
241 | + relative_position_scores_query
242 | + relative_position_scores_key
243 | )
244 |
245 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246 | if attention_mask is not None:
247 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248 | # print(attention_scores.device, attention_mask.device)
249 | attention_scores = attention_scores + attention_mask
250 |
251 | # Normalize the attention scores to probabilities.
252 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
253 |
254 | if is_cross_attention and self.save_attention:
255 | self.save_attention_map(attention_probs)
256 | attention_probs.register_hook(self.save_attn_gradients)
257 |
258 | # This is actually dropping out entire tokens to attend to, which might
259 | # seem a bit unusual, but is taken from the original Transformer paper.
260 | attention_probs_dropped = self.dropout(attention_probs)
261 |
262 | # Mask heads if we want to
263 | if head_mask is not None:
264 | attention_probs_dropped = attention_probs_dropped * head_mask
265 |
266 | context_layer = torch.matmul(attention_probs_dropped, value_layer)
267 |
268 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
269 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
270 | context_layer = context_layer.view(*new_context_layer_shape)
271 |
272 | outputs = (
273 | (context_layer, attention_probs) if output_attentions else (context_layer,)
274 | )
275 |
276 | outputs = outputs + (past_key_value,)
277 | return outputs
278 |
279 |
280 | class BertSelfOutput(nn.Module):
281 | def __init__(self, config):
282 | super().__init__()
283 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
284 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
285 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
286 |
287 | def forward(self, hidden_states, input_tensor):
288 | hidden_states = self.dense(hidden_states)
289 | hidden_states = self.dropout(hidden_states)
290 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
291 | return hidden_states
292 |
293 |
294 | class BertAttention(nn.Module):
295 | def __init__(self, config, is_cross_attention=False):
296 | super().__init__()
297 | self.self = BertSelfAttention(config, is_cross_attention)
298 | self.output = BertSelfOutput(config)
299 | self.pruned_heads = set()
300 |
301 | def prune_heads(self, heads):
302 | if len(heads) == 0:
303 | return
304 | heads, index = find_pruneable_heads_and_indices(
305 | heads,
306 | self.self.num_attention_heads,
307 | self.self.attention_head_size,
308 | self.pruned_heads,
309 | )
310 |
311 | # Prune linear layers
312 | self.self.query = prune_linear_layer(self.self.query, index)
313 | self.self.key = prune_linear_layer(self.self.key, index)
314 | self.self.value = prune_linear_layer(self.self.value, index)
315 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
316 |
317 | # Update hyper params and store pruned heads
318 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
319 | self.self.all_head_size = (
320 | self.self.attention_head_size * self.self.num_attention_heads
321 | )
322 | self.pruned_heads = self.pruned_heads.union(heads)
323 |
324 | def forward(
325 | self,
326 | hidden_states,
327 | attention_mask=None,
328 | head_mask=None,
329 | encoder_hidden_states=None,
330 | encoder_attention_mask=None,
331 | past_key_value=None,
332 | output_attentions=False,
333 | ):
334 | self_outputs = self.self(
335 | hidden_states,
336 | attention_mask,
337 | head_mask,
338 | encoder_hidden_states,
339 | encoder_attention_mask,
340 | past_key_value,
341 | output_attentions,
342 | )
343 | attention_output = self.output(self_outputs[0], hidden_states)
344 |
345 | outputs = (attention_output,) + self_outputs[
346 | 1:
347 | ] # add attentions if we output them
348 | return outputs
349 |
350 |
351 | class BertIntermediate(nn.Module):
352 | def __init__(self, config):
353 | super().__init__()
354 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
355 | if isinstance(config.hidden_act, str):
356 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
357 | else:
358 | self.intermediate_act_fn = config.hidden_act
359 |
360 | def forward(self, hidden_states):
361 | hidden_states = self.dense(hidden_states)
362 | hidden_states = self.intermediate_act_fn(hidden_states)
363 | return hidden_states
364 |
365 |
366 | class BertOutput(nn.Module):
367 | def __init__(self, config):
368 | super().__init__()
369 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
370 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
371 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
372 |
373 | def forward(self, hidden_states, input_tensor):
374 | hidden_states = self.dense(hidden_states)
375 | hidden_states = self.dropout(hidden_states)
376 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
377 | return hidden_states
378 |
379 |
380 | class BertLayer(nn.Module):
381 | def __init__(self, config, layer_num):
382 | super().__init__()
383 | self.config = config
384 | self.chunk_size_feed_forward = config.chunk_size_feed_forward
385 | self.seq_len_dim = 1
386 | self.attention = BertAttention(config)
387 | self.layer_num = layer_num
388 | if (
389 | self.config.add_cross_attention
390 | and layer_num % self.config.cross_attention_freq == 0
391 | ):
392 | self.crossattention = BertAttention(
393 | config, is_cross_attention=self.config.add_cross_attention
394 | )
395 | self.has_cross_attention = True
396 | else:
397 | self.has_cross_attention = False
398 | self.intermediate = BertIntermediate(config)
399 | self.output = BertOutput(config)
400 |
401 | self.intermediate_query = BertIntermediate(config)
402 | self.output_query = BertOutput(config)
403 |
404 | def forward(
405 | self,
406 | hidden_states,
407 | attention_mask=None,
408 | head_mask=None,
409 | encoder_hidden_states=None,
410 | encoder_attention_mask=None,
411 | past_key_value=None,
412 | output_attentions=False,
413 | query_length=0,
414 | ):
415 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
416 | self_attn_past_key_value = (
417 | past_key_value[:2] if past_key_value is not None else None
418 | )
419 | self_attention_outputs = self.attention(
420 | hidden_states,
421 | attention_mask,
422 | head_mask,
423 | output_attentions=output_attentions,
424 | past_key_value=self_attn_past_key_value,
425 | )
426 | attention_output = self_attention_outputs[0]
427 | outputs = self_attention_outputs[1:-1]
428 |
429 | present_key_value = self_attention_outputs[-1]
430 |
431 | if query_length > 0:
432 | query_attention_output = attention_output[:, :query_length, :]
433 |
434 | if self.has_cross_attention:
435 | assert (
436 | encoder_hidden_states is not None
437 | ), "encoder_hidden_states must be given for cross-attention layers"
438 | cross_attention_outputs = self.crossattention(
439 | query_attention_output,
440 | attention_mask,
441 | head_mask,
442 | encoder_hidden_states,
443 | encoder_attention_mask,
444 | output_attentions=output_attentions,
445 | )
446 | query_attention_output = cross_attention_outputs[0]
447 | outputs = (
448 | outputs + cross_attention_outputs[1:-1]
449 | ) # add cross attentions if we output attention weights
450 |
451 | layer_output = apply_chunking_to_forward(
452 | self.feed_forward_chunk_query,
453 | self.chunk_size_feed_forward,
454 | self.seq_len_dim,
455 | query_attention_output,
456 | )
457 | if attention_output.shape[1] > query_length:
458 | layer_output_text = apply_chunking_to_forward(
459 | self.feed_forward_chunk,
460 | self.chunk_size_feed_forward,
461 | self.seq_len_dim,
462 | attention_output[:, query_length:, :],
463 | )
464 | layer_output = torch.cat([layer_output, layer_output_text], dim=1)
465 | else:
466 | layer_output = apply_chunking_to_forward(
467 | self.feed_forward_chunk,
468 | self.chunk_size_feed_forward,
469 | self.seq_len_dim,
470 | attention_output,
471 | )
472 | outputs = (layer_output,) + outputs
473 |
474 | outputs = outputs + (present_key_value,)
475 |
476 | return outputs
477 |
478 | def feed_forward_chunk(self, attention_output):
479 | intermediate_output = self.intermediate(attention_output)
480 | layer_output = self.output(intermediate_output, attention_output)
481 | return layer_output
482 |
483 | def feed_forward_chunk_query(self, attention_output):
484 | intermediate_output = self.intermediate_query(attention_output)
485 | layer_output = self.output_query(intermediate_output, attention_output)
486 | return layer_output
487 |
488 |
489 | class BertEncoder(nn.Module):
490 | def __init__(self, config):
491 | super().__init__()
492 | self.config = config
493 | self.layer = nn.ModuleList(
494 | [BertLayer(config, i) for i in range(config.num_hidden_layers)]
495 | )
496 |
497 | def forward(
498 | self,
499 | hidden_states,
500 | attention_mask=None,
501 | head_mask=None,
502 | encoder_hidden_states=None,
503 | encoder_attention_mask=None,
504 | past_key_values=None,
505 | use_cache=None,
506 | output_attentions=False,
507 | output_hidden_states=False,
508 | return_dict=True,
509 | query_length=0,
510 | ):
511 | all_hidden_states = () if output_hidden_states else None
512 | all_self_attentions = () if output_attentions else None
513 | all_cross_attentions = (
514 | () if output_attentions and self.config.add_cross_attention else None
515 | )
516 |
517 | next_decoder_cache = () if use_cache else None
518 |
519 | for i in range(self.config.num_hidden_layers):
520 | layer_module = self.layer[i]
521 | if output_hidden_states:
522 | all_hidden_states = all_hidden_states + (hidden_states,)
523 |
524 | layer_head_mask = head_mask[i] if head_mask is not None else None
525 | past_key_value = past_key_values[i] if past_key_values is not None else None
526 |
527 | if getattr(self.config, "gradient_checkpointing", False) and self.training:
528 |
529 | if use_cache:
530 | logger.warn(
531 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
532 | )
533 | use_cache = False
534 |
535 | def create_custom_forward(module):
536 | def custom_forward(*inputs):
537 | return module(
538 | *inputs, past_key_value, output_attentions, query_length
539 | )
540 |
541 | return custom_forward
542 |
543 | layer_outputs = torch.utils.checkpoint.checkpoint(
544 | create_custom_forward(layer_module),
545 | hidden_states,
546 | attention_mask,
547 | layer_head_mask,
548 | encoder_hidden_states,
549 | encoder_attention_mask,
550 | )
551 | else:
552 | layer_outputs = layer_module(
553 | hidden_states,
554 | attention_mask,
555 | layer_head_mask,
556 | encoder_hidden_states,
557 | encoder_attention_mask,
558 | past_key_value,
559 | output_attentions,
560 | query_length,
561 | )
562 |
563 | hidden_states = layer_outputs[0]
564 | if use_cache:
565 | next_decoder_cache += (layer_outputs[-1],)
566 | if output_attentions:
567 | all_self_attentions = all_self_attentions + (layer_outputs[1],)
568 | all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
569 |
570 | if output_hidden_states:
571 | all_hidden_states = all_hidden_states + (hidden_states,)
572 |
573 | if not return_dict:
574 | return tuple(
575 | v
576 | for v in [
577 | hidden_states,
578 | next_decoder_cache,
579 | all_hidden_states,
580 | all_self_attentions,
581 | all_cross_attentions,
582 | ]
583 | if v is not None
584 | )
585 | return BaseModelOutputWithPastAndCrossAttentions(
586 | last_hidden_state=hidden_states,
587 | past_key_values=next_decoder_cache,
588 | hidden_states=all_hidden_states,
589 | attentions=all_self_attentions,
590 | cross_attentions=all_cross_attentions,
591 | )
592 |
593 |
594 | class BertPooler(nn.Module):
595 | def __init__(self, config):
596 | super().__init__()
597 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
598 | self.activation = nn.Tanh()
599 |
600 | def forward(self, hidden_states):
601 | # We "pool" the model by simply taking the hidden state corresponding
602 | # to the first token.
603 | first_token_tensor = hidden_states[:, 0]
604 | pooled_output = self.dense(first_token_tensor)
605 | pooled_output = self.activation(pooled_output)
606 | return pooled_output
607 |
608 |
609 | class BertPredictionHeadTransform(nn.Module):
610 | def __init__(self, config):
611 | super().__init__()
612 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
613 | if isinstance(config.hidden_act, str):
614 | self.transform_act_fn = ACT2FN[config.hidden_act]
615 | else:
616 | self.transform_act_fn = config.hidden_act
617 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
618 |
619 | def forward(self, hidden_states):
620 | hidden_states = self.dense(hidden_states)
621 | hidden_states = self.transform_act_fn(hidden_states)
622 | hidden_states = self.LayerNorm(hidden_states)
623 | return hidden_states
624 |
625 |
626 | class BertLMPredictionHead(nn.Module):
627 | def __init__(self, config):
628 | super().__init__()
629 | self.transform = BertPredictionHeadTransform(config)
630 |
631 | # The output weights are the same as the input embeddings, but there is
632 | # an output-only bias for each token.
633 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
634 |
635 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
636 |
637 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
638 | self.decoder.bias = self.bias
639 |
640 | def forward(self, hidden_states):
641 | hidden_states = self.transform(hidden_states)
642 | hidden_states = self.decoder(hidden_states)
643 | return hidden_states
644 |
645 |
646 | class BertOnlyMLMHead(nn.Module):
647 | def __init__(self, config):
648 | super().__init__()
649 | self.predictions = BertLMPredictionHead(config)
650 |
651 | def forward(self, sequence_output):
652 | prediction_scores = self.predictions(sequence_output)
653 | return prediction_scores
654 |
655 |
656 | class BertPreTrainedModel(PreTrainedModel):
657 | """
658 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
659 | models.
660 | """
661 |
662 | config_class = BertConfig
663 | base_model_prefix = "bert"
664 | _keys_to_ignore_on_load_missing = [r"position_ids"]
665 |
666 | def _init_weights(self, module):
667 | """Initialize the weights"""
668 | if isinstance(module, (nn.Linear, nn.Embedding)):
669 | # Slightly different from the TF version which uses truncated_normal for initialization
670 | # cf https://github.com/pytorch/pytorch/pull/5617
671 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
672 | elif isinstance(module, nn.LayerNorm):
673 | module.bias.data.zero_()
674 | module.weight.data.fill_(1.0)
675 | if isinstance(module, nn.Linear) and module.bias is not None:
676 | module.bias.data.zero_()
677 |
678 |
679 | class BertModel(BertPreTrainedModel):
680 | """
681 | The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
682 | cross-attention is added between the self-attention layers, following the architecture described in `Attention is
683 | all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
684 | Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
685 | argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
686 | input to the forward pass.
687 | """
688 |
689 | def __init__(self, config, add_pooling_layer=False):
690 | super().__init__(config)
691 | self.config = config
692 |
693 | self.embeddings = BertEmbeddings(config)
694 |
695 | self.encoder = BertEncoder(config)
696 |
697 | self.pooler = BertPooler(config) if add_pooling_layer else None
698 |
699 | self.init_weights()
700 |
701 | def get_input_embeddings(self):
702 | return self.embeddings.word_embeddings
703 |
704 | def set_input_embeddings(self, value):
705 | self.embeddings.word_embeddings = value
706 |
707 | def _prune_heads(self, heads_to_prune):
708 | """
709 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
710 | class PreTrainedModel
711 | """
712 | for layer, heads in heads_to_prune.items():
713 | self.encoder.layer[layer].attention.prune_heads(heads)
714 |
715 | def get_extended_attention_mask(
716 | self,
717 | attention_mask: Tensor,
718 | input_shape: Tuple[int],
719 | device: device,
720 | is_decoder: bool,
721 | has_query: bool = False,
722 | ) -> Tensor:
723 | """
724 | Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
725 |
726 | Arguments:
727 | attention_mask (:obj:`torch.Tensor`):
728 | Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
729 | input_shape (:obj:`Tuple[int]`):
730 | The shape of the input to the model.
731 | device: (:obj:`torch.device`):
732 | The device of the input to the model.
733 |
734 | Returns:
735 | :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
736 | """
737 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
738 | # ourselves in which case we just need to make it broadcastable to all heads.
739 | if attention_mask.dim() == 3:
740 | extended_attention_mask = attention_mask[:, None, :, :]
741 | elif attention_mask.dim() == 2:
742 | # Provided a padding mask of dimensions [batch_size, seq_length]
743 | # - if the model is a decoder, apply a causal mask in addition to the padding mask
744 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
745 | if is_decoder:
746 | batch_size, seq_length = input_shape
747 |
748 | seq_ids = torch.arange(seq_length, device=device)
749 | causal_mask = (
750 | seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
751 | <= seq_ids[None, :, None]
752 | )
753 |
754 | # add a prefix ones mask to the causal mask
755 | # causal and attention masks must have same type with pytorch version < 1.3
756 | causal_mask = causal_mask.to(attention_mask.dtype)
757 |
758 | if causal_mask.shape[1] < attention_mask.shape[1]:
759 | prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
760 | if has_query: # UniLM style attention mask
761 | causal_mask = torch.cat(
762 | [
763 | torch.zeros(
764 | (batch_size, prefix_seq_len, seq_length),
765 | device=device,
766 | dtype=causal_mask.dtype,
767 | ),
768 | causal_mask,
769 | ],
770 | axis=1,
771 | )
772 | causal_mask = torch.cat(
773 | [
774 | torch.ones(
775 | (batch_size, causal_mask.shape[1], prefix_seq_len),
776 | device=device,
777 | dtype=causal_mask.dtype,
778 | ),
779 | causal_mask,
780 | ],
781 | axis=-1,
782 | )
783 | extended_attention_mask = (
784 | causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
785 | )
786 | else:
787 | extended_attention_mask = attention_mask[:, None, None, :]
788 | else:
789 | raise ValueError(
790 | "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
791 | input_shape, attention_mask.shape
792 | )
793 | )
794 |
795 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
796 | # masked positions, this operation will create a tensor which is 0.0 for
797 | # positions we want to attend and -10000.0 for masked positions.
798 | # Since we are adding it to the raw scores before the softmax, this is
799 | # effectively the same as removing these entirely.
800 | extended_attention_mask = extended_attention_mask.to(
801 | dtype=self.dtype
802 | ) # fp16 compatibility
803 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
804 | return extended_attention_mask
805 |
806 | def forward(
807 | self,
808 | input_ids=None,
809 | attention_mask=None,
810 | position_ids=None,
811 | head_mask=None,
812 | query_embeds=None,
813 | encoder_hidden_states=None,
814 | encoder_attention_mask=None,
815 | past_key_values=None,
816 | use_cache=None,
817 | output_attentions=None,
818 | output_hidden_states=None,
819 | return_dict=None,
820 | is_decoder=False,
821 | ):
822 | r"""
823 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
824 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
825 | the model is configured as a decoder.
826 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
827 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
828 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
829 | - 1 for tokens that are **not masked**,
830 | - 0 for tokens that are **masked**.
831 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
832 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
833 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
834 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
835 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
836 | use_cache (:obj:`bool`, `optional`):
837 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
838 | decoding (see :obj:`past_key_values`).
839 | """
840 | output_attentions = (
841 | output_attentions
842 | if output_attentions is not None
843 | else self.config.output_attentions
844 | )
845 | output_hidden_states = (
846 | output_hidden_states
847 | if output_hidden_states is not None
848 | else self.config.output_hidden_states
849 | )
850 | return_dict = (
851 | return_dict if return_dict is not None else self.config.use_return_dict
852 | )
853 |
854 | # use_cache = use_cache if use_cache is not None else self.config.use_cache
855 |
856 | if input_ids is None:
857 | assert (
858 | query_embeds is not None
859 | ), "You have to specify query_embeds when input_ids is None"
860 |
861 | # past_key_values_length
862 | past_key_values_length = (
863 | past_key_values[0][0].shape[2] - self.config.query_length
864 | if past_key_values is not None
865 | else 0
866 | )
867 |
868 | query_length = query_embeds.shape[1] if query_embeds is not None else 0
869 |
870 | embedding_output = self.embeddings(
871 | input_ids=input_ids,
872 | position_ids=position_ids,
873 | query_embeds=query_embeds,
874 | past_key_values_length=past_key_values_length,
875 | )
876 |
877 | input_shape = embedding_output.size()[:-1]
878 | batch_size, seq_length = input_shape
879 | device = embedding_output.device
880 |
881 | if attention_mask is None:
882 | attention_mask = torch.ones(
883 | ((batch_size, seq_length + past_key_values_length)), device=device
884 | )
885 |
886 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
887 | # ourselves in which case we just need to make it broadcastable to all heads.
888 | if is_decoder:
889 | extended_attention_mask = self.get_extended_attention_mask(
890 | attention_mask,
891 | input_ids.shape,
892 | device,
893 | is_decoder,
894 | has_query=(query_embeds is not None),
895 | )
896 | else:
897 | extended_attention_mask = self.get_extended_attention_mask(
898 | attention_mask, input_shape, device, is_decoder
899 | )
900 |
901 | # If a 2D or 3D attention mask is provided for the cross-attention
902 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
903 | if encoder_hidden_states is not None:
904 | if type(encoder_hidden_states) == list:
905 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
906 | 0
907 | ].size()
908 | else:
909 | (
910 | encoder_batch_size,
911 | encoder_sequence_length,
912 | _,
913 | ) = encoder_hidden_states.size()
914 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
915 |
916 | if type(encoder_attention_mask) == list:
917 | encoder_extended_attention_mask = [
918 | self.invert_attention_mask(mask) for mask in encoder_attention_mask
919 | ]
920 | elif encoder_attention_mask is None:
921 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
922 | encoder_extended_attention_mask = self.invert_attention_mask(
923 | encoder_attention_mask
924 | )
925 | else:
926 | encoder_extended_attention_mask = self.invert_attention_mask(
927 | encoder_attention_mask
928 | )
929 | else:
930 | encoder_extended_attention_mask = None
931 |
932 | # Prepare head mask if needed
933 | # 1.0 in head_mask indicate we keep the head
934 | # attention_probs has shape bsz x n_heads x N x N
935 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
936 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
937 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
938 |
939 | encoder_outputs = self.encoder(
940 | embedding_output,
941 | attention_mask=extended_attention_mask,
942 | head_mask=head_mask,
943 | encoder_hidden_states=encoder_hidden_states,
944 | encoder_attention_mask=encoder_extended_attention_mask,
945 | past_key_values=past_key_values,
946 | use_cache=use_cache,
947 | output_attentions=output_attentions,
948 | output_hidden_states=output_hidden_states,
949 | return_dict=return_dict,
950 | query_length=query_length,
951 | )
952 | sequence_output = encoder_outputs[0]
953 | pooled_output = (
954 | self.pooler(sequence_output) if self.pooler is not None else None
955 | )
956 |
957 | if not return_dict:
958 | return (sequence_output, pooled_output) + encoder_outputs[1:]
959 |
960 | return BaseModelOutputWithPoolingAndCrossAttentions(
961 | last_hidden_state=sequence_output,
962 | pooler_output=pooled_output,
963 | past_key_values=encoder_outputs.past_key_values,
964 | hidden_states=encoder_outputs.hidden_states,
965 | attentions=encoder_outputs.attentions,
966 | cross_attentions=encoder_outputs.cross_attentions,
967 | )
968 |
969 |
970 | class BertLMHeadModel(BertPreTrainedModel):
971 |
972 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
973 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
974 |
975 | def __init__(self, config):
976 | super().__init__(config)
977 |
978 | self.bert = BertModel(config, add_pooling_layer=False)
979 | self.cls = BertOnlyMLMHead(config)
980 |
981 | self.init_weights()
982 |
983 | def get_output_embeddings(self):
984 | return self.cls.predictions.decoder
985 |
986 | def set_output_embeddings(self, new_embeddings):
987 | self.cls.predictions.decoder = new_embeddings
988 |
989 | def forward(
990 | self,
991 | input_ids=None,
992 | attention_mask=None,
993 | position_ids=None,
994 | head_mask=None,
995 | query_embeds=None,
996 | encoder_hidden_states=None,
997 | encoder_attention_mask=None,
998 | labels=None,
999 | past_key_values=None,
1000 | use_cache=True,
1001 | output_attentions=None,
1002 | output_hidden_states=None,
1003 | return_dict=None,
1004 | return_logits=False,
1005 | is_decoder=True,
1006 | reduction="mean",
1007 | ):
1008 | r"""
1009 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1010 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1011 | the model is configured as a decoder.
1012 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1013 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1014 | the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1015 | - 1 for tokens that are **not masked**,
1016 | - 0 for tokens that are **masked**.
1017 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1018 | Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1019 | ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1020 | ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1021 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1022 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1023 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1024 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1025 | instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1026 | use_cache (:obj:`bool`, `optional`):
1027 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1028 | decoding (see :obj:`past_key_values`).
1029 | Returns:
1030 | Example::
1031 | >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1032 | >>> import torch
1033 | >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1034 | >>> config = BertConfig.from_pretrained("bert-base-cased")
1035 | >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1036 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1037 | >>> outputs = model(**inputs)
1038 | >>> prediction_logits = outputs.logits
1039 | """
1040 | return_dict = (
1041 | return_dict if return_dict is not None else self.config.use_return_dict
1042 | )
1043 | if labels is not None:
1044 | use_cache = False
1045 | if past_key_values is not None:
1046 | query_embeds = None
1047 |
1048 | outputs = self.bert(
1049 | input_ids,
1050 | attention_mask=attention_mask,
1051 | position_ids=position_ids,
1052 | head_mask=head_mask,
1053 | query_embeds=query_embeds,
1054 | encoder_hidden_states=encoder_hidden_states,
1055 | encoder_attention_mask=encoder_attention_mask,
1056 | past_key_values=past_key_values,
1057 | use_cache=use_cache,
1058 | output_attentions=output_attentions,
1059 | output_hidden_states=output_hidden_states,
1060 | return_dict=return_dict,
1061 | is_decoder=is_decoder,
1062 | )
1063 |
1064 | sequence_output = outputs[0]
1065 | if query_embeds is not None:
1066 | sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1067 |
1068 | prediction_scores = self.cls(sequence_output)
1069 |
1070 | if return_logits:
1071 | return prediction_scores[:, :-1, :].contiguous()
1072 |
1073 | lm_loss = None
1074 | if labels is not None:
1075 | # we are doing next-token prediction; shift prediction scores and input ids by one
1076 | shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1077 | labels = labels[:, 1:].contiguous()
1078 | loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1079 | lm_loss = loss_fct(
1080 | shifted_prediction_scores.view(-1, self.config.vocab_size),
1081 | labels.view(-1),
1082 | )
1083 | if reduction == "none":
1084 | lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1085 |
1086 | if not return_dict:
1087 | output = (prediction_scores,) + outputs[2:]
1088 | return ((lm_loss,) + output) if lm_loss is not None else output
1089 |
1090 | return CausalLMOutputWithCrossAttentions(
1091 | loss=lm_loss,
1092 | logits=prediction_scores,
1093 | past_key_values=outputs.past_key_values,
1094 | hidden_states=outputs.hidden_states,
1095 | attentions=outputs.attentions,
1096 | cross_attentions=outputs.cross_attentions,
1097 | )
1098 |
1099 | def prepare_inputs_for_generation(
1100 | self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1101 | ):
1102 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1103 | if attention_mask is None:
1104 | attention_mask = input_ids.new_ones(input_ids.shape)
1105 | query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1106 | attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1107 |
1108 | # cut decoder_input_ids if past is used
1109 | if past is not None:
1110 | input_ids = input_ids[:, -1:]
1111 |
1112 | return {
1113 | "input_ids": input_ids,
1114 | "query_embeds": query_embeds,
1115 | "attention_mask": attention_mask,
1116 | "past_key_values": past,
1117 | "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1118 | "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1119 | "is_decoder": True,
1120 | }
1121 |
1122 | def _reorder_cache(self, past, beam_idx):
1123 | reordered_past = ()
1124 | for layer_past in past:
1125 | reordered_past += (
1126 | tuple(
1127 | past_state.index_select(0, beam_idx) for past_state in layer_past
1128 | ),
1129 | )
1130 | return reordered_past
1131 |
1132 |
1133 | class BertForMaskedLM(BertPreTrainedModel):
1134 |
1135 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
1136 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1137 |
1138 | def __init__(self, config):
1139 | super().__init__(config)
1140 |
1141 | self.bert = BertModel(config, add_pooling_layer=False)
1142 | self.cls = BertOnlyMLMHead(config)
1143 |
1144 | self.init_weights()
1145 |
1146 | def get_output_embeddings(self):
1147 | return self.cls.predictions.decoder
1148 |
1149 | def set_output_embeddings(self, new_embeddings):
1150 | self.cls.predictions.decoder = new_embeddings
1151 |
1152 | def forward(
1153 | self,
1154 | input_ids=None,
1155 | attention_mask=None,
1156 | position_ids=None,
1157 | head_mask=None,
1158 | query_embeds=None,
1159 | encoder_hidden_states=None,
1160 | encoder_attention_mask=None,
1161 | labels=None,
1162 | output_attentions=None,
1163 | output_hidden_states=None,
1164 | return_dict=None,
1165 | return_logits=False,
1166 | is_decoder=False,
1167 | ):
1168 | r"""
1169 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1170 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1171 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1172 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1173 | """
1174 |
1175 | return_dict = (
1176 | return_dict if return_dict is not None else self.config.use_return_dict
1177 | )
1178 |
1179 | outputs = self.bert(
1180 | input_ids,
1181 | attention_mask=attention_mask,
1182 | position_ids=position_ids,
1183 | head_mask=head_mask,
1184 | query_embeds=query_embeds,
1185 | encoder_hidden_states=encoder_hidden_states,
1186 | encoder_attention_mask=encoder_attention_mask,
1187 | output_attentions=output_attentions,
1188 | output_hidden_states=output_hidden_states,
1189 | return_dict=return_dict,
1190 | is_decoder=is_decoder,
1191 | )
1192 |
1193 | if query_embeds is not None:
1194 | sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1195 | prediction_scores = self.cls(sequence_output)
1196 |
1197 | if return_logits:
1198 | return prediction_scores
1199 |
1200 | masked_lm_loss = None
1201 | if labels is not None:
1202 | loss_fct = CrossEntropyLoss() # -100 index = padding token
1203 | masked_lm_loss = loss_fct(
1204 | prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1205 | )
1206 |
1207 | if not return_dict:
1208 | output = (prediction_scores,) + outputs[2:]
1209 | return (
1210 | ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1211 | )
1212 |
1213 | return MaskedLMOutput(
1214 | loss=masked_lm_loss,
1215 | logits=prediction_scores,
1216 | hidden_states=outputs.hidden_states,
1217 | attentions=outputs.attentions,
1218 | )
1219 |
--------------------------------------------------------------------------------