├── .gitignore ├── LICENSE ├── README.md ├── datamodule.py ├── main.py ├── model.py └── postprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 jzm-chairman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sohu2022-nlp-rank1 2 | 3 | 2022搜狐校园算法大赛NLP赛道第一名开源方案(实验代码) 4 | 5 | [方案介绍文章](https://zhuanlan.zhihu.com/p/533808475) 6 | 7 | 该代码使用pytorch-lightning框架进行编写。**注意:该代码为我本人在初赛阶段实验和迭代使用的代码,并非用于复赛和决赛提交的代码,有部分trick没有加入,效果会比最终提交代码稍差。** 8 | 9 | 核心代码段: 10 | 11 | `datamodule.py`的输入构造部分 12 | 13 | ```python 14 | def _setup(self, data): 15 | output = [] 16 | for item in tqdm(data): 17 | output_item = {} 18 | text = item["content"] 19 | if not text or not item["entity"]: 20 | continue 21 | prompt = "".join([f"{entity}{self.mask_symbol}" for entity in item["entity"]]) 22 | inputs = self.tokenizer.__call__(text=text, text_pair=prompt, add_special_tokens=True, max_length=self.hparams.max_length, truncation="only_first") 23 | inputs["is_masked"] = [int(i == self.tokenizer.mask_token_id) for i in inputs["input_ids"]] 24 | inputs["first_mask"] = [int(i == 0) for i in inputs["token_type_ids"]] 25 | output_item["inputs"] = inputs 26 | if isinstance(item["entity"], dict): 27 | labels = list(map(lambda x: x + 2, item["entity"].values())) 28 | output_item["labels"] = labels 29 | output.append(output_item) 30 | ``` 31 | 32 | `model.py`的`forward`部分 33 | 34 | ```python 35 | def forward(self, inputs, output_hidden_states=False): 36 | is_masked = inputs['is_masked'].bool() 37 | first_mask = inputs.get('first_mask', None) 38 | inputs = {k: v for k, v in inputs.items() if k in ["input_ids", "attention_mask", "token_type_ids"]} 39 | backbone_outputs = self.xlnet(**inputs, output_hidden_states=True) 40 | masked_outputs = backbone_outputs.last_hidden_state[is_masked] 41 | ... 42 | logits = self.classifier(masked_outputs) 43 | if not output_hidden_states: 44 | return logits 45 | ... 46 | ``` 47 | -------------------------------------------------------------------------------- /datamodule.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import random 5 | 6 | # import roformer 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | import pytorch_lightning as pl 10 | import transformers 11 | from tqdm import tqdm 12 | 13 | 14 | class BasicDataset(Dataset): 15 | def __init__(self, data: list): 16 | self.data = data 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, item): 22 | return self.data[item] 23 | 24 | class DataModule(pl.LightningDataModule): 25 | def __init__(self, **kwargs): 26 | super().__init__() 27 | self.save_hyperparameters() 28 | self.train_dataset = self.valid_dataset = self.test_dataset = None 29 | if self.hparams.model_type == "xlnet": 30 | self.tokenizer = transformers.XLNetTokenizerFast.from_pretrained(self.hparams.model_name) 31 | self.mask_symbol = "" 32 | elif self.hparams.model_type == "roformer": 33 | self.tokenizer = transformers.BertTokenizerFast.from_pretrained(self.hparams.model_name) 34 | self.mask_symbol = "[MASK][PAD]" 35 | else: 36 | self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.hparams.model_name) 37 | self.mask_symbol = "[MASK]," 38 | 39 | def _setup(self, data): 40 | output = [] 41 | for item in tqdm(data): 42 | output_item = {} 43 | text = item["content"] 44 | if not text or not item["entity"]: 45 | continue 46 | prompt = "".join([f"{entity}{self.mask_symbol}" for entity in item["entity"]]) 47 | inputs = self.tokenizer.__call__(text=text, text_pair=prompt, add_special_tokens=True, max_length=self.hparams.max_length, truncation="only_first") 48 | inputs["is_masked"] = [int(i == self.tokenizer.mask_token_id) for i in inputs["input_ids"]] 49 | inputs["first_mask"] = [int(i == 0) for i in inputs["token_type_ids"]] 50 | output_item["inputs"] = inputs 51 | if isinstance(item["entity"], dict): 52 | labels = list(map(lambda x: x + 2, item["entity"].values())) 53 | output_item["labels"] = labels 54 | output.append(output_item) 55 | return output 56 | 57 | def prepare_data(self) -> None: 58 | load = lambda file: list(map(json.loads, open(file, "r+", encoding="utf-8").readlines())) 59 | self.train_cache_file = self.hparams.train_data_path.replace(".txt", f"_{self.hparams.model_type}.pkl") 60 | self.test_cache_file = self.hparams.test_data_path.replace(".txt", f"_{self.hparams.model_type}.pkl") 61 | if not os.path.exists(self.train_cache_file): 62 | train_data = self._setup(load(self.hparams.train_data_path)) 63 | pickle.dump(train_data, open(self.train_cache_file, "wb")) 64 | if not os.path.exists(self.test_cache_file): 65 | test_data = self._setup(load(self.hparams.test_data_path)) 66 | pickle.dump(test_data, open(self.test_cache_file, "wb")) 67 | if self.hparams.pseudo_data_path: 68 | pass 69 | 70 | def setup(self, stage=None): 71 | load_pkl = lambda file: pickle.load(open(file, "rb")) 72 | if stage in ["fit", "validate"]: 73 | if self.train_dataset is None or self.valid_dataset is None: 74 | train_data = load_pkl(self.train_cache_file) 75 | if self.hparams.shuffle_valid: 76 | random.shuffle(train_data) 77 | expanded_train_data = train_data[self.hparams.valid_size:] 78 | if self.hparams.pseudo_data_path: 79 | pass 80 | self.train_dataset = BasicDataset(train_data[self.hparams.valid_size:]) 81 | self.valid_dataset = BasicDataset(train_data[:self.hparams.valid_size]) 82 | elif stage in ["test", "predict"]: 83 | if self.test_dataset is None: 84 | self.test_dataset = BasicDataset(load_pkl(self.test_cache_file)) 85 | else: 86 | raise NotImplementedError 87 | 88 | def collate_fn(self, batch): 89 | output = {"inputs": {key: [] for key in batch[0]["inputs"]}, "labels": []} 90 | for item in batch: 91 | for key in item["inputs"]: 92 | output["inputs"][key].append(torch.tensor(item["inputs"][key])) 93 | output["labels"].extend(item["labels"]) 94 | for key in output["inputs"]: 95 | output["inputs"][key] = torch.nn.utils.rnn.pad_sequence(output["inputs"][key], batch_first=True, padding_value=0) 96 | output["labels"] = torch.tensor(output["labels"]) 97 | return output 98 | 99 | def test_collate_fn(self, batch): 100 | output = {"inputs": {key: [] for key in batch[0]["inputs"]}} 101 | for item in batch: 102 | for key in item["inputs"]: 103 | output["inputs"][key].append(torch.tensor(item["inputs"][key])) 104 | for key in output["inputs"]: 105 | output["inputs"][key] = torch.nn.utils.rnn.pad_sequence(output["inputs"][key], batch_first=True, padding_value=0) 106 | return output 107 | 108 | def train_dataloader(self): 109 | return DataLoader(self.train_dataset, batch_size=self.hparams.train_batch_size, shuffle=True, collate_fn=self.collate_fn) 110 | 111 | def val_dataloader(self): 112 | return DataLoader(self.valid_dataset, batch_size=self.hparams.test_batch_size, shuffle=False, collate_fn=self.collate_fn) 113 | 114 | def test_dataloader(self): 115 | return DataLoader(self.test_dataset, batch_size=self.hparams.test_batch_size, shuffle=False, collate_fn=self.test_collate_fn) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from argparse import ArgumentParser 4 | import pytorch_lightning as pl 5 | 6 | from datamodule import DataModule 7 | from model import SentimentClassifier, SWASentimentClassifier 8 | from postprocess import postprocess 9 | 10 | 11 | def get_callbacks(args): 12 | callbacks = [ 13 | pl.callbacks.ModelCheckpoint(dirpath=args.output_path, every_n_epochs=1, save_on_train_epoch_end=False, monitor="val_acc", save_last=True, save_top_k=10, mode="max", auto_insert_metric_name=True) 14 | ] 15 | return callbacks 16 | 17 | def console_args(): 18 | parser = ArgumentParser() 19 | parser.add_argument("--mode", type=str, default="train", help="train or test") 20 | parser.add_argument("--model_type", type=str, default="xlnet") 21 | parser.add_argument("--model_name", type=str, default="hfl/chinese-xlnet-base", help="model name") 22 | parser.add_argument("--device", type=str, default="cuda", help="device") 23 | parser.add_argument("--gpus", type=int, default=1, help="number of gpus") 24 | 25 | parser.add_argument("--root_path", type=str, default="resources/nlp_data", help="root path") 26 | parser.add_argument("--train_data_path", type=str, default="resources/nlp_data/train.txt", help="train data path") 27 | parser.add_argument("--pseudo_data_path", type=str, help="pseudo data path", required=False) 28 | parser.add_argument("--test_data_path", type=str, default="resources/nlp_data/test.txt", help="test data path") 29 | parser.add_argument("--valid_size", type=int, default=2000, help="valid size") 30 | parser.add_argument("--num_workers", type=int, default=8, help="num workers") 31 | parser.add_argument("--train_batch_size", type=int, default=6, help="train batch size") 32 | parser.add_argument("--test_batch_size", type=int, default=32, help="test batch size") 33 | parser.add_argument("--max_length", type=int, default=900, help="max length") 34 | 35 | parser.add_argument("--lr", type=float, default=5e-5, help="learning rate") 36 | parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay") 37 | parser.add_argument("--eps", type=float, default=1e-8, help="eps") 38 | parser.add_argument("--num_warmup_steps", type=int, help="warmup steps", required=False) 39 | parser.add_argument("--warmup_proportion", type=float, default=0.1) 40 | 41 | parser.add_argument("--output_path", type=str, default="output", help="output path") 42 | parser.add_argument("--label_smoothing", type=float, default=0.05) 43 | parser.add_argument("--num_classes", type=int, default=5) 44 | parser.add_argument("--dropout", type=float, default=0.1) 45 | parser.add_argument("--layer_norm", type=bool, default=False) 46 | parser.add_argument("--regression", type=bool, default=False) 47 | parser.add_argument("--r_drop", type=bool, default=False) 48 | parser.add_argument("--kl_weight", type=float, default=1.0) 49 | parser.add_argument("--pooling_layers", type=int, default=1) 50 | parser.add_argument("--attack_epsilon", type=float, default=0.1) 51 | 52 | parser.add_argument("--gradient_clip_val", default=1.0, type=float) 53 | parser.add_argument("--gradient_clip_algorithm", default="norm", type=str) 54 | parser.add_argument("--accumulate_grad_batches", default=3, type=int) 55 | 56 | parser.add_argument("--max_epochs", type=int, default=10, help="epochs") 57 | parser.add_argument("--precision", type=int, default=16, help="precision") 58 | parser.add_argument("--seed", type=int, default=19260817, help="seed") 59 | parser.add_argument("--ckpt_path", type=str, required=False) 60 | parser.add_argument("--is_extra_output", type=bool, default=False) 61 | parser.add_argument("--use_swa", type=bool, default=False) 62 | parser.add_argument("--shuffle_valid", type=bool, default=True) 63 | parser.add_argument("--adv_train", type=bool, default=False) 64 | parser.add_argument("--optimize_f1", default=False, type=bool) 65 | 66 | return parser.parse_args() 67 | 68 | if __name__ == "__main__": 69 | args = console_args() 70 | if not os.path.exists(args.output_path): 71 | os.makedirs(args.output_path) 72 | pl.seed_everything(args.seed) 73 | if args.adv_train: 74 | args.manual_gradient_clip_val = args.gradient_clip_val 75 | args.manual_gradient_clip_algorithm = args.gradient_clip_algorithm 76 | args.gradient_clip_val = args.gradient_clip_algorithm = None 77 | if args.mode == "train": 78 | args.num_training_steps = math.ceil(((len(open(args.train_data_path, "r+", encoding="utf-8").readlines()) - args.valid_size) / args.train_batch_size) * args.max_epochs / args.accumulate_grad_batches) 79 | if args.num_warmup_steps is None: 80 | args.num_warmup_steps = round(args.num_training_steps * args.warmup_proportion) 81 | datamodule = DataModule(**vars(args)) 82 | if args.use_swa: 83 | model = SWASentimentClassifier(**vars(args)) 84 | else: 85 | model = SentimentClassifier(**vars(args)) 86 | model.datamodule = datamodule 87 | trainer: pl.Trainer = pl.Trainer.from_argparse_args(args, callbacks=get_callbacks(args), detect_anomaly=False) 88 | if args.mode == "train": 89 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=args.ckpt_path) 90 | elif args.mode == "test": 91 | if args.optimize_f1: 92 | trainer.validate(model=model, datamodule=datamodule, ckpt_path=args.ckpt_path) 93 | trainer.test(model=model, datamodule=datamodule, ckpt_path=args.ckpt_path) 94 | postprocess(prediction_file=os.path.join(args.output_path, "prediction.pkl"), raw_file=args.test_data_path, output_file=os.path.join(args.output_path, "section1.txt")) 95 | else: 96 | raise NotImplementedError -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import scipy 7 | import torch 8 | import transformers 9 | from torch import nn 10 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score, f1_score 11 | 12 | 13 | # import roformer 14 | 15 | class FGM: 16 | def __init__(self, model): 17 | self.model = model 18 | self.backup = {} 19 | 20 | def attack(self, epsilon=0.25, emb_name='word_embeddings'): 21 | # emb_name这个参数要换成你模型中embedding的参数名 22 | for name, param in self.model.named_parameters(): 23 | if param.requires_grad and emb_name in name: 24 | self.backup[name] = param.data.clone() 25 | norm = torch.norm(param.grad) 26 | if norm != 0 and not torch.isnan(norm): 27 | r_at = epsilon * param.grad / norm 28 | param.data.add_(r_at) 29 | 30 | def restore(self, emb_name='word_embeddings'): 31 | # emb_name这个参数要换成你模型中embedding的参数名 32 | for name, param in self.model.named_parameters(): 33 | if param.requires_grad and emb_name in name: 34 | assert name in self.backup 35 | param.data = self.backup[name] 36 | self.backup = {} 37 | 38 | class SentimentClassifier(pl.LightningModule): 39 | def __init__(self, **kwargs): 40 | super().__init__() 41 | self.save_hyperparameters() 42 | self.automatic_optimization = not self.hparams.adv_train 43 | if self.hparams.model_type == "xlnet": 44 | self.xlnet: transformers.models.xlnet.XLNetModel = transformers.XLNetModel.from_pretrained(self.hparams.model_name) 45 | self.hidden_size = self.xlnet.config.d_model 46 | elif self.hparams.model_type == "roformer": 47 | self.xlnet: roformer.RoFormerModel = roformer.RoFormerModel.from_pretrained(self.hparams.model_name, max_position_embeddings=1536) 48 | self.hidden_size = self.xlnet.config.hidden_size 49 | else: 50 | self.xlnet = transformers.AutoModel.from_pretrained(self.hparams.model_name) 51 | self.hidden_size = self.xlnet.config.hidden_size 52 | if self.hparams.regression: 53 | self.criterion = nn.MSELoss() 54 | self.output_dim = 1 55 | else: 56 | self.criterion = nn.CrossEntropyLoss(label_smoothing=self.hparams.label_smoothing) 57 | self.output_dim = self.hparams.num_classes 58 | if self.hparams.layer_norm: 59 | self.classifier = nn.Sequential( 60 | nn.Linear(self.hidden_size, self.hidden_size), 61 | nn.LayerNorm(self.hidden_size), 62 | nn.LeakyReLU(), 63 | nn.Dropout(p=self.hparams.dropout), 64 | nn.Linear(self.hidden_size, self.output_dim), 65 | ) 66 | else: 67 | self.classifier = nn.Sequential( 68 | nn.Linear(self.hidden_size, self.hidden_size), 69 | nn.LeakyReLU(), 70 | nn.Dropout(p=self.hparams.dropout), 71 | nn.Linear(self.hidden_size, self.output_dim), 72 | ) 73 | self.kld = nn.KLDivLoss(reduction="batchmean") 74 | self.attacker = FGM(self) if self.hparams.adv_train else None 75 | self.class_weights = None 76 | 77 | def ttl(self, t): 78 | return t.detach().cpu().numpy() 79 | 80 | def logits_to_prediction(self, logits): 81 | if not self.hparams.regression: 82 | return torch.argmax(logits, dim=1) 83 | prediction = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device) 84 | prediction[logits < -1.5] = 0 85 | prediction[(logits >= -1.5) & (logits < -0.5)] = 1 86 | prediction[(logits >= -0.5) & (logits < 0.5)] = 2 87 | prediction[(logits >= 0.5) & (logits < 1.5)] = 3 88 | prediction[logits >= 1.5] = 4 89 | return prediction 90 | 91 | def forward(self, inputs, output_hidden_states=False): 92 | is_masked = inputs['is_masked'].bool() 93 | first_mask = inputs.get('first_mask', None) 94 | inputs = {k: v for k, v in inputs.items() if k in ["input_ids", "attention_mask", "token_type_ids"]} 95 | backbone_outputs = self.xlnet(**inputs, output_hidden_states=True) 96 | masked_outputs = backbone_outputs.last_hidden_state[is_masked] 97 | if self.hparams.pooling_layers > 1: 98 | for i in range(2, self.hparams.pooling_layers + 1): 99 | masked_outputs += backbone_outputs.hidden_states[-i][is_masked] 100 | masked_outputs /= self.hparams.pooling_layers 101 | logits = self.classifier(masked_outputs) 102 | if not output_hidden_states: 103 | return logits 104 | hidden_states = ((hs := backbone_outputs.hidden_states)[-1] + hs[-2]) / 2 105 | pooling_output = torch.einsum("bsh,bs,b->bh", hidden_states, first_mask.float(), 1 / first_mask.float().sum(dim=1)) 106 | return logits, pooling_output 107 | 108 | def training_step(self, batch, batch_idx): 109 | labels = batch.pop("labels") 110 | inputs = batch["inputs"] 111 | logits = self(inputs).squeeze(-1) 112 | if self.hparams.regression: 113 | labels = labels.float() - 2 114 | if self.hparams.r_drop: 115 | loss1 = self.criterion(logits, labels) 116 | logits_extra = self(inputs).squeeze(-1) 117 | loss2 = self.criterion(logits, labels) 118 | kl_loss1 = self.kld(torch.log_softmax(logits_extra, dim=-1), torch.softmax(logits, dim=-1)) 119 | kl_loss2 = self.kld(torch.log_softmax(logits, dim=-1), torch.softmax(logits_extra, dim=-1)) 120 | loss = (loss1 + loss2) / 2 + self.hparams.kl_weight * (kl_loss1 + kl_loss2) / 2 121 | else: 122 | loss = self.criterion(logits, labels) 123 | if self.hparams.regression: 124 | labels = labels.round().long() + 2 125 | self.log("train_acc", ((prediction := self.logits_to_prediction(logits)) == labels).sum() / labels.size(0), prog_bar=True, on_step=False, on_epoch=True) 126 | # self.log("train_loss", loss.item(), prog_bar=True, on_step=False, on_epoch=True) 127 | output = {"prediction": self.ttl(prediction), "labels": self.ttl(labels)} 128 | if self.automatic_optimization: 129 | return {"loss": loss} | output 130 | optimizer = self.optimizers(use_pl_optimizer=True) 131 | lr_scheduler = self.lr_schedulers() 132 | optimizer.zero_grad() 133 | loss /= 2 134 | self.manual_backward(loss) 135 | self.attacker.attack(epsilon=self.hparams.attack_epsilon) 136 | adv_logits = self(inputs).squeeze(-1) 137 | adv_loss = self.criterion(adv_logits, labels) / 2 138 | self.manual_backward(adv_loss) 139 | self.attacker.restore() 140 | self.clip_gradients(optimizer, gradient_clip_val=self.hparams.manual_gradient_clip_val, gradient_clip_algorithm=self.hparams.manual_gradient_clip_algorithm) 141 | optimizer.step() 142 | lr_scheduler.step() 143 | self.log("loss", (loss + adv_loss).item(), prog_bar=True, on_step=True, on_epoch=True) 144 | return output 145 | 146 | def training_epoch_end(self, outputs): 147 | predictions = np.concatenate([x["prediction"] for x in outputs]) 148 | labels = np.concatenate([x["labels"] for x in outputs]) 149 | print() 150 | accuracy = accuracy_score(labels, predictions) 151 | precision, recall, fscore, _ = precision_recall_fscore_support(labels, predictions, average="macro") 152 | print(f"Epoch {self.current_epoch} Train | Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {fscore:.4f}") 153 | 154 | def validation_step(self, batch, batch_idx): 155 | labels = batch.pop("labels") 156 | inputs = batch["inputs"] 157 | logits = self(inputs).squeeze(-1) 158 | if self.hparams.regression: 159 | labels = labels.float() - 2 160 | loss = self.criterion(logits, labels) 161 | if self.hparams.regression: 162 | labels = labels.round().long() + 2 163 | prediction = self.logits_to_prediction(logits) 164 | self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) 165 | output = {"prediction": self.ttl(prediction), "labels": self.ttl(labels)} 166 | if self.hparams.mode == "test" and self.hparams.optimize_f1: 167 | output = output | {"logits": self.ttl(logits)} 168 | return output 169 | 170 | def validation_epoch_end(self, outputs): 171 | predictions = np.concatenate([x["prediction"] for x in outputs]) 172 | labels = np.concatenate([x["labels"] for x in outputs]) 173 | # print(predictions, labels) 174 | # print(predictions.shape, labels.shape) 175 | print() 176 | accuracy = accuracy_score(labels, predictions) 177 | precision, recall, fscore, _ = precision_recall_fscore_support(labels, predictions, average="macro") 178 | self.log("val_f1", fscore, prog_bar=False, on_step=False, on_epoch=True) 179 | self.log("val_acc", accuracy, prog_bar=False, on_step=False, on_epoch=True) 180 | print(f"Epoch {self.current_epoch} Validate | Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {fscore:.4f}") 181 | 182 | if self.hparams.mode == "test" and self.hparams.optimize_f1: 183 | # print(outputs[:3]) 184 | logits = np.concatenate([x["logits"] for x in outputs]) 185 | weighted_prediction = lambda logits, weight: np.argmax(np.einsum("bn,n->bn", logits, weight), axis=1) 186 | f1_loss_func = lambda weight: -f1_score(labels, weighted_prediction(logits, weight), average="macro") 187 | class_weights = scipy.optimize.minimize(f1_loss_func, np.ones(logits.shape[1]), method="nelder-mead", options={"maxiter": 5 * 1000, "disp": True}).x 188 | predictions = weighted_prediction(logits, class_weights) 189 | accuracy = accuracy_score(labels, predictions) 190 | precision, recall, fscore, _ = precision_recall_fscore_support(labels, predictions, average="macro") 191 | print(f"Epoch {self.current_epoch} Validate (Optimized) | Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {fscore:.4f}") 192 | self.class_weights = torch.tensor(class_weights, dtype=torch.float32, device=self.device) 193 | print(class_weights) 194 | 195 | def test_step(self, batch, batch_idx): 196 | inputs = batch["inputs"] 197 | logits = self(inputs).squeeze(-1) 198 | if self.hparams.optimize_f1: 199 | logits = torch.einsum("bn,n->bn", logits, self.class_weights) 200 | prediction = self.logits_to_prediction(logits) 201 | return {"prediction": self.ttl(prediction)} | ({"logits": self.ttl(logits)} if self.hparams.is_extra_output else {}) 202 | 203 | def test_epoch_end(self, outputs): 204 | predictions = np.concatenate([x["prediction"] for x in outputs]).tolist() 205 | pickle.dump(predictions, open(os.path.join(self.hparams.output_path, "prediction.pkl"), "wb")) 206 | if self.hparams.is_extra_output: 207 | # pooling_outputs = np.concatenate([x["pooling_output"] for x in outputs], axis=0).tolist() 208 | logits = np.concatenate([x["logits"] for x in outputs], axis=0).tolist() 209 | # pickle.dump({"outputs": pooling_outputs, "logits": logits}, open(os.path.join(self.hparams.output_path, "extra_output.pkl"), "wb")) 210 | pickle.dump({"logits": logits}, open(os.path.join(self.hparams.output_path, "extra_output.pkl"), "wb")) 211 | 212 | def configure_optimizers(self): 213 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, eps=self.hparams.eps) 214 | scheduler = transformers.optimization.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.num_warmup_steps, num_training_steps=self.hparams.num_training_steps) 215 | return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler,'interval': 'step'}} 216 | 217 | class SWASupportModel(nn.Module): 218 | def __init__(self, backbone, classifier): 219 | super(SWASupportModel, self).__init__() 220 | self.backbone = backbone 221 | self.classifier = classifier 222 | 223 | def forward(self, inputs): 224 | is_masked = inputs.pop('is_masked').bool() 225 | first_mask = inputs.pop("first_mask", None) 226 | backbone_outputs = self.backbone(**inputs, output_hidden_states=True) 227 | masked_outputs = backbone_outputs.last_hidden_state[is_masked] 228 | logits = self.classifier(masked_outputs) 229 | return logits 230 | 231 | class SWASentimentClassifier(SentimentClassifier): 232 | def __init__(self, **kwargs): 233 | super(SWASentimentClassifier, self).__init__(**kwargs) 234 | self.swa_model = None 235 | if self.hparams.mode == "test": 236 | self.check_if_swa_ready() 237 | 238 | def check_if_swa_ready(self): 239 | if self.swa_model is None: 240 | self.model = SWASupportModel(self.xlnet, self.classifier) 241 | self.swa_model = torch.optim.swa_utils.AveragedModel(self.model, avg_fn=self.average_function) 242 | 243 | def average_function(self, ax: torch.Tensor, x: torch.Tensor, num: int) -> torch.Tensor: 244 | return ax + (x - ax) / (num + 1) 245 | 246 | def on_train_epoch_start(self) -> None: 247 | self.check_if_swa_ready() 248 | 249 | def validation_step(self, batch, batch_idx): 250 | labels = batch.pop("labels") 251 | inputs = batch["inputs"] 252 | logits = self.swa_model(inputs) 253 | loss = self.criterion(logits, labels) 254 | prediction = torch.argmax(logits, dim=1) 255 | self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) 256 | output = {"prediction": self.ttl(prediction), "labels": self.ttl(labels)} 257 | if self.hparams.mode == "test" and self.hparams.optimize_f1: 258 | output = output | {"logits": self.ttl(logits)} 259 | return output 260 | 261 | def test_step(self, batch, batch_idx): 262 | inputs = batch["inputs"] 263 | logits = self.swa_model(inputs) 264 | if self.hparams.optimize_f1: 265 | logits = torch.einsum("bn,n->bn", logits, self.class_weights) 266 | prediction = torch.argmax(logits, dim=1) 267 | return {"prediction": self.ttl(prediction)} | ({"logits": self.ttl(logits)} if self.hparams.is_extra_output else {}) 268 | 269 | def on_validation_epoch_start(self) -> None: 270 | self.check_if_swa_ready() 271 | self.swa_model.update_parameters(self.model) 272 | -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm, trange 8 | 9 | def postprocess(prediction_file, raw_file, output_file): 10 | prediction = pickle.load(open(prediction_file, 'rb')) 11 | test_raw = list(map(json.loads, open(raw_file, "r+", encoding="utf-8").readlines())) 12 | f = open(output_file, 'w+', encoding="utf-8") 13 | f.write("id\tresult\n") 14 | step = 0 15 | for item in test_raw: 16 | output = {} 17 | for entity in item["entity"]: 18 | output[entity] = prediction[step] - 2 19 | step += 1 20 | f.write(f"{item['id']}\t{output}\n") 21 | f.close() 22 | assert len(prediction) == step 23 | 24 | def get_text_feature(raw_file, feature_file, output_path): 25 | lines = list(map(json.loads, open(raw_file, "r+", encoding="utf-8").readlines())) 26 | features = pickle.load(open(feature_file, 'rb')) 27 | logits = np.array(features["logits"]) 28 | print("Files Loaded") 29 | # text_features = np.zeros((len(lines), 768)) 30 | item_id_list = [] 31 | aggeration_feature_list = [f"sentiment_{agg}_{i}" for agg in ["mean", "max", "min", "std"] for i in range(5)] 32 | sentiment_features = np.zeros((len(lines), 5 * 4)) 33 | step = 0 34 | for i in trange(len(lines)): 35 | line: dict = lines[i] 36 | item_id_list.append(line["id"]) 37 | if not line["entity"] or not line["content"]: 38 | continue 39 | logits_part = logits[step:step+len(line["entity"])] 40 | sentiment_features[i, 0:5] = logits_part.mean(axis=0) 41 | sentiment_features[i, 5:10] = logits_part.max(axis=0) 42 | sentiment_features[i, 10:15] = logits_part.min(axis=0) 43 | sentiment_features[i, 15:20] = logits_part.std(axis=0) 44 | step += len(line["entity"]) 45 | # assert step_1 == len(features["outputs"]) 46 | assert step == len(features["logits"]) 47 | output_table = pd.DataFrame() 48 | output_table["itemId"] = item_id_list 49 | output_table[aggeration_feature_list] = sentiment_features 50 | output_table.to_csv(os.path.join(output_path, "sentiment.csv"), index=False) 51 | output_table.to_feather(os.path.join(output_path, "sentiment.feather")) 52 | 53 | if __name__ == "__main__": 54 | # output_path = "output" 55 | # postprocess(prediction_file=(output_file := os.path.join(output_path, "prediction.pkl")), raw_file="resources/nlp_data/test.txt", output_file=os.path.join(output_path, "section1_epoch2.txt")) 56 | output_path = "output_rec" 57 | get_text_feature(raw_file="resources/rec_data/recommend_content_entity_0317.txt", feature_file=os.path.join(output_path, "extra_output.pkl"), output_path=output_path) --------------------------------------------------------------------------------