├── pics
├── baseline.png
└── overall.png
├── README.md
└── rewriter.py
/pics/baseline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krystalan/K-SportsSum/HEAD/pics/baseline.png
--------------------------------------------------------------------------------
/pics/overall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/krystalan/K-SportsSum/HEAD/pics/overall.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge Enhanced Sports Game Summarization
2 |
3 |
4 |
5 |
6 |
7 |
8 | This repository contains data and code for the WSDM 2022 paper [*Knowledge Enhanced Sports Game Summarization*](https://arxiv.org/abs/2111.12535).
9 |
10 | In this work, we propose K-SportsSum dataset as well as the KES model.
11 | - K-SportsSum: It has 7854 sports game summarization samples together with a large-scale knowledge corpus containing information of 523 sports teams and 14k+ sports players.
12 | - KES: a NEW sports game summarization model based on mT5.
13 |
14 |
15 | ### 1. K-SportsSum Dataset
16 | The K-SportsSum dataset is available [here](https://drive.google.com/file/d/1RGWIz3Nw_kzfgIYo_Ke9elLfPOg0rS4V/view?usp=sharing). You can obtain the following five files from the shared link:
17 | - `train.json`, `val.json` and `test.json` are the core data files of K-SportsSum, each of which contains live commentaries and news reports of sports games.
18 | - `player_knowledge.json` contains background knowledge of 14,724 sports players.
19 | - `team_knowledge.json` contains background information of 523 sports teams.
20 |
21 | ### 2. Baseline Model Construction
22 |
23 |
24 |
25 |
26 |
27 |
28 | In this Section, we introduce how to build a two-step baseline system for Sports Game Summarization. As shown in the above Figure, the baseline framework first selects important commentary sentences from original live commentary documents through a text classification model. `Ck` represents each commentary sentence and `Cik` denotes each selected important sentence. Further, we should convert each selected sentence to a news-style sentence through a generative model. `Rik` is the generated sentence corresponding to `Cik`.
29 |
30 | #### 2.1 Selector
31 | The selector is a text classification model. In our work, we resort to the following toolkits:
32 | - [Chinese-Text-Classification-Pytorch](https://github.com/649453932/Chinese-Text-Classification-Pytorch): This toolkit includes multiple codes (both training and inference) of text classification model before BERT era, such as TextCNN.
33 | - [Bert-Chinese-Text-Classification-Pytorch](https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch): This toolkit contains text classification codes after BERT era, e.g., BERT, ERNIE.
34 |
35 | These two toolkits are very useful for buiding a Chinese text classification system.
36 |
37 | #### 2.2 Rewriter
38 | The rewriter is a generative model. Existing works typically employ [PTGen (See, ACL 2017)](https://arxiv.org/abs/1704.04368), mBART, mT5 et al.
39 |
40 | For PTGen, the [pointer_summarizer](https://github.com/atulkum/pointer_summarizer) toolkit is widely used. I also recommend the [implementation](https://github.com/xcfcode/PLM_annotator/tree/main/pgn) released by Xiachong Feng in his dialogue summarization work. Both of implementations are convenient. Please note that if u choose to use PTGen as rewriter, you should select a pre-trained word embedding to help model achieve great performance ([Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors) is helpful).
41 |
42 | For mBART, mT5 et al. We use the implementations of [Huggingface Transformers Library](https://huggingface.co/docs/transformers/index). I release the corresponding training and inference codes for public use (See `rewriter.py`, `mBART-50` is used as rewriter in this code).
43 | Requirements: pytorch-lighting 0.8.5; transformers >= 4.4; torch >= 1.7
44 | (This code is based on the [Longformer code](https://github.com/allenai/longformer/blob/master/scripts/summarization.py) from AI2.)
45 |
46 | For training, you can run commands like this:
47 | ```sh
48 | python rewriter.py --device_id 0
49 | ```
50 | For evaluation, the command may like this:
51 | ```sh
52 | python rewriter.py --device_id 0 --test
53 | ```
54 | Note that, if u want to inference with a trained model, remember to initialize the model with corresponding `.ckpt` file.
55 |
56 | #### 2.3 Construct Training Samples for Selector and Rewriter
57 | In order to construct training samples for selector and rewriter, we should map each new sentence to corresponding commentary sentence, if possible. (If u do not understand it, please see more details in Section3.1 of [SportsSum2.0](https://arxiv.org/abs/2110.05750)).
58 |
59 | Thus, the core content of this process is calculating the ROUGE scores and BERTScore given two Chinese sentences.
60 | - ROUGE: you can use [multilingual_rouge_scoring](https://github.com/csebuetnlp/xl-sum/tree/master/multilingual_rouge_scoring) toolkit to calculate the Chinese ROUGE Scores. Note that, the [py-rouge](https://github.com/Diego999/py-rouge) and [rouge](https://github.com/pltrdy/rouge) toolkits are not suitable for Chinese.
61 | - BERTScore: Please find more details in [bert_score](https://github.com/Tiiiger/bert_score).
62 |
63 | ### 3. KES model
64 | Code will be published once the author of this repo has time.
65 |
66 | ### Existing Works
67 | To facilitate researchers to efficiently comprehend and follow the Sports Game Summarization task, we write a Chinese survey post: [《体育赛事摘要任务概览》](https://mp.weixin.qq.com/s/EidRYB_80AhRclz-mryVhQ), where we also discuss some future directions and give our thoughts.
68 |
69 | We list and classify existing works of Sports Game Summarization:
70 |
71 | | Paper | Conference/Journal | Data/Code | Category |
72 | | :--: | :--: | :--: | :--: |
73 | | [Towards Constructing Sports News from Live Text Commentary](https://aclanthology.org/P16-1129) | ACL 2016 | - | `Dataset`, `Ext.` |
74 | | [Overview of the NLPCC-ICCPOL 2016 Shared Task: Sports News Generation from Live Webcast Scripts](https://link.springer.com/chapter/10.1007%2F978-3-319-50496-4_80) | NLPCC 2016 | [NLPCC 2016 shared task](http://tcci.ccf.org.cn/conference/2016/pages/page05_CFPTasks.html) | `Dataset` |
75 | | [Research on Summary Sentences Extraction Oriented to Live Sports Text](https://link.springer.com/chapter/10.1007%2F978-3-319-50496-4_72) | NLPCC 2016 | - | `Ext.` |
76 | | [Sports News Generation from Live Webcast Scripts Based on Rules and Templates](https://link.springer.com/chapter/10.1007%2F978-3-319-50496-4_81) | NLPCC 2016 | - | `Ext.+Temp.` |
77 | | [Content Selection for Real-time Sports News Construction from Commentary Texts](https://aclanthology.org/W17-3504/) | INLG 2017 | - | `Ext.` |
78 | | [Generate Football News from Live Webcast Scripts Based on Character-CNN with Five Strokes](http://csroc.org.tw/journal/JOC31-1/JOC3101-21.pdf) | 2020 | - | `Ext.+Temp.` |
79 | | [Generating Sports News from Live Commentary: A Chinese Dataset for Sports Game Summarization](https://aclanthology.org/2020.aacl-main.61/) | AACL 2020 | [SportsSum](https://github.com/ej0cl6/SportsSum) | `Dataset`, `Ext.+Abs.` |
80 | | [SportsSum2.0: Generating High-Quality Sports News from Live Text Commentary](https://arxiv.org/abs/2110.05750) | CIKM 2021 | [SportsSum2.0](https://github.com/krystalan/SportsSum2.0) | `Dataset`, `Ext.+Abs.` |
81 | | [Knowledge Enhanced Sports Game Summarization](https://arxiv.org/abs/2111.12535) | WSDM 2022 | [K-SportsSum](https://github.com/krystalan/K-SportsSum) | `Dataset`, `Ext.+Abs.` |
82 |
83 | The concepts used in Category are illustrated as follows:
84 | - `Dataset`: The work contributes a dataset for sports game summarization.
85 | - `Ext.`: Extractive sports game summarization method.
86 | - `Ext.+Temp.`: The method first extracts important commentary sentence and further utilize the human-labeled template to convey each commentary sentence to a news sentence.
87 | - `Ext.+Abs.`: The method first extracts important commentary sentence and further utilize the seq2seq model to convey each commentary sentence to the news sentence.
88 |
89 | ### Q&A
90 | Q1: What the differences among SportsSum, SportsSum2.0, SGSum and K-SportsSum?
91 | A1: **SportsSum (Huang et al. AACL 2020)** is the first large-scale Sports Game Summarization dataset which has 5428 samples. Though its wonderful contribution, the SportsSum dataset has about 15% noisy samples. Thus, **SportsSum2.0 (Wang et al, CIKM 2021)** cleans the original SportsSum and obtains 5402 samples (26 bad samples in SportsSum are removed). Following previous works, **SGSum (Non-Archival Paper, 未正式发表)** collects and cleans a large amount of data from massive games. It has 7854 samples. **K-SportsSum (Wang et al. WSDM 2022)** shuffle and randomly divide the **SGSum**. Furthermore, **K-SportsSum** has a large-scale knowledge corpus about sports teams and players, which could be useful for alleviating the knowledge gap issue (See K-SportsSum paper).
92 |
93 | Q2: There is less code about sports game summarization.
94 | A2: Yeah, I know that. All existing works follow the pipeline paradigm to build sports game summarization systems. They may have two or three steps together with a pseudo label construction process. Thus, the code is too messy. For the solution, we 1) release a tutorial for building a two-step baseline for Sports Game Summarization (See Section2 in this page); 2) build an end-to-end model for public use (Work in progress, maybe will be published in 2022, but there is no guarantee).
95 |
96 | Q3: About position embedding in mT5.
97 | A3: Position embedding of mT5 is set to zero vector since it uses relative position embeddings in self-attention.
98 |
99 | Q4: Any questions and suggestions?
100 | A4: Please feel free to contact me (jawang1[at]suda.edu.cn).
101 |
102 | ### Acknowledgement
103 | Jiaan Wang would like to thank **[KW Lab, Fudan Univ.](http://kw.fudan.edu.cn/)** and **[iFLYTEK AI Research, Suzhou](https://www.iflytek.com/index.html)** for their helpful discussions and GPU device support.
104 |
105 | ### Citation
106 | If you find this project is useful or use the data in your work, please consider cite our paper:
107 | ```
108 | @article{Wang2022KnowledgeES,
109 | title={Knowledge Enhanced Sports Game Summarization},
110 | author={Jiaan Wang and Zhixu Li and Tingyi Zhang and Duo Zheng and Jianfeng Qu and An Liu and Lei Zhao and Zhigang Chen},
111 | journal={Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining},
112 | year={2022}
113 | }
114 | ```
115 |
--------------------------------------------------------------------------------
/rewriter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | import numpy as np
5 |
6 | import torch
7 | from torch.utils.data import DataLoader, Dataset
8 | from transformers.optimization import get_linear_schedule_with_warmup, Adafactor
9 | import rouge
10 |
11 | import pytorch_lightning as pl
12 | from pytorch_lightning.logging import TestTubeLogger
13 | from pytorch_lightning.callbacks import ModelCheckpoint
14 | from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
15 |
16 | from transformers import MBartForConditionalGeneration, MBartTokenizer, MBart50TokenizerFast
17 |
18 | import json
19 |
20 |
21 | class SummarizationDataset(Dataset):
22 | def __init__(self, split_name, tokenizer, max_input_len, max_output_len):
23 | self.tokenizer = tokenizer
24 | self.max_input_len = max_input_len
25 | self.max_output_len = max_output_len
26 | ## 读取数据集,将数据集处理成 train.json, val.json 以及 test.json 三个文件夹
27 | with open('%s.json'%split_name, 'r', encoding='utf-8') as f:
28 | self.hf_dataset = json.load(f)
29 |
30 | def __len__(self):
31 | return len(self.hf_dataset)
32 |
33 | def __getitem__(self, idx):
34 | entry = self.hf_dataset[idx]
35 | '''
36 | json文件中的数据格式:
37 | [
38 | {
39 | "input": "一句评论句Ci1",
40 | "output": "对应的新闻句Ri1"
41 | },
42 | {
43 | "input": "一句评论句Ci2",
44 | "output": "对应的新闻句Ri2"
45 | },
46 | ...
47 | {
48 | "input": "一句评论句Cim",
49 | "output": "对应的新闻句Rim"
50 | },
51 | ]
52 | '''
53 | input_ids = self.tokenizer.encode(entry['input'].lower(), truncation=True, max_length=self.max_input_len)
54 | with self.tokenizer.as_target_tokenizer():
55 | output_ids = self.tokenizer.encode(entry['output'].lower(), truncation=True, max_length=self.max_output_len)
56 | return torch.tensor(input_ids), torch.tensor(output_ids)
57 |
58 | @staticmethod
59 | def collate_fn(batch):
60 | pad_token_id = 1 # 对于bart/mbart来说,pad token的id是1,pegasus/t5是0
61 | input_ids, output_ids = list(zip(*batch))
62 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id)
63 | output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id)
64 | return input_ids, output_ids
65 |
66 |
67 | class Summarizer(pl.LightningModule):
68 | def __init__(self, params):
69 | super().__init__()
70 | self.args = params
71 | self.hparams = params
72 | self.src_lang = "zh_CN" # 在使用mBART中,需要添加语言token,因为体育赛事摘要任务的输入与输出均为中文,所以我们都设置成中文
73 | self.tgt_lang = "zh_CN"
74 | self.tokenizer = MBart50TokenizerFast.from_pretrained(self.args.model_path, src_lang=self.src_lang, tgt_lang=self.tgt_lang)
75 | self.model = MBartForConditionalGeneration.from_pretrained(self.args.model_path)
76 | self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None
77 | self.generated_id = 0
78 |
79 | self.decoder_start_token_id = self.tokenizer.lang_code_to_id[self.tgt_lang]
80 | self.model.config.decoder_start_token_id = self.decoder_start_token_id
81 |
82 |
83 |
84 | def _prepare_input(self, input_ids):
85 | attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device)
86 | attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
87 | return input_ids, attention_mask
88 |
89 | def forward(self, input_ids, output_ids):
90 | input_ids, attention_mask = self._prepare_input(input_ids)
91 | decoder_input_ids = output_ids[:, :-1]
92 | decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id)
93 |
94 | labels = output_ids[:, 1:].clone()
95 | outputs = self.model(
96 | input_ids,
97 | attention_mask=attention_mask,
98 | decoder_input_ids=decoder_input_ids,
99 | decoder_attention_mask=decoder_attention_mask,
100 | use_cache=False,
101 | )
102 |
103 | lm_logits = outputs[0]
104 | ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
105 | assert lm_logits.shape[-1] == self.model.config.vocab_size
106 | loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))
107 |
108 | return [loss]
109 |
110 | def training_step(self, batch, batch_nb):
111 | output = self.forward(*batch)
112 | loss = output[0]
113 | lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr']
114 | tensorboard_logs = {'train_loss': loss, 'lr': lr,
115 | 'input_size': batch[0].numel(),
116 | 'output_size': batch[1].numel(),
117 | 'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0}
118 | return {'loss': loss, 'log': tensorboard_logs}
119 |
120 | def validation_step(self, batch, batch_nb):
121 | for p in self.model.parameters():
122 | p.requires_grad = False
123 |
124 | outputs = self.forward(*batch)
125 | vloss = outputs[0]
126 | input_ids, output_ids = batch
127 | input_ids, attention_mask = self._prepare_input(input_ids)
128 |
129 | ### 这里设置inference时的长度,beam search的num等参数。
130 | generated_ids = self.model.generate(
131 | input_ids=input_ids,
132 | attention_mask=attention_mask,
133 | use_cache=True,
134 | num_beams= 5,
135 | max_length = 128,
136 | decoder_start_token_id=self.tokenizer.lang_code_to_id[self.tgt_lang],
137 | )
138 | generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True)
139 | gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True)
140 |
141 |
142 | return {'vloss': vloss,
143 | 'generated': generated_str,
144 | 'gold': gold_str
145 | }
146 |
147 | def validation_epoch_end(self, outputs):
148 | for p in self.model.parameters():
149 | p.requires_grad = True
150 |
151 | names = ['vloss', 'rouge1', 'rouge2', 'rougeL']
152 | metrics = []
153 | for name in names:
154 | metric = torch.stack([x[name] for x in outputs]).mean()
155 | if self.trainer.use_ddp:
156 | torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
157 | metric /= self.trainer.world_size
158 | metrics.append(metric)
159 | logs = dict(zip(*[names, metrics]))
160 | # print(logs)
161 |
162 |
163 | ## 将生成的结果写入文件
164 | generated_str = []
165 | gold_str = []
166 | for item in outputs:
167 | generated_str.extend(item['generated'])
168 | gold_str.extend(item['gold'])
169 |
170 |
171 | with open(self.args.save_dir + '/' + self.args.save_prefix + '/generated_summary_%d.txt'%self.generated_id, 'w', encoding='utf-8') as f:
172 | for ending in generated_str:
173 | f.write(str(ending)+'\n')
174 |
175 | with open(self.args.save_dir + '/' + self.args.save_prefix + '/gold_%d.txt'%self.generated_id, 'w', encoding='utf-8') as f:
176 | for ending in gold_str:
177 | f.write(str(ending)+'\n')
178 |
179 | self.generated_id += 1
180 | return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs}
181 |
182 | def test_step(self, batch, batch_nb):
183 | return self.validation_step(batch, batch_nb)
184 |
185 | def test_epoch_end(self, outputs):
186 | result = self.validation_epoch_end(outputs)
187 | # print(result)
188 |
189 | def configure_optimizers(self):
190 | if self.args.adafactor:
191 | optimizer = Adafactor(self.model.parameters(), lr=self.args.lr, scale_parameter=False, relative_step=False)
192 | else:
193 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
194 | num_gpus = 1 ## 设置GPU的个数, 如需使用DDP并行训练,请自行修改第301行的显卡选择。
195 | num_steps = self.args.dataset_size * self.args.epochs / num_gpus / self.args.grad_accum / self.args.batch_size
196 | scheduler = get_linear_schedule_with_warmup(
197 | optimizer, num_warmup_steps=self.args.warmup, num_training_steps=num_steps
198 | )
199 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
200 |
201 | def _get_dataloader(self, current_dataloader, split_name, is_train):
202 | if current_dataloader is not None:
203 | return current_dataloader
204 | dataset = SummarizationDataset(split_name = split_name, tokenizer=self.tokenizer,
205 | max_input_len=self.args.max_input_len, max_output_len=self.args.max_output_len)
206 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None
207 | if split_name != 'train':
208 |
209 | return DataLoader(dataset, batch_size=self.args.val_batch_size, shuffle=(sampler is None),
210 | num_workers=self.args.num_workers, sampler=sampler,
211 | collate_fn=SummarizationDataset.collate_fn)
212 | else:
213 | return DataLoader(dataset, batch_size=self.args.batch_size, shuffle=(sampler is None),
214 | num_workers=self.args.num_workers, sampler=sampler,
215 | collate_fn=SummarizationDataset.collate_fn)
216 |
217 | @pl.data_loader
218 | def train_dataloader(self):
219 | self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True)
220 | return self.train_dataloader_object
221 |
222 | @pl.data_loader
223 | def val_dataloader(self):
224 | self.val_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'val', is_train=False)
225 | return self.val_dataloader_object
226 |
227 | @pl.data_loader
228 | def test_dataloader(self):
229 | self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False)
230 | return self.test_dataloader_object
231 |
232 | def configure_ddp(self, model, device_ids):
233 | model = LightningDistributedDataParallel(
234 | model,
235 | device_ids=device_ids,
236 | find_unused_parameters=False
237 | )
238 | return model
239 |
240 | @staticmethod
241 | def add_model_specific_args(parser, root_dir):
242 | parser.add_argument("--save_dir", type=str, default='output') # 输出文件夹
243 | parser.add_argument("--save_prefix", type=str, default='Sports1') # 输出文件夹,结果保存在 save_dir/save_prefix文件夹下
244 | parser.add_argument("--model_path", type=str, default='model/mbart-large-50-many-to-many-mmt', # mBART50文件目录
245 | help="Path to the checkpoint directory or model name")
246 | parser.add_argument("--tokenizer", type=str, default='model/mbart-large-50-many-to-many-mmt') # mBART50文件目录
247 | parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") # 训练epoch数
248 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") # batch size设置
249 | parser.add_argument("--val_batch_size", type=int, default=4, help="Batch size") # inference时的batch size设置
250 | parser.add_argument("--grad_accum", type=int, default=1, help="number of gradient accumulation steps") # 梯度累计
251 | parser.add_argument("--device_id", type=int, default=0, help="Number of gpus. 0 for CPU") # 使用哪一张卡做训练
252 | parser.add_argument("--warmup", type=int, default=500, help="Number of warmup steps")
253 | parser.add_argument("--lr", type=float, default=2e-5, help="Maximum learning rate") # 学习率
254 | parser.add_argument("--val_every", type=float, default=1.0, help="Number of training steps between validations") # 这里的意思是,没训练一个epoch,在验证集上inference一次
255 | parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers")
256 | parser.add_argument("--seed", type=int, default=1234, help="Seed")
257 | parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing")
258 | parser.add_argument("--max_output_len", type=int, default=128, # 输出最大长度
259 | help="maximum num of wordpieces/summary. Used for training and testing")
260 | parser.add_argument("--max_input_len", type=int, default=1024, # 输入最大长度,请不要超过预训练模型的限制,例如mBART的最大长度限制是1024,否则会报错。
261 | help="maximum num of wordpieces/summary. Used for training and testing")
262 | parser.add_argument("--test", action='store_true', help="Test only, no training")
263 | parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing")
264 | parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32")
265 | parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from")
266 | parser.add_argument("--adafactor", action='store_true', help="Use adafactor optimizer")
267 |
268 | return parser
269 |
270 |
271 | def main(args):
272 | random.seed(args.seed)
273 | np.random.seed(args.seed)
274 | torch.manual_seed(args.seed)
275 | if torch.cuda.is_available():
276 | torch.cuda.manual_seed_all(args.seed)
277 |
278 | model = Summarizer(args)
279 |
280 | logger = TestTubeLogger(
281 | save_dir=args.save_dir,
282 | name=args.save_prefix,
283 | version=0 # always use version=0
284 | )
285 |
286 | checkpoint_callback = ModelCheckpoint(
287 | filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"),
288 | save_top_k=30,
289 | verbose=True,
290 | monitor='avg_val_loss',
291 | mode='min',
292 | period=-1,
293 | prefix=''
294 | )
295 |
296 | print(args)
297 |
298 | args.dataset_size = 20000 # 训练集的sample个数,会影响到warm up,请按需调整
299 |
300 | trainer = pl.Trainer(
301 | gpus = [args.device_id],
302 | distributed_backend = 'ddp' if torch.cuda.is_available() else None,
303 | track_grad_norm = -1,
304 | max_epochs = args.epochs,
305 | replace_sampler_ddp = False,
306 | accumulate_grad_batches = args.grad_accum,
307 | val_check_interval = args.val_every,
308 | num_sanity_val_steps=2,
309 | check_val_every_n_epoch=1,
310 | logger=logger,
311 | checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False,
312 | show_progress_bar=not args.no_progress_bar,
313 | use_amp=not args.fp32, amp_level='O2',
314 | resume_from_checkpoint=args.resume_ckpt,
315 | )
316 | if not args.test:
317 | trainer.fit(model)
318 | trainer.test(model)
319 |
320 |
321 | if __name__ == "__main__":
322 | main_arg_parser = argparse.ArgumentParser(description="summarization")
323 | parser = Summarizer.add_model_specific_args(main_arg_parser, os.getcwd())
324 | args = parser.parse_args()
325 | main(args)
--------------------------------------------------------------------------------