├── pics ├── baseline.png └── overall.png ├── README.md └── rewriter.py /pics/baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krystalan/K-SportsSum/HEAD/pics/baseline.png -------------------------------------------------------------------------------- /pics/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/krystalan/K-SportsSum/HEAD/pics/overall.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge Enhanced Sports Game Summarization 2 |

3 |
4 | 5 |
6 |

7 | 8 | This repository contains data and code for the WSDM 2022 paper [*Knowledge Enhanced Sports Game Summarization*](https://arxiv.org/abs/2111.12535). 9 | 10 | In this work, we propose K-SportsSum dataset as well as the KES model. 11 | - K-SportsSum: It has 7854 sports game summarization samples together with a large-scale knowledge corpus containing information of 523 sports teams and 14k+ sports players. 12 | - KES: a NEW sports game summarization model based on mT5. 13 | 14 | 15 | ### 1. K-SportsSum Dataset 16 | The K-SportsSum dataset is available [here](https://drive.google.com/file/d/1RGWIz3Nw_kzfgIYo_Ke9elLfPOg0rS4V/view?usp=sharing). You can obtain the following five files from the shared link: 17 | - `train.json`, `val.json` and `test.json` are the core data files of K-SportsSum, each of which contains live commentaries and news reports of sports games. 18 | - `player_knowledge.json` contains background knowledge of 14,724 sports players. 19 | - `team_knowledge.json` contains background information of 523 sports teams. 20 | 21 | ### 2. Baseline Model Construction 22 |

23 |
24 | 25 |
26 |

27 | 28 | In this Section, we introduce how to build a two-step baseline system for Sports Game Summarization. As shown in the above Figure, the baseline framework first selects important commentary sentences from original live commentary documents through a text classification model. `Ck` represents each commentary sentence and `Cik` denotes each selected important sentence. Further, we should convert each selected sentence to a news-style sentence through a generative model. `Rik` is the generated sentence corresponding to `Cik`. 29 | 30 | #### 2.1 Selector 31 | The selector is a text classification model. In our work, we resort to the following toolkits: 32 | - [Chinese-Text-Classification-Pytorch](https://github.com/649453932/Chinese-Text-Classification-Pytorch): This toolkit includes multiple codes (both training and inference) of text classification model before BERT era, such as TextCNN. 33 | - [Bert-Chinese-Text-Classification-Pytorch](https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch): This toolkit contains text classification codes after BERT era, e.g., BERT, ERNIE. 34 | 35 | These two toolkits are very useful for buiding a Chinese text classification system. 36 | 37 | #### 2.2 Rewriter 38 | The rewriter is a generative model. Existing works typically employ [PTGen (See, ACL 2017)](https://arxiv.org/abs/1704.04368), mBART, mT5 et al. 39 | 40 | For PTGen, the [pointer_summarizer](https://github.com/atulkum/pointer_summarizer) toolkit is widely used. I also recommend the [implementation](https://github.com/xcfcode/PLM_annotator/tree/main/pgn) released by Xiachong Feng in his dialogue summarization work. Both of implementations are convenient. Please note that if u choose to use PTGen as rewriter, you should select a pre-trained word embedding to help model achieve great performance ([Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors) is helpful). 41 | 42 | For mBART, mT5 et al. We use the implementations of [Huggingface Transformers Library](https://huggingface.co/docs/transformers/index). I release the corresponding training and inference codes for public use (See `rewriter.py`, `mBART-50` is used as rewriter in this code). 43 | Requirements: pytorch-lighting 0.8.5; transformers >= 4.4; torch >= 1.7 44 | (This code is based on the [Longformer code](https://github.com/allenai/longformer/blob/master/scripts/summarization.py) from AI2.) 45 | 46 | For training, you can run commands like this: 47 | ```sh 48 | python rewriter.py --device_id 0 49 | ``` 50 | For evaluation, the command may like this: 51 | ```sh 52 | python rewriter.py --device_id 0 --test 53 | ``` 54 | Note that, if u want to inference with a trained model, remember to initialize the model with corresponding `.ckpt` file. 55 | 56 | #### 2.3 Construct Training Samples for Selector and Rewriter 57 | In order to construct training samples for selector and rewriter, we should map each new sentence to corresponding commentary sentence, if possible. (If u do not understand it, please see more details in Section3.1 of [SportsSum2.0](https://arxiv.org/abs/2110.05750)). 58 | 59 | Thus, the core content of this process is calculating the ROUGE scores and BERTScore given two Chinese sentences. 60 | - ROUGE: you can use [multilingual_rouge_scoring](https://github.com/csebuetnlp/xl-sum/tree/master/multilingual_rouge_scoring) toolkit to calculate the Chinese ROUGE Scores. Note that, the [py-rouge](https://github.com/Diego999/py-rouge) and [rouge](https://github.com/pltrdy/rouge) toolkits are not suitable for Chinese. 61 | - BERTScore: Please find more details in [bert_score](https://github.com/Tiiiger/bert_score). 62 | 63 | ### 3. KES model 64 | Code will be published once the author of this repo has time. 65 | 66 | ### Existing Works 67 | To facilitate researchers to efficiently comprehend and follow the Sports Game Summarization task, we write a Chinese survey post: [《体育赛事摘要任务概览》](https://mp.weixin.qq.com/s/EidRYB_80AhRclz-mryVhQ), where we also discuss some future directions and give our thoughts. 68 | 69 | We list and classify existing works of Sports Game Summarization: 70 | 71 | | Paper | Conference/Journal | Data/Code | Category | 72 | | :--: | :--: | :--: | :--: | 73 | | [Towards Constructing Sports News from Live Text Commentary](https://aclanthology.org/P16-1129) | ACL 2016 | - | `Dataset`, `Ext.` | 74 | | [Overview of the NLPCC-ICCPOL 2016 Shared Task: Sports News Generation from Live Webcast Scripts](https://link.springer.com/chapter/10.1007%2F978-3-319-50496-4_80) | NLPCC 2016 | [NLPCC 2016 shared task](http://tcci.ccf.org.cn/conference/2016/pages/page05_CFPTasks.html) | `Dataset` | 75 | | [Research on Summary Sentences Extraction Oriented to Live Sports Text](https://link.springer.com/chapter/10.1007%2F978-3-319-50496-4_72) | NLPCC 2016 | - | `Ext.` | 76 | | [Sports News Generation from Live Webcast Scripts Based on Rules and Templates](https://link.springer.com/chapter/10.1007%2F978-3-319-50496-4_81) | NLPCC 2016 | - | `Ext.+Temp.` | 77 | | [Content Selection for Real-time Sports News Construction from Commentary Texts](https://aclanthology.org/W17-3504/) | INLG 2017 | - | `Ext.` | 78 | | [Generate Football News from Live Webcast Scripts Based on Character-CNN with Five Strokes](http://csroc.org.tw/journal/JOC31-1/JOC3101-21.pdf) | 2020 | - | `Ext.+Temp.` | 79 | | [Generating Sports News from Live Commentary: A Chinese Dataset for Sports Game Summarization](https://aclanthology.org/2020.aacl-main.61/) | AACL 2020 | [SportsSum](https://github.com/ej0cl6/SportsSum) | `Dataset`, `Ext.+Abs.` | 80 | | [SportsSum2.0: Generating High-Quality Sports News from Live Text Commentary](https://arxiv.org/abs/2110.05750) | CIKM 2021 | [SportsSum2.0](https://github.com/krystalan/SportsSum2.0) | `Dataset`, `Ext.+Abs.` | 81 | | [Knowledge Enhanced Sports Game Summarization](https://arxiv.org/abs/2111.12535) | WSDM 2022 | [K-SportsSum](https://github.com/krystalan/K-SportsSum) | `Dataset`, `Ext.+Abs.` | 82 | 83 | The concepts used in Category are illustrated as follows: 84 | - `Dataset`: The work contributes a dataset for sports game summarization. 85 | - `Ext.`: Extractive sports game summarization method. 86 | - `Ext.+Temp.`: The method first extracts important commentary sentence and further utilize the human-labeled template to convey each commentary sentence to a news sentence. 87 | - `Ext.+Abs.`: The method first extracts important commentary sentence and further utilize the seq2seq model to convey each commentary sentence to the news sentence. 88 | 89 | ### Q&A 90 | Q1: What the differences among SportsSum, SportsSum2.0, SGSum and K-SportsSum? 91 | A1: **SportsSum (Huang et al. AACL 2020)** is the first large-scale Sports Game Summarization dataset which has 5428 samples. Though its wonderful contribution, the SportsSum dataset has about 15% noisy samples. Thus, **SportsSum2.0 (Wang et al, CIKM 2021)** cleans the original SportsSum and obtains 5402 samples (26 bad samples in SportsSum are removed). Following previous works, **SGSum (Non-Archival Paper, 未正式发表)** collects and cleans a large amount of data from massive games. It has 7854 samples. **K-SportsSum (Wang et al. WSDM 2022)** shuffle and randomly divide the **SGSum**. Furthermore, **K-SportsSum** has a large-scale knowledge corpus about sports teams and players, which could be useful for alleviating the knowledge gap issue (See K-SportsSum paper). 92 | 93 | Q2: There is less code about sports game summarization. 94 | A2: Yeah, I know that. All existing works follow the pipeline paradigm to build sports game summarization systems. They may have two or three steps together with a pseudo label construction process. Thus, the code is too messy. For the solution, we 1) release a tutorial for building a two-step baseline for Sports Game Summarization (See Section2 in this page); 2) build an end-to-end model for public use (Work in progress, maybe will be published in 2022, but there is no guarantee). 95 | 96 | Q3: About position embedding in mT5. 97 | A3: Position embedding of mT5 is set to zero vector since it uses relative position embeddings in self-attention. 98 | 99 | Q4: Any questions and suggestions? 100 | A4: Please feel free to contact me (jawang1[at]suda.edu.cn). 101 | 102 | ### Acknowledgement 103 | Jiaan Wang would like to thank **[KW Lab, Fudan Univ.](http://kw.fudan.edu.cn/)** and **[iFLYTEK AI Research, Suzhou](https://www.iflytek.com/index.html)** for their helpful discussions and GPU device support. 104 | 105 | ### Citation 106 | If you find this project is useful or use the data in your work, please consider cite our paper: 107 | ``` 108 | @article{Wang2022KnowledgeES, 109 | title={Knowledge Enhanced Sports Game Summarization}, 110 | author={Jiaan Wang and Zhixu Li and Tingyi Zhang and Duo Zheng and Jianfeng Qu and An Liu and Lei Zhao and Zhigang Chen}, 111 | journal={Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining}, 112 | year={2022} 113 | } 114 | ``` 115 | -------------------------------------------------------------------------------- /rewriter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import DataLoader, Dataset 8 | from transformers.optimization import get_linear_schedule_with_warmup, Adafactor 9 | import rouge 10 | 11 | import pytorch_lightning as pl 12 | from pytorch_lightning.logging import TestTubeLogger 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel 15 | 16 | from transformers import MBartForConditionalGeneration, MBartTokenizer, MBart50TokenizerFast 17 | 18 | import json 19 | 20 | 21 | class SummarizationDataset(Dataset): 22 | def __init__(self, split_name, tokenizer, max_input_len, max_output_len): 23 | self.tokenizer = tokenizer 24 | self.max_input_len = max_input_len 25 | self.max_output_len = max_output_len 26 | ## 读取数据集,将数据集处理成 train.json, val.json 以及 test.json 三个文件夹 27 | with open('%s.json'%split_name, 'r', encoding='utf-8') as f: 28 | self.hf_dataset = json.load(f) 29 | 30 | def __len__(self): 31 | return len(self.hf_dataset) 32 | 33 | def __getitem__(self, idx): 34 | entry = self.hf_dataset[idx] 35 | ''' 36 | json文件中的数据格式: 37 | [ 38 | { 39 | "input": "一句评论句Ci1", 40 | "output": "对应的新闻句Ri1" 41 | }, 42 | { 43 | "input": "一句评论句Ci2", 44 | "output": "对应的新闻句Ri2" 45 | }, 46 | ... 47 | { 48 | "input": "一句评论句Cim", 49 | "output": "对应的新闻句Rim" 50 | }, 51 | ] 52 | ''' 53 | input_ids = self.tokenizer.encode(entry['input'].lower(), truncation=True, max_length=self.max_input_len) 54 | with self.tokenizer.as_target_tokenizer(): 55 | output_ids = self.tokenizer.encode(entry['output'].lower(), truncation=True, max_length=self.max_output_len) 56 | return torch.tensor(input_ids), torch.tensor(output_ids) 57 | 58 | @staticmethod 59 | def collate_fn(batch): 60 | pad_token_id = 1 # 对于bart/mbart来说,pad token的id是1,pegasus/t5是0 61 | input_ids, output_ids = list(zip(*batch)) 62 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_token_id) 63 | output_ids = torch.nn.utils.rnn.pad_sequence(output_ids, batch_first=True, padding_value=pad_token_id) 64 | return input_ids, output_ids 65 | 66 | 67 | class Summarizer(pl.LightningModule): 68 | def __init__(self, params): 69 | super().__init__() 70 | self.args = params 71 | self.hparams = params 72 | self.src_lang = "zh_CN" # 在使用mBART中,需要添加语言token,因为体育赛事摘要任务的输入与输出均为中文,所以我们都设置成中文 73 | self.tgt_lang = "zh_CN" 74 | self.tokenizer = MBart50TokenizerFast.from_pretrained(self.args.model_path, src_lang=self.src_lang, tgt_lang=self.tgt_lang) 75 | self.model = MBartForConditionalGeneration.from_pretrained(self.args.model_path) 76 | self.train_dataloader_object = self.val_dataloader_object = self.test_dataloader_object = None 77 | self.generated_id = 0 78 | 79 | self.decoder_start_token_id = self.tokenizer.lang_code_to_id[self.tgt_lang] 80 | self.model.config.decoder_start_token_id = self.decoder_start_token_id 81 | 82 | 83 | 84 | def _prepare_input(self, input_ids): 85 | attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) 86 | attention_mask[input_ids == self.tokenizer.pad_token_id] = 0 87 | return input_ids, attention_mask 88 | 89 | def forward(self, input_ids, output_ids): 90 | input_ids, attention_mask = self._prepare_input(input_ids) 91 | decoder_input_ids = output_ids[:, :-1] 92 | decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id) 93 | 94 | labels = output_ids[:, 1:].clone() 95 | outputs = self.model( 96 | input_ids, 97 | attention_mask=attention_mask, 98 | decoder_input_ids=decoder_input_ids, 99 | decoder_attention_mask=decoder_attention_mask, 100 | use_cache=False, 101 | ) 102 | 103 | lm_logits = outputs[0] 104 | ce_loss_fct = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) 105 | assert lm_logits.shape[-1] == self.model.config.vocab_size 106 | loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1)) 107 | 108 | return [loss] 109 | 110 | def training_step(self, batch, batch_nb): 111 | output = self.forward(*batch) 112 | loss = output[0] 113 | lr = loss.new_zeros(1) + self.trainer.optimizers[0].param_groups[0]['lr'] 114 | tensorboard_logs = {'train_loss': loss, 'lr': lr, 115 | 'input_size': batch[0].numel(), 116 | 'output_size': batch[1].numel(), 117 | 'mem': torch.cuda.memory_allocated(loss.device) / 1024 ** 3 if torch.cuda.is_available() else 0} 118 | return {'loss': loss, 'log': tensorboard_logs} 119 | 120 | def validation_step(self, batch, batch_nb): 121 | for p in self.model.parameters(): 122 | p.requires_grad = False 123 | 124 | outputs = self.forward(*batch) 125 | vloss = outputs[0] 126 | input_ids, output_ids = batch 127 | input_ids, attention_mask = self._prepare_input(input_ids) 128 | 129 | ### 这里设置inference时的长度,beam search的num等参数。 130 | generated_ids = self.model.generate( 131 | input_ids=input_ids, 132 | attention_mask=attention_mask, 133 | use_cache=True, 134 | num_beams= 5, 135 | max_length = 128, 136 | decoder_start_token_id=self.tokenizer.lang_code_to_id[self.tgt_lang], 137 | ) 138 | generated_str = self.tokenizer.batch_decode(generated_ids.tolist(), skip_special_tokens=True) 139 | gold_str = self.tokenizer.batch_decode(output_ids.tolist(), skip_special_tokens=True) 140 | 141 | 142 | return {'vloss': vloss, 143 | 'generated': generated_str, 144 | 'gold': gold_str 145 | } 146 | 147 | def validation_epoch_end(self, outputs): 148 | for p in self.model.parameters(): 149 | p.requires_grad = True 150 | 151 | names = ['vloss', 'rouge1', 'rouge2', 'rougeL'] 152 | metrics = [] 153 | for name in names: 154 | metric = torch.stack([x[name] for x in outputs]).mean() 155 | if self.trainer.use_ddp: 156 | torch.distributed.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) 157 | metric /= self.trainer.world_size 158 | metrics.append(metric) 159 | logs = dict(zip(*[names, metrics])) 160 | # print(logs) 161 | 162 | 163 | ## 将生成的结果写入文件 164 | generated_str = [] 165 | gold_str = [] 166 | for item in outputs: 167 | generated_str.extend(item['generated']) 168 | gold_str.extend(item['gold']) 169 | 170 | 171 | with open(self.args.save_dir + '/' + self.args.save_prefix + '/generated_summary_%d.txt'%self.generated_id, 'w', encoding='utf-8') as f: 172 | for ending in generated_str: 173 | f.write(str(ending)+'\n') 174 | 175 | with open(self.args.save_dir + '/' + self.args.save_prefix + '/gold_%d.txt'%self.generated_id, 'w', encoding='utf-8') as f: 176 | for ending in gold_str: 177 | f.write(str(ending)+'\n') 178 | 179 | self.generated_id += 1 180 | return {'avg_val_loss': logs['vloss'], 'log': logs, 'progress_bar': logs} 181 | 182 | def test_step(self, batch, batch_nb): 183 | return self.validation_step(batch, batch_nb) 184 | 185 | def test_epoch_end(self, outputs): 186 | result = self.validation_epoch_end(outputs) 187 | # print(result) 188 | 189 | def configure_optimizers(self): 190 | if self.args.adafactor: 191 | optimizer = Adafactor(self.model.parameters(), lr=self.args.lr, scale_parameter=False, relative_step=False) 192 | else: 193 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr) 194 | num_gpus = 1 ## 设置GPU的个数, 如需使用DDP并行训练,请自行修改第301行的显卡选择。 195 | num_steps = self.args.dataset_size * self.args.epochs / num_gpus / self.args.grad_accum / self.args.batch_size 196 | scheduler = get_linear_schedule_with_warmup( 197 | optimizer, num_warmup_steps=self.args.warmup, num_training_steps=num_steps 198 | ) 199 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 200 | 201 | def _get_dataloader(self, current_dataloader, split_name, is_train): 202 | if current_dataloader is not None: 203 | return current_dataloader 204 | dataset = SummarizationDataset(split_name = split_name, tokenizer=self.tokenizer, 205 | max_input_len=self.args.max_input_len, max_output_len=self.args.max_output_len) 206 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) if self.trainer.use_ddp else None 207 | if split_name != 'train': 208 | 209 | return DataLoader(dataset, batch_size=self.args.val_batch_size, shuffle=(sampler is None), 210 | num_workers=self.args.num_workers, sampler=sampler, 211 | collate_fn=SummarizationDataset.collate_fn) 212 | else: 213 | return DataLoader(dataset, batch_size=self.args.batch_size, shuffle=(sampler is None), 214 | num_workers=self.args.num_workers, sampler=sampler, 215 | collate_fn=SummarizationDataset.collate_fn) 216 | 217 | @pl.data_loader 218 | def train_dataloader(self): 219 | self.train_dataloader_object = self._get_dataloader(self.train_dataloader_object, 'train', is_train=True) 220 | return self.train_dataloader_object 221 | 222 | @pl.data_loader 223 | def val_dataloader(self): 224 | self.val_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'val', is_train=False) 225 | return self.val_dataloader_object 226 | 227 | @pl.data_loader 228 | def test_dataloader(self): 229 | self.test_dataloader_object = self._get_dataloader(self.test_dataloader_object, 'test', is_train=False) 230 | return self.test_dataloader_object 231 | 232 | def configure_ddp(self, model, device_ids): 233 | model = LightningDistributedDataParallel( 234 | model, 235 | device_ids=device_ids, 236 | find_unused_parameters=False 237 | ) 238 | return model 239 | 240 | @staticmethod 241 | def add_model_specific_args(parser, root_dir): 242 | parser.add_argument("--save_dir", type=str, default='output') # 输出文件夹 243 | parser.add_argument("--save_prefix", type=str, default='Sports1') # 输出文件夹,结果保存在 save_dir/save_prefix文件夹下 244 | parser.add_argument("--model_path", type=str, default='model/mbart-large-50-many-to-many-mmt', # mBART50文件目录 245 | help="Path to the checkpoint directory or model name") 246 | parser.add_argument("--tokenizer", type=str, default='model/mbart-large-50-many-to-many-mmt') # mBART50文件目录 247 | parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") # 训练epoch数 248 | parser.add_argument("--batch_size", type=int, default=4, help="Batch size") # batch size设置 249 | parser.add_argument("--val_batch_size", type=int, default=4, help="Batch size") # inference时的batch size设置 250 | parser.add_argument("--grad_accum", type=int, default=1, help="number of gradient accumulation steps") # 梯度累计 251 | parser.add_argument("--device_id", type=int, default=0, help="Number of gpus. 0 for CPU") # 使用哪一张卡做训练 252 | parser.add_argument("--warmup", type=int, default=500, help="Number of warmup steps") 253 | parser.add_argument("--lr", type=float, default=2e-5, help="Maximum learning rate") # 学习率 254 | parser.add_argument("--val_every", type=float, default=1.0, help="Number of training steps between validations") # 这里的意思是,没训练一个epoch,在验证集上inference一次 255 | parser.add_argument("--num_workers", type=int, default=0, help="Number of data loader workers") 256 | parser.add_argument("--seed", type=int, default=1234, help="Seed") 257 | parser.add_argument("--disable_checkpointing", action='store_true', help="No logging or checkpointing") 258 | parser.add_argument("--max_output_len", type=int, default=128, # 输出最大长度 259 | help="maximum num of wordpieces/summary. Used for training and testing") 260 | parser.add_argument("--max_input_len", type=int, default=1024, # 输入最大长度,请不要超过预训练模型的限制,例如mBART的最大长度限制是1024,否则会报错。 261 | help="maximum num of wordpieces/summary. Used for training and testing") 262 | parser.add_argument("--test", action='store_true', help="Test only, no training") 263 | parser.add_argument("--no_progress_bar", action='store_true', help="no progress bar. Good for printing") 264 | parser.add_argument("--fp32", action='store_true', help="default is fp16. Use --fp32 to switch to fp32") 265 | parser.add_argument("--resume_ckpt", type=str, help="Path of a checkpoint to resume from") 266 | parser.add_argument("--adafactor", action='store_true', help="Use adafactor optimizer") 267 | 268 | return parser 269 | 270 | 271 | def main(args): 272 | random.seed(args.seed) 273 | np.random.seed(args.seed) 274 | torch.manual_seed(args.seed) 275 | if torch.cuda.is_available(): 276 | torch.cuda.manual_seed_all(args.seed) 277 | 278 | model = Summarizer(args) 279 | 280 | logger = TestTubeLogger( 281 | save_dir=args.save_dir, 282 | name=args.save_prefix, 283 | version=0 # always use version=0 284 | ) 285 | 286 | checkpoint_callback = ModelCheckpoint( 287 | filepath=os.path.join(args.save_dir, args.save_prefix, "checkpoints"), 288 | save_top_k=30, 289 | verbose=True, 290 | monitor='avg_val_loss', 291 | mode='min', 292 | period=-1, 293 | prefix='' 294 | ) 295 | 296 | print(args) 297 | 298 | args.dataset_size = 20000 # 训练集的sample个数,会影响到warm up,请按需调整 299 | 300 | trainer = pl.Trainer( 301 | gpus = [args.device_id], 302 | distributed_backend = 'ddp' if torch.cuda.is_available() else None, 303 | track_grad_norm = -1, 304 | max_epochs = args.epochs, 305 | replace_sampler_ddp = False, 306 | accumulate_grad_batches = args.grad_accum, 307 | val_check_interval = args.val_every, 308 | num_sanity_val_steps=2, 309 | check_val_every_n_epoch=1, 310 | logger=logger, 311 | checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else False, 312 | show_progress_bar=not args.no_progress_bar, 313 | use_amp=not args.fp32, amp_level='O2', 314 | resume_from_checkpoint=args.resume_ckpt, 315 | ) 316 | if not args.test: 317 | trainer.fit(model) 318 | trainer.test(model) 319 | 320 | 321 | if __name__ == "__main__": 322 | main_arg_parser = argparse.ArgumentParser(description="summarization") 323 | parser = Summarizer.add_model_specific_args(main_arg_parser, os.getcwd()) 324 | args = parser.parse_args() 325 | main(args) --------------------------------------------------------------------------------