├── __init__.py ├── review ├── __init__.py ├── train │ ├── __init__.py │ ├── README.md │ ├── topic_summarization.ipynb │ └── test_new_algo.ipynb ├── config.py ├── README.md ├── text.py ├── model.py └── dataset │ ├── prepare_dataset.ipynb │ └── analyze_archive.ipynb ├── .gitignore ├── resources ├── mail.svg └── github.svg ├── AUTHORS.md ├── README.md ├── environment.yml └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /review/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /review/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.iml 3 | 4 | */.ipynb_checkpoints 5 | __pycache__ 6 | .pyc 7 | -------------------------------------------------------------------------------- /review/config.py: -------------------------------------------------------------------------------- 1 | # Paths config 2 | base_path = "~/review" 3 | dataset_path = f"{base_path}/dataset" 4 | weights_path = f"{base_path}/weights" 5 | log_path = f"{base_path}/logs" 6 | 7 | # Model lookup is in /model and ~/.pubtrends/model folders 8 | model_name = 'learn_simple_berta.pth' 9 | -------------------------------------------------------------------------------- /resources/mail.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /resources/github.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | Authors 2 | ======= 3 | 4 | * [![icon][mail]](mailto:nikiannanik@gmail.com) 5 | [![icon][github]](https://github.com/Javanochka) 6 |   Anna Nikiforovskaya 7 | * [![icon][mail]](mailto:alexey@shpilman.com) 8 | [![icon][github]](https://github.com/ashpilman) 9 |   Aleksei Shpilman 10 | * [![icon][mail]](mailto:oleg.shpynov@gmail.com) 11 | [![icon][github]](https://github.com/olegs) 12 |   Oleg Shpynov 13 | 14 | 15 | Icons by Feather 16 | 17 | [mail]: resources/mail.svg 18 | [github]: resources/github.svg 19 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![JetBrains Research](https://jb.gg/badges/research.svg)](https://confluence.jetbrains.com/display/ALL/JetBrains+on+GitHub) 2 | [![DOI](https://zenodo.org/badge/299675473.svg)](https://doi.org/10.5281/zenodo.15131486) 3 | 4 | # pubtrends-review 5 | 6 | This is a source code for the paper: ["Automatic generation of reviews of scientific paper"](https://arxiv.org/abs/2010.04147).\ 7 | Was accepted as a full paper at [International Conference on Machine Learning and Applications (ICMLA) 2019](https://ieeexplore.ieee.org/document/9356351). 8 | 9 | Refer to [review/README.md](review/README.md) for instruction on dataset creation and model training. 10 | 11 | # Authors 12 | See [AUTHORS.md](AUTHORS.md) for a list of authors and contributors. 13 | 14 | -------------------------------------------------------------------------------- /review/README.md: -------------------------------------------------------------------------------- 1 | * Download [PubMedCentral Author Manuscript Collection](https://ftp.ncbi.nlm.nih.gov/pub/pmc/manuscript/xml/) 2 | into `~/pmc_dataset` folder.\ 3 | Expected folder content: 4 | ``` 5 | author_manuscript_xml.PMC001xxxxxx.baseline.2022-06-16.filelist.csv 6 | author_manuscript_xml.PMC001xxxxxx.baseline.2022-06-16.filelist.txt 7 | author_manuscript_xml.PMC001xxxxxx.baseline.2022-06-16.tar.gz 8 | ... 9 | ``` 10 | * Extract all downloaded `tar.gz` files.\ 11 | Expected extracted folders: 12 | ``` 13 | PMC001xxxxxx 14 | PMC002xxxxxx 15 | PMC003xxxxxx 16 | PMC004xxxxxx 17 | PMC005xxxxxx 18 | PMC006xxxxxx 19 | PMC007xxxxxx 20 | PMC008xxxxxx 21 | PMC009xxxxxx 22 | ``` 23 | * Launch `dataset/analyze_archive.ipynb` to analyze downloaded archive. 24 | * Launch `dataset/prepare_dataset.ipynb` to create tables required for model training. 25 | * Refer to `train/README.md` for instruction on model training. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # Keep this file as small as possible! Only explicit imports here! 2 | name: pubtrends 3 | channels: 4 | - conda-forge 5 | - bokeh 6 | - pytorch 7 | dependencies: 8 | - bokeh=3.6.3 9 | - cachetools=5.5.2 10 | - celery=5.4.0 11 | - ipywidgets=8.1.5 12 | - holoviews=1.20.1 13 | - huggingface_hub=0.29.1 14 | - flask=3.1.0 15 | - matplotlib=3.10.1 16 | - nltk=3.9.1 17 | - pandas=2.2.3 18 | - parameterized=0.9.0 19 | - python=3.10.16 20 | - pytest=8.3.5 21 | - psycopg2=2.9.9 22 | - wordcloud=1.9.4 23 | - networkx=3.4.2 24 | - python-louvain=0.16 25 | - scikit-learn=1.6.1 26 | - pytorch=2.5.1 27 | - transformers=4.49.0 28 | - tensorboardx=2.6.2.2 29 | - psutil=7.0.0 30 | - unidecode=1.3.8 31 | - pip 32 | - pip: 33 | - email-validator==2.2.0 34 | - flask-admin==1.6.0 35 | - flask-security==5.6.0 36 | - flask-sqlalchemy==3.1.1 37 | - gunicorn==23.0.0 38 | - lazy==1.6 39 | - rouge==1.0.1 40 | - redis==5.2.1 41 | - vine==5.1.0 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 JetBrains-Research 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /review/train/README.md: -------------------------------------------------------------------------------- 1 | ## Instructions 2 | 3 | Here are the instructions on how to create the dataset and then how to run different models training on this dataset. 4 | Now, the easiest way to do it is to use python notebook `test_new_algo.ipynb`. There are 18 steps there. 5 | 6 | Soon the instructions for running everything from command line will be added. 7 | 8 | ### Create a dataset 9 | 10 | You can create a dataset with features or without them. 11 | 12 | To generate a dataset look at the steps 3-8 (+10 in case you want to divide the dataset on train-test-val on the spot). 13 | Some of them are not needed in case you create a dataset without features. The model with features is still being 14 | developed, you can try it, but now it works worse. 15 | 16 | ### Train the model 17 | 18 | Steps 1-2, 9, 10 (to get the dataset), 11-14 are needed. 19 | 20 | You can save the trained model at any time, using function `model.save("name_of_the_file")` 21 | Also, try different parameters for the model (you can either freeze some layers, or not, you can try a model with 22 | features). 23 | 24 | ### See how the model works 25 | 26 | This is what the step 15 is for. You can just try the model on different examples. 27 | 28 | ### Evaluation 29 | 30 | These are the last steps. 16 and 17 are to look at MSE score and to understand, how good it is compared to the dataset, 31 | while step 18 is to perform a test of summarizing several scientific papers into a review one. 32 | 33 | For topic summarization: Prepare test dataset by running `topic_summarization.ipynb` 34 | -------------------------------------------------------------------------------- /review/text.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | 3 | 4 | def split_text(text): 5 | return nltk.tokenize.sent_tokenize(text) 6 | 7 | 8 | def get_token_id(tokenizer, tkn): 9 | return tokenizer.convert_tokens_to_ids([tkn])[0] 10 | 11 | 12 | def preprocess_text(text, max_len, tokenizer): 13 | """ 14 | Preprocess text for BERT / ROBERTA model. 15 | NOTE: not all the text can be processed because of max_len. 16 | IMPORTANT: preprocessed text may contain more than one [CLS] token!!! and more than two sentences!!! 17 | :param text: list(list(str)) 18 | :param max_len: maximum length of preprocessing 19 | :param tokenizer: BERT or ROBERTA tokenizer 20 | :return: 21 | ids | tokenized ids of length max_len, 0 if padding 22 | attention_mask | list(str) 1 if real token, not padding 23 | token_type_ids | 0-1 for different sentences 24 | n_sents | number of actual sentences encoded 25 | """ 26 | sents = [ 27 | [tokenizer.BOS] + tokenizer.tokenize(sent) + [tokenizer.EOS] for sent in text 28 | ] 29 | 30 | ids, token_type_ids, segment_signature = [], [], 0 31 | n_sents = 0 32 | for i, s in enumerate(sents): 33 | if len(ids) + len(s) <= max_len: 34 | n_sents += 1 35 | ids.extend(tokenizer.convert_tokens_to_ids(s)) 36 | token_type_ids.extend([segment_signature] * len(s)) 37 | segment_signature = (segment_signature + 1) % 2 38 | else: 39 | break 40 | attention_mask = [1] * len(ids) 41 | 42 | pad_len = max(0, max_len - len(ids)) 43 | ids += [get_token_id(tokenizer, tokenizer.PAD)] * pad_len 44 | attention_mask += [0] * pad_len 45 | token_type_ids += [segment_signature] * pad_len 46 | assert len(ids) == len(attention_mask) 47 | assert len(ids) == len(token_type_ids) 48 | return ids, attention_mask, token_type_ids, n_sents 49 | 50 | 51 | # Overlap helps to keep context and connect different parts of abstract 52 | OVERLAP = 5 53 | 54 | 55 | def text_to_data(text, max_len, tokenizer): 56 | """ 57 | This is the main entry point, which will be called from PubTrends review feature. 58 | """ 59 | text = split_text(text) 60 | total_sents = 0 61 | data = [] 62 | while total_sents < len(text): 63 | offset = max(0, total_sents - OVERLAP) 64 | # Preprocessing BERT cannot encode all the text, 65 | # only limited number of sentences per single model run is supported. 66 | ids, attention_mask, token_type_ids, n_sents = \ 67 | preprocess_text(text[offset:], max_len, tokenizer) 68 | if offset + n_sents <= total_sents: 69 | total_sents += 1 70 | continue 71 | data.append((ids, attention_mask, token_type_ids, offset, text[offset: offset + n_sents])) 72 | total_sents = offset + n_sents 73 | return data 74 | -------------------------------------------------------------------------------- /review/train/topic_summarization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Notebook for testing model" 7 | ], 8 | "metadata": { 9 | "collapsed": false, 10 | "pycharm": { 11 | "name": "#%% md\n" 12 | } 13 | } 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "metadata": { 19 | "pycharm": { 20 | "name": "#%%\n" 21 | } 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "from review.train.preprocess import parse_sents, standardize\n", 27 | "from nltk.tokenize import sent_tokenize" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "pycharm": { 35 | "name": "#%%\n" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "def prepare_text(text):\n", 41 | " text = parse_sents(sent_tokenize(text))\n", 42 | " text = standardize(text)\n", 43 | " text = ' '.join(text)\n", 44 | " return text" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "pycharm": { 51 | "name": "#%% md\n" 52 | } 53 | }, 54 | "source": [ 55 | "Prepare topic abstracts\n" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "pycharm": { 63 | "name": "#%%\n" 64 | } 65 | }, 66 | "outputs": [], 67 | "source": [ 68 | "topics_df = pd.read_csv(\"topic/topic_abstracts.csv\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "pycharm": { 76 | "name": "#%%\n" 77 | } 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "topics_df = topics_df.groupby(by=['topic_id']).agg({'text': lambda x: ''.join(x)}).reset_index()\n", 82 | "topics_df = topics_df.rename(columns={'topic_id':'id', 'text': 'paper_top50'})" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "pycharm": { 90 | "name": "#%%\n" 91 | } 92 | }, 93 | "outputs": [], 94 | "source": [ 95 | "topics_df" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": { 102 | "pycharm": { 103 | "name": "#%%\n" 104 | } 105 | }, 106 | "outputs": [], 107 | "source": [ 108 | "def create_dummy_dataframe(df):\n", 109 | " df['paper_top50'] = df['paper_top50'].apply(lambda text: prepare_text(text))\n", 110 | " df['abstract'] = 'dummy abstract'\n", 111 | " df['gold_ids_top6'] = str([0])\n", 112 | " return df" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "pycharm": { 120 | "name": "#%%\n" 121 | } 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "test_df = create_dummy_dataframe(topics_df)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "pycharm": { 133 | "name": "#%%\n" 134 | } 135 | }, 136 | "outputs": [], 137 | "source": [ 138 | "test_df" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "pycharm": { 146 | "name": "#%%\n" 147 | } 148 | }, 149 | "outputs": [], 150 | "source": [ 151 | "test_df.to_csv(\"data/pubmedtop50_test_topic.csv\")" 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 3 (ipykernel)", 158 | "language": "python", 159 | "name": "python3" 160 | }, 161 | "language_info": { 162 | "codemirror_mode": { 163 | "name": "ipython", 164 | "version": 3 165 | }, 166 | "file_extension": ".py", 167 | "mimetype": "text/x-python", 168 | "name": "python", 169 | "nbconvert_exporter": "python", 170 | "pygments_lexer": "ipython3", 171 | "version": "3.10.4" 172 | } 173 | }, 174 | "nbformat": 4, 175 | "nbformat_minor": 2 176 | } -------------------------------------------------------------------------------- /review/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers import BertModel, BertTokenizer 7 | from transformers import RobertaModel, RobertaTokenizer 8 | 9 | import review.config as cfg 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | EMBEDDINGS_ADDITIONAL = 20 14 | 15 | FEATURES_NUMBER = 10 16 | FEATURES_NN_INTERMEDIATE = 100 17 | FEATURES_NN_OUT = 50 18 | FEATURES_DROPOUT = 0.1 19 | 20 | DECODER_DROPOUT = 0.1 21 | 22 | 23 | def setup_cuda_device(model): 24 | logging.info('Setup single-device settings...') 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | model = model.to(device) 27 | return model, device 28 | 29 | 30 | def load_model(model_type, froze_strategy, article_len, features=False): 31 | logger.info(f'Loading model {model_type} {froze_strategy} {article_len} {features}') 32 | model = Summarizer(model_type, article_len, features) 33 | model.expand_posembs_ifneed() 34 | model.froze_backbone(froze_strategy) 35 | model.unfroze_head() 36 | logger.info(f'Parameters {sum(p.numel() for p in model.parameters() if p.requires_grad)}') 37 | return model 38 | 39 | 40 | def get_token_id(tokenizer, tkn): 41 | return tokenizer.convert_tokens_to_ids([tkn])[0] 42 | 43 | 44 | class Summarizer(nn.Module): 45 | """ 46 | This is the main summarization model. 47 | See https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_tf_bert.py 48 | It operates the same input format as original BERT used underneath. 49 | See forward, evaluate params description. 50 | """ 51 | enc_output: torch.Tensor 52 | dec_ids_mask: torch.Tensor 53 | encdec_ids_mask: torch.Tensor 54 | 55 | def __init__(self, model_type, article_len, additional_features, num_features=FEATURES_NUMBER): 56 | super(Summarizer, self).__init__() 57 | 58 | print(f'Initialize backbone and tokenizer for {model_type}') 59 | self.article_len = article_len 60 | if model_type == 'bert': 61 | self.backbone = self.initialize_bert() 62 | self.tokenizer = self.initialize_bert_tokenizer() 63 | elif model_type == 'roberta': 64 | self.backbone = self.initialize_roberta() 65 | self.tokenizer = self.initialize_roberta_tokenizer() 66 | else: 67 | raise Exception(f"Wrong model_type argument: {model_type}") 68 | self.backbone.resize_token_embeddings(EMBEDDINGS_ADDITIONAL + self.tokenizer.vocab_size) 69 | 70 | if additional_features: 71 | print('Adding additional features double fully connected nn') 72 | self.features = nn.Sequential( 73 | nn.Linear(num_features, FEATURES_NN_INTERMEDIATE), 74 | nn.LeakyReLU(), 75 | nn.Dropout(FEATURES_DROPOUT), 76 | nn.Linear(FEATURES_NN_INTERMEDIATE, FEATURES_NN_OUT) 77 | ) 78 | else: 79 | self.features = None 80 | 81 | print('Initialize backbone embeddings pulling') 82 | 83 | def backbone_forward(input_ids, attention_mask, token_type_ids, position_ids): 84 | return self.backbone( 85 | input_ids=input_ids, 86 | attention_mask=attention_mask, 87 | token_type_ids=token_type_ids, 88 | position_ids=position_ids 89 | ) 90 | 91 | self.encoder = lambda *args: backbone_forward(*args)[0] 92 | 93 | print('Initialize decoder') 94 | if additional_features: 95 | self.decoder = Classifier(768 + FEATURES_NN_OUT) # Default BERT output with additional features 96 | else: 97 | self.decoder = Classifier(768) # Default BERT output 98 | 99 | def expand_positional_embs_if_need(self): 100 | print('Expand positional embeddings if need') 101 | print('Positional embeddings', self.backbone.config.max_position_embeddings, self.article_len) 102 | if self.article_len > self.backbone.config.max_position_embeddings: 103 | old_maxlen = self.backbone.config.max_position_embeddings 104 | old_w = self.backbone.embeddings.position_embeddings.weight 105 | logging.info(f'Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}') 106 | self.backbone.embeddings.position_embeddings = nn.Embedding( 107 | self.article_len, self.backbone.config.hidden_size 108 | ) 109 | self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w) 110 | self.backbone.config.max_position_embeddings = self.article_len 111 | print('New positional embeddings', self.backbone.config.max_position_embeddings) 112 | 113 | @staticmethod 114 | def initialize_bert(): 115 | return BertModel.from_pretrained( 116 | "bert-base-uncased", output_hidden_states=False 117 | ) 118 | 119 | @staticmethod 120 | def initialize_bert_tokenizer(): 121 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) 122 | tokenizer.BOS = "[CLS]" 123 | tokenizer.EOS = "[SEP]" 124 | tokenizer.PAD = "[PAD]" 125 | return tokenizer 126 | 127 | @staticmethod 128 | def initialize_roberta(): 129 | backbone = RobertaModel.from_pretrained( 130 | 'roberta-base', output_hidden_states=False 131 | ) 132 | print('initialize token type emb, by default roberta doesnt have it') 133 | backbone.config.type_vocab_size = 2 134 | backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size) 135 | backbone.embeddings.token_type_embeddings.weight.data.normal_( 136 | mean=0.0, std=backbone.config.initializer_range 137 | ) 138 | return backbone 139 | 140 | @staticmethod 141 | def initialize_roberta_tokenizer(): 142 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True) 143 | tokenizer.BOS = "" 144 | tokenizer.EOS = "" 145 | tokenizer.PAD = "" 146 | return tokenizer 147 | 148 | def save(self, save_filename): 149 | """ Save model in filename 150 | 151 | :param save_filename: str 152 | """ 153 | state = dict( 154 | encoder_dict=self.backbone.state_dict(), 155 | decoder_dict=self.decoder.state_dict() 156 | ) 157 | if self.features: 158 | state['features_dict'] = self.features.state_dict() 159 | models_folder = os.path.expanduser(cfg.weights_path) 160 | if not os.path.exists(models_folder): 161 | os.makedirs(models_folder) 162 | torch.save(state, f"{models_folder}/{save_filename}.pth") 163 | 164 | def load(self, load_filename): 165 | path = f"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth" 166 | state = torch.load(path, map_location=lambda storage, location: storage) 167 | self.backbone.load_state_dict(state['encoder_dict']) 168 | self.decoder.load_state_dict(state['decoder_dict']) 169 | if self.features: 170 | self.features.load_state_dict(state['features_dict']) 171 | 172 | def froze_backbone(self, froze_strategy): 173 | if froze_strategy == 'froze_all': 174 | for param in self.backbone.parameters(): 175 | param.requires_grad_(False) 176 | 177 | elif froze_strategy == 'unfroze_last': 178 | for name, param in self.backbone.named_parameters(): 179 | param.requires_grad_( 180 | 'encoder.layer.11' in name or 181 | 'encoder.layer.10' in name or 182 | 'encoder.layer.9' in name 183 | ) 184 | 185 | elif froze_strategy == 'unfroze_all': 186 | for param in self.backbone.parameters(): 187 | param.requires_grad_(True) 188 | 189 | else: 190 | raise Exception(f'Unsupported froze strategy {froze_strategy}') 191 | 192 | def unfroze_head(self): 193 | for param in self.decoder.parameters(): 194 | param.requires_grad_(True) 195 | 196 | def forward(self, input_ids, attention_mask, token_type_ids, input_features=None): 197 | """ 198 | :param input_ids: torch.Size([batch_size, article_len]) 199 | Indices of input sequence tokens in the vocabulary. 200 | :param attention_mask: torch.Size([batch_size, article_len]) 201 | Mask to avoid performing attention on padding token indices. 202 | Mask values selected in `[0, 1]`: 203 | - 1 for tokens that are **not masked**, 204 | - 0 for tokens that are **masked**. 205 | :param token_type_ids: torch.Size([batch_size, article_len]) 206 | Segment token indices to indicate first and second portions of the inputs. 207 | Indices are selected in `[0, 1]`: 208 | - 0 corresponds to a *sentence A* token, 209 | - 1 corresponds to a *sentence B* token. 210 | :return: scores | torch.Size([batch_size, summary_len]) 211 | """ 212 | 213 | # The output of [CLS] is inferred by all other words in this sentence. 214 | # This makes [CLS] a good representation for sentence-level classification. 215 | cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS)) 216 | print(f'cls_mask {cls_mask.shape}') 217 | 218 | # Indices of positions of each input sequence tokens in the position embeddings. 219 | # position ids | torch.Size([batch_size, article_len]) 220 | pos_ids = torch.arange( 221 | 0, 222 | self.article_len, 223 | dtype=torch.long, 224 | device=input_ids.device 225 | ).unsqueeze(0).repeat(len(input_ids), 1) 226 | print(f'pos_ids {pos_ids.shape}') 227 | # extract bert embeddings | torch.Size([batch_size, article_len, d_bert]) 228 | # for each word in the input, the BERT base internally creates a 768-dimensional output, 229 | # but for tasks like classification, we do not actually require the output for all the embeddings. 230 | # So by default, BERT considers only the output corresponding to the first token [CLS] 231 | # and drops the output vectors corresponding to all the other tokens. 232 | enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids) 233 | 234 | if self.features: 235 | out_features = self.features(input_features) 236 | scores = self.decoder(torch.cat([enc_output[cls_mask], out_features], dim=-1)) 237 | else: 238 | print('enc_output', enc_output.shape) 239 | print('enc_output[cls_mask]', enc_output[cls_mask].shape) 240 | scores = self.decoder(enc_output[cls_mask]) 241 | print('scores', scores.shape) 242 | 243 | return scores 244 | 245 | def evaluate(self, input_ids, attention_mask, token_type_ids, input_features=None): 246 | """See forward for parameters and output description""" 247 | 248 | # The output of [CLS] is inferred by all other words in this sentence. 249 | # This makes [CLS] a good representation for sentence-level classification. 250 | cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS)) 251 | 252 | # position ids | torch.Size([batch_size, article_len]) 253 | pos_ids = torch.arange( 254 | 0, 255 | self.article_len, 256 | dtype=torch.long, 257 | device=input_ids.device 258 | ).unsqueeze(0).repeat(len(input_ids), 1) 259 | 260 | # extract bert embeddings | torch.Size([batch_size, article_len, d_bert]) 261 | enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids) 262 | 263 | scores = [] 264 | for eo, cm in zip(enc_output, cls_mask): 265 | if self.features: 266 | out_features = self.features(input_features) 267 | score = self.decoder.evaluate(torch.cat([eo[cm], out_features], dim=-1)) 268 | else: 269 | score = self.decoder.evaluate(eo[cm]) 270 | scores.append(score) 271 | return scores 272 | 273 | 274 | class Classifier(nn.Module): 275 | def __init__(self, hidden_size): 276 | super(Classifier, self).__init__() 277 | self.dropout = nn.Dropout(DECODER_DROPOUT) 278 | self.linear = nn.Linear(hidden_size, 1) 279 | self.sigmoid = nn.Sigmoid() 280 | 281 | def forward(self, x): 282 | return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1)) 283 | 284 | def evaluate(self, x): 285 | return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1)) 286 | -------------------------------------------------------------------------------- /review/dataset/prepare_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "\n", 12 | "# Make dataset\n", 13 | "\n", 14 | "This notebook contains the code to analyse content of the PubMedCentral Author Manuscript Collection. \\\n", 15 | "See: https://www.ncbi.nlm.nih.gov/pmc/about/mscollection/\n", 16 | "\n", 17 | "Files should be downloaded from https://ftp.ncbi.nlm.nih.gov/pub/pmc/manuscript/xml/ into `~/pmc_dataset` folder.\n", 18 | "\n", 19 | "Resulting tables will be created under `~/review/dataset` folder (see `config.py`).\n", 20 | "\n", 21 | "Please ensure that env variable `PYTHONPATH` includes project folder to be able to import `review.config` module.\n", 22 | "\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "pycharm": { 30 | "name": "#%%\n" 31 | } 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "% matplotlib inline\n", 36 | "% config InlineBackend.figure_format='retina'" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "outputs": [], 43 | "source": [ 44 | "import logging\n", 45 | "import os\n", 46 | "import sys\n", 47 | "\n", 48 | "import pandas as pd\n", 49 | "from IPython.display import display\n", 50 | "\n", 51 | "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')\n", 52 | "\n", 53 | "import review.config as cfg\n", 54 | "\n", 55 | "pmc_dataset_root = os.path.expanduser('~/pmc_dataset')\n", 56 | "dataset_root = os.path.expanduser(cfg.dataset_path)" 57 | ], 58 | "metadata": { 59 | "collapsed": false, 60 | "pycharm": { 61 | "name": "#%%\n" 62 | } 63 | } 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": { 68 | "pycharm": { 69 | "name": "#%% md\n" 70 | } 71 | }, 72 | "source": [ 73 | "### Collecting articles" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": { 80 | "pycharm": { 81 | "name": "#%%\n" 82 | } 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "from glob import glob\n", 87 | "from lxml import etree\n", 88 | "from tqdm.auto import tqdm\n", 89 | "\n", 90 | "dict_articles = {}\n", 91 | "\n", 92 | "for filelist in tqdm(glob(os.path.join(pmc_dataset_root, '*filelist.txt'))):\n", 93 | " with open(filelist, 'r') as f:\n", 94 | " for line in f:\n", 95 | " if 'LastUpdated' in line:\n", 96 | " continue\n", 97 | " filename, pmcid, pmid, mid, date, time = line.split()\n", 98 | " dict_articles[pmid] = filename" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "pycharm": { 106 | "name": "#%%\n" 107 | } 108 | }, 109 | "outputs": [], 110 | "source": [ 111 | "print(list(dict_articles.items())[:10])" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "pycharm": { 119 | "name": "#%%\n" 120 | } 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "import nltk\n", 125 | "\n", 126 | "\n", 127 | "def split_text(text):\n", 128 | " sents = nltk.tokenize.sent_tokenize(text)\n", 129 | " res_sents = []\n", 130 | " i = 0\n", 131 | " while i < len(sents):\n", 132 | " check = False\n", 133 | " if i + 1 < len(sents):\n", 134 | " check = sents[i + 1].strip()[0].islower() or sents[i + 1].strip()[0].isdigit()\n", 135 | " made = sents[i]\n", 136 | " while i + 1 < len(sents) and (made.endswith('Fig.') or check):\n", 137 | " made += \" \" + \" \".join(sents[i + 1].strip().split())\n", 138 | " i += 1\n", 139 | " if i + 1 < len(sents):\n", 140 | " check = sents[i + 1].strip()[0].islower() or sents[i + 1].strip()[0].isdigit()\n", 141 | " res_sents.append(\" \".join(made.strip().split()))\n", 142 | " i += 1\n", 143 | " return res_sents\n", 144 | "\n", 145 | "\n", 146 | "def get_sentences(node):\n", 147 | " def helper(node, is_disc):\n", 148 | " if node.tag == 'xref':\n", 149 | " ntail = ''\n", 150 | " if node.tail is not None:\n", 151 | " ntail = node.tail\n", 152 | " res = f' xref_{node.get(\"ref-type\")}_{node.get(\"rid\")} ' + ntail\n", 153 | " if res is None:\n", 154 | " return '', ''\n", 155 | " if is_disc:\n", 156 | " return '', res\n", 157 | " return res, ''\n", 158 | " if node.tag == 'title':\n", 159 | " if node.tail is None:\n", 160 | " return '', ''\n", 161 | " if is_disc:\n", 162 | " return '', node.tail\n", 163 | " return node.tail, ''\n", 164 | " if not is_disc and node.find('title') is not None:\n", 165 | " title = \"\".join(node.find('title').itertext()).lower()\n", 166 | " if 'discussion' in title:\n", 167 | " is_disc = True\n", 168 | " st_text = ''\n", 169 | " if node.text is not None:\n", 170 | " st_text = node.text\n", 171 | " if is_disc:\n", 172 | " n_disc = st_text\n", 173 | " n_gen = \"\"\n", 174 | " else:\n", 175 | " n_gen = st_text\n", 176 | " n_disc = \"\"\n", 177 | " for ch in node.getchildren():\n", 178 | " gen, disc = helper(ch, is_disc)\n", 179 | " n_gen += gen\n", 180 | " n_disc += disc\n", 181 | " tail = \"\"\n", 182 | " if node.tail is not None:\n", 183 | " tail = node.tail\n", 184 | " if is_disc:\n", 185 | " n_disc += tail\n", 186 | " else:\n", 187 | " n_gen += tail\n", 188 | " return n_gen, n_disc\n", 189 | "\n", 190 | " gen_res, disc_res = helper(node.find('body'), False)\n", 191 | " gen_res = split_text(gen_res)\n", 192 | " disc_res = split_text(disc_res)\n", 193 | "\n", 194 | " abstract = \"\"\n", 195 | "\n", 196 | " try:\n", 197 | " abstract = \"\".join(node.find('front').find('article-meta').find('abstract').itertext())\n", 198 | " abstract = \" \".join(abstract.strip().split())\n", 199 | " except Exception:\n", 200 | " pass\n", 201 | " return gen_res, disc_res, abstract" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": { 208 | "pycharm": { 209 | "name": "#%%\n" 210 | } 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "tree = etree.parse(f\"{pmc_dataset_root}/PMC004xxxxxx/PMC4239434.xml\")\n", 215 | "sents = get_sentences(tree.getroot())\n", 216 | "print(sents)\n", 217 | "print(len(sents[0]), len(sents[1]), len(sents[2]))" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": { 224 | "pycharm": { 225 | "name": "#%%\n" 226 | } 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "def get_all_refs(node):\n", 231 | " def get_cit_id_type(node):\n", 232 | " if node.find('element-citation') is None:\n", 233 | " return None\n", 234 | " if node.find('element-citation').find('pub-id') is None:\n", 235 | " return None\n", 236 | " return node.find('element-citation').find('pub-id').get('pub-id-type')\n", 237 | "\n", 238 | " def get_citation_info(node):\n", 239 | " if node is None:\n", 240 | " return {}\n", 241 | " res = {}\n", 242 | " for ch in node.getchildren():\n", 243 | " if ch.tag == 'ref':\n", 244 | " id_type = get_cit_id_type(ch)\n", 245 | " if id_type is not None and id_type == 'pmid':\n", 246 | " res[ch.get('id')] = {\n", 247 | " 'publication-type': ch.find('element-citation').get('publication-type'),\n", 248 | " 'pmid': ch.find('element-citation').find('pub-id').text\n", 249 | " }\n", 250 | " return res\n", 251 | "\n", 252 | " def get_figs_info(node):\n", 253 | " if node is None:\n", 254 | " return {}\n", 255 | " res = {}\n", 256 | " for ch in node.getchildren():\n", 257 | " if ch.tag == 'fig' and ch.find('caption') is not None:\n", 258 | " res[ch.get('id')] = \" \".join(''.join(ch.find('caption').itertext()).strip().split())\n", 259 | " return res\n", 260 | "\n", 261 | " def get_tables_info(node):\n", 262 | " if node is None:\n", 263 | " return {}\n", 264 | " res = {}\n", 265 | " for ch in node.getchildren():\n", 266 | " if ch.tag == 'table-wrap' and ch.find('caption') is not None:\n", 267 | " res[ch.get('id')] = \" \".join(''.join(ch.find('caption').itertext()).strip().split())\n", 268 | " return res\n", 269 | "\n", 270 | " citations = get_citation_info(node.find('back').find('ref-list'))\n", 271 | " figs = get_figs_info(node.find('floats-group'))\n", 272 | " tables = get_tables_info(node.find('floats-group'))\n", 273 | " return citations, figs, tables" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": { 280 | "pycharm": { 281 | "name": "#%%\n" 282 | } 283 | }, 284 | "outputs": [], 285 | "source": [ 286 | "get_all_refs(tree.getroot())" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": { 293 | "pycharm": { 294 | "name": "#%%\n" 295 | } 296 | }, 297 | "outputs": [], 298 | "source": [ 299 | "import re\n", 300 | "\n", 301 | "pattern = re.compile(\"(?<=xref_bibr_)[\\d\\w]+\")\n", 302 | "\n", 303 | "\n", 304 | "def count_reverse(sents_gen, sents_disc, pmid):\n", 305 | " result = []\n", 306 | " for i, sent in enumerate(sents_gen):\n", 307 | " results = re.findall(pattern, sent)\n", 308 | " result.extend(list(map(lambda x: (pmid, 'general', str(i), x), results)))\n", 309 | " for i, sent in enumerate(sents_disc):\n", 310 | " results = re.findall(pattern, sent)\n", 311 | " result.extend(list(map(lambda x: (pmid, 'discussion', str(i), x), results)))\n", 312 | " return result" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": { 319 | "pycharm": { 320 | "name": "#%%\n" 321 | } 322 | }, 323 | "outputs": [], 324 | "source": [ 325 | "gen_sents, disc_sents, abst = get_sentences(tree.getroot())\n", 326 | "count_reverse(gen_sents, disc_sents, '2000292')" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "outputs": [], 333 | "source": [ 334 | "def is_review(tree):\n", 335 | " try:\n", 336 | " return any('Review' in sg.find('subject').text for sg in\n", 337 | " tree.find('front').find('article-meta').find('article-categories').findall('subj-group'))\n", 338 | " except:\n", 339 | " return False\n", 340 | "\n", 341 | "\n", 342 | "# Test\n", 343 | "is_review(etree.parse(f\"{pmc_dataset_root}/PMC001xxxxxx/PMC1817751.xml\").getroot())" 344 | ], 345 | "metadata": { 346 | "collapsed": false, 347 | "pycharm": { 348 | "name": "#%%\n" 349 | } 350 | } 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": { 355 | "pycharm": { 356 | "name": "#%% md\n" 357 | } 358 | }, 359 | "source": [ 360 | "## Create tables required for model learning" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "metadata": { 367 | "scrolled": true, 368 | "pycharm": { 369 | "name": "#%%\n" 370 | } 371 | }, 372 | "outputs": [], 373 | "source": [ 374 | "! mkdir -p {dataset_root}\n", 375 | "\n", 376 | "print('Headers')\n", 377 | "with open(f'{dataset_root}/review_files.csv', 'w') as f:\n", 378 | " print('pmid', file=f)\n", 379 | "with open(f'{dataset_root}/citations.csv', 'w') as f:\n", 380 | " print('\\t'.join(['pmid', 'ref_id', 'pub_type', 'ref_pmid']), file=f)\n", 381 | "with open(f'{dataset_root}/sentences.csv', 'w') as f:\n", 382 | " print('\\t'.join(['pmid', 'sent_id', 'type', 'sentence']), file=f)\n", 383 | "with open(f'{dataset_root}/abstracts.csv', 'w') as f:\n", 384 | " print('\\t'.join(['pmid', 'abstract']), file=f)\n", 385 | "with open(f'{dataset_root}/figures.csv', 'w') as f:\n", 386 | " print('\\t'.join(['pmid', 'fig_id', 'caption']), file=f)\n", 387 | "with open(f'{dataset_root}/tables.csv', 'w') as f:\n", 388 | " print('\\t'.join(['pmid', 'tab_id', 'caption']), file=f)\n", 389 | "with open(f'{dataset_root}/reverse_ref.csv', 'w') as f:\n", 390 | " print('\\t'.join(['pmid', 'sent_type', 'sent_id', 'ref_id']), file=f)\n", 391 | "\n", 392 | "print('Processing articles')\n", 393 | "for id, filename in tqdm(list(dict_articles.items())):\n", 394 | " try:\n", 395 | " tree = etree.parse(pmc_dataset_root + \"/\" + filename).getroot()\n", 396 | " gen_sents, disc_sents, abstract = get_sentences(tree)\n", 397 | " cits, figs, tables = get_all_refs(tree)\n", 398 | " except Exception as e:\n", 399 | " print(\"\\rsomething went wrong\", id, filename, e)\n", 400 | " continue\n", 401 | " if is_review(tree):\n", 402 | " with open(f'{dataset_root}/review_files.csv', 'a') as f:\n", 403 | " print(id, file=f)\n", 404 | " with open(f'{dataset_root}/citations.csv', 'a') as f:\n", 405 | " for i, dic in cits.items():\n", 406 | " print('\\t'.join([id, str(i), dic['publication-type'], dic['pmid']]), file=f)\n", 407 | " with open(f'{dataset_root}/sentences.csv', 'a') as f:\n", 408 | " for i, sent in enumerate(gen_sents):\n", 409 | " print('\\t'.join([id, str(i), 'general', sent]), file=f)\n", 410 | " for i, sent in enumerate(disc_sents):\n", 411 | " print('\\t'.join([id, str(i), 'discussion', sent]), file=f)\n", 412 | " if abstract != '':\n", 413 | " with open(f'{dataset_root}/abstracts.csv', 'a') as f:\n", 414 | " print('\\t'.join([id, abstract]), file=f)\n", 415 | " with open(f'{dataset_root}/figures.csv', 'a') as f:\n", 416 | " for i, text in figs.items():\n", 417 | " print('\\t'.join([id, i, text]), file=f)\n", 418 | " with open(f'{dataset_root}/tables.csv', 'a') as f:\n", 419 | " for i, text in tables.items():\n", 420 | " print('\\t'.join([id, i, text]), file=f)\n", 421 | " with open(f'{dataset_root}/reverse_ref.csv', 'a') as f:\n", 422 | " res = count_reverse(gen_sents, disc_sents, id)\n", 423 | " for row in res:\n", 424 | " print('\\t'.join(list(row)), file=f)" 425 | ] 426 | }, 427 | { 428 | "cell_type": "markdown", 429 | "source": [ 430 | "## Check dataset loading" 431 | ], 432 | "metadata": { 433 | "collapsed": false, 434 | "pycharm": { 435 | "name": "#%% md\n" 436 | } 437 | } 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "outputs": [], 443 | "source": [ 444 | "def sizeof_fmt(num, suffix='B'):\n", 445 | " \"\"\"Used memory analysis utility\"\"\"\n", 446 | " for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:\n", 447 | " if abs(num) < 1024.0:\n", 448 | " return \"%3.1f %s%s\" % (num, unit, suffix)\n", 449 | " num /= 1024.0\n", 450 | " return \"%.1f %s%s\" % (num, 'Yi', suffix)" 451 | ], 452 | "metadata": { 453 | "collapsed": false, 454 | "pycharm": { 455 | "name": "#%%\n" 456 | } 457 | } 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": null, 462 | "outputs": [], 463 | "source": [ 464 | "logging.info('Loading citations_df')\n", 465 | "citations_df = pd.read_csv(os.path.join(dataset_root, \"citations.csv\"), sep='\\t')\n", 466 | "logging.info(sizeof_fmt(sys.getsizeof(citations_df)))\n", 467 | "display(citations_df.head())\n", 468 | "del citations_df" 469 | ], 470 | "metadata": { 471 | "collapsed": false, 472 | "pycharm": { 473 | "name": "#%%\n" 474 | } 475 | } 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "outputs": [], 481 | "source": [ 482 | "logging.info('Loading review_files_df')\n", 483 | "review_files_df = pd.read_csv(os.path.join(dataset_root, \"review_files.csv\"), sep='\\t')\n", 484 | "logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))\n", 485 | "display(review_files_df.head())\n", 486 | "del review_files_df" 487 | ], 488 | "metadata": { 489 | "collapsed": false, 490 | "pycharm": { 491 | "name": "#%%\n" 492 | } 493 | } 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "outputs": [], 499 | "source": [ 500 | "logging.info('Loading reverse_ref_df')\n", 501 | "reverse_ref_df = pd.read_csv(os.path.join(dataset_root, \"reverse_ref.csv\"), sep='\\t')\n", 502 | "logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))\n", 503 | "display(reverse_ref_df.head())\n", 504 | "del reverse_ref_df" 505 | ], 506 | "metadata": { 507 | "collapsed": false, 508 | "pycharm": { 509 | "name": "#%%\n" 510 | } 511 | } 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "outputs": [], 517 | "source": [ 518 | "logging.info('Loading abstracts_df')\n", 519 | "abstracts_df = pd.read_csv(os.path.join(dataset_root, \"abstracts.csv\"), sep='\\t')\n", 520 | "logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))\n", 521 | "display(abstracts_df.head())\n", 522 | "del abstracts_df" 523 | ], 524 | "metadata": { 525 | "collapsed": false, 526 | "pycharm": { 527 | "name": "#%%\n" 528 | } 529 | } 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "outputs": [], 535 | "source": [ 536 | "logging.info('Loading figures_df')\n", 537 | "figures_df = pd.read_csv(os.path.join(dataset_root, \"figures.csv\"), sep='\\t')\n", 538 | "logging.info(sizeof_fmt(sys.getsizeof(figures_df)))\n", 539 | "display(figures_df.head())\n", 540 | "del figures_df" 541 | ], 542 | "metadata": { 543 | "collapsed": false, 544 | "pycharm": { 545 | "name": "#%%\n" 546 | } 547 | } 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "outputs": [], 553 | "source": [ 554 | "logging.info('Loading tables_df')\n", 555 | "tables_df = pd.read_csv(os.path.join(dataset_root, \"tables.csv\"), sep='\\t')\n", 556 | "logging.info(sizeof_fmt(sys.getsizeof(tables_df)))\n", 557 | "display(tables_df.head())\n", 558 | "del tables_df" 559 | ], 560 | "metadata": { 561 | "collapsed": false, 562 | "pycharm": { 563 | "name": "#%%\n" 564 | } 565 | } 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "outputs": [], 571 | "source": [ 572 | "logging.info('Loading sentences_df')\n", 573 | "sentences_df = pd.read_csv(os.path.join(dataset_root, \"sentences.csv\"), sep='\\t')\n", 574 | "logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))\n", 575 | "display(sentences_df.head())\n", 576 | "del sentences_df" 577 | ], 578 | "metadata": { 579 | "collapsed": false, 580 | "pycharm": { 581 | "name": "#%%\n" 582 | } 583 | } 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": null, 588 | "outputs": [], 589 | "source": [], 590 | "metadata": { 591 | "collapsed": false, 592 | "pycharm": { 593 | "name": "#%%\n" 594 | } 595 | } 596 | } 597 | ], 598 | "metadata": { 599 | "kernelspec": { 600 | "display_name": "Python 3 (ipykernel)", 601 | "language": "python", 602 | "name": "python3" 603 | }, 604 | "language_info": { 605 | "codemirror_mode": { 606 | "name": "ipython", 607 | "version": 3 608 | }, 609 | "file_extension": ".py", 610 | "mimetype": "text/x-python", 611 | "name": "python", 612 | "nbconvert_exporter": "python", 613 | "pygments_lexer": "ipython3", 614 | "version": "3.10.4" 615 | } 616 | }, 617 | "nbformat": 4, 618 | "nbformat_minor": 4 619 | } -------------------------------------------------------------------------------- /review/dataset/analyze_archive.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# Analyze archive\n", 12 | "\n", 13 | "This notebook contains the code to analyse content of the PubMedCentral Author Manuscript Collection. \\\n", 14 | "See: https://www.ncbi.nlm.nih.gov/pmc/about/mscollection/\n", 15 | "\n", 16 | "Files can be downloaded here: https://ftp.ncbi.nlm.nih.gov/pub/pmc/manuscript/xml/ \\\n", 17 | "**Please ensure** that files are downloaded into `~/pmc_dataset` folder to proceed." 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "pycharm": { 24 | "name": "#%% md\n" 25 | } 26 | }, 27 | "source": [ 28 | "## Collecting files" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": { 35 | "pycharm": { 36 | "name": "#%%\n" 37 | } 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "% matplotlib inline\n", 42 | "% config InlineBackend.figure_format='retina'\n", 43 | "\n", 44 | "import functools\n", 45 | "import os\n", 46 | "from collections import Counter\n", 47 | "from glob import glob\n", 48 | "\n", 49 | "import matplotlib.pyplot as plt\n", 50 | "from lxml import etree\n", 51 | "from tqdm.auto import tqdm\n", 52 | "\n", 53 | "dict_articles = {}\n", 54 | "pmc_dataset_root = os.path.expanduser('~/pmc_dataset')\n", 55 | "\n", 56 | "\n", 57 | "for filelist in tqdm(glob(os.path.join(pmc_dataset_root, '*filelist.txt'))):\n", 58 | " with open(filelist, 'r') as f:\n", 59 | " for line in f:\n", 60 | " if 'LastUpdated' in line:\n", 61 | " continue\n", 62 | " filename, pmcid, pmid, mid, date, time = line.split()\n", 63 | " dict_articles[pmid] = filename" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "pycharm": { 71 | "name": "#%%\n" 72 | } 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "print('Total papers', len(dict_articles))" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": { 83 | "pycharm": { 84 | "name": "#%%\n" 85 | } 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "list(dict_articles.values())[:10]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "pycharm": { 97 | "name": "#%%\n" 98 | } 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "def count_tags(node):\n", 103 | " stat = Counter()\n", 104 | "\n", 105 | " def dfs(root):\n", 106 | " stat[root.tag] += 1\n", 107 | " for child in root.getchildren():\n", 108 | " dfs(child)\n", 109 | "\n", 110 | " dfs(node)\n", 111 | " return stat" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": { 118 | "pycharm": { 119 | "name": "#%%\n" 120 | } 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "def get_title(tree):\n", 125 | " return etree.tostring(tree.getroot().find(\"front\").find(\"article-meta\").find(\"title-group\").find(\"article-title\"))" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": { 131 | "pycharm": { 132 | "name": "#%% md\n" 133 | } 134 | }, 135 | "source": [ 136 | "## Collecting review papers" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "pycharm": { 144 | "name": "#%%\n" 145 | } 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "review_filenames = set()\n", 150 | "for filename in tqdm(dict_articles.values()):\n", 151 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n", 152 | " title = str(get_title(tree))\n", 153 | " if not title:\n", 154 | " print(f\"\\r{filename}\")\n", 155 | " if \"review\" in title.lower():\n", 156 | " review_filenames.add(filename)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "pycharm": { 164 | "name": "#%%\n" 165 | } 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "print('Review papers', len(review_filenames))" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "metadata": { 175 | "pycharm": { 176 | "name": "#%% md\n" 177 | } 178 | }, 179 | "source": [ 180 | "## Collecting tag statistics" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "metadata": { 187 | "pycharm": { 188 | "name": "#%%\n" 189 | } 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "tag_stat = {}\n", 194 | "tag_stat['review'] = Counter()\n", 195 | "tag_stat['ordinary'] = Counter()\n", 196 | "for filename in tqdm(dict_articles.values()):\n", 197 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n", 198 | " cur_stat = count_tags(tree.getroot())\n", 199 | " if filename in review_filenames:\n", 200 | " tag_stat['review'] += cur_stat\n", 201 | " else:\n", 202 | " tag_stat['ordinary'] += cur_stat" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "metadata": { 209 | "pycharm": { 210 | "name": "#%%\n" 211 | } 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "for s, cnt in zip(['ordinary', 'review'], [len(dict_articles) - len(review_filenames), len(review_filenames)]):\n", 216 | " with open(f'{s}_tag_stat.txt', 'w') as f:\n", 217 | " srt = sorted(tag_stat[s].items(), key=lambda x: x[1])\n", 218 | " srt = list(map(lambda x: (x[0], x[1] / cnt), srt))\n", 219 | " print(f'Number: {cnt}', file=f)\n", 220 | " for val, count in srt:\n", 221 | " print(f'{val} {count}', file=f)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": { 228 | "pycharm": { 229 | "name": "#%%\n" 230 | } 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "def tag_depth(node):\n", 235 | " def dfs(root):\n", 236 | " d = 1\n", 237 | " for child in root.getchildren():\n", 238 | " d = max(d, dfs(child) + 1)\n", 239 | " return d\n", 240 | "\n", 241 | " return dfs(node)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": { 248 | "pycharm": { 249 | "name": "#%%\n" 250 | } 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "d_stat = {}\n", 255 | "d_stat['review'] = {}\n", 256 | "d_stat['ordinary'] = {}\n", 257 | "for filename in tqdm(dict_articles.values()):\n", 258 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n", 259 | " cur_stat = tag_depth(tree.getroot())\n", 260 | " if filename in review_filenames:\n", 261 | " d_stat['review'][filename] = cur_stat\n", 262 | " else:\n", 263 | " d_stat['ordinary'][filename] = cur_stat" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "pycharm": { 271 | "name": "#%%\n" 272 | } 273 | }, 274 | "outputs": [], 275 | "source": [ 276 | "print(list(d_stat['review'].items())[:10])" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "pycharm": { 284 | "name": "#%%\n" 285 | } 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "for s in ['ordinary', 'review']:\n", 290 | " with open(f'{s}_tag_depth.txt', 'w') as f:\n", 291 | " srt = sorted(d_stat[s].items(), key=lambda x: x[1])\n", 292 | " for val, count in srt:\n", 293 | " print(f'{val} {count}', file=f)" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": { 300 | "pycharm": { 301 | "name": "#%%\n" 302 | } 303 | }, 304 | "outputs": [], 305 | "source": [ 306 | "plt.title('Tag depths review papers')\n", 307 | "plt.hist(d_stat['review'].values(), bins=range(5, 20))\n", 308 | "plt.show()" 309 | ] 310 | }, 311 | { 312 | "cell_type": "code", 313 | "execution_count": null, 314 | "metadata": { 315 | "pycharm": { 316 | "name": "#%%\n" 317 | } 318 | }, 319 | "outputs": [], 320 | "source": [ 321 | "plt.title('Tag depth ordinary papers')\n", 322 | "plt.hist(d_stat['ordinary'].values(), bins=range(5, 20))\n", 323 | "plt.show()" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": null, 329 | "metadata": { 330 | "pycharm": { 331 | "name": "#%%\n" 332 | } 333 | }, 334 | "outputs": [], 335 | "source": [ 336 | "tree = etree.parse(os.path.join(pmc_dataset_root, list(dict_articles.values())[0]))" 337 | ] 338 | }, 339 | { 340 | "cell_type": "markdown", 341 | "metadata": { 342 | "pycharm": { 343 | "name": "#%% md\n" 344 | } 345 | }, 346 | "source": [ 347 | "### Collecting paragraphs statistics" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": { 354 | "pycharm": { 355 | "name": "#%%\n" 356 | } 357 | }, 358 | "outputs": [], 359 | "source": [ 360 | "def get_paragraph_info(root):\n", 361 | " num = 0\n", 362 | " sum_pos = -1\n", 363 | " disc_pos = -1\n", 364 | " lens = Counter()\n", 365 | " for ch in root.find('body').getchildren():\n", 366 | " if ch.tag == 'sec':\n", 367 | " num += 1\n", 368 | " try:\n", 369 | " lens[num] = len(etree.tostring(ch))\n", 370 | " except Exception:\n", 371 | " lens[num] = 0\n", 372 | " print(\"\\n!\")\n", 373 | " str_title = str(etree.tostring(ch.find('title'))).lower()\n", 374 | " if 'summary' in str_title:\n", 375 | " sum_pos = num\n", 376 | " if 'discussion' in str_title:\n", 377 | " disc_pos = num\n", 378 | " return num, sum_pos, disc_pos, lens" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "pycharm": { 386 | "name": "#%%\n" 387 | } 388 | }, 389 | "outputs": [], 390 | "source": [ 391 | "review_filenames = set()\n", 392 | "\n", 393 | "with open('review_tag_depth.txt', 'r') as f:\n", 394 | " for line in f:\n", 395 | " filename, _ = line.split()\n", 396 | " review_filenames.add(filename)" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": { 403 | "scrolled": true, 404 | "pycharm": { 405 | "name": "#%%\n" 406 | } 407 | }, 408 | "outputs": [], 409 | "source": [ 410 | "para_stats = {}\n", 411 | "para_stats['review'] = {}\n", 412 | "para_stats['ordinary'] = {}\n", 413 | "\n", 414 | "for filename in tqdm(dict_articles.values()):\n", 415 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n", 416 | " try:\n", 417 | " cur_stat = get_paragraph_info(tree.getroot())\n", 418 | " except Exception:\n", 419 | " print(f\"\\n{filename}\")\n", 420 | " continue\n", 421 | " if filename in review_filenames:\n", 422 | " para_stats['review'][filename] = cur_stat\n", 423 | " else:\n", 424 | " para_stats['ordinary'][filename] = cur_stat" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": null, 430 | "metadata": { 431 | "pycharm": { 432 | "name": "#%%\n" 433 | } 434 | }, 435 | "outputs": [], 436 | "source": [ 437 | "list(para_stats['review'].items())[:10]" 438 | ] 439 | }, 440 | { 441 | "cell_type": "markdown", 442 | "metadata": { 443 | "pycharm": { 444 | "name": "#%% md\n" 445 | } 446 | }, 447 | "source": [ 448 | "#### Number of sections" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": null, 454 | "metadata": { 455 | "pycharm": { 456 | "name": "#%%\n" 457 | } 458 | }, 459 | "outputs": [], 460 | "source": [ 461 | "para_nums = list(map(lambda x: x[1][0], para_stats['review'].items()))\n", 462 | "print(para_nums[:10])\n", 463 | "plt.title('Number of sections in review papers')\n", 464 | "plt.hist(para_nums, bins=range(1, 20))\n", 465 | "plt.show()" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "metadata": { 472 | "pycharm": { 473 | "name": "#%%\n" 474 | } 475 | }, 476 | "outputs": [], 477 | "source": [ 478 | "para_nums = list(map(lambda x: x[1][0], para_stats['ordinary'].items()))\n", 479 | "print(para_nums[:10])\n", 480 | "plt.title('Number of sections in ordinary papers')\n", 481 | "plt.hist(para_nums, bins=range(1, 20))\n", 482 | "plt.show()" 483 | ] 484 | }, 485 | { 486 | "cell_type": "markdown", 487 | "metadata": { 488 | "pycharm": { 489 | "name": "#%% md\n" 490 | } 491 | }, 492 | "source": [ 493 | "### Discussion section position" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": null, 499 | "metadata": { 500 | "pycharm": { 501 | "name": "#%%\n" 502 | } 503 | }, 504 | "outputs": [], 505 | "source": [ 506 | "sum_stat = list(map(lambda x: x[1][1], para_stats['review'].items()))\n", 507 | "print(sum_stat[:10])\n", 508 | "plt.title('Position of discussion section in review papers')\n", 509 | "plt.hist(sum_stat, bins=range(1, 20))\n", 510 | "plt.show()" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "metadata": { 517 | "pycharm": { 518 | "name": "#%%\n" 519 | } 520 | }, 521 | "outputs": [], 522 | "source": [ 523 | "sum_stat = list(map(lambda x: x[1][1], para_stats['ordinary'].items()))\n", 524 | "print(sum_stat[:10])\n", 525 | "plt.title('Position of discussion section in ordinary papers')\n", 526 | "plt.hist(sum_stat, bins=range(1, 20))\n", 527 | "plt.show()" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": { 533 | "pycharm": { 534 | "name": "#%% md\n" 535 | } 536 | }, 537 | "source": [ 538 | "### Position of discussion papers" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "metadata": { 545 | "pycharm": { 546 | "name": "#%%\n" 547 | } 548 | }, 549 | "outputs": [], 550 | "source": [ 551 | "sum_stat = list(map(lambda x: x[1][2], para_stats['review'].items()))\n", 552 | "print(sum_stat[:10])\n", 553 | "plt.title('Position of discussion section in review papers')\n", 554 | "plt.hist(sum_stat, bins=range(-1, 20))\n", 555 | "plt.show()" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "metadata": { 562 | "pycharm": { 563 | "name": "#%%\n" 564 | } 565 | }, 566 | "outputs": [], 567 | "source": [ 568 | "sum_stat = list(map(lambda x: x[1][2], para_stats['ordinary'].items()))\n", 569 | "print(sum_stat[:10])\n", 570 | "plt.title('Position of discussion section in ordinary papers')\n", 571 | "plt.hist(sum_stat, bins=range(-1, 20))\n", 572 | "plt.show()" 573 | ] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "metadata": { 578 | "pycharm": { 579 | "name": "#%% md\n" 580 | } 581 | }, 582 | "source": [ 583 | "### Average number of sections" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": { 590 | "pycharm": { 591 | "name": "#%%\n" 592 | } 593 | }, 594 | "outputs": [], 595 | "source": [ 596 | "len_stat = functools.reduce(lambda x, y: x + y, map(lambda x: x[1][3], para_stats['review'].items()))\n", 597 | "plt.title('Average number of sections in review papers')\n", 598 | "plt.bar(len_stat.keys(), list(map(lambda x: x / len(para_stats['review'].items()), len_stat.values())))\n", 599 | "plt.show()" 600 | ] 601 | }, 602 | { 603 | "cell_type": "code", 604 | "execution_count": null, 605 | "metadata": { 606 | "pycharm": { 607 | "name": "#%%\n" 608 | } 609 | }, 610 | "outputs": [], 611 | "source": [ 612 | "len_stat = functools.reduce(lambda x, y: x + y, map(lambda x: x[1][3], para_stats['ordinary'].items()))\n", 613 | "plt.title('Average number of sections in ordinary papers')\n", 614 | "plt.bar(list(map(lambda x: min(35, x), len_stat.keys())),\n", 615 | " list(map(lambda x: x / len(para_stats['ordinary'].items()), len_stat.values())))\n", 616 | "plt.show()" 617 | ] 618 | }, 619 | { 620 | "cell_type": "markdown", 621 | "metadata": { 622 | "pycharm": { 623 | "name": "#%% md\n" 624 | } 625 | }, 626 | "source": [ 627 | "### Position of conclusion section" 628 | ] 629 | }, 630 | { 631 | "cell_type": "code", 632 | "execution_count": null, 633 | "metadata": { 634 | "pycharm": { 635 | "name": "#%%\n" 636 | } 637 | }, 638 | "outputs": [], 639 | "source": [ 640 | "xml = 'Some example text'\n", 641 | "tree = etree.fromstring(xml)\n", 642 | "print(''.join(tree.itertext()))" 643 | ] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "execution_count": null, 648 | "metadata": { 649 | "pycharm": { 650 | "name": "#%%\n" 651 | } 652 | }, 653 | "outputs": [], 654 | "source": [ 655 | "list(filter(lambda x: x[1][0] == 1, para_stats['ordinary'].items()))[:10]" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": { 662 | "pycharm": { 663 | "name": "#%%\n" 664 | } 665 | }, 666 | "outputs": [], 667 | "source": [ 668 | "def get_conc_info(root):\n", 669 | " conc_pos = -1\n", 670 | " num = 0\n", 671 | " for ch in root.find('body').getchildren():\n", 672 | " if ch.tag == 'sec':\n", 673 | " num += 1\n", 674 | " str_title = str(etree.tostring(ch.find('title'))).lower()\n", 675 | " if 'conclusion' in str_title:\n", 676 | " conc_pos = num\n", 677 | " return conc_pos" 678 | ] 679 | }, 680 | { 681 | "cell_type": "code", 682 | "execution_count": null, 683 | "metadata": { 684 | "pycharm": { 685 | "name": "#%%\n" 686 | } 687 | }, 688 | "outputs": [], 689 | "source": [ 690 | "conc_stats = {}\n", 691 | "conc_stats['review'] = {}\n", 692 | "conc_stats['ordinary'] = {}\n", 693 | "\n", 694 | "for filename in tqdm(dict_articles.values()):\n", 695 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n", 696 | " try:\n", 697 | " cur_stat = get_conc_info(tree.getroot())\n", 698 | " except Exception:\n", 699 | " print(f\"\\n{filename}\")\n", 700 | " continue\n", 701 | " if filename in review_filenames:\n", 702 | " conc_stats['review'][filename] = cur_stat\n", 703 | " else:\n", 704 | " conc_stats['ordinary'][filename] = cur_stat" 705 | ] 706 | }, 707 | { 708 | "cell_type": "code", 709 | "execution_count": null, 710 | "metadata": { 711 | "pycharm": { 712 | "name": "#%%\n" 713 | } 714 | }, 715 | "outputs": [], 716 | "source": [ 717 | "conc_stat = list(map(lambda x: x[1], conc_stats['review'].items()))\n", 718 | "print(conc_stat[:10])\n", 719 | "plt.title('Position of conclusion section in review papers')\n", 720 | "plt.hist(conc_stat, bins=range(-1, 20))\n", 721 | "plt.show()" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": null, 727 | "metadata": { 728 | "pycharm": { 729 | "name": "#%%\n" 730 | } 731 | }, 732 | "outputs": [], 733 | "source": [ 734 | "conc_stat = list(map(lambda x: x[1], conc_stats['ordinary'].items()))\n", 735 | "print(conc_stat[:10])\n", 736 | "plt.title('Position of conclusion section in ordinary papers')\n", 737 | "plt.hist(conc_stat, bins=range(-1, 20))\n", 738 | "plt.show()" 739 | ] 740 | } 741 | ], 742 | "metadata": { 743 | "kernelspec": { 744 | "display_name": "Python 3 (ipykernel)", 745 | "language": "python", 746 | "name": "python3" 747 | }, 748 | "language_info": { 749 | "codemirror_mode": { 750 | "name": "ipython", 751 | "version": 3 752 | }, 753 | "file_extension": ".py", 754 | "mimetype": "text/x-python", 755 | "name": "python", 756 | "nbconvert_exporter": "python", 757 | "pygments_lexer": "ipython3", 758 | "version": "3.10.4" 759 | } 760 | }, 761 | "nbformat": 4, 762 | "nbformat_minor": 4 763 | } -------------------------------------------------------------------------------- /review/train/test_new_algo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# Training&Evaluation of a developed algorithm\n", 7 | "\n", 8 | "This notebook contains source code for several possible architectures to train & evaluate them." 9 | ], 10 | "metadata": { 11 | "collapsed": false, 12 | "pycharm": { 13 | "name": "#%% md\n" 14 | } 15 | } 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "outputs": [], 21 | "source": [ 22 | "import logging\n", 23 | "import os\n", 24 | "import random\n", 25 | "import sys\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "import pandas as pd\n", 29 | "import torch\n", 30 | "from tqdm.auto import tqdm\n", 31 | "\n", 32 | "from review import config as cfg\n", 33 | "\n", 34 | "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')" 35 | ], 36 | "metadata": { 37 | "collapsed": false, 38 | "pycharm": { 39 | "name": "#%%\n" 40 | } 41 | } 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "outputs": [], 47 | "source": [ 48 | "print('CUDA version', torch.version.cuda)\n", 49 | "if torch.cuda.is_available():\n", 50 | " print('GPU:', torch.cuda.get_device_name(0))\n", 51 | "else:\n", 52 | " print('CPU')\n", 53 | "\n", 54 | "\n", 55 | "def setup_cuda_device(model):\n", 56 | " logging.info('Setup single-device settings...')\n", 57 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 58 | " model = model.to(device)\n", 59 | " return model, device" 60 | ], 61 | "metadata": { 62 | "collapsed": false, 63 | "pycharm": { 64 | "name": "#%%\n" 65 | } 66 | } 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "outputs": [], 72 | "source": [ 73 | "logging.info('Fix seed')\n", 74 | "seed = 42\n", 75 | "random.seed(seed)\n", 76 | "np.random.seed(seed)\n", 77 | "torch.manual_seed(seed)\n", 78 | "torch.cuda.manual_seed_all(seed)" 79 | ], 80 | "metadata": { 81 | "collapsed": false, 82 | "pycharm": { 83 | "name": "#%%\n" 84 | } 85 | } 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "source": [ 90 | "# Loading data and preparation of dataset\n", 91 | "\n", 92 | "`load_data` loads all the needed datafiles for building train dataset.\\\n", 93 | "The several next steps are only should be done if no train/test/val datasets are saved." 94 | ], 95 | "metadata": { 96 | "collapsed": false, 97 | "pycharm": { 98 | "name": "#%% md\n" 99 | } 100 | } 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "outputs": [], 106 | "source": [ 107 | "def sizeof_fmt(num, suffix='B'):\n", 108 | " for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:\n", 109 | " if abs(num) < 1024.0:\n", 110 | " return \"%3.1f %s%s\" % (num, unit, suffix)\n", 111 | " num /= 1024.0\n", 112 | " return \"%.1f %s%s\" % (num, 'Yi', suffix)\n", 113 | "\n", 114 | "\n", 115 | "def load_data(additional_features):\n", 116 | " dataset_root = os.path.expanduser(cfg.dataset_path)\n", 117 | " logging.info(f'Loading references dataset from {dataset_root}')\n", 118 | "\n", 119 | " logging.info('Loading citations_df')\n", 120 | " citations_df = pd.read_csv(os.path.join(dataset_root, \"citations.csv\"), sep='\\t')\n", 121 | " logging.info(sizeof_fmt(sys.getsizeof(citations_df)))\n", 122 | "\n", 123 | " logging.info('Loading sentences_df')\n", 124 | " sentences_df = pd.read_csv(os.path.join(dataset_root, \"sentences.csv\"), sep='\\t')\n", 125 | " logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))\n", 126 | "\n", 127 | " logging.info('Loading review_files_df')\n", 128 | " review_files_df = pd.read_csv(os.path.join(dataset_root, \"review_files.csv\"), sep='\\t')\n", 129 | " logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))\n", 130 | "\n", 131 | " logging.info('Loading reverse_ref_df')\n", 132 | " reverse_ref_df = pd.read_csv(os.path.join(dataset_root, \"reverse_ref.csv\"), sep='\\t')\n", 133 | " logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))\n", 134 | "\n", 135 | " if not additional_features:\n", 136 | " return citations_df, sentences_df, review_files_df, reverse_ref_df, None, None, None\n", 137 | "\n", 138 | " logging.info('Loading abstracts_df')\n", 139 | " abstracts_df = pd.read_csv(os.path.join(dataset_root, \"abstracts.csv\"), sep='\\t')\n", 140 | " logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))\n", 141 | "\n", 142 | " logging.info('Loading figures_df')\n", 143 | " figures_df = pd.read_csv(os.path.join(dataset_root, \"figures.csv\"), sep='\\t')\n", 144 | " logging.info(sizeof_fmt(sys.getsizeof(figures_df)))\n", 145 | "\n", 146 | " logging.info('Loading tables_df')\n", 147 | " tables_df = pd.read_csv(os.path.join(dataset_root, \"tables.csv\"), sep='\\t')\n", 148 | " logging.info(sizeof_fmt(sys.getsizeof(tables_df)))\n", 149 | "\n", 150 | " return citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df" 151 | ], 152 | "metadata": { 153 | "collapsed": false, 154 | "pycharm": { 155 | "name": "#%%\n" 156 | } 157 | } 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "outputs": [], 163 | "source": [ 164 | "ADDITIONAL_FEATURES = False\n", 165 | "\n", 166 | "citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df = load_data(\n", 167 | " additional_features=ADDITIONAL_FEATURES)\n", 168 | "\n", 169 | "logging.info('Done loading references dataset')" 170 | ], 171 | "metadata": { 172 | "collapsed": false, 173 | "pycharm": { 174 | "name": "#%%\n" 175 | } 176 | } 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "outputs": [], 182 | "source": [ 183 | "import matplotlib.pyplot as plt\n", 184 | "from collections import Counter\n", 185 | "\n", 186 | "res = Counter(list(sentences_df['pmid'].values))\n", 187 | "plt.hist(res.values(), bins=range(-1, 400))\n", 188 | "plt.title('Length of papers')\n", 189 | "plt.show()\n", 190 | "del res" 191 | ], 192 | "metadata": { 193 | "collapsed": false, 194 | "pycharm": { 195 | "name": "#%%\n" 196 | } 197 | } 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "outputs": [], 203 | "source": [ 204 | "REF_SENTS_DF_PATH = f\"{os.path.expanduser(cfg.base_path)}/ref_sents.csv\"\n", 205 | "! rm {REF_SENTS_DF_PATH}\n", 206 | "\n", 207 | "if os.path.exists(REF_SENTS_DF_PATH):\n", 208 | " ref_sents_df = pd.read_csv(REF_SENTS_DF_PATH, sep='\\t')\n", 209 | "else:\n", 210 | " logging.info('Creating reference sentences dataset')\n", 211 | " ref_sents_df = pd.merge(citations_df, reverse_ref_df, left_on=['pmid', 'ref_id'], right_on=['pmid', 'ref_id'])\n", 212 | " ref_sents_df = pd.merge(ref_sents_df, sentences_df, left_on=['pmid', 'sent_type', 'sent_id'],\n", 213 | " right_on=['pmid', 'type', 'sent_id'])\n", 214 | " ref_sents_df = ref_sents_df[ref_sents_df['pmid'].isin(review_files_df['pmid'].values)]\n", 215 | " ref_sents_df = ref_sents_df.drop_duplicates()\n", 216 | " logging.info(f'Len of unique ref_sents {len(set(ref_sents_df[\"ref_pmid\"]))}')\n", 217 | " ref_sents_df = ref_sents_df[['pmid', 'ref_id', 'pub_type', 'ref_pmid', 'sent_type', 'sent_id', 'sentence']]\n", 218 | " ref_sents_df.to_csv(REF_SENTS_DF_PATH, sep='\\t', index=False)\n", 219 | "\n", 220 | "logging.info('Cleanup memory')\n", 221 | "del citations_df\n", 222 | "del review_files_df\n", 223 | "\n", 224 | "display(ref_sents_df.head())" 225 | ], 226 | "metadata": { 227 | "collapsed": false, 228 | "pycharm": { 229 | "name": "#%%\n" 230 | } 231 | } 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "source": [ 236 | "# Rouge\n", 237 | "`get_rouge` function allows to compute similarity between two sentences.\n", 238 | "`rouge-l` is another possible option." 239 | ], 240 | "metadata": { 241 | "collapsed": false, 242 | "pycharm": { 243 | "name": "#%% md\n" 244 | } 245 | } 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "outputs": [], 251 | "source": [ 252 | "from rouge import Rouge\n", 253 | "from transformers import BertTokenizer\n", 254 | "\n", 255 | "ROUGE_METER = Rouge()\n", 256 | "TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n", 257 | "\n", 258 | "\n", 259 | "def get_rouge(sent1, sent2):\n", 260 | " if sent1 is None or sent2 is None:\n", 261 | " return None\n", 262 | " sent_1 = TOKENIZER.tokenize(sent1)\n", 263 | " if len(sent_1) == 0:\n", 264 | " return None\n", 265 | " sent_1 = \" \".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_1)))\n", 266 | " sent_2 = TOKENIZER.tokenize(sent2)\n", 267 | " if len(sent_2) == 0:\n", 268 | " return None\n", 269 | " sent_2 = \" \".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_2)))\n", 270 | " if len(sent_1) == 0 or len(sent_2) == 0:\n", 271 | " return None\n", 272 | " rouges = ROUGE_METER.get_scores(sent_1, sent_2)[0]\n", 273 | " rouges = [rouges[f'rouge-{x}'][\"f\"] for x in ('1', '2')] # , 'l')]\n", 274 | " return np.mean(rouges) * 100\n", 275 | "\n", 276 | "\n", 277 | "def mean_rouge(sent, text):\n", 278 | " if len(text) == 0:\n", 279 | " return None\n", 280 | " try:\n", 281 | " return sum(get_rouge(sent, ref_sent) for ref_sent in text) / len(text)\n", 282 | " except Exception as e:\n", 283 | " logging.error(f'Exception at mean_rouge {e}')\n", 284 | " return None\n", 285 | "\n", 286 | "\n", 287 | "def min_rouge(sent, text):\n", 288 | " try:\n", 289 | " score = 100000000\n", 290 | " for ref_sent in text:\n", 291 | " score = min(get_rouge(sent, ref_sent), score)\n", 292 | " if score == 100000000:\n", 293 | " return None\n", 294 | " return score\n", 295 | " except Exception as e:\n", 296 | " logging.error(f'Exception at min_rouge {e}')\n", 297 | " return None\n", 298 | "\n", 299 | "\n", 300 | "def max_rouge(sent, text):\n", 301 | " try:\n", 302 | " score = -100000\n", 303 | " for ref_sent in text:\n", 304 | " score = max(get_rouge(sent, ref_sent), score)\n", 305 | " if score == -100000:\n", 306 | " return None\n", 307 | " return score\n", 308 | " except Exception as e:\n", 309 | " logging.error(f'Exception at max_rouge {e}')\n", 310 | " return None" 311 | ], 312 | "metadata": { 313 | "collapsed": false, 314 | "pycharm": { 315 | "name": "#%%\n" 316 | } 317 | } 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "outputs": [], 323 | "source": [ 324 | "get_rouge(\"I am scout. True!\", \"No you are not a scout.\")" 325 | ], 326 | "metadata": { 327 | "collapsed": false, 328 | "pycharm": { 329 | "name": "#%%\n" 330 | } 331 | } 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "source": [ 336 | "# Build train / test/ validate datasets\n", 337 | "\n", 338 | "To create a dataset with features use `preprocess_paper_with_features`. Otherwise, use `preprocess_paper`.\n", 339 | "\n", 340 | "In case datasets are not yet created and `ref_sents_df` is also not yet created, let's create `ref_sents_df`.\\\n", 341 | "For each paper pmid there is a list of sentences from review papers in which the paper with this `pmid` is cited." 342 | ], 343 | "metadata": { 344 | "collapsed": false, 345 | "pycharm": { 346 | "name": "#%% md\n" 347 | } 348 | } 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "outputs": [], 354 | "source": [ 355 | "from unidecode import unidecode\n", 356 | "from nltk.tokenize import sent_tokenize\n", 357 | "import re\n", 358 | "\n", 359 | "REPLACE_SYMBOLS = {\n", 360 | " '—': '-',\n", 361 | " '–': '-',\n", 362 | " '―': '-',\n", 363 | " '…': '...',\n", 364 | " '´´': \"´\",\n", 365 | " '´´´': \"´´\",\n", 366 | " \"''\": \"'\",\n", 367 | " \"'''\": \"'\",\n", 368 | " \"``\": \"`\",\n", 369 | " \"```\": \"`\",\n", 370 | " \":\": \" : \",\n", 371 | "}\n", 372 | "\n", 373 | "\n", 374 | "def parse_sents(data):\n", 375 | " sents = sum([sent_tokenize(text) for text in data], [])\n", 376 | " sents = list(filter(lambda x: len(x) > 3, sents))\n", 377 | " return sents\n", 378 | "\n", 379 | "\n", 380 | "def sent_standardize(sent):\n", 381 | " sent = unidecode(sent)\n", 382 | " sent = re.sub(r\"\\[(xref_\\w*_\\w\\d*]*)(, xref_\\w*_\\w\\d*)*\\]\", \" \", sent) # delete [xref,...]\n", 383 | " sent = re.sub(r\"\\( (xref_\\w*_\\w\\d*)(; xref_\\w*_\\w\\d*)* \\)\", \" \", sent) # delete (xref; ...)\n", 384 | " sent = re.sub(r\"\\[xref_\\w*_\\w\\d*\\]\", \" \", sent) # delete [xref]\n", 385 | " sent = re.sub(r\"xref_\\w*_\\w\\d*\", \" \", sent) # delete [[xref]]\n", 386 | " for k, v in REPLACE_SYMBOLS.items():\n", 387 | " sent = sent.replace(k, v)\n", 388 | " return sent.strip()\n", 389 | "\n", 390 | "\n", 391 | "def standardize(text):\n", 392 | " return [x for x in (sent_standardize(sent) for sent in text) if len(x) > 3]\n", 393 | "\n", 394 | "\n", 395 | "PAPER_MIN_SENTENCES = 50\n", 396 | "PAPER_MAX_SENTENCES = 100\n", 397 | "\n", 398 | "\n", 399 | "def preprocess_paper(\n", 400 | " paper_id, sentences_df, ref_sents_df,\n", 401 | " paper_min_sents=PAPER_MIN_SENTENCES, paper_max_sents=PAPER_MAX_SENTENCES):\n", 402 | " paper = sentences_df[sentences_df['pmid'] == paper_id]['sentence']\n", 403 | " paper = standardize(paper)\n", 404 | "\n", 405 | " ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']\n", 406 | " ref_sents = standardize(ref_sents)\n", 407 | "\n", 408 | " if len(paper) < paper_min_sents:\n", 409 | " return None\n", 410 | "\n", 411 | " if len(paper) > paper_max_sents:\n", 412 | " paper = list(paper[:paper_min_sents]) + list(paper[-paper_min_sents:])\n", 413 | "\n", 414 | " preprocessed_score = [\n", 415 | " sum(get_rouge(sent, ref_sent) for ref_sent in ref_sents) / len(ref_sents) for sent in paper\n", 416 | " ]\n", 417 | " return paper, preprocessed_score\n", 418 | "\n", 419 | "\n", 420 | "def preprocess_paper_with_features(paper_id, sentences_df, ref_sents_df,\n", 421 | " abstracts_df, figures_df, reverse_ref_df, tables_df,\n", 422 | " paper_min_sents=PAPER_MIN_SENTENCES, paper_max_sents=PAPER_MAX_SENTENCES):\n", 423 | " preprocessed_score = []\n", 424 | " features = []\n", 425 | "\n", 426 | " papers = sentences_df[sentences_df['pmid'] == paper_id]['sentence']\n", 427 | " papers = standardize(papers)\n", 428 | "\n", 429 | " sent_ids = sentences_df[sentences_df['pmid'] == paper_id]['sent_id']\n", 430 | " sent_types = sentences_df[sentences_df['pmid'] == paper_id]['type']\n", 431 | "\n", 432 | " ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']\n", 433 | " ref_sents = standardize(ref_sents)\n", 434 | "\n", 435 | " fig_captions = figures_df[figures_df['pmid'] == paper_id]['caption']\n", 436 | " fig_captions = standardize(fig_captions)\n", 437 | "\n", 438 | " tab_captions = tables_df[tables_df['pmid'] == paper_id]['caption']\n", 439 | " tab_captions = standardize(tab_captions)\n", 440 | "\n", 441 | " abstract = abstracts_df[abstracts_df['pmid'] == paper_id]['abstract']\n", 442 | " if len(abstract) != 0:\n", 443 | " abstract = standardize(abstract)\n", 444 | "\n", 445 | " tmp_df = reverse_ref_df[reverse_ref_df['pmid'] == paper_id]\n", 446 | "\n", 447 | " if len(papers) < paper_min_sents:\n", 448 | " return None\n", 449 | "\n", 450 | " if len(papers) > paper_max_sents:\n", 451 | " papers = list(papers[:paper_min_sents]) + list(papers[-paper_min_sents:])\n", 452 | " sent_ids = list(sent_ids[:paper_min_sents]) + list(sent_ids[-paper_min_sents:])\n", 453 | " sent_types = list(sent_types[:paper_min_sents]) + list(sent_types[-paper_min_sents:])\n", 454 | "\n", 455 | " for sent, sent_type, sent_id in zip(papers, sent_types, sent_ids):\n", 456 | " score = mean_rouge(sent, ref_sents)\n", 457 | " if score is None:\n", 458 | " return None\n", 459 | "\n", 460 | " r_abs = get_rouge(sent, abstract[0])\n", 461 | " num_refs = len(tmp_df[(tmp_df['sent_type'] == sent_type) & (tmp_df['sent_id'] == sent_id)])\n", 462 | " preprocessed_score.append(score)\n", 463 | " features.append((sent_id, int(sent_type == \"general\"), r_abs, num_refs,\n", 464 | " mean_rouge(sent, fig_captions), mean_rouge(sent, tab_captions),\n", 465 | " min_rouge(sent, fig_captions), min_rouge(sent, tab_captions),\n", 466 | " max_rouge(sent, fig_captions), max_rouge(sent, tab_captions)))\n", 467 | " return papers, preprocessed_score, features" 468 | ], 469 | "metadata": { 470 | "collapsed": false, 471 | "pycharm": { 472 | "name": "#%%\n" 473 | } 474 | } 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "outputs": [], 480 | "source": [ 481 | "FEATURES_NUMBER = 10\n", 482 | "\n", 483 | "\n", 484 | "def process_reference_sentences_dataset(sentences_df, ref_sents_df, additional_features):\n", 485 | " res = {}\n", 486 | " inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)\n", 487 | " for pmid in tqdm(inter):\n", 488 | " try:\n", 489 | " if additional_features:\n", 490 | " temp = preprocess_paper_with_features(\n", 491 | " pmid, sentences_df, ref_sents_df, abstracts_df, figures_df, reverse_ref_df, tables_df\n", 492 | " )\n", 493 | " else:\n", 494 | " temp = preprocess_paper(pmid, sentences_df, ref_sents_df)\n", 495 | " except Exception as e:\n", 496 | " logging.warning(f'Error during processing {pmid} {e}')\n", 497 | " continue\n", 498 | " if temp is None:\n", 499 | " logging.warning(f'temp is None for {pmid}')\n", 500 | " continue\n", 501 | " res[pmid] = temp\n", 502 | " print(f\"\\r{pmid} {np.mean(res[pmid][1])}\", end=\"\")\n", 503 | "\n", 504 | " logging.info(f'Successfully preprocessed {len(res)} of {len(inter)} papers')\n", 505 | "\n", 506 | " logging.info(f'Creating train dataset')\n", 507 | " feature_names = [\n", 508 | " 'sent_id', 'sent_type', 'r_abs', 'num_refs',\n", 509 | " 'mean_r_fig', 'mean_r_tab', 'min_r_fig', 'min_r_tab', 'max_r_fig', 'max_r_tab'\n", 510 | " ]\n", 511 | " assert len(feature_names) == FEATURES_NUMBER\n", 512 | " train_dic = dict(\n", 513 | " pmid=[], sentence=[], score=[], sent_id=[], sent_type=[], r_abs=[], num_refs=[],\n", 514 | " mean_r_fig=[], mean_r_tab=[], min_r_fig=[], min_r_tab=[], max_r_fig=[], max_r_tab=[]\n", 515 | " )\n", 516 | "\n", 517 | " for pmid, stat in tqdm(res.items()):\n", 518 | " if len(stat) == 2:\n", 519 | " for sent, score in zip(*stat):\n", 520 | " train_dic['pmid'].append(pmid)\n", 521 | " train_dic['sentence'].append(sent)\n", 522 | " train_dic['score'].append(score)\n", 523 | " else:\n", 524 | " for sent, score, features in zip(*stat):\n", 525 | " train_dic['pmid'].append(pmid)\n", 526 | " train_dic['sentence'].append(sent)\n", 527 | " train_dic['score'].append(score)\n", 528 | " for name, val in zip(feature_names, features):\n", 529 | " train_dic[name].append(val)\n", 530 | "\n", 531 | " train_df = pd.DataFrame({k: v for k, v in train_dic.items() if v})\n", 532 | " logging.info(f'Full train dataset {len(train_df)}')\n", 533 | " return train_df" 534 | ], 535 | "metadata": { 536 | "collapsed": false, 537 | "pycharm": { 538 | "name": "#%%\n" 539 | } 540 | } 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "outputs": [], 546 | "source": [ 547 | "TRAIN_DATASET_PATH = f'{os.path.expanduser(cfg.base_path)}/dataset.csv'\n", 548 | "! rm {TRAIN_DATASET_PATH}\n", 549 | "\n", 550 | "if os.path.exists(TRAIN_DATASET_PATH):\n", 551 | " train_df = pd.read_csv(TRAIN_DATASET_PATH)\n", 552 | "else:\n", 553 | " train_df = process_reference_sentences_dataset(\n", 554 | " sentences_df, ref_sents_df, additional_features=ADDITIONAL_FEATURES\n", 555 | " )\n", 556 | " train_df.to_csv(TRAIN_DATASET_PATH, index=False)" 557 | ], 558 | "metadata": { 559 | "collapsed": false, 560 | "pycharm": { 561 | "name": "#%%\n" 562 | } 563 | } 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "outputs": [], 569 | "source": [ 570 | "display(train_df.head())" 571 | ], 572 | "metadata": { 573 | "collapsed": false, 574 | "pycharm": { 575 | "name": "#%%\n" 576 | } 577 | } 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": null, 582 | "outputs": [], 583 | "source": [ 584 | "import matplotlib.pyplot as plt\n", 585 | "from collections import Counter\n", 586 | "\n", 587 | "res = Counter(list(train_df['score'].values))\n", 588 | "\n", 589 | "plt.hist(res.values(), bins=range(2, 20))\n", 590 | "plt.title('Train dataset scores')\n", 591 | "plt.show()" 592 | ], 593 | "metadata": { 594 | "collapsed": false, 595 | "pycharm": { 596 | "name": "#%%\n" 597 | } 598 | } 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "source": [ 603 | "# Splitting data into train/test/val" 604 | ], 605 | "metadata": { 606 | "collapsed": false, 607 | "pycharm": { 608 | "name": "#%% md\n" 609 | } 610 | } 611 | }, 612 | { 613 | "cell_type": "code", 614 | "execution_count": null, 615 | "outputs": [], 616 | "source": [ 617 | "from sklearn.model_selection import train_test_split\n", 618 | "\n", 619 | "train_ids, test_ids = train_test_split(list(set(train_df['pmid'].values)), test_size=0.2)\n", 620 | "test_ids, val_ids = train_test_split(test_ids, test_size=0.4)\n", 621 | "\n", 622 | "train = train_df[train_df['pmid'].isin(train_ids)]\n", 623 | "logging.info(f'Train {len(train)}')\n", 624 | "display(train.head(1))\n", 625 | "\n", 626 | "test = train_df[train_df['pmid'].isin(test_ids)]\n", 627 | "logging.info(f'Test {len(test)}')\n", 628 | "display(test.head(1))\n", 629 | "\n", 630 | "val = train_df[train_df['pmid'].isin(val_ids)]\n", 631 | "logging.info(f'Validate {len(val)}')\n", 632 | "display(val.head(1))\n" 633 | ], 634 | "metadata": { 635 | "collapsed": false, 636 | "pycharm": { 637 | "name": "#%%\n" 638 | } 639 | } 640 | }, 641 | { 642 | "cell_type": "markdown", 643 | "source": [ 644 | "# Inspect pretrained models: BERT and Roberta" 645 | ], 646 | "metadata": { 647 | "collapsed": false, 648 | "pycharm": { 649 | "name": "#%% md\n" 650 | } 651 | } 652 | }, 653 | { 654 | "cell_type": "code", 655 | "execution_count": null, 656 | "outputs": [], 657 | "source": [ 658 | "from transformers import BertModel\n", 659 | "\n", 660 | "backbone = BertModel.from_pretrained(\n", 661 | " \"bert-base-uncased\", output_hidden_states=False\n", 662 | ")\n", 663 | "print('BERT pretrained')\n", 664 | "print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')\n", 665 | "print(', '.join(n for n, p in backbone.named_parameters()))\n", 666 | "print(backbone)\n", 667 | "\n", 668 | "# backbone = RobertaModel.from_pretrained(\n", 669 | "# 'roberta-base', output_hidden_states=False\n", 670 | "# )\n", 671 | "# print('ROBERTA pretrained')\n", 672 | "# print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')\n", 673 | "# print(', '.join(n for n, p in backbone.named_parameters()))\n", 674 | "# print(backbone)" 675 | ], 676 | "metadata": { 677 | "collapsed": false, 678 | "pycharm": { 679 | "name": "#%%\n" 680 | } 681 | } 682 | }, 683 | { 684 | "cell_type": "markdown", 685 | "source": [ 686 | "# Main model classes\n", 687 | "\n", 688 | "The model in main pubtrends application is loaded using `load_model` function from `review.model` module.\n", 689 | "\n", 690 | "It has several options to set up:\n", 691 | "* with or without features (right now without features works better),\n", 692 | "* `BERT` or `roberta` as basis (no big difference),\n", 693 | "\n", 694 | "You can also choose `frozen_strategy`:\n", 695 | "* `froze_all` in case you don't want to improve bert layers but only the summarization layer,\n", 696 | "* `unfroze_last` -- modifies bert weights and still training not very slow,\n", 697 | "* `unfroze_all` -- the training is slow, the results may better though" 698 | ], 699 | "metadata": { 700 | "collapsed": false, 701 | "pycharm": { 702 | "name": "#%% md\n" 703 | } 704 | } 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": null, 709 | "outputs": [], 710 | "source": [ 711 | "import os\n", 712 | "import torch.nn as nn\n", 713 | "from transformers import BertModel, RobertaModel\n", 714 | "from transformers import BertTokenizer, RobertaTokenizer\n", 715 | "\n", 716 | "\n", 717 | "def get_token_id(tokenizer, tkn):\n", 718 | " return tokenizer.convert_tokens_to_ids([tkn])[0]\n", 719 | "\n", 720 | "\n", 721 | "EMBEDDINGS_ADDITIONAL = 20\n", 722 | "\n", 723 | "FEATURES_NN_INTERMEDIATE = 100\n", 724 | "FEATURES_NN_OUT = 50\n", 725 | "FEATURES_DROPOUT = 0.1\n", 726 | "\n", 727 | "DECODER_DROPOUT = 0.1\n", 728 | "\n", 729 | "\n", 730 | "class Summarizer(nn.Module):\n", 731 | " \"\"\"\n", 732 | " This is the main summarization model.\n", 733 | " See https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_tf_bert.py\n", 734 | " It operates the same input format as original BERT used underneath.\n", 735 | " See forward, evaluate params description.\n", 736 | " \"\"\"\n", 737 | " enc_output: torch.Tensor\n", 738 | " dec_ids_mask: torch.Tensor\n", 739 | " encdec_ids_mask: torch.Tensor\n", 740 | "\n", 741 | " def __init__(self, model_type, article_len, additional_features, num_features=FEATURES_NUMBER):\n", 742 | " super(Summarizer, self).__init__()\n", 743 | "\n", 744 | " print(f'Initialize backbone and tokenizer for {model_type}')\n", 745 | " self.article_len = article_len\n", 746 | " if model_type == 'bert':\n", 747 | " self.backbone = self.initialize_bert()\n", 748 | " self.tokenizer = self.initialize_bert_tokenizer()\n", 749 | " elif model_type == 'roberta':\n", 750 | " self.backbone = self.initialize_roberta()\n", 751 | " self.tokenizer = self.initialize_roberta_tokenizer()\n", 752 | " else:\n", 753 | " raise Exception(f\"Wrong model_type argument: {model_type}\")\n", 754 | " self.backbone.resize_token_embeddings(EMBEDDINGS_ADDITIONAL + self.tokenizer.vocab_size)\n", 755 | "\n", 756 | " if additional_features:\n", 757 | " print('Adding additional features double fully connected nn')\n", 758 | " self.features = nn.Sequential(\n", 759 | " nn.Linear(num_features, FEATURES_NN_INTERMEDIATE),\n", 760 | " nn.LeakyReLU(),\n", 761 | " nn.Dropout(FEATURES_DROPOUT),\n", 762 | " nn.Linear(FEATURES_NN_INTERMEDIATE, FEATURES_NN_OUT)\n", 763 | " )\n", 764 | " else:\n", 765 | " self.features = None\n", 766 | "\n", 767 | " print('Initialize backbone embeddings pulling')\n", 768 | "\n", 769 | " def backbone_forward(input_ids, attention_mask, token_type_ids, position_ids):\n", 770 | " return self.backbone(\n", 771 | " input_ids=input_ids,\n", 772 | " attention_mask=attention_mask,\n", 773 | " token_type_ids=token_type_ids,\n", 774 | " position_ids=position_ids\n", 775 | " )\n", 776 | "\n", 777 | " self.encoder = lambda *args: backbone_forward(*args)[0]\n", 778 | "\n", 779 | " print('Initialize decoder')\n", 780 | " if additional_features:\n", 781 | " self.decoder = Classifier(768 + FEATURES_NN_OUT) # Default BERT output with additional features\n", 782 | " else:\n", 783 | " self.decoder = Classifier(768) # Default BERT output\n", 784 | "\n", 785 | " def expand_positional_embs_if_need(self):\n", 786 | " print('Expand positional embeddings if need')\n", 787 | " print('Positional embeddings', self.backbone.config.max_position_embeddings, self.article_len)\n", 788 | " if self.article_len > self.backbone.config.max_position_embeddings:\n", 789 | " old_maxlen = self.backbone.config.max_position_embeddings\n", 790 | " old_w = self.backbone.embeddings.position_embeddings.weight\n", 791 | " logging.info(f'Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}')\n", 792 | " self.backbone.embeddings.position_embeddings = nn.Embedding(\n", 793 | " self.article_len, self.backbone.config.hidden_size\n", 794 | " )\n", 795 | " self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w)\n", 796 | " self.backbone.config.max_position_embeddings = self.article_len\n", 797 | " print('New positional embeddings', self.backbone.config.max_position_embeddings)\n", 798 | "\n", 799 | " @staticmethod\n", 800 | " def initialize_bert():\n", 801 | " return BertModel.from_pretrained(\n", 802 | " \"bert-base-uncased\", output_hidden_states=False\n", 803 | " )\n", 804 | "\n", 805 | " @staticmethod\n", 806 | " def initialize_bert_tokenizer():\n", 807 | " tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n", 808 | " tokenizer.BOS = \"[CLS]\"\n", 809 | " tokenizer.EOS = \"[SEP]\"\n", 810 | " tokenizer.PAD = \"[PAD]\"\n", 811 | " return tokenizer\n", 812 | "\n", 813 | " @staticmethod\n", 814 | " def initialize_roberta():\n", 815 | " backbone = RobertaModel.from_pretrained(\n", 816 | " 'roberta-base', output_hidden_states=False\n", 817 | " )\n", 818 | " print('initialize token type emb, by default roberta doesnt have it')\n", 819 | " backbone.config.type_vocab_size = 2\n", 820 | " backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size)\n", 821 | " backbone.embeddings.token_type_embeddings.weight.data.normal_(\n", 822 | " mean=0.0, std=backbone.config.initializer_range\n", 823 | " )\n", 824 | " return backbone\n", 825 | "\n", 826 | " @staticmethod\n", 827 | " def initialize_roberta_tokenizer():\n", 828 | " tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)\n", 829 | " tokenizer.BOS = \"\"\n", 830 | " tokenizer.EOS = \"\"\n", 831 | " tokenizer.PAD = \"\"\n", 832 | " return tokenizer\n", 833 | "\n", 834 | " def save(self, save_filename):\n", 835 | " \"\"\" Save model in filename\n", 836 | "\n", 837 | " :param save_filename: str\n", 838 | " \"\"\"\n", 839 | " state = dict(\n", 840 | " encoder_dict=self.backbone.state_dict(),\n", 841 | " decoder_dict=self.decoder.state_dict()\n", 842 | " )\n", 843 | " if self.features:\n", 844 | " state['features_dict'] = self.features.state_dict()\n", 845 | " models_folder = os.path.expanduser(cfg.weights_path)\n", 846 | " if not os.path.exists(models_folder):\n", 847 | " os.makedirs(models_folder)\n", 848 | " torch.save(state, f\"{models_folder}/{save_filename}.pth\")\n", 849 | "\n", 850 | " def load(self, load_filename):\n", 851 | " path = f\"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth\"\n", 852 | " state = torch.load(path, map_location=lambda storage, location: storage)\n", 853 | " self.backbone.load_state_dict(state['encoder_dict'])\n", 854 | " self.decoder.load_state_dict(state['decoder_dict'])\n", 855 | " if self.features:\n", 856 | " self.features.load_state_dict(state['features_dict'])\n", 857 | "\n", 858 | " def froze_backbone(self, froze_strategy):\n", 859 | " if froze_strategy == 'froze_all':\n", 860 | " for param in self.backbone.parameters():\n", 861 | " param.requires_grad_(False)\n", 862 | "\n", 863 | " elif froze_strategy == 'unfroze_last':\n", 864 | " for name, param in self.backbone.named_parameters():\n", 865 | " param.requires_grad_(\n", 866 | " 'encoder.layer.11' in name or\n", 867 | " 'encoder.layer.10' in name or\n", 868 | " 'encoder.layer.9' in name\n", 869 | " )\n", 870 | "\n", 871 | " elif froze_strategy == 'unfroze_all':\n", 872 | " for param in self.backbone.parameters():\n", 873 | " param.requires_grad_(True)\n", 874 | "\n", 875 | " else:\n", 876 | " raise Exception(f'Unsupported froze strategy {froze_strategy}')\n", 877 | "\n", 878 | " def unfroze_head(self):\n", 879 | " for param in self.decoder.parameters():\n", 880 | " param.requires_grad_(True)\n", 881 | "\n", 882 | " def forward(self, input_ids, attention_mask, token_type_ids, input_features=None):\n", 883 | " \"\"\"\n", 884 | " :param input_ids: torch.Size([batch_size, article_len])\n", 885 | " Indices of input sequence tokens in the vocabulary.\n", 886 | " :param attention_mask: torch.Size([batch_size, article_len])\n", 887 | " Mask to avoid performing attention on padding token indices.\n", 888 | " Mask values selected in `[0, 1]`:\n", 889 | " - 1 for tokens that are **not masked**,\n", 890 | " - 0 for tokens that are **masked**.\n", 891 | " :param token_type_ids: torch.Size([batch_size, article_len])\n", 892 | " Segment token indices to indicate first and second portions of the inputs.\n", 893 | " Indices are selected in `[0, 1]`:\n", 894 | " - 0 corresponds to a *sentence A* token,\n", 895 | " - 1 corresponds to a *sentence B* token.\n", 896 | " :return: scores | torch.Size([batch_size, summary_len])\n", 897 | " \"\"\"\n", 898 | "\n", 899 | " # The output of [CLS] is inferred by all other words in this sentence.\n", 900 | " # This makes [CLS] a good representation for sentence-level classification.\n", 901 | " cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS))\n", 902 | " print(f'cls_mask {cls_mask.shape}')\n", 903 | "\n", 904 | " # Indices of positions of each input sequence tokens in the position embeddings.\n", 905 | " # position ids | torch.Size([batch_size, article_len])\n", 906 | " pos_ids = torch.arange(\n", 907 | " 0,\n", 908 | " self.article_len,\n", 909 | " dtype=torch.long,\n", 910 | " device=input_ids.device\n", 911 | " ).unsqueeze(0).repeat(len(input_ids), 1)\n", 912 | " print(f'pos_ids {pos_ids.shape}')\n", 913 | " # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])\n", 914 | " # for each word in the input, the BERT base internally creates a 768-dimensional output,\n", 915 | " # but for tasks like classification, we do not actually require the output for all the embeddings.\n", 916 | " # So by default, BERT considers only the output corresponding to the first token [CLS]\n", 917 | " # and drops the output vectors corresponding to all the other tokens.\n", 918 | " enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)\n", 919 | "\n", 920 | " if self.features:\n", 921 | " out_features = self.features(input_features)\n", 922 | " scores = self.decoder(torch.cat([enc_output[cls_mask], out_features], dim=-1))\n", 923 | " else:\n", 924 | " print('enc_output', enc_output.shape)\n", 925 | " print('enc_output[cls_mask]', enc_output[cls_mask].shape)\n", 926 | " scores = self.decoder(enc_output[cls_mask])\n", 927 | " print('scores', scores.shape)\n", 928 | "\n", 929 | " return scores\n", 930 | "\n", 931 | " def evaluate(self, input_ids, attention_mask, token_type_ids, input_features=None):\n", 932 | " \"\"\"See forward for parameters and output description\"\"\"\n", 933 | "\n", 934 | " # The output of [CLS] is inferred by all other words in this sentence.\n", 935 | " # This makes [CLS] a good representation for sentence-level classification.\n", 936 | " cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS))\n", 937 | "\n", 938 | " # position ids | torch.Size([batch_size, article_len])\n", 939 | " pos_ids = torch.arange(\n", 940 | " 0,\n", 941 | " self.article_len,\n", 942 | " dtype=torch.long,\n", 943 | " device=input_ids.device\n", 944 | " ).unsqueeze(0).repeat(len(input_ids), 1)\n", 945 | "\n", 946 | " # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])\n", 947 | " enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)\n", 948 | "\n", 949 | " scores = []\n", 950 | " for eo, cm in zip(enc_output, cls_mask):\n", 951 | " if self.features:\n", 952 | " out_features = self.features(input_features)\n", 953 | " score = self.decoder.evaluate(torch.cat([eo[cm], out_features], dim=-1))\n", 954 | " else:\n", 955 | " score = self.decoder.evaluate(eo[cm])\n", 956 | " scores.append(score)\n", 957 | " return scores\n", 958 | "\n", 959 | "\n", 960 | "class Classifier(nn.Module):\n", 961 | " def __init__(self, hidden_size):\n", 962 | " super(Classifier, self).__init__()\n", 963 | " self.dropout = nn.Dropout(DECODER_DROPOUT)\n", 964 | " self.linear = nn.Linear(hidden_size, 1)\n", 965 | " self.sigmoid = nn.Sigmoid()\n", 966 | "\n", 967 | " def forward(self, x):\n", 968 | " return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1))\n", 969 | "\n", 970 | " def evaluate(self, x):\n", 971 | " return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1))\n", 972 | "\n", 973 | "\n", 974 | "def create_model(model_type, froze_strategy, article_len, additional_features):\n", 975 | " model = Summarizer(model_type, article_len, additional_features)\n", 976 | " model.expand_positional_embs_if_need()\n", 977 | " # Load intermediate model\n", 978 | " # model.load('temp')\n", 979 | " model.froze_backbone(froze_strategy)\n", 980 | " model.unfroze_head()\n", 981 | " if additional_features:\n", 982 | " print('Parameters for features NN', sum(p.numel() for p in model.features.parameters() if p.requires_grad))\n", 983 | " print('Parameters for backbone', sum(p.numel() for p in model.backbone.parameters() if p.requires_grad))\n", 984 | " print('Parameters for classifier', sum(p.numel() for p in model.decoder.parameters() if p.requires_grad))\n", 985 | " return model" 986 | ], 987 | "metadata": { 988 | "collapsed": false, 989 | "pycharm": { 990 | "name": "#%%\n" 991 | } 992 | } 993 | }, 994 | { 995 | "cell_type": "markdown", 996 | "source": [ 997 | "# Preprocessing text for BERT" 998 | ], 999 | "metadata": { 1000 | "collapsed": false, 1001 | "pycharm": { 1002 | "name": "#%% md\n" 1003 | } 1004 | } 1005 | }, 1006 | { 1007 | "cell_type": "code", 1008 | "execution_count": null, 1009 | "outputs": [], 1010 | "source": [ 1011 | "# TODO: likely incorrect training scheme\n", 1012 | "# BERT was pretrained to support only a single [CLS] token in the input.\n", 1013 | "# [SEP] and token_type_ids are used to process two connected sentences in a way:\n", 1014 | "# [CLS] ... sent1 ... [SEP] ... sent2 ... [SEP] [PAD] ... with token_type_ids\n", 1015 | "# 0 ... ... ... ... 0 1 ... ... ... ... ... 1 ... ...\n", 1016 | "# Current learning strategy may produce incorrect results.\n", 1017 | "# We can concat max N sentences into a single input and use average score for them.\n", 1018 | "\n", 1019 | "def preprocess_text(text, max_len, tokenizer):\n", 1020 | " \"\"\"\n", 1021 | " Preprocess text for BERT / ROBERTA model.\n", 1022 | " NOTE: not all the text can be processed because of max_len.\n", 1023 | " IMPORTANT: preprocessed text may contain more than one [CLS] token!!! and more than two sentences!!!\n", 1024 | " :param text: list(list(str))\n", 1025 | " :param max_len: maximum length of preprocessing\n", 1026 | " :param tokenizer: BERT or ROBERTA tokenizer\n", 1027 | " :return:\n", 1028 | " ids | tokenized ids of length max_len, 0 if padding\n", 1029 | " attention_mask | list(str) 1 if real token, not padding\n", 1030 | " token_type_ids | 0-1 for different sentences\n", 1031 | " n_sents | number of actual sentences encoded\n", 1032 | " \"\"\"\n", 1033 | " sents = [\n", 1034 | " [tokenizer.BOS] + tokenizer.tokenize(sent) + [tokenizer.EOS] for sent in text\n", 1035 | " ]\n", 1036 | " logging.debug(f'sents {sents}')\n", 1037 | " ids, token_type_ids, segment_signature = [], [], 0\n", 1038 | " n_sents = 0\n", 1039 | " for i, s in enumerate(sents):\n", 1040 | " logging.debug(f'sentence {i} {s}')\n", 1041 | " logging.debug(f'ids {len(ids)}')\n", 1042 | " logging.debug(f'segments {len(token_type_ids)}')\n", 1043 | " logging.debug(f'segment_signature {segment_signature}')\n", 1044 | " if len(ids) + len(s) <= max_len:\n", 1045 | " n_sents += 1\n", 1046 | " ids.extend(tokenizer.convert_tokens_to_ids(s))\n", 1047 | " token_type_ids.extend([segment_signature] * len(s))\n", 1048 | " segment_signature = (segment_signature + 1) % 2\n", 1049 | " else:\n", 1050 | " logging.debug(f'break, len(s)={len(s)}')\n", 1051 | " break\n", 1052 | " attention_mask = [1] * len(ids)\n", 1053 | "\n", 1054 | " logging.debug('Padding data')\n", 1055 | " pad_len = max(0, max_len - len(ids))\n", 1056 | " ids += [get_token_id(tokenizer, tokenizer.PAD)] * pad_len\n", 1057 | " attention_mask += [0] * pad_len\n", 1058 | " token_type_ids += [segment_signature] * pad_len\n", 1059 | " assert len(ids) == len(attention_mask)\n", 1060 | " assert len(ids) == len(token_type_ids)\n", 1061 | " return ids, attention_mask, token_type_ids, n_sents" 1062 | ], 1063 | "metadata": { 1064 | "collapsed": false, 1065 | "pycharm": { 1066 | "name": "#%%\n", 1067 | "is_executing": true 1068 | } 1069 | } 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": null, 1074 | "outputs": [], 1075 | "source": [ 1076 | "print('Inspect preprocessing text for BERT')\n", 1077 | "\n", 1078 | "tokenizer = Summarizer.initialize_bert_tokenizer()\n", 1079 | "text = list(train_df.head(5)['sentence'])\n", 1080 | "input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(text, 512, tokenizer)\n", 1081 | "\n", 1082 | "print(f'input_ids {input_ids}')\n", 1083 | "print(f'attention_mask {attention_mask}')\n", 1084 | "print(f'token_type_ids {token_type_ids}')\n", 1085 | "print(f'n_sents {n_sents}')" 1086 | ], 1087 | "metadata": { 1088 | "collapsed": false, 1089 | "pycharm": { 1090 | "name": "#%%\n" 1091 | } 1092 | } 1093 | }, 1094 | { 1095 | "cell_type": "markdown", 1096 | "source": [ 1097 | "## Train and evaluate functions" 1098 | ], 1099 | "metadata": { 1100 | "collapsed": false, 1101 | "pycharm": { 1102 | "name": "#%% md\n" 1103 | } 1104 | } 1105 | }, 1106 | { 1107 | "cell_type": "code", 1108 | "execution_count": null, 1109 | "outputs": [], 1110 | "source": [ 1111 | "import torch.nn as nn\n", 1112 | "from torch.optim.optimizer import Optimizer\n", 1113 | "import math\n", 1114 | "\n", 1115 | "\n", 1116 | "def get_enc_lr(optimizer):\n", 1117 | " return optimizer.param_groups[0]['lr']\n", 1118 | "\n", 1119 | "\n", 1120 | "def get_dec_lr(optimizer):\n", 1121 | " return optimizer.param_groups[1]['lr']\n", 1122 | "\n", 1123 | "\n", 1124 | "def backward_step(loss: torch.Tensor, model: nn.Module, clip: float):\n", 1125 | " loss.backward()\n", 1126 | " total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n", 1127 | " return total_norm\n", 1128 | "\n", 1129 | "\n", 1130 | "def batch_to_device(batch, additional_features, device):\n", 1131 | " batch_on_device = [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]\n", 1132 | " if additional_features:\n", 1133 | " input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_on_device\n", 1134 | " input_features = torch.cat(input_features).to(device)\n", 1135 | " else:\n", 1136 | " input_ids, attention_mask, token_type_ids, target_scores = batch_on_device\n", 1137 | " input_features = None\n", 1138 | " return input_ids, attention_mask, token_type_ids, target_scores, input_features\n", 1139 | "\n", 1140 | "\n", 1141 | "def train_fun(\n", 1142 | " model,\n", 1143 | " dataloader,\n", 1144 | " optimizer,\n", 1145 | " scheduler,\n", 1146 | " criter,\n", 1147 | " device,\n", 1148 | " writer,\n", 1149 | " additional_features\n", 1150 | "):\n", 1151 | " # Put the model into training mode. Don't be mislead--the call to\n", 1152 | " # `train` just changes the *mode*, it doesn't *perform* the training.\n", 1153 | " # `dropout` and `batchnorm` layers behave differently during training\n", 1154 | " # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)\n", 1155 | " model.train()\n", 1156 | " loss_val = 0\n", 1157 | " target_val = 0\n", 1158 | " mean_sents = 0\n", 1159 | " szs = 0\n", 1160 | "\n", 1161 | " for idx_batch, batch in enumerate(dataloader):\n", 1162 | " input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_to_device(\n", 1163 | " batch, additional_features, device\n", 1164 | " )\n", 1165 | "\n", 1166 | " sizes = [dc.shape[0] for dc in target_scores]\n", 1167 | " mean_sents += sum(sizes)\n", 1168 | " szs += len(sizes)\n", 1169 | "\n", 1170 | " # forward pass\n", 1171 | " print(f'batch {idx_batch}')\n", 1172 | " print('forward')\n", 1173 | " print(f'input_ids {input_ids.shape}')\n", 1174 | " print(f'attention_mask {attention_mask.shape}')\n", 1175 | " print(f'token_type_ids {token_type_ids.shape}')\n", 1176 | " if additional_features:\n", 1177 | " print(f'input_features {input_features.shape}')\n", 1178 | "\n", 1179 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n", 1180 | " print(f'models_scores {model_scores.shape}')\n", 1181 | " print(f'target_scores {len(target_scores)}')\n", 1182 | "\n", 1183 | " target_scores = torch.cat(target_scores).to(device)\n", 1184 | " target_val += sum(target_scores) / len(target_scores)\n", 1185 | " try:\n", 1186 | " # loss\n", 1187 | " loss = criter(model_scores, target_scores, )\n", 1188 | " loss_val += loss.item()\n", 1189 | " print(f'loss {loss}')\n", 1190 | " except Exception:\n", 1191 | " print(idx_batch, model_scores.shape, target_scores.shape, token_type_ids)\n", 1192 | " return\n", 1193 | "\n", 1194 | " # backward\n", 1195 | " print('backward')\n", 1196 | " grad_norm = backward_step(loss, model, optimizer.clip_value)\n", 1197 | " grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm\n", 1198 | "\n", 1199 | " # record a loss value\n", 1200 | " print(f'{idx_batch} / {len(dataloader)} train loss {loss.item()}')\n", 1201 | " print(f'\\r{idx_batch} / {len(dataloader)} train loss {loss.item()}', end='')\n", 1202 | " writer.add_scalar(f\"Train/loss\", loss.item(), writer.train_step)\n", 1203 | " writer.add_scalar(\"Train/grad_norm\", grad_norm, writer.train_step)\n", 1204 | " writer.add_scalar(\"Train/lr_enc\", get_enc_lr(optimizer), writer.train_step)\n", 1205 | " writer.add_scalar(\"Train/lr_dec\", get_dec_lr(optimizer), writer.train_step)\n", 1206 | " writer.train_step += 1\n", 1207 | "\n", 1208 | " # make a gradient step\n", 1209 | " if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):\n", 1210 | " print('optimizer step')\n", 1211 | " optimizer.step()\n", 1212 | "\n", 1213 | " # Always clear any previously calculated gradients before performing a\n", 1214 | " # backward pass. PyTorch doesn't do this automatically because\n", 1215 | " # accumulating the gradients is \"convenient while training RNNs\".\n", 1216 | " # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)\n", 1217 | " optimizer.zero_grad()\n", 1218 | "\n", 1219 | " print('scheduler step')\n", 1220 | " scheduler.step()\n", 1221 | "\n", 1222 | " print(\"\\rTrain loss:\", loss_val / len(dataloader), f\"{100 * loss_val / target_val:.5f}%\")\n", 1223 | " # print(\"Train mean sent len:\", mean_sents / szs)\n", 1224 | "\n", 1225 | " # save model, just in case\n", 1226 | " model.save('temp')\n", 1227 | " print('save model to temp')\n", 1228 | "\n", 1229 | " return model, optimizer, scheduler, writer\n", 1230 | "\n", 1231 | "\n", 1232 | "def evaluate_fun(\n", 1233 | " model,\n", 1234 | " dataloader,\n", 1235 | " criter,\n", 1236 | " device,\n", 1237 | " writer,\n", 1238 | " additional_features\n", 1239 | "):\n", 1240 | " # Put the model in evaluation mode--the dropout layers behave differently\n", 1241 | " # during evaluation.\n", 1242 | " model.eval()\n", 1243 | " loss_val = 0\n", 1244 | " target_val = 0\n", 1245 | " mean_sents = 0\n", 1246 | " szs = 0\n", 1247 | "\n", 1248 | " for idx_batch, batch in enumerate(dataloader):\n", 1249 | " input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_to_device(\n", 1250 | " batch, additional_features, device\n", 1251 | " )\n", 1252 | " sizes = [dc.shape[0] for dc in target_scores]\n", 1253 | " mean_sents += sum(sizes)\n", 1254 | " szs += len(sizes)\n", 1255 | " target_scores = torch.cat(target_scores).to(device)\n", 1256 | "\n", 1257 | " # evaluate pass\n", 1258 | " print('evaluate')\n", 1259 | " print(f'input_ids {input_ids.shape}')\n", 1260 | " print(f'attention_mask {attention_mask.shape}')\n", 1261 | " print(f'token_type_ids {token_type_ids.shape}')\n", 1262 | " if additional_features:\n", 1263 | " print(f'input_features {input_features.shape}')\n", 1264 | " # Tell pytorch not to bother with constructing the compute graph during\n", 1265 | " # the forward pass, since this is only needed for backprop (training).\n", 1266 | " with torch.no_grad():\n", 1267 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n", 1268 | " print(f'model_scores {model_scores.shape}')\n", 1269 | " print(f'target_scores {len(target_scores)}')\n", 1270 | "\n", 1271 | " # loss\n", 1272 | " loss = criter(model_scores, target_scores, )\n", 1273 | " target_val += sum(target_scores) / len(target_scores)\n", 1274 | "\n", 1275 | " # record a loss value\n", 1276 | " print(f'{idx_batch} / {len(dataloader)} val loss {loss.item()}')\n", 1277 | " print(f'\\r{idx_batch} / {len(dataloader)} val loss {loss.item()}', end='')\n", 1278 | " loss_val += loss.item()\n", 1279 | " writer.add_scalar(f\"Eval/loss\", loss.item(), writer.train_step)\n", 1280 | " writer.train_step += 1\n", 1281 | "\n", 1282 | " print(\"\\rValidate loss:\", loss_val / len(dataloader), f\"{100 * loss_val / target_val:.5f}%\")\n", 1283 | " # print(\"Validate mean sent len:\", mean_sents / szs)\n", 1284 | "\n", 1285 | " # save model, just in case\n", 1286 | " model.save('validated_weights')\n", 1287 | " print('save model to validated_weights')\n", 1288 | "\n", 1289 | " return model" 1290 | ], 1291 | "metadata": { 1292 | "collapsed": false, 1293 | "pycharm": { 1294 | "name": "#%%\n" 1295 | } 1296 | } 1297 | }, 1298 | { 1299 | "cell_type": "markdown", 1300 | "source": [ 1301 | "# Dataset and dataloader classes" 1302 | ], 1303 | "metadata": { 1304 | "collapsed": false, 1305 | "pycharm": { 1306 | "name": "#%% md\n" 1307 | } 1308 | } 1309 | }, 1310 | { 1311 | "cell_type": "code", 1312 | "execution_count": null, 1313 | "outputs": [], 1314 | "source": [ 1315 | "from torch.utils.data import Dataset, DataLoader\n", 1316 | "\n", 1317 | "# Overlap helps to keep context and connect different parts of abstract\n", 1318 | "OVERLAP = 5\n", 1319 | "\n", 1320 | "\n", 1321 | "class TrainDataset(Dataset):\n", 1322 | " \"\"\" Custom Train Dataset for data with additional features.\n", 1323 | " First preprocess all the data and then give out the batches.\n", 1324 | " It implements overlapping between batches to keep context between train examples.\n", 1325 | " \"\"\"\n", 1326 | "\n", 1327 | " def __init__(self, dataframe, tokenizer, article_len, additional_features):\n", 1328 | " self.df = dataframe\n", 1329 | " self.data = []\n", 1330 | " self.pmids = list(set(dataframe['pmid'].values))\n", 1331 | " self.tokenizer = tokenizer\n", 1332 | " self.article_len = article_len\n", 1333 | " self.additional_features = additional_features\n", 1334 | " # Create a list of test inputs for each pmid\n", 1335 | " for pmid in tqdm(self.pmids):\n", 1336 | " ex = self.df[self.df['pmid'] == pmid]\n", 1337 | " text = ex['sentence'].values\n", 1338 | " features = np.nan_to_num(\n", 1339 | " ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',\n", 1340 | " 'mean_r_fig', 'mean_r_tab',\n", 1341 | " 'min_r_fig', 'min_r_tab',\n", 1342 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n", 1343 | " ) if additional_features else None\n", 1344 | " # Preprocessing BERT cannot encode all the text,\n", 1345 | " # only limited number of sentences per single model run is supported.\n", 1346 | " total_sents = 0\n", 1347 | " while total_sents < len(text):\n", 1348 | " offset = max(0, total_sents - OVERLAP)\n", 1349 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n", 1350 | " text[offset:], self.article_len, self.tokenizer\n", 1351 | " )\n", 1352 | " if n_sents <= OVERLAP:\n", 1353 | " total_sents += 1\n", 1354 | " continue\n", 1355 | " total_sents = offset + n_sents\n", 1356 | " target_scores = ex['score'].values[offset: offset + n_sents] / 100\n", 1357 | " input_features = features[offset: offset + n_sents] if additional_features else None\n", 1358 | " # print(f'Train dataset example {len(self.data)}\\n'\n", 1359 | " # f'input_ids {input_ids}\\n'\n", 1360 | " # f'attention_mask {attention_mask}\\n'\n", 1361 | " # f'token_type_ids {token_type_ids}\\n'\n", 1362 | " # f'target_scores {target_scores}\\n'\n", 1363 | " # f'features {input_features}')\n", 1364 | " if additional_features:\n", 1365 | " self.data.append((input_ids, attention_mask, token_type_ids, target_scores, input_features))\n", 1366 | " else:\n", 1367 | " self.data.append((input_ids, attention_mask, token_type_ids, target_scores))\n", 1368 | "\n", 1369 | " logging.info(f'Train dataset size {len(self.data)}')\n", 1370 | "\n", 1371 | " def __getitem__(self, idx):\n", 1372 | " return self.data[idx]\n", 1373 | "\n", 1374 | " def __len__(self):\n", 1375 | " return len(self.data)\n", 1376 | "\n", 1377 | "\n", 1378 | "class EvalDataset(Dataset):\n", 1379 | " \"\"\" Custom Valid/Test Dataset\n", 1380 | " \"\"\"\n", 1381 | "\n", 1382 | " def __init__(self, dataframe, tokenizer, article_len, additional_features):\n", 1383 | " self.df = dataframe\n", 1384 | " self.pmids = list(set(dataframe['pmid'].values))\n", 1385 | " logging.info(f'Eval dataset size {len(self.pmids)}')\n", 1386 | " self.tokenizer = tokenizer\n", 1387 | " self.article_len = article_len\n", 1388 | " self.additional_features = additional_features\n", 1389 | "\n", 1390 | " def __getitem__(self, idx):\n", 1391 | " pmid = self.pmids[idx]\n", 1392 | " ex = self.df[self.df['pmid'] == pmid]\n", 1393 | " paper = ex['sentence'].values\n", 1394 | " features = np.nan_to_num(\n", 1395 | " ex[['sent_id', 'sent_type', 'r_abs',\n", 1396 | " 'num_refs', 'mean_r_fig', 'mean_r_tab',\n", 1397 | " 'min_r_fig', 'min_r_tab',\n", 1398 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n", 1399 | " ) if self.additional_features else None\n", 1400 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n", 1401 | " paper, self.article_len, self.tokenizer\n", 1402 | " )\n", 1403 | "\n", 1404 | " # form target\n", 1405 | " target_scores = ex['score'].values[:n_sents] / 100\n", 1406 | " input_features = features[:n_sents] if self.additional_features else None\n", 1407 | " # print(f'Eval dataset example {idx}\\n'\n", 1408 | " # f'input_ids {input_ids}\\n'\n", 1409 | " # f'attention_mask {attention_mask}\\n'\n", 1410 | " # f'token_type_ids {token_type_ids}\\n'\n", 1411 | " # f'target_scores {target_scores}\\n'\n", 1412 | " # f'features {input_features}')\n", 1413 | " if self.additional_features:\n", 1414 | " return input_ids, attention_mask, token_type_ids, target_scores, input_features\n", 1415 | " else:\n", 1416 | " return input_ids, attention_mask, token_type_ids, target_scores\n", 1417 | "\n", 1418 | " def __len__(self):\n", 1419 | " return len(self.pmids)\n", 1420 | "\n", 1421 | "\n", 1422 | "def create_collate_fn(additional_features):\n", 1423 | " \"\"\"Create Function to pull batch for train / eval.\"\"\"\n", 1424 | "\n", 1425 | " def _collate_fn(batch_data):\n", 1426 | " \"\"\"\n", 1427 | " :param batch_data: list of `TrainDataset` or `EvalDataset` Examples\n", 1428 | " :return: one batch of data\n", 1429 | " \"\"\"\n", 1430 | " data = list(zip(*batch_data))\n", 1431 | " result = [\n", 1432 | " torch.tensor(data[0], dtype=torch.long),\n", 1433 | " torch.tensor(data[1], dtype=torch.long),\n", 1434 | " torch.tensor(data[2], dtype=torch.long),\n", 1435 | " [torch.tensor(e, dtype=torch.float) for e in data[3]]\n", 1436 | " ]\n", 1437 | " if additional_features:\n", 1438 | " result.append([torch.tensor(e, dtype=torch.float) for e in data[4]])\n", 1439 | " return result\n", 1440 | "\n", 1441 | " return _collate_fn\n", 1442 | "\n", 1443 | "\n", 1444 | "# The DataLoader needs to know our batch size for training, so we specify it\n", 1445 | "# here. For fine-tuning BERT on a specific task, the authors recommend a batch\n", 1446 | "# size of 16 or 32.\n", 1447 | "BATCH_SIZE = 16\n", 1448 | "\n", 1449 | "\n", 1450 | "def get_dataloaders(train, val, batch_size, article_len, tokenizer, additional_features):\n", 1451 | " logging.info('Creating train dataset...')\n", 1452 | " train_ds = TrainDataset(train, tokenizer, article_len, additional_features)\n", 1453 | "\n", 1454 | " logging.info('Applying loader functions to train...')\n", 1455 | " train_dl = DataLoader(\n", 1456 | " dataset=train_ds, batch_size=batch_size, shuffle=False,\n", 1457 | " pin_memory=True, collate_fn=create_collate_fn(additional_features), num_workers=1\n", 1458 | " )\n", 1459 | "\n", 1460 | " logging.info('Creating val dataset...')\n", 1461 | " val_ds = EvalDataset(val, tokenizer, article_len, additional_features)\n", 1462 | "\n", 1463 | " logging.info('Applying loader functions to val...')\n", 1464 | " val_dl = DataLoader(\n", 1465 | " dataset=val_ds, batch_size=batch_size, shuffle=False,\n", 1466 | " pin_memory=True, collate_fn=create_collate_fn(additional_features), num_workers=1\n", 1467 | " )\n", 1468 | "\n", 1469 | " return train_dl, val_dl" 1470 | ], 1471 | "metadata": { 1472 | "collapsed": false, 1473 | "pycharm": { 1474 | "name": "#%%\n" 1475 | } 1476 | } 1477 | }, 1478 | { 1479 | "cell_type": "code", 1480 | "execution_count": null, 1481 | "outputs": [], 1482 | "source": [ 1483 | "train_dl, _ = get_dataloaders(\n", 1484 | " train, val, 1, 512, create_model('bert', 'froze_all', 512, False).tokenizer, False\n", 1485 | ")\n", 1486 | "for batch in train_dl:\n", 1487 | " break\n", 1488 | "batch" 1489 | ], 1490 | "metadata": { 1491 | "collapsed": false, 1492 | "pycharm": { 1493 | "name": "#%%\n" 1494 | } 1495 | } 1496 | }, 1497 | { 1498 | "cell_type": "markdown", 1499 | "source": [ 1500 | "# Custom scheduler used in training" 1501 | ], 1502 | "metadata": { 1503 | "collapsed": false, 1504 | "pycharm": { 1505 | "name": "#%% md\n" 1506 | } 1507 | } 1508 | }, 1509 | { 1510 | "cell_type": "code", 1511 | "execution_count": null, 1512 | "outputs": [], 1513 | "source": [ 1514 | "from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR\n", 1515 | "\n", 1516 | "\n", 1517 | "class CustomScheduler(_LRScheduler):\n", 1518 | " timestep: int = 0\n", 1519 | "\n", 1520 | " def __init__(self, optimizer, gamma, warmup=None):\n", 1521 | " self.optimizer = optimizer\n", 1522 | " self.after_warmup = ExponentialLR(optimizer, gamma=gamma)\n", 1523 | " self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]\n", 1524 | " self.warmup = 0 if warmup is None else warmup\n", 1525 | " super(CustomScheduler, self).__init__(optimizer)\n", 1526 | "\n", 1527 | " def get_lr(self):\n", 1528 | " return [self.timestep * group_init_lr / self.warmup for group_init_lr in\n", 1529 | " self.initial_lrs] if self.timestep < self.warmup else self.after_warmup.get_lr()\n", 1530 | "\n", 1531 | " def step(self, epoch=None):\n", 1532 | " if self.timestep < self.warmup:\n", 1533 | " self.timestep += 1\n", 1534 | " super(CustomScheduler, self).step(epoch)\n", 1535 | " else:\n", 1536 | " self.after_warmup.step(epoch)\n", 1537 | "\n", 1538 | "\n", 1539 | "class NoamScheduler(_LRScheduler):\n", 1540 | " \"\"\"\n", 1541 | " Noam optimizer has a warm-up period and then an exponentially decaying learning.\n", 1542 | " This is the PyTorch implementation of optimizer introduced in the paper \"Attention is all you need\"\n", 1543 | " \"\"\"\n", 1544 | "\n", 1545 | " def __init__(self, optimizer, warmup):\n", 1546 | " assert warmup > 0\n", 1547 | " self.optimizer = optimizer\n", 1548 | " self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]\n", 1549 | " self.warmup = warmup\n", 1550 | " self.timestep = 0\n", 1551 | " super(NoamScheduler, self).__init__(optimizer)\n", 1552 | "\n", 1553 | " def get_lr(self):\n", 1554 | " noam_lr = self.get_noam_lr()\n", 1555 | " return [group_init_lr * noam_lr for group_init_lr in self.initial_lrs]\n", 1556 | "\n", 1557 | " def get_noam_lr(self):\n", 1558 | " return min(self.timestep ** -0.5, self.timestep * self.warmup ** -1.5)\n", 1559 | "\n", 1560 | " def step(self, epoch=None):\n", 1561 | " self.timestep += 1\n", 1562 | " super(NoamScheduler, self).step(epoch)" 1563 | ], 1564 | "metadata": { 1565 | "collapsed": false, 1566 | "pycharm": { 1567 | "name": "#%%\n" 1568 | } 1569 | } 1570 | }, 1571 | { 1572 | "cell_type": "markdown", 1573 | "source": [ 1574 | "# Prepare and configure training" 1575 | ], 1576 | "metadata": { 1577 | "collapsed": false, 1578 | "pycharm": { 1579 | "name": "#%% md\n" 1580 | } 1581 | } 1582 | }, 1583 | { 1584 | "cell_type": "code", 1585 | "execution_count": null, 1586 | "outputs": [], 1587 | "source": [ 1588 | "from torch.optim import AdamW\n", 1589 | "from torch.nn import MSELoss\n", 1590 | "from tensorboardX import SummaryWriter\n", 1591 | "\n", 1592 | "ENCODER_LEARNING_RATE = 0.0001\n", 1593 | "DECODER_LEARNING_RATE = 0.001\n", 1594 | "\n", 1595 | "WARMUP = 5\n", 1596 | "WEIGHT_DECAY = 0.01\n", 1597 | "CLIP_VALUE = 1.0\n", 1598 | "ACCUMULATION_INTERVAL = 1\n", 1599 | "\n", 1600 | "# Number of training epochs. The BERT authors recommend between 2 and 4.\n", 1601 | "# We chose to run for 4, but we'll see later that this may be over-fitting the\n", 1602 | "# training data.\n", 1603 | "EPOCHS_NUMBER = 10\n", 1604 | "\n", 1605 | "\n", 1606 | "def prepare_learning_tools(\n", 1607 | " model,\n", 1608 | " enc_lr=ENCODER_LEARNING_RATE,\n", 1609 | " dec_lr=DECODER_LEARNING_RATE,\n", 1610 | " warmup=WARMUP,\n", 1611 | " weight_decay=WEIGHT_DECAY,\n", 1612 | " clip_value=CLIP_VALUE,\n", 1613 | " accumulation_interval=ACCUMULATION_INTERVAL\n", 1614 | "):\n", 1615 | " # TODO fix for Roberta model\n", 1616 | " enc_parameters = [\n", 1617 | " param for name, param in model.named_parameters()\n", 1618 | " if param.requires_grad and name.startswith('bert.')\n", 1619 | " ]\n", 1620 | " dec_parameters = [\n", 1621 | " param for name, param in model.named_parameters() if param.requires_grad\n", 1622 | " ]\n", 1623 | " optimizer = AdamW([\n", 1624 | " dict(params=enc_parameters, lr=enc_lr),\n", 1625 | " dict(params=dec_parameters, lr=dec_lr),\n", 1626 | " ], weight_decay=weight_decay)\n", 1627 | " optimizer.clip_value = clip_value\n", 1628 | " optimizer.accumulation_interval = accumulation_interval\n", 1629 | "\n", 1630 | " scheduler = NoamScheduler(optimizer, warmup=warmup)\n", 1631 | " criter = MSELoss()\n", 1632 | "\n", 1633 | " return optimizer, scheduler, criter\n", 1634 | "\n", 1635 | "\n", 1636 | "def load_or_train_model(model, device, additional_features):\n", 1637 | " model_name = f'learn_simple_berta_{additional_features}.pth'.lower()\n", 1638 | " MODEL_PATH = f'{os.path.expanduser(cfg.weights_path)}/{model_name}'\n", 1639 | " ! rm {MODEL_PATH}\n", 1640 | "\n", 1641 | " if os.path.exists(MODEL_PATH):\n", 1642 | " logging.info(f'Loading model {MODEL_PATH}')\n", 1643 | " model.load(model_name)\n", 1644 | " model, device = setup_cuda_device(model)\n", 1645 | " return model\n", 1646 | " else:\n", 1647 | " logging.info('Create dataloaders...')\n", 1648 | " train_loader, valid_loader = get_dataloaders(\n", 1649 | " train, val, BATCH_SIZE, ARTICLE_LENGTH, model.tokenizer, additional_features\n", 1650 | " )\n", 1651 | "\n", 1652 | " writer = SummaryWriter(log_dir=os.path.expanduser(cfg.log_path))\n", 1653 | " writer.train_step, writer.eval_step = 0, 0\n", 1654 | "\n", 1655 | " optimizer, scheduler, criter = prepare_learning_tools(model)\n", 1656 | "\n", 1657 | " logging.info(f\"Start training {EPOCHS_NUMBER} epochs...\")\n", 1658 | " for epoch in tqdm(range(1, EPOCHS_NUMBER + 1)):\n", 1659 | " print(f'Epoch {epoch}')\n", 1660 | " model, optimizer, scheduler, writer = train_fun(\n", 1661 | " model, train_loader, optimizer, scheduler,\n", 1662 | " criter, device, writer, additional_features\n", 1663 | " )\n", 1664 | " model = evaluate_fun(\n", 1665 | " model, valid_loader, criter, device, writer, additional_features\n", 1666 | " )\n", 1667 | " logging.info(f\"Done training {EPOCHS_NUMBER} epochs...\")\n", 1668 | " logging.info(f'Save trained model to {model_name}')\n", 1669 | " model.save(model_name)\n", 1670 | "\n", 1671 | " return model" 1672 | ], 1673 | "metadata": { 1674 | "collapsed": false, 1675 | "pycharm": { 1676 | "name": "#%%\n" 1677 | } 1678 | } 1679 | }, 1680 | { 1681 | "cell_type": "markdown", 1682 | "source": [ 1683 | "# Create and train model" 1684 | ], 1685 | "metadata": { 1686 | "collapsed": false, 1687 | "pycharm": { 1688 | "name": "#%% md\n" 1689 | } 1690 | } 1691 | }, 1692 | { 1693 | "cell_type": "code", 1694 | "execution_count": null, 1695 | "outputs": [], 1696 | "source": [ 1697 | "ARTICLE_LENGTH = 512\n", 1698 | "\n", 1699 | "model = create_model(\"bert\", \"froze_all\", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)\n", 1700 | "model, device = setup_cuda_device(model)" 1701 | ], 1702 | "metadata": { 1703 | "collapsed": false, 1704 | "pycharm": { 1705 | "name": "#%%\n" 1706 | } 1707 | } 1708 | }, 1709 | { 1710 | "cell_type": "code", 1711 | "execution_count": null, 1712 | "outputs": [], 1713 | "source": [ 1714 | "model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)" 1715 | ], 1716 | "metadata": { 1717 | "collapsed": false, 1718 | "pycharm": { 1719 | "name": "#%%\n", 1720 | "is_executing": true 1721 | } 1722 | } 1723 | }, 1724 | { 1725 | "cell_type": "markdown", 1726 | "source": [ 1727 | "# Example of model predictions" 1728 | ], 1729 | "metadata": { 1730 | "collapsed": false, 1731 | "pycharm": { 1732 | "name": "#%% md\n" 1733 | } 1734 | } 1735 | }, 1736 | { 1737 | "cell_type": "code", 1738 | "execution_count": null, 1739 | "outputs": [], 1740 | "source": [ 1741 | "print('Prepare data for model')\n", 1742 | "ex = val[val['pmid'] == val['pmid'].values[0]]\n", 1743 | "print('ex', len(ex))\n", 1744 | "text = ex['sentence'].values\n", 1745 | "\n", 1746 | "input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n", 1747 | " text, ARTICLE_LENGTH, model.tokenizer\n", 1748 | ")\n", 1749 | "print('n_sents', n_sents)\n", 1750 | "res_sents = text[:n_sents]\n", 1751 | "scores = ex['score'].values[:n_sents]\n", 1752 | "\n", 1753 | "input_ids = torch.tensor([input_ids]).to(device)\n", 1754 | "attention_mask = torch.tensor([attention_mask]).to(device)\n", 1755 | "token_type_ids = torch.tensor([token_type_ids]).to(device)\n", 1756 | "\n", 1757 | "if ADDITIONAL_FEATURES:\n", 1758 | " features = np.nan_to_num(\n", 1759 | " ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',\n", 1760 | " 'mean_r_fig', 'mean_r_tab',\n", 1761 | " 'min_r_fig', 'min_r_tab',\n", 1762 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n", 1763 | " )\n", 1764 | " features = features[:n_sents]\n", 1765 | " features = [torch.tensor(e, dtype=torch.float) for e in features]\n", 1766 | " features = torch.stack(features).to(device)\n", 1767 | "else:\n", 1768 | " features = None\n", 1769 | "\n", 1770 | "print('Apply model')\n", 1771 | "model_scores = model(input_ids, attention_mask, token_type_ids, features)\n", 1772 | "\n", 1773 | "to_show_df = pd.DataFrame(\n", 1774 | " dict(sentence=res_sents, ideal_score=scores / 100, res_score=model_scores.cpu().detach().numpy())\n", 1775 | ")\n", 1776 | "display(to_show_df.head())\n", 1777 | "print(((to_show_df['ideal_score'].values - to_show_df['res_score'].values) ** 2).mean() ** 0.5)" 1778 | ], 1779 | "metadata": { 1780 | "collapsed": false, 1781 | "pycharm": { 1782 | "name": "#%%\n" 1783 | } 1784 | } 1785 | }, 1786 | { 1787 | "cell_type": "markdown", 1788 | "source": [ 1789 | "# Evaluate model performance" 1790 | ], 1791 | "metadata": { 1792 | "collapsed": false, 1793 | "pycharm": { 1794 | "name": "#%% md\n" 1795 | } 1796 | } 1797 | }, 1798 | { 1799 | "cell_type": "code", 1800 | "execution_count": null, 1801 | "outputs": [], 1802 | "source": [ 1803 | "if 'model' not in globals():\n", 1804 | " model = create_model(\"bert\", \"froze_all\", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)\n", 1805 | " model, device = setup_cuda_device(model)\n", 1806 | " model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)" 1807 | ], 1808 | "metadata": { 1809 | "collapsed": false, 1810 | "pycharm": { 1811 | "name": "#%%\n" 1812 | } 1813 | } 1814 | }, 1815 | { 1816 | "cell_type": "code", 1817 | "execution_count": null, 1818 | "outputs": [], 1819 | "source": [ 1820 | "print('Prepare refs_and_scores dataset')\n", 1821 | "REF_SCORES_PATH = os.path.expanduser(f\"{cfg.base_path}/refs_and_scores.csv\")\n", 1822 | "! rm {REF_SCORES_PATH}\n", 1823 | "\n", 1824 | "if os.path.exists(REF_SCORES_PATH):\n", 1825 | " final_ref_show_df = pd.read_csv(REF_SCORES_PATH)\n", 1826 | "else:\n", 1827 | " to_show_ref = pd.merge(train_df, ref_sents_df[['ref_pmid', 'sentence']],\n", 1828 | " left_on=['pmid'], right_on=['ref_pmid'])\n", 1829 | " to_show_ref = to_show_ref.rename(columns=dict(sentence_x='sentence', sentence_y='ref_sentence'))\n", 1830 | " to_show_ref = to_show_ref[['pmid', 'sentence', 'ref_sentence', 'score']]\n", 1831 | " final_ref_show_dic = dict(pmid=[], sentence=[], ref_sentence=[], score=[])\n", 1832 | " ite = [(pmid, sent) for pmid, sent in to_show_ref[['pmid', 'sentence']].values]\n", 1833 | "\n", 1834 | " for pmid, sent in tqdm(set(ite)):\n", 1835 | " refs_df = to_show_ref[(to_show_ref['pmid'] == pmid) & (to_show_ref['sentence'] == sent)]\n", 1836 | " final_ref_show_dic['pmid'].append(pmid)\n", 1837 | " final_ref_show_dic['sentence'].append(sent)\n", 1838 | " final_ref_show_dic['ref_sentence'].append(\" \".join(refs_df['ref_sentence'].values))\n", 1839 | " final_ref_show_dic['score'].append(refs_df['score'].values[0])\n", 1840 | " final_ref_show_df = pd.DataFrame(final_ref_show_dic)\n", 1841 | " final_ref_show_df.to_csv(REF_SCORES_PATH, index=False)\n", 1842 | "\n", 1843 | "display(final_ref_show_df)" 1844 | ], 1845 | "metadata": { 1846 | "collapsed": false, 1847 | "pycharm": { 1848 | "name": "#%%\n" 1849 | } 1850 | } 1851 | }, 1852 | { 1853 | "cell_type": "code", 1854 | "execution_count": null, 1855 | "outputs": [], 1856 | "source": [ 1857 | "print('Prepare dataset to estimate performance')\n", 1858 | "to_test = final_ref_show_df[final_ref_show_df['pmid'].isin(set(val['pmid'].values))]\n", 1859 | "if ADDITIONAL_FEATURES:\n", 1860 | " to_test = pd.merge(to_test, train_df[['pmid', 'sentence',\n", 1861 | " 'sent_id', 'sent_type', 'r_abs', 'num_refs',\n", 1862 | " 'mean_r_fig', 'mean_r_tab',\n", 1863 | " 'min_r_fig', 'min_r_tab',\n", 1864 | " 'max_r_fig', 'max_r_tab']],\n", 1865 | " left_on=['pmid', 'sentence'], right_on=['pmid', 'sentence'])\n", 1866 | "display(to_test)" 1867 | ], 1868 | "metadata": { 1869 | "collapsed": false, 1870 | "pycharm": { 1871 | "name": "#%%\n" 1872 | } 1873 | } 1874 | }, 1875 | { 1876 | "cell_type": "code", 1877 | "execution_count": null, 1878 | "outputs": [], 1879 | "source": [ 1880 | "print('Using model for predictions')\n", 1881 | "res = dict(pmid=[], sentence=[], ref_sentences=[], score=[], res_score=[])\n", 1882 | "\n", 1883 | "for pmid in tqdm(set(to_test['pmid'].values)):\n", 1884 | " ex = to_test[to_test['pmid'] == pmid]\n", 1885 | " text = ex['sentence'].values\n", 1886 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n", 1887 | " text, ARTICLE_LENGTH, model.tokenizer\n", 1888 | " )\n", 1889 | " res_sents = text[:n_sents]\n", 1890 | " scores = ex['score'].values[:n_sents] / 100\n", 1891 | " input_ids = torch.tensor([input_ids]).to(device)\n", 1892 | " attention_mask = torch.tensor([attention_mask]).to(device)\n", 1893 | " token_type_ids = torch.tensor([token_type_ids]).to(device)\n", 1894 | " if ADDITIONAL_FEATURES:\n", 1895 | " features = np.nan_to_num(\n", 1896 | " ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',\n", 1897 | " 'mean_r_fig', 'mean_r_tab',\n", 1898 | " 'min_r_fig', 'min_r_tab',\n", 1899 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n", 1900 | " )\n", 1901 | " features = features[:n_sents]\n", 1902 | " input_features = [torch.tensor(e, dtype=torch.float) for e in features]\n", 1903 | " input_features = torch.stack(input_features).to(device)\n", 1904 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n", 1905 | " else:\n", 1906 | " model_scores = model(input_ids, attention_mask, token_type_ids)\n", 1907 | " for sent, sc, res_sc in zip(res_sents, scores, model_scores.cpu().detach().numpy()):\n", 1908 | " res['pmid'].append(pmid)\n", 1909 | " res['sentence'].append(sent)\n", 1910 | " res['ref_sentences'].append(ex['ref_sentence'].values[0])\n", 1911 | " res['score'].append(sc)\n", 1912 | " res['res_score'].append(res_sc)\n", 1913 | "\n", 1914 | "res_df = pd.DataFrame(res)\n", 1915 | "display(res_df.head())\n", 1916 | "res_df.to_csv(f\"{cfg.base_path}/saved_example_refs.csv\")\n", 1917 | "\n", 1918 | "print('MSE score', ((res_df['score'].values - res_df['res_score'].values) ** 2).mean() ** 0.5)" 1919 | ], 1920 | "metadata": { 1921 | "collapsed": false, 1922 | "pycharm": { 1923 | "name": "#%%\n" 1924 | } 1925 | } 1926 | }, 1927 | { 1928 | "cell_type": "markdown", 1929 | "source": [ 1930 | "# Quality analysis" 1931 | ], 1932 | "metadata": { 1933 | "collapsed": false, 1934 | "pycharm": { 1935 | "name": "#%% md\n" 1936 | } 1937 | } 1938 | }, 1939 | { 1940 | "cell_type": "code", 1941 | "execution_count": null, 1942 | "outputs": [], 1943 | "source": [ 1944 | "if 'model' not in globals():\n", 1945 | " model = create_model(\"bert\", \"froze_all\", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)\n", 1946 | " model, device = setup_cuda_device(model)\n", 1947 | " model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)" 1948 | ], 1949 | "metadata": { 1950 | "collapsed": false, 1951 | "pycharm": { 1952 | "name": "#%%\n" 1953 | } 1954 | } 1955 | }, 1956 | { 1957 | "cell_type": "code", 1958 | "execution_count": null, 1959 | "outputs": [], 1960 | "source": [ 1961 | "print('Searching for review papers')\n", 1962 | "inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)\n", 1963 | "review_papers = list(set(ref_sents_df[ref_sents_df['ref_pmid'].isin(inter)]['pmid'].values))\n", 1964 | "print('Review papers', len(review_papers))" 1965 | ], 1966 | "metadata": { 1967 | "collapsed": false, 1968 | "pycharm": { 1969 | "name": "#%%\n" 1970 | } 1971 | } 1972 | }, 1973 | { 1974 | "cell_type": "code", 1975 | "execution_count": null, 1976 | "outputs": [], 1977 | "source": [ 1978 | "test_stat = dict(rev_pmid=[], sent_num=[], rouge=[], true_rouge=[], diff_papers=[])\n", 1979 | "\n", 1980 | "for rev_id in tqdm(review_papers):\n", 1981 | " paper_ref = sentences_df[sentences_df['pmid'] == rev_id]['sentence'].values\n", 1982 | " papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid'] == rev_id]['ref_pmid'].values))\n", 1983 | " result = {'pmid': [], 'sentence': [], 'score': []}\n", 1984 | " for paper_id in papers_to_check:\n", 1985 | " ex = test[test['pmid'] == paper_id]\n", 1986 | " text = ex['sentence'].values\n", 1987 | "\n", 1988 | " features = np.nan_to_num(\n", 1989 | " ex[['sent_id', 'sent_type', 'r_abs',\n", 1990 | " 'num_refs', 'mean_r_fig', 'mean_r_tab',\n", 1991 | " 'min_r_fig', 'min_r_tab',\n", 1992 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n", 1993 | " ) if ADDITIONAL_FEATURES else None\n", 1994 | " total_sents = 0\n", 1995 | " while total_sents < len(text):\n", 1996 | " offset = max(0, total_sents - OVERLAP)\n", 1997 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n", 1998 | " text[offset:], ARTICLE_LENGTH, model.tokenizer\n", 1999 | " )\n", 2000 | " if n_sents <= OVERLAP:\n", 2001 | " total_sents += 1\n", 2002 | " continue\n", 2003 | " old_total = total_sents\n", 2004 | " total_sents = offset + n_sents\n", 2005 | " input_ids = torch.tensor([input_ids]).to(device)\n", 2006 | " attention_mask = torch.tensor([attention_mask]).to(device)\n", 2007 | " token_type_ids = torch.tensor([token_type_ids]).to(device)\n", 2008 | " if ADDITIONAL_FEATURES:\n", 2009 | " input_features = [torch.tensor(e, dtype=torch.float) for e in features[offset:total_sents]]\n", 2010 | " input_features = torch.stack(input_features).to(device)\n", 2011 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n", 2012 | " else:\n", 2013 | " model_scores = model(input_ids, attention_mask, token_type_ids)\n", 2014 | "\n", 2015 | " result['pmid'].extend([paper_id] * (total_sents - old_total))\n", 2016 | " result['sentence'].extend(list(text[old_total:total_sents]))\n", 2017 | " result['score'].extend(list(model_scores.cpu().detach().numpy())[old_total - offset:])\n", 2018 | "\n", 2019 | " res_df = pd.DataFrame(result)\n", 2020 | " sorted_arr = sorted(list(res_df['score'].values))\n", 2021 | " for i in range(5, 103, 5):\n", 2022 | " if len(sorted_arr) < i:\n", 2023 | " break\n", 2024 | " threshold = sorted_arr[-i]\n", 2025 | " final_text = res_df[res_df['score'] >= threshold][['pmid', 'sentence']]\n", 2026 | " mean_score = 0\n", 2027 | " num = 0\n", 2028 | " for sent in final_text['sentence'].values:\n", 2029 | " for ref_sent in paper_ref:\n", 2030 | " try:\n", 2031 | " mean_score += get_rouge(sent, ref_sent)\n", 2032 | " num += 1\n", 2033 | " except Exception:\n", 2034 | " continue\n", 2035 | " mean_score /= num\n", 2036 | " real_score = get_rouge(\" \".join(final_text['sentence'].values), \" \".join(paper_ref))\n", 2037 | " test_stat['rev_pmid'].append(rev_id)\n", 2038 | " test_stat['sent_num'].append(i)\n", 2039 | " print(len(\" \".join(final_text['sentence'].values)), len(\" \".join(paper_ref)))\n", 2040 | "\n", 2041 | " test_stat['rouge'].append(mean_score)\n", 2042 | " test_stat['true_rouge'].append(real_score)\n", 2043 | " test_stat['diff_papers'].append(len(set(final_text['pmid'])))\n", 2044 | "\n", 2045 | "test_stat_df = pd.DataFrame(test_stat)\n", 2046 | "test_stat_df" 2047 | ], 2048 | "metadata": { 2049 | "collapsed": false, 2050 | "pycharm": { 2051 | "name": "#%%\n" 2052 | } 2053 | } 2054 | }, 2055 | { 2056 | "cell_type": "code", 2057 | "execution_count": null, 2058 | "outputs": [], 2059 | "source": [ 2060 | "print(*[len(arr) for key, arr in test_stat.items()])" 2061 | ], 2062 | "metadata": { 2063 | "collapsed": false, 2064 | "pycharm": { 2065 | "name": "#%%\n" 2066 | } 2067 | } 2068 | }, 2069 | { 2070 | "cell_type": "code", 2071 | "execution_count": null, 2072 | "outputs": [], 2073 | "source": [ 2074 | "test_stat_df.to_csv(f\"{cfg.base_path}/simple_right_test_on_review_{ADDITIONAL_FEATURES}.csv\", index=False)" 2075 | ], 2076 | "metadata": { 2077 | "collapsed": false, 2078 | "pycharm": { 2079 | "name": "#%%\n" 2080 | } 2081 | } 2082 | }, 2083 | { 2084 | "cell_type": "code", 2085 | "execution_count": null, 2086 | "outputs": [], 2087 | "source": [ 2088 | "rouge_means = []\n", 2089 | "rouge_err = []\n", 2090 | "papers_means = []\n", 2091 | "papers_err = []\n", 2092 | "\n", 2093 | "for i in range(5, 103, 5):\n", 2094 | " tmp = test_stat_df.groupby(['sent_num']).get_group(i)\n", 2095 | " rouge_means.append(tmp['rouge'].mean())\n", 2096 | " rouge_err.append(tmp['rouge'].std())\n", 2097 | " papers_means.append(tmp['diff_papers'].mean())\n", 2098 | " papers_err.append(tmp['diff_papers'].std())" 2099 | ], 2100 | "metadata": { 2101 | "collapsed": false, 2102 | "pycharm": { 2103 | "name": "#%%\n" 2104 | } 2105 | } 2106 | }, 2107 | { 2108 | "cell_type": "code", 2109 | "execution_count": null, 2110 | "outputs": [], 2111 | "source": [ 2112 | "plt.errorbar(list(range(5, 103, 5)), rouge_means, yerr=rouge_err, fmt='-o')\n", 2113 | "plt.title('Mean rouge value')\n", 2114 | "plt.show()" 2115 | ], 2116 | "metadata": { 2117 | "collapsed": false, 2118 | "pycharm": { 2119 | "name": "#%%\n" 2120 | } 2121 | } 2122 | }, 2123 | { 2124 | "cell_type": "code", 2125 | "execution_count": null, 2126 | "outputs": [], 2127 | "source": [ 2128 | "plt.errorbar(list(range(5, 103, 5)), papers_means, yerr=papers_err, fmt='-o')\n", 2129 | "plt.title('Mean number of papers')\n", 2130 | "plt.show()" 2131 | ], 2132 | "metadata": { 2133 | "collapsed": false, 2134 | "pycharm": { 2135 | "name": "#%%\n" 2136 | } 2137 | } 2138 | }, 2139 | { 2140 | "cell_type": "markdown", 2141 | "source": [ 2142 | "# Compare models with and without features" 2143 | ], 2144 | "metadata": { 2145 | "collapsed": false, 2146 | "pycharm": { 2147 | "name": "#%% md\n" 2148 | } 2149 | } 2150 | }, 2151 | { 2152 | "cell_type": "code", 2153 | "execution_count": null, 2154 | "outputs": [], 2155 | "source": [ 2156 | "import pandas as pd\n", 2157 | "\n", 2158 | "df1 = pd.read_csv(f\"{cfg.base_path}/simple_right_test_on_review_{False}.csv\", index=False)\n", 2159 | "df1 = df1.assign(model=['BERTSUM'] * len(df1))\n", 2160 | "df2 = pd.read_csv(f\"{cfg.base_path}/simple_right_test_on_review_{True}.csv\", index=False)\n", 2161 | "df2 = df1.assign(model=['BERTSUM with features'] * len(df1))\n", 2162 | "draw_df = pd.concat([df1, df2])" 2163 | ], 2164 | "metadata": { 2165 | "collapsed": false, 2166 | "pycharm": { 2167 | "name": "#%%\n" 2168 | } 2169 | } 2170 | }, 2171 | { 2172 | "cell_type": "code", 2173 | "execution_count": null, 2174 | "outputs": [], 2175 | "source": [ 2176 | "import seaborn as sns\n", 2177 | "\n", 2178 | "sns.catplot(x=\"sent_num\", y=\"rouge\", kind=\"box\", hue='model', aspect=1.7, color='lightblue',\n", 2179 | " data=draw_df).set_axis_labels(\"Number of sentences\", \"ROUGE, %\")\n", 2180 | "plt.show()" 2181 | ], 2182 | "metadata": { 2183 | "collapsed": false, 2184 | "pycharm": { 2185 | "name": "#%%\n" 2186 | } 2187 | } 2188 | }, 2189 | { 2190 | "cell_type": "code", 2191 | "execution_count": null, 2192 | "outputs": [], 2193 | "source": [], 2194 | "metadata": { 2195 | "collapsed": false, 2196 | "pycharm": { 2197 | "name": "#%%\n" 2198 | } 2199 | } 2200 | } 2201 | ], 2202 | "metadata": { 2203 | "kernelspec": { 2204 | "display_name": "Python 3 (ipykernel)", 2205 | "language": "python", 2206 | "name": "python3" 2207 | }, 2208 | "language_info": { 2209 | "codemirror_mode": { 2210 | "name": "ipython", 2211 | "version": 3 2212 | }, 2213 | "file_extension": ".py", 2214 | "mimetype": "text/x-python", 2215 | "name": "python", 2216 | "nbconvert_exporter": "python", 2217 | "pygments_lexer": "ipython3", 2218 | "version": "3.10.4" 2219 | } 2220 | }, 2221 | "nbformat": 4, 2222 | "nbformat_minor": 4 2223 | } --------------------------------------------------------------------------------