├── __init__.py
├── review
├── __init__.py
├── train
│ ├── __init__.py
│ ├── README.md
│ ├── topic_summarization.ipynb
│ └── test_new_algo.ipynb
├── config.py
├── README.md
├── text.py
├── model.py
└── dataset
│ ├── prepare_dataset.ipynb
│ └── analyze_archive.ipynb
├── .gitignore
├── resources
├── mail.svg
└── github.svg
├── AUTHORS.md
├── README.md
├── environment.yml
└── LICENSE
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/review/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/review/train/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.iml
3 |
4 | */.ipynb_checkpoints
5 | __pycache__
6 | .pyc
7 |
--------------------------------------------------------------------------------
/review/config.py:
--------------------------------------------------------------------------------
1 | # Paths config
2 | base_path = "~/review"
3 | dataset_path = f"{base_path}/dataset"
4 | weights_path = f"{base_path}/weights"
5 | log_path = f"{base_path}/logs"
6 |
7 | # Model lookup is in /model and ~/.pubtrends/model folders
8 | model_name = 'learn_simple_berta.pth'
9 |
--------------------------------------------------------------------------------
/resources/mail.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/resources/github.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/AUTHORS.md:
--------------------------------------------------------------------------------
1 | Authors
2 | =======
3 |
4 | * [![icon][mail]](mailto:nikiannanik@gmail.com)
5 | [![icon][github]](https://github.com/Javanochka)
6 | Anna Nikiforovskaya
7 | * [![icon][mail]](mailto:alexey@shpilman.com)
8 | [![icon][github]](https://github.com/ashpilman)
9 | Aleksei Shpilman
10 | * [![icon][mail]](mailto:oleg.shpynov@gmail.com)
11 | [![icon][github]](https://github.com/olegs)
12 | Oleg Shpynov
13 |
14 |
15 | Icons by Feather
16 |
17 | [mail]: resources/mail.svg
18 | [github]: resources/github.svg
19 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://confluence.jetbrains.com/display/ALL/JetBrains+on+GitHub)
2 | [](https://doi.org/10.5281/zenodo.15131486)
3 |
4 | # pubtrends-review
5 |
6 | This is a source code for the paper: ["Automatic generation of reviews of scientific paper"](https://arxiv.org/abs/2010.04147).\
7 | Was accepted as a full paper at [International Conference on Machine Learning and Applications (ICMLA) 2019](https://ieeexplore.ieee.org/document/9356351).
8 |
9 | Refer to [review/README.md](review/README.md) for instruction on dataset creation and model training.
10 |
11 | # Authors
12 | See [AUTHORS.md](AUTHORS.md) for a list of authors and contributors.
13 |
14 |
--------------------------------------------------------------------------------
/review/README.md:
--------------------------------------------------------------------------------
1 | * Download [PubMedCentral Author Manuscript Collection](https://ftp.ncbi.nlm.nih.gov/pub/pmc/manuscript/xml/)
2 | into `~/pmc_dataset` folder.\
3 | Expected folder content:
4 | ```
5 | author_manuscript_xml.PMC001xxxxxx.baseline.2022-06-16.filelist.csv
6 | author_manuscript_xml.PMC001xxxxxx.baseline.2022-06-16.filelist.txt
7 | author_manuscript_xml.PMC001xxxxxx.baseline.2022-06-16.tar.gz
8 | ...
9 | ```
10 | * Extract all downloaded `tar.gz` files.\
11 | Expected extracted folders:
12 | ```
13 | PMC001xxxxxx
14 | PMC002xxxxxx
15 | PMC003xxxxxx
16 | PMC004xxxxxx
17 | PMC005xxxxxx
18 | PMC006xxxxxx
19 | PMC007xxxxxx
20 | PMC008xxxxxx
21 | PMC009xxxxxx
22 | ```
23 | * Launch `dataset/analyze_archive.ipynb` to analyze downloaded archive.
24 | * Launch `dataset/prepare_dataset.ipynb` to create tables required for model training.
25 | * Refer to `train/README.md` for instruction on model training.
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | # Keep this file as small as possible! Only explicit imports here!
2 | name: pubtrends
3 | channels:
4 | - conda-forge
5 | - bokeh
6 | - pytorch
7 | dependencies:
8 | - bokeh=3.6.3
9 | - cachetools=5.5.2
10 | - celery=5.4.0
11 | - ipywidgets=8.1.5
12 | - holoviews=1.20.1
13 | - huggingface_hub=0.29.1
14 | - flask=3.1.0
15 | - matplotlib=3.10.1
16 | - nltk=3.9.1
17 | - pandas=2.2.3
18 | - parameterized=0.9.0
19 | - python=3.10.16
20 | - pytest=8.3.5
21 | - psycopg2=2.9.9
22 | - wordcloud=1.9.4
23 | - networkx=3.4.2
24 | - python-louvain=0.16
25 | - scikit-learn=1.6.1
26 | - pytorch=2.5.1
27 | - transformers=4.49.0
28 | - tensorboardx=2.6.2.2
29 | - psutil=7.0.0
30 | - unidecode=1.3.8
31 | - pip
32 | - pip:
33 | - email-validator==2.2.0
34 | - flask-admin==1.6.0
35 | - flask-security==5.6.0
36 | - flask-sqlalchemy==3.1.1
37 | - gunicorn==23.0.0
38 | - lazy==1.6
39 | - rouge==1.0.1
40 | - redis==5.2.1
41 | - vine==5.1.0
42 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 JetBrains-Research
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/review/train/README.md:
--------------------------------------------------------------------------------
1 | ## Instructions
2 |
3 | Here are the instructions on how to create the dataset and then how to run different models training on this dataset.
4 | Now, the easiest way to do it is to use python notebook `test_new_algo.ipynb`. There are 18 steps there.
5 |
6 | Soon the instructions for running everything from command line will be added.
7 |
8 | ### Create a dataset
9 |
10 | You can create a dataset with features or without them.
11 |
12 | To generate a dataset look at the steps 3-8 (+10 in case you want to divide the dataset on train-test-val on the spot).
13 | Some of them are not needed in case you create a dataset without features. The model with features is still being
14 | developed, you can try it, but now it works worse.
15 |
16 | ### Train the model
17 |
18 | Steps 1-2, 9, 10 (to get the dataset), 11-14 are needed.
19 |
20 | You can save the trained model at any time, using function `model.save("name_of_the_file")`
21 | Also, try different parameters for the model (you can either freeze some layers, or not, you can try a model with
22 | features).
23 |
24 | ### See how the model works
25 |
26 | This is what the step 15 is for. You can just try the model on different examples.
27 |
28 | ### Evaluation
29 |
30 | These are the last steps. 16 and 17 are to look at MSE score and to understand, how good it is compared to the dataset,
31 | while step 18 is to perform a test of summarizing several scientific papers into a review one.
32 |
33 | For topic summarization: Prepare test dataset by running `topic_summarization.ipynb`
34 |
--------------------------------------------------------------------------------
/review/text.py:
--------------------------------------------------------------------------------
1 | import nltk
2 |
3 |
4 | def split_text(text):
5 | return nltk.tokenize.sent_tokenize(text)
6 |
7 |
8 | def get_token_id(tokenizer, tkn):
9 | return tokenizer.convert_tokens_to_ids([tkn])[0]
10 |
11 |
12 | def preprocess_text(text, max_len, tokenizer):
13 | """
14 | Preprocess text for BERT / ROBERTA model.
15 | NOTE: not all the text can be processed because of max_len.
16 | IMPORTANT: preprocessed text may contain more than one [CLS] token!!! and more than two sentences!!!
17 | :param text: list(list(str))
18 | :param max_len: maximum length of preprocessing
19 | :param tokenizer: BERT or ROBERTA tokenizer
20 | :return:
21 | ids | tokenized ids of length max_len, 0 if padding
22 | attention_mask | list(str) 1 if real token, not padding
23 | token_type_ids | 0-1 for different sentences
24 | n_sents | number of actual sentences encoded
25 | """
26 | sents = [
27 | [tokenizer.BOS] + tokenizer.tokenize(sent) + [tokenizer.EOS] for sent in text
28 | ]
29 |
30 | ids, token_type_ids, segment_signature = [], [], 0
31 | n_sents = 0
32 | for i, s in enumerate(sents):
33 | if len(ids) + len(s) <= max_len:
34 | n_sents += 1
35 | ids.extend(tokenizer.convert_tokens_to_ids(s))
36 | token_type_ids.extend([segment_signature] * len(s))
37 | segment_signature = (segment_signature + 1) % 2
38 | else:
39 | break
40 | attention_mask = [1] * len(ids)
41 |
42 | pad_len = max(0, max_len - len(ids))
43 | ids += [get_token_id(tokenizer, tokenizer.PAD)] * pad_len
44 | attention_mask += [0] * pad_len
45 | token_type_ids += [segment_signature] * pad_len
46 | assert len(ids) == len(attention_mask)
47 | assert len(ids) == len(token_type_ids)
48 | return ids, attention_mask, token_type_ids, n_sents
49 |
50 |
51 | # Overlap helps to keep context and connect different parts of abstract
52 | OVERLAP = 5
53 |
54 |
55 | def text_to_data(text, max_len, tokenizer):
56 | """
57 | This is the main entry point, which will be called from PubTrends review feature.
58 | """
59 | text = split_text(text)
60 | total_sents = 0
61 | data = []
62 | while total_sents < len(text):
63 | offset = max(0, total_sents - OVERLAP)
64 | # Preprocessing BERT cannot encode all the text,
65 | # only limited number of sentences per single model run is supported.
66 | ids, attention_mask, token_type_ids, n_sents = \
67 | preprocess_text(text[offset:], max_len, tokenizer)
68 | if offset + n_sents <= total_sents:
69 | total_sents += 1
70 | continue
71 | data.append((ids, attention_mask, token_type_ids, offset, text[offset: offset + n_sents]))
72 | total_sents = offset + n_sents
73 | return data
74 |
--------------------------------------------------------------------------------
/review/train/topic_summarization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "# Notebook for testing model"
7 | ],
8 | "metadata": {
9 | "collapsed": false,
10 | "pycharm": {
11 | "name": "#%% md\n"
12 | }
13 | }
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": null,
18 | "metadata": {
19 | "pycharm": {
20 | "name": "#%%\n"
21 | }
22 | },
23 | "outputs": [],
24 | "source": [
25 | "import pandas as pd\n",
26 | "from review.train.preprocess import parse_sents, standardize\n",
27 | "from nltk.tokenize import sent_tokenize"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "pycharm": {
35 | "name": "#%%\n"
36 | }
37 | },
38 | "outputs": [],
39 | "source": [
40 | "def prepare_text(text):\n",
41 | " text = parse_sents(sent_tokenize(text))\n",
42 | " text = standardize(text)\n",
43 | " text = ' '.join(text)\n",
44 | " return text"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {
50 | "pycharm": {
51 | "name": "#%% md\n"
52 | }
53 | },
54 | "source": [
55 | "Prepare topic abstracts\n"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {
62 | "pycharm": {
63 | "name": "#%%\n"
64 | }
65 | },
66 | "outputs": [],
67 | "source": [
68 | "topics_df = pd.read_csv(\"topic/topic_abstracts.csv\")"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "metadata": {
75 | "pycharm": {
76 | "name": "#%%\n"
77 | }
78 | },
79 | "outputs": [],
80 | "source": [
81 | "topics_df = topics_df.groupby(by=['topic_id']).agg({'text': lambda x: ''.join(x)}).reset_index()\n",
82 | "topics_df = topics_df.rename(columns={'topic_id':'id', 'text': 'paper_top50'})"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": null,
88 | "metadata": {
89 | "pycharm": {
90 | "name": "#%%\n"
91 | }
92 | },
93 | "outputs": [],
94 | "source": [
95 | "topics_df"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {
102 | "pycharm": {
103 | "name": "#%%\n"
104 | }
105 | },
106 | "outputs": [],
107 | "source": [
108 | "def create_dummy_dataframe(df):\n",
109 | " df['paper_top50'] = df['paper_top50'].apply(lambda text: prepare_text(text))\n",
110 | " df['abstract'] = 'dummy abstract'\n",
111 | " df['gold_ids_top6'] = str([0])\n",
112 | " return df"
113 | ]
114 | },
115 | {
116 | "cell_type": "code",
117 | "execution_count": null,
118 | "metadata": {
119 | "pycharm": {
120 | "name": "#%%\n"
121 | }
122 | },
123 | "outputs": [],
124 | "source": [
125 | "test_df = create_dummy_dataframe(topics_df)"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {
132 | "pycharm": {
133 | "name": "#%%\n"
134 | }
135 | },
136 | "outputs": [],
137 | "source": [
138 | "test_df"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {
145 | "pycharm": {
146 | "name": "#%%\n"
147 | }
148 | },
149 | "outputs": [],
150 | "source": [
151 | "test_df.to_csv(\"data/pubmedtop50_test_topic.csv\")"
152 | ]
153 | }
154 | ],
155 | "metadata": {
156 | "kernelspec": {
157 | "display_name": "Python 3 (ipykernel)",
158 | "language": "python",
159 | "name": "python3"
160 | },
161 | "language_info": {
162 | "codemirror_mode": {
163 | "name": "ipython",
164 | "version": 3
165 | },
166 | "file_extension": ".py",
167 | "mimetype": "text/x-python",
168 | "name": "python",
169 | "nbconvert_exporter": "python",
170 | "pygments_lexer": "ipython3",
171 | "version": "3.10.4"
172 | }
173 | },
174 | "nbformat": 4,
175 | "nbformat_minor": 2
176 | }
--------------------------------------------------------------------------------
/review/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import torch
5 | import torch.nn as nn
6 | from transformers import BertModel, BertTokenizer
7 | from transformers import RobertaModel, RobertaTokenizer
8 |
9 | import review.config as cfg
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | EMBEDDINGS_ADDITIONAL = 20
14 |
15 | FEATURES_NUMBER = 10
16 | FEATURES_NN_INTERMEDIATE = 100
17 | FEATURES_NN_OUT = 50
18 | FEATURES_DROPOUT = 0.1
19 |
20 | DECODER_DROPOUT = 0.1
21 |
22 |
23 | def setup_cuda_device(model):
24 | logging.info('Setup single-device settings...')
25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 | model = model.to(device)
27 | return model, device
28 |
29 |
30 | def load_model(model_type, froze_strategy, article_len, features=False):
31 | logger.info(f'Loading model {model_type} {froze_strategy} {article_len} {features}')
32 | model = Summarizer(model_type, article_len, features)
33 | model.expand_posembs_ifneed()
34 | model.froze_backbone(froze_strategy)
35 | model.unfroze_head()
36 | logger.info(f'Parameters {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
37 | return model
38 |
39 |
40 | def get_token_id(tokenizer, tkn):
41 | return tokenizer.convert_tokens_to_ids([tkn])[0]
42 |
43 |
44 | class Summarizer(nn.Module):
45 | """
46 | This is the main summarization model.
47 | See https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_tf_bert.py
48 | It operates the same input format as original BERT used underneath.
49 | See forward, evaluate params description.
50 | """
51 | enc_output: torch.Tensor
52 | dec_ids_mask: torch.Tensor
53 | encdec_ids_mask: torch.Tensor
54 |
55 | def __init__(self, model_type, article_len, additional_features, num_features=FEATURES_NUMBER):
56 | super(Summarizer, self).__init__()
57 |
58 | print(f'Initialize backbone and tokenizer for {model_type}')
59 | self.article_len = article_len
60 | if model_type == 'bert':
61 | self.backbone = self.initialize_bert()
62 | self.tokenizer = self.initialize_bert_tokenizer()
63 | elif model_type == 'roberta':
64 | self.backbone = self.initialize_roberta()
65 | self.tokenizer = self.initialize_roberta_tokenizer()
66 | else:
67 | raise Exception(f"Wrong model_type argument: {model_type}")
68 | self.backbone.resize_token_embeddings(EMBEDDINGS_ADDITIONAL + self.tokenizer.vocab_size)
69 |
70 | if additional_features:
71 | print('Adding additional features double fully connected nn')
72 | self.features = nn.Sequential(
73 | nn.Linear(num_features, FEATURES_NN_INTERMEDIATE),
74 | nn.LeakyReLU(),
75 | nn.Dropout(FEATURES_DROPOUT),
76 | nn.Linear(FEATURES_NN_INTERMEDIATE, FEATURES_NN_OUT)
77 | )
78 | else:
79 | self.features = None
80 |
81 | print('Initialize backbone embeddings pulling')
82 |
83 | def backbone_forward(input_ids, attention_mask, token_type_ids, position_ids):
84 | return self.backbone(
85 | input_ids=input_ids,
86 | attention_mask=attention_mask,
87 | token_type_ids=token_type_ids,
88 | position_ids=position_ids
89 | )
90 |
91 | self.encoder = lambda *args: backbone_forward(*args)[0]
92 |
93 | print('Initialize decoder')
94 | if additional_features:
95 | self.decoder = Classifier(768 + FEATURES_NN_OUT) # Default BERT output with additional features
96 | else:
97 | self.decoder = Classifier(768) # Default BERT output
98 |
99 | def expand_positional_embs_if_need(self):
100 | print('Expand positional embeddings if need')
101 | print('Positional embeddings', self.backbone.config.max_position_embeddings, self.article_len)
102 | if self.article_len > self.backbone.config.max_position_embeddings:
103 | old_maxlen = self.backbone.config.max_position_embeddings
104 | old_w = self.backbone.embeddings.position_embeddings.weight
105 | logging.info(f'Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}')
106 | self.backbone.embeddings.position_embeddings = nn.Embedding(
107 | self.article_len, self.backbone.config.hidden_size
108 | )
109 | self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w)
110 | self.backbone.config.max_position_embeddings = self.article_len
111 | print('New positional embeddings', self.backbone.config.max_position_embeddings)
112 |
113 | @staticmethod
114 | def initialize_bert():
115 | return BertModel.from_pretrained(
116 | "bert-base-uncased", output_hidden_states=False
117 | )
118 |
119 | @staticmethod
120 | def initialize_bert_tokenizer():
121 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
122 | tokenizer.BOS = "[CLS]"
123 | tokenizer.EOS = "[SEP]"
124 | tokenizer.PAD = "[PAD]"
125 | return tokenizer
126 |
127 | @staticmethod
128 | def initialize_roberta():
129 | backbone = RobertaModel.from_pretrained(
130 | 'roberta-base', output_hidden_states=False
131 | )
132 | print('initialize token type emb, by default roberta doesnt have it')
133 | backbone.config.type_vocab_size = 2
134 | backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size)
135 | backbone.embeddings.token_type_embeddings.weight.data.normal_(
136 | mean=0.0, std=backbone.config.initializer_range
137 | )
138 | return backbone
139 |
140 | @staticmethod
141 | def initialize_roberta_tokenizer():
142 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
143 | tokenizer.BOS = ""
144 | tokenizer.EOS = ""
145 | tokenizer.PAD = ""
146 | return tokenizer
147 |
148 | def save(self, save_filename):
149 | """ Save model in filename
150 |
151 | :param save_filename: str
152 | """
153 | state = dict(
154 | encoder_dict=self.backbone.state_dict(),
155 | decoder_dict=self.decoder.state_dict()
156 | )
157 | if self.features:
158 | state['features_dict'] = self.features.state_dict()
159 | models_folder = os.path.expanduser(cfg.weights_path)
160 | if not os.path.exists(models_folder):
161 | os.makedirs(models_folder)
162 | torch.save(state, f"{models_folder}/{save_filename}.pth")
163 |
164 | def load(self, load_filename):
165 | path = f"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth"
166 | state = torch.load(path, map_location=lambda storage, location: storage)
167 | self.backbone.load_state_dict(state['encoder_dict'])
168 | self.decoder.load_state_dict(state['decoder_dict'])
169 | if self.features:
170 | self.features.load_state_dict(state['features_dict'])
171 |
172 | def froze_backbone(self, froze_strategy):
173 | if froze_strategy == 'froze_all':
174 | for param in self.backbone.parameters():
175 | param.requires_grad_(False)
176 |
177 | elif froze_strategy == 'unfroze_last':
178 | for name, param in self.backbone.named_parameters():
179 | param.requires_grad_(
180 | 'encoder.layer.11' in name or
181 | 'encoder.layer.10' in name or
182 | 'encoder.layer.9' in name
183 | )
184 |
185 | elif froze_strategy == 'unfroze_all':
186 | for param in self.backbone.parameters():
187 | param.requires_grad_(True)
188 |
189 | else:
190 | raise Exception(f'Unsupported froze strategy {froze_strategy}')
191 |
192 | def unfroze_head(self):
193 | for param in self.decoder.parameters():
194 | param.requires_grad_(True)
195 |
196 | def forward(self, input_ids, attention_mask, token_type_ids, input_features=None):
197 | """
198 | :param input_ids: torch.Size([batch_size, article_len])
199 | Indices of input sequence tokens in the vocabulary.
200 | :param attention_mask: torch.Size([batch_size, article_len])
201 | Mask to avoid performing attention on padding token indices.
202 | Mask values selected in `[0, 1]`:
203 | - 1 for tokens that are **not masked**,
204 | - 0 for tokens that are **masked**.
205 | :param token_type_ids: torch.Size([batch_size, article_len])
206 | Segment token indices to indicate first and second portions of the inputs.
207 | Indices are selected in `[0, 1]`:
208 | - 0 corresponds to a *sentence A* token,
209 | - 1 corresponds to a *sentence B* token.
210 | :return: scores | torch.Size([batch_size, summary_len])
211 | """
212 |
213 | # The output of [CLS] is inferred by all other words in this sentence.
214 | # This makes [CLS] a good representation for sentence-level classification.
215 | cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS))
216 | print(f'cls_mask {cls_mask.shape}')
217 |
218 | # Indices of positions of each input sequence tokens in the position embeddings.
219 | # position ids | torch.Size([batch_size, article_len])
220 | pos_ids = torch.arange(
221 | 0,
222 | self.article_len,
223 | dtype=torch.long,
224 | device=input_ids.device
225 | ).unsqueeze(0).repeat(len(input_ids), 1)
226 | print(f'pos_ids {pos_ids.shape}')
227 | # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
228 | # for each word in the input, the BERT base internally creates a 768-dimensional output,
229 | # but for tasks like classification, we do not actually require the output for all the embeddings.
230 | # So by default, BERT considers only the output corresponding to the first token [CLS]
231 | # and drops the output vectors corresponding to all the other tokens.
232 | enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)
233 |
234 | if self.features:
235 | out_features = self.features(input_features)
236 | scores = self.decoder(torch.cat([enc_output[cls_mask], out_features], dim=-1))
237 | else:
238 | print('enc_output', enc_output.shape)
239 | print('enc_output[cls_mask]', enc_output[cls_mask].shape)
240 | scores = self.decoder(enc_output[cls_mask])
241 | print('scores', scores.shape)
242 |
243 | return scores
244 |
245 | def evaluate(self, input_ids, attention_mask, token_type_ids, input_features=None):
246 | """See forward for parameters and output description"""
247 |
248 | # The output of [CLS] is inferred by all other words in this sentence.
249 | # This makes [CLS] a good representation for sentence-level classification.
250 | cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS))
251 |
252 | # position ids | torch.Size([batch_size, article_len])
253 | pos_ids = torch.arange(
254 | 0,
255 | self.article_len,
256 | dtype=torch.long,
257 | device=input_ids.device
258 | ).unsqueeze(0).repeat(len(input_ids), 1)
259 |
260 | # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
261 | enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)
262 |
263 | scores = []
264 | for eo, cm in zip(enc_output, cls_mask):
265 | if self.features:
266 | out_features = self.features(input_features)
267 | score = self.decoder.evaluate(torch.cat([eo[cm], out_features], dim=-1))
268 | else:
269 | score = self.decoder.evaluate(eo[cm])
270 | scores.append(score)
271 | return scores
272 |
273 |
274 | class Classifier(nn.Module):
275 | def __init__(self, hidden_size):
276 | super(Classifier, self).__init__()
277 | self.dropout = nn.Dropout(DECODER_DROPOUT)
278 | self.linear = nn.Linear(hidden_size, 1)
279 | self.sigmoid = nn.Sigmoid()
280 |
281 | def forward(self, x):
282 | return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1))
283 |
284 | def evaluate(self, x):
285 | return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1))
286 |
--------------------------------------------------------------------------------
/review/dataset/prepare_dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "pycharm": {
7 | "name": "#%% md\n"
8 | }
9 | },
10 | "source": [
11 | "\n",
12 | "# Make dataset\n",
13 | "\n",
14 | "This notebook contains the code to analyse content of the PubMedCentral Author Manuscript Collection. \\\n",
15 | "See: https://www.ncbi.nlm.nih.gov/pmc/about/mscollection/\n",
16 | "\n",
17 | "Files should be downloaded from https://ftp.ncbi.nlm.nih.gov/pub/pmc/manuscript/xml/ into `~/pmc_dataset` folder.\n",
18 | "\n",
19 | "Resulting tables will be created under `~/review/dataset` folder (see `config.py`).\n",
20 | "\n",
21 | "Please ensure that env variable `PYTHONPATH` includes project folder to be able to import `review.config` module.\n",
22 | "\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {
29 | "pycharm": {
30 | "name": "#%%\n"
31 | }
32 | },
33 | "outputs": [],
34 | "source": [
35 | "% matplotlib inline\n",
36 | "% config InlineBackend.figure_format='retina'"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "outputs": [],
43 | "source": [
44 | "import logging\n",
45 | "import os\n",
46 | "import sys\n",
47 | "\n",
48 | "import pandas as pd\n",
49 | "from IPython.display import display\n",
50 | "\n",
51 | "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')\n",
52 | "\n",
53 | "import review.config as cfg\n",
54 | "\n",
55 | "pmc_dataset_root = os.path.expanduser('~/pmc_dataset')\n",
56 | "dataset_root = os.path.expanduser(cfg.dataset_path)"
57 | ],
58 | "metadata": {
59 | "collapsed": false,
60 | "pycharm": {
61 | "name": "#%%\n"
62 | }
63 | }
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {
68 | "pycharm": {
69 | "name": "#%% md\n"
70 | }
71 | },
72 | "source": [
73 | "### Collecting articles"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {
80 | "pycharm": {
81 | "name": "#%%\n"
82 | }
83 | },
84 | "outputs": [],
85 | "source": [
86 | "from glob import glob\n",
87 | "from lxml import etree\n",
88 | "from tqdm.auto import tqdm\n",
89 | "\n",
90 | "dict_articles = {}\n",
91 | "\n",
92 | "for filelist in tqdm(glob(os.path.join(pmc_dataset_root, '*filelist.txt'))):\n",
93 | " with open(filelist, 'r') as f:\n",
94 | " for line in f:\n",
95 | " if 'LastUpdated' in line:\n",
96 | " continue\n",
97 | " filename, pmcid, pmid, mid, date, time = line.split()\n",
98 | " dict_articles[pmid] = filename"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {
105 | "pycharm": {
106 | "name": "#%%\n"
107 | }
108 | },
109 | "outputs": [],
110 | "source": [
111 | "print(list(dict_articles.items())[:10])"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {
118 | "pycharm": {
119 | "name": "#%%\n"
120 | }
121 | },
122 | "outputs": [],
123 | "source": [
124 | "import nltk\n",
125 | "\n",
126 | "\n",
127 | "def split_text(text):\n",
128 | " sents = nltk.tokenize.sent_tokenize(text)\n",
129 | " res_sents = []\n",
130 | " i = 0\n",
131 | " while i < len(sents):\n",
132 | " check = False\n",
133 | " if i + 1 < len(sents):\n",
134 | " check = sents[i + 1].strip()[0].islower() or sents[i + 1].strip()[0].isdigit()\n",
135 | " made = sents[i]\n",
136 | " while i + 1 < len(sents) and (made.endswith('Fig.') or check):\n",
137 | " made += \" \" + \" \".join(sents[i + 1].strip().split())\n",
138 | " i += 1\n",
139 | " if i + 1 < len(sents):\n",
140 | " check = sents[i + 1].strip()[0].islower() or sents[i + 1].strip()[0].isdigit()\n",
141 | " res_sents.append(\" \".join(made.strip().split()))\n",
142 | " i += 1\n",
143 | " return res_sents\n",
144 | "\n",
145 | "\n",
146 | "def get_sentences(node):\n",
147 | " def helper(node, is_disc):\n",
148 | " if node.tag == 'xref':\n",
149 | " ntail = ''\n",
150 | " if node.tail is not None:\n",
151 | " ntail = node.tail\n",
152 | " res = f' xref_{node.get(\"ref-type\")}_{node.get(\"rid\")} ' + ntail\n",
153 | " if res is None:\n",
154 | " return '', ''\n",
155 | " if is_disc:\n",
156 | " return '', res\n",
157 | " return res, ''\n",
158 | " if node.tag == 'title':\n",
159 | " if node.tail is None:\n",
160 | " return '', ''\n",
161 | " if is_disc:\n",
162 | " return '', node.tail\n",
163 | " return node.tail, ''\n",
164 | " if not is_disc and node.find('title') is not None:\n",
165 | " title = \"\".join(node.find('title').itertext()).lower()\n",
166 | " if 'discussion' in title:\n",
167 | " is_disc = True\n",
168 | " st_text = ''\n",
169 | " if node.text is not None:\n",
170 | " st_text = node.text\n",
171 | " if is_disc:\n",
172 | " n_disc = st_text\n",
173 | " n_gen = \"\"\n",
174 | " else:\n",
175 | " n_gen = st_text\n",
176 | " n_disc = \"\"\n",
177 | " for ch in node.getchildren():\n",
178 | " gen, disc = helper(ch, is_disc)\n",
179 | " n_gen += gen\n",
180 | " n_disc += disc\n",
181 | " tail = \"\"\n",
182 | " if node.tail is not None:\n",
183 | " tail = node.tail\n",
184 | " if is_disc:\n",
185 | " n_disc += tail\n",
186 | " else:\n",
187 | " n_gen += tail\n",
188 | " return n_gen, n_disc\n",
189 | "\n",
190 | " gen_res, disc_res = helper(node.find('body'), False)\n",
191 | " gen_res = split_text(gen_res)\n",
192 | " disc_res = split_text(disc_res)\n",
193 | "\n",
194 | " abstract = \"\"\n",
195 | "\n",
196 | " try:\n",
197 | " abstract = \"\".join(node.find('front').find('article-meta').find('abstract').itertext())\n",
198 | " abstract = \" \".join(abstract.strip().split())\n",
199 | " except Exception:\n",
200 | " pass\n",
201 | " return gen_res, disc_res, abstract"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {
208 | "pycharm": {
209 | "name": "#%%\n"
210 | }
211 | },
212 | "outputs": [],
213 | "source": [
214 | "tree = etree.parse(f\"{pmc_dataset_root}/PMC004xxxxxx/PMC4239434.xml\")\n",
215 | "sents = get_sentences(tree.getroot())\n",
216 | "print(sents)\n",
217 | "print(len(sents[0]), len(sents[1]), len(sents[2]))"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {
224 | "pycharm": {
225 | "name": "#%%\n"
226 | }
227 | },
228 | "outputs": [],
229 | "source": [
230 | "def get_all_refs(node):\n",
231 | " def get_cit_id_type(node):\n",
232 | " if node.find('element-citation') is None:\n",
233 | " return None\n",
234 | " if node.find('element-citation').find('pub-id') is None:\n",
235 | " return None\n",
236 | " return node.find('element-citation').find('pub-id').get('pub-id-type')\n",
237 | "\n",
238 | " def get_citation_info(node):\n",
239 | " if node is None:\n",
240 | " return {}\n",
241 | " res = {}\n",
242 | " for ch in node.getchildren():\n",
243 | " if ch.tag == 'ref':\n",
244 | " id_type = get_cit_id_type(ch)\n",
245 | " if id_type is not None and id_type == 'pmid':\n",
246 | " res[ch.get('id')] = {\n",
247 | " 'publication-type': ch.find('element-citation').get('publication-type'),\n",
248 | " 'pmid': ch.find('element-citation').find('pub-id').text\n",
249 | " }\n",
250 | " return res\n",
251 | "\n",
252 | " def get_figs_info(node):\n",
253 | " if node is None:\n",
254 | " return {}\n",
255 | " res = {}\n",
256 | " for ch in node.getchildren():\n",
257 | " if ch.tag == 'fig' and ch.find('caption') is not None:\n",
258 | " res[ch.get('id')] = \" \".join(''.join(ch.find('caption').itertext()).strip().split())\n",
259 | " return res\n",
260 | "\n",
261 | " def get_tables_info(node):\n",
262 | " if node is None:\n",
263 | " return {}\n",
264 | " res = {}\n",
265 | " for ch in node.getchildren():\n",
266 | " if ch.tag == 'table-wrap' and ch.find('caption') is not None:\n",
267 | " res[ch.get('id')] = \" \".join(''.join(ch.find('caption').itertext()).strip().split())\n",
268 | " return res\n",
269 | "\n",
270 | " citations = get_citation_info(node.find('back').find('ref-list'))\n",
271 | " figs = get_figs_info(node.find('floats-group'))\n",
272 | " tables = get_tables_info(node.find('floats-group'))\n",
273 | " return citations, figs, tables"
274 | ]
275 | },
276 | {
277 | "cell_type": "code",
278 | "execution_count": null,
279 | "metadata": {
280 | "pycharm": {
281 | "name": "#%%\n"
282 | }
283 | },
284 | "outputs": [],
285 | "source": [
286 | "get_all_refs(tree.getroot())"
287 | ]
288 | },
289 | {
290 | "cell_type": "code",
291 | "execution_count": null,
292 | "metadata": {
293 | "pycharm": {
294 | "name": "#%%\n"
295 | }
296 | },
297 | "outputs": [],
298 | "source": [
299 | "import re\n",
300 | "\n",
301 | "pattern = re.compile(\"(?<=xref_bibr_)[\\d\\w]+\")\n",
302 | "\n",
303 | "\n",
304 | "def count_reverse(sents_gen, sents_disc, pmid):\n",
305 | " result = []\n",
306 | " for i, sent in enumerate(sents_gen):\n",
307 | " results = re.findall(pattern, sent)\n",
308 | " result.extend(list(map(lambda x: (pmid, 'general', str(i), x), results)))\n",
309 | " for i, sent in enumerate(sents_disc):\n",
310 | " results = re.findall(pattern, sent)\n",
311 | " result.extend(list(map(lambda x: (pmid, 'discussion', str(i), x), results)))\n",
312 | " return result"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {
319 | "pycharm": {
320 | "name": "#%%\n"
321 | }
322 | },
323 | "outputs": [],
324 | "source": [
325 | "gen_sents, disc_sents, abst = get_sentences(tree.getroot())\n",
326 | "count_reverse(gen_sents, disc_sents, '2000292')"
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": null,
332 | "outputs": [],
333 | "source": [
334 | "def is_review(tree):\n",
335 | " try:\n",
336 | " return any('Review' in sg.find('subject').text for sg in\n",
337 | " tree.find('front').find('article-meta').find('article-categories').findall('subj-group'))\n",
338 | " except:\n",
339 | " return False\n",
340 | "\n",
341 | "\n",
342 | "# Test\n",
343 | "is_review(etree.parse(f\"{pmc_dataset_root}/PMC001xxxxxx/PMC1817751.xml\").getroot())"
344 | ],
345 | "metadata": {
346 | "collapsed": false,
347 | "pycharm": {
348 | "name": "#%%\n"
349 | }
350 | }
351 | },
352 | {
353 | "cell_type": "markdown",
354 | "metadata": {
355 | "pycharm": {
356 | "name": "#%% md\n"
357 | }
358 | },
359 | "source": [
360 | "## Create tables required for model learning"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": null,
366 | "metadata": {
367 | "scrolled": true,
368 | "pycharm": {
369 | "name": "#%%\n"
370 | }
371 | },
372 | "outputs": [],
373 | "source": [
374 | "! mkdir -p {dataset_root}\n",
375 | "\n",
376 | "print('Headers')\n",
377 | "with open(f'{dataset_root}/review_files.csv', 'w') as f:\n",
378 | " print('pmid', file=f)\n",
379 | "with open(f'{dataset_root}/citations.csv', 'w') as f:\n",
380 | " print('\\t'.join(['pmid', 'ref_id', 'pub_type', 'ref_pmid']), file=f)\n",
381 | "with open(f'{dataset_root}/sentences.csv', 'w') as f:\n",
382 | " print('\\t'.join(['pmid', 'sent_id', 'type', 'sentence']), file=f)\n",
383 | "with open(f'{dataset_root}/abstracts.csv', 'w') as f:\n",
384 | " print('\\t'.join(['pmid', 'abstract']), file=f)\n",
385 | "with open(f'{dataset_root}/figures.csv', 'w') as f:\n",
386 | " print('\\t'.join(['pmid', 'fig_id', 'caption']), file=f)\n",
387 | "with open(f'{dataset_root}/tables.csv', 'w') as f:\n",
388 | " print('\\t'.join(['pmid', 'tab_id', 'caption']), file=f)\n",
389 | "with open(f'{dataset_root}/reverse_ref.csv', 'w') as f:\n",
390 | " print('\\t'.join(['pmid', 'sent_type', 'sent_id', 'ref_id']), file=f)\n",
391 | "\n",
392 | "print('Processing articles')\n",
393 | "for id, filename in tqdm(list(dict_articles.items())):\n",
394 | " try:\n",
395 | " tree = etree.parse(pmc_dataset_root + \"/\" + filename).getroot()\n",
396 | " gen_sents, disc_sents, abstract = get_sentences(tree)\n",
397 | " cits, figs, tables = get_all_refs(tree)\n",
398 | " except Exception as e:\n",
399 | " print(\"\\rsomething went wrong\", id, filename, e)\n",
400 | " continue\n",
401 | " if is_review(tree):\n",
402 | " with open(f'{dataset_root}/review_files.csv', 'a') as f:\n",
403 | " print(id, file=f)\n",
404 | " with open(f'{dataset_root}/citations.csv', 'a') as f:\n",
405 | " for i, dic in cits.items():\n",
406 | " print('\\t'.join([id, str(i), dic['publication-type'], dic['pmid']]), file=f)\n",
407 | " with open(f'{dataset_root}/sentences.csv', 'a') as f:\n",
408 | " for i, sent in enumerate(gen_sents):\n",
409 | " print('\\t'.join([id, str(i), 'general', sent]), file=f)\n",
410 | " for i, sent in enumerate(disc_sents):\n",
411 | " print('\\t'.join([id, str(i), 'discussion', sent]), file=f)\n",
412 | " if abstract != '':\n",
413 | " with open(f'{dataset_root}/abstracts.csv', 'a') as f:\n",
414 | " print('\\t'.join([id, abstract]), file=f)\n",
415 | " with open(f'{dataset_root}/figures.csv', 'a') as f:\n",
416 | " for i, text in figs.items():\n",
417 | " print('\\t'.join([id, i, text]), file=f)\n",
418 | " with open(f'{dataset_root}/tables.csv', 'a') as f:\n",
419 | " for i, text in tables.items():\n",
420 | " print('\\t'.join([id, i, text]), file=f)\n",
421 | " with open(f'{dataset_root}/reverse_ref.csv', 'a') as f:\n",
422 | " res = count_reverse(gen_sents, disc_sents, id)\n",
423 | " for row in res:\n",
424 | " print('\\t'.join(list(row)), file=f)"
425 | ]
426 | },
427 | {
428 | "cell_type": "markdown",
429 | "source": [
430 | "## Check dataset loading"
431 | ],
432 | "metadata": {
433 | "collapsed": false,
434 | "pycharm": {
435 | "name": "#%% md\n"
436 | }
437 | }
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "outputs": [],
443 | "source": [
444 | "def sizeof_fmt(num, suffix='B'):\n",
445 | " \"\"\"Used memory analysis utility\"\"\"\n",
446 | " for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:\n",
447 | " if abs(num) < 1024.0:\n",
448 | " return \"%3.1f %s%s\" % (num, unit, suffix)\n",
449 | " num /= 1024.0\n",
450 | " return \"%.1f %s%s\" % (num, 'Yi', suffix)"
451 | ],
452 | "metadata": {
453 | "collapsed": false,
454 | "pycharm": {
455 | "name": "#%%\n"
456 | }
457 | }
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": null,
462 | "outputs": [],
463 | "source": [
464 | "logging.info('Loading citations_df')\n",
465 | "citations_df = pd.read_csv(os.path.join(dataset_root, \"citations.csv\"), sep='\\t')\n",
466 | "logging.info(sizeof_fmt(sys.getsizeof(citations_df)))\n",
467 | "display(citations_df.head())\n",
468 | "del citations_df"
469 | ],
470 | "metadata": {
471 | "collapsed": false,
472 | "pycharm": {
473 | "name": "#%%\n"
474 | }
475 | }
476 | },
477 | {
478 | "cell_type": "code",
479 | "execution_count": null,
480 | "outputs": [],
481 | "source": [
482 | "logging.info('Loading review_files_df')\n",
483 | "review_files_df = pd.read_csv(os.path.join(dataset_root, \"review_files.csv\"), sep='\\t')\n",
484 | "logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))\n",
485 | "display(review_files_df.head())\n",
486 | "del review_files_df"
487 | ],
488 | "metadata": {
489 | "collapsed": false,
490 | "pycharm": {
491 | "name": "#%%\n"
492 | }
493 | }
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": null,
498 | "outputs": [],
499 | "source": [
500 | "logging.info('Loading reverse_ref_df')\n",
501 | "reverse_ref_df = pd.read_csv(os.path.join(dataset_root, \"reverse_ref.csv\"), sep='\\t')\n",
502 | "logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))\n",
503 | "display(reverse_ref_df.head())\n",
504 | "del reverse_ref_df"
505 | ],
506 | "metadata": {
507 | "collapsed": false,
508 | "pycharm": {
509 | "name": "#%%\n"
510 | }
511 | }
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": null,
516 | "outputs": [],
517 | "source": [
518 | "logging.info('Loading abstracts_df')\n",
519 | "abstracts_df = pd.read_csv(os.path.join(dataset_root, \"abstracts.csv\"), sep='\\t')\n",
520 | "logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))\n",
521 | "display(abstracts_df.head())\n",
522 | "del abstracts_df"
523 | ],
524 | "metadata": {
525 | "collapsed": false,
526 | "pycharm": {
527 | "name": "#%%\n"
528 | }
529 | }
530 | },
531 | {
532 | "cell_type": "code",
533 | "execution_count": null,
534 | "outputs": [],
535 | "source": [
536 | "logging.info('Loading figures_df')\n",
537 | "figures_df = pd.read_csv(os.path.join(dataset_root, \"figures.csv\"), sep='\\t')\n",
538 | "logging.info(sizeof_fmt(sys.getsizeof(figures_df)))\n",
539 | "display(figures_df.head())\n",
540 | "del figures_df"
541 | ],
542 | "metadata": {
543 | "collapsed": false,
544 | "pycharm": {
545 | "name": "#%%\n"
546 | }
547 | }
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": null,
552 | "outputs": [],
553 | "source": [
554 | "logging.info('Loading tables_df')\n",
555 | "tables_df = pd.read_csv(os.path.join(dataset_root, \"tables.csv\"), sep='\\t')\n",
556 | "logging.info(sizeof_fmt(sys.getsizeof(tables_df)))\n",
557 | "display(tables_df.head())\n",
558 | "del tables_df"
559 | ],
560 | "metadata": {
561 | "collapsed": false,
562 | "pycharm": {
563 | "name": "#%%\n"
564 | }
565 | }
566 | },
567 | {
568 | "cell_type": "code",
569 | "execution_count": null,
570 | "outputs": [],
571 | "source": [
572 | "logging.info('Loading sentences_df')\n",
573 | "sentences_df = pd.read_csv(os.path.join(dataset_root, \"sentences.csv\"), sep='\\t')\n",
574 | "logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))\n",
575 | "display(sentences_df.head())\n",
576 | "del sentences_df"
577 | ],
578 | "metadata": {
579 | "collapsed": false,
580 | "pycharm": {
581 | "name": "#%%\n"
582 | }
583 | }
584 | },
585 | {
586 | "cell_type": "code",
587 | "execution_count": null,
588 | "outputs": [],
589 | "source": [],
590 | "metadata": {
591 | "collapsed": false,
592 | "pycharm": {
593 | "name": "#%%\n"
594 | }
595 | }
596 | }
597 | ],
598 | "metadata": {
599 | "kernelspec": {
600 | "display_name": "Python 3 (ipykernel)",
601 | "language": "python",
602 | "name": "python3"
603 | },
604 | "language_info": {
605 | "codemirror_mode": {
606 | "name": "ipython",
607 | "version": 3
608 | },
609 | "file_extension": ".py",
610 | "mimetype": "text/x-python",
611 | "name": "python",
612 | "nbconvert_exporter": "python",
613 | "pygments_lexer": "ipython3",
614 | "version": "3.10.4"
615 | }
616 | },
617 | "nbformat": 4,
618 | "nbformat_minor": 4
619 | }
--------------------------------------------------------------------------------
/review/dataset/analyze_archive.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "pycharm": {
7 | "name": "#%% md\n"
8 | }
9 | },
10 | "source": [
11 | "# Analyze archive\n",
12 | "\n",
13 | "This notebook contains the code to analyse content of the PubMedCentral Author Manuscript Collection. \\\n",
14 | "See: https://www.ncbi.nlm.nih.gov/pmc/about/mscollection/\n",
15 | "\n",
16 | "Files can be downloaded here: https://ftp.ncbi.nlm.nih.gov/pub/pmc/manuscript/xml/ \\\n",
17 | "**Please ensure** that files are downloaded into `~/pmc_dataset` folder to proceed."
18 | ]
19 | },
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "pycharm": {
24 | "name": "#%% md\n"
25 | }
26 | },
27 | "source": [
28 | "## Collecting files"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {
35 | "pycharm": {
36 | "name": "#%%\n"
37 | }
38 | },
39 | "outputs": [],
40 | "source": [
41 | "% matplotlib inline\n",
42 | "% config InlineBackend.figure_format='retina'\n",
43 | "\n",
44 | "import functools\n",
45 | "import os\n",
46 | "from collections import Counter\n",
47 | "from glob import glob\n",
48 | "\n",
49 | "import matplotlib.pyplot as plt\n",
50 | "from lxml import etree\n",
51 | "from tqdm.auto import tqdm\n",
52 | "\n",
53 | "dict_articles = {}\n",
54 | "pmc_dataset_root = os.path.expanduser('~/pmc_dataset')\n",
55 | "\n",
56 | "\n",
57 | "for filelist in tqdm(glob(os.path.join(pmc_dataset_root, '*filelist.txt'))):\n",
58 | " with open(filelist, 'r') as f:\n",
59 | " for line in f:\n",
60 | " if 'LastUpdated' in line:\n",
61 | " continue\n",
62 | " filename, pmcid, pmid, mid, date, time = line.split()\n",
63 | " dict_articles[pmid] = filename"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {
70 | "pycharm": {
71 | "name": "#%%\n"
72 | }
73 | },
74 | "outputs": [],
75 | "source": [
76 | "print('Total papers', len(dict_articles))"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {
83 | "pycharm": {
84 | "name": "#%%\n"
85 | }
86 | },
87 | "outputs": [],
88 | "source": [
89 | "list(dict_articles.values())[:10]"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {
96 | "pycharm": {
97 | "name": "#%%\n"
98 | }
99 | },
100 | "outputs": [],
101 | "source": [
102 | "def count_tags(node):\n",
103 | " stat = Counter()\n",
104 | "\n",
105 | " def dfs(root):\n",
106 | " stat[root.tag] += 1\n",
107 | " for child in root.getchildren():\n",
108 | " dfs(child)\n",
109 | "\n",
110 | " dfs(node)\n",
111 | " return stat"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {
118 | "pycharm": {
119 | "name": "#%%\n"
120 | }
121 | },
122 | "outputs": [],
123 | "source": [
124 | "def get_title(tree):\n",
125 | " return etree.tostring(tree.getroot().find(\"front\").find(\"article-meta\").find(\"title-group\").find(\"article-title\"))"
126 | ]
127 | },
128 | {
129 | "cell_type": "markdown",
130 | "metadata": {
131 | "pycharm": {
132 | "name": "#%% md\n"
133 | }
134 | },
135 | "source": [
136 | "## Collecting review papers"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {
143 | "pycharm": {
144 | "name": "#%%\n"
145 | }
146 | },
147 | "outputs": [],
148 | "source": [
149 | "review_filenames = set()\n",
150 | "for filename in tqdm(dict_articles.values()):\n",
151 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n",
152 | " title = str(get_title(tree))\n",
153 | " if not title:\n",
154 | " print(f\"\\r{filename}\")\n",
155 | " if \"review\" in title.lower():\n",
156 | " review_filenames.add(filename)"
157 | ]
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "metadata": {
163 | "pycharm": {
164 | "name": "#%%\n"
165 | }
166 | },
167 | "outputs": [],
168 | "source": [
169 | "print('Review papers', len(review_filenames))"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "metadata": {
175 | "pycharm": {
176 | "name": "#%% md\n"
177 | }
178 | },
179 | "source": [
180 | "## Collecting tag statistics"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "metadata": {
187 | "pycharm": {
188 | "name": "#%%\n"
189 | }
190 | },
191 | "outputs": [],
192 | "source": [
193 | "tag_stat = {}\n",
194 | "tag_stat['review'] = Counter()\n",
195 | "tag_stat['ordinary'] = Counter()\n",
196 | "for filename in tqdm(dict_articles.values()):\n",
197 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n",
198 | " cur_stat = count_tags(tree.getroot())\n",
199 | " if filename in review_filenames:\n",
200 | " tag_stat['review'] += cur_stat\n",
201 | " else:\n",
202 | " tag_stat['ordinary'] += cur_stat"
203 | ]
204 | },
205 | {
206 | "cell_type": "code",
207 | "execution_count": null,
208 | "metadata": {
209 | "pycharm": {
210 | "name": "#%%\n"
211 | }
212 | },
213 | "outputs": [],
214 | "source": [
215 | "for s, cnt in zip(['ordinary', 'review'], [len(dict_articles) - len(review_filenames), len(review_filenames)]):\n",
216 | " with open(f'{s}_tag_stat.txt', 'w') as f:\n",
217 | " srt = sorted(tag_stat[s].items(), key=lambda x: x[1])\n",
218 | " srt = list(map(lambda x: (x[0], x[1] / cnt), srt))\n",
219 | " print(f'Number: {cnt}', file=f)\n",
220 | " for val, count in srt:\n",
221 | " print(f'{val} {count}', file=f)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": null,
227 | "metadata": {
228 | "pycharm": {
229 | "name": "#%%\n"
230 | }
231 | },
232 | "outputs": [],
233 | "source": [
234 | "def tag_depth(node):\n",
235 | " def dfs(root):\n",
236 | " d = 1\n",
237 | " for child in root.getchildren():\n",
238 | " d = max(d, dfs(child) + 1)\n",
239 | " return d\n",
240 | "\n",
241 | " return dfs(node)"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": null,
247 | "metadata": {
248 | "pycharm": {
249 | "name": "#%%\n"
250 | }
251 | },
252 | "outputs": [],
253 | "source": [
254 | "d_stat = {}\n",
255 | "d_stat['review'] = {}\n",
256 | "d_stat['ordinary'] = {}\n",
257 | "for filename in tqdm(dict_articles.values()):\n",
258 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n",
259 | " cur_stat = tag_depth(tree.getroot())\n",
260 | " if filename in review_filenames:\n",
261 | " d_stat['review'][filename] = cur_stat\n",
262 | " else:\n",
263 | " d_stat['ordinary'][filename] = cur_stat"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "metadata": {
270 | "pycharm": {
271 | "name": "#%%\n"
272 | }
273 | },
274 | "outputs": [],
275 | "source": [
276 | "print(list(d_stat['review'].items())[:10])"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": null,
282 | "metadata": {
283 | "pycharm": {
284 | "name": "#%%\n"
285 | }
286 | },
287 | "outputs": [],
288 | "source": [
289 | "for s in ['ordinary', 'review']:\n",
290 | " with open(f'{s}_tag_depth.txt', 'w') as f:\n",
291 | " srt = sorted(d_stat[s].items(), key=lambda x: x[1])\n",
292 | " for val, count in srt:\n",
293 | " print(f'{val} {count}', file=f)"
294 | ]
295 | },
296 | {
297 | "cell_type": "code",
298 | "execution_count": null,
299 | "metadata": {
300 | "pycharm": {
301 | "name": "#%%\n"
302 | }
303 | },
304 | "outputs": [],
305 | "source": [
306 | "plt.title('Tag depths review papers')\n",
307 | "plt.hist(d_stat['review'].values(), bins=range(5, 20))\n",
308 | "plt.show()"
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": null,
314 | "metadata": {
315 | "pycharm": {
316 | "name": "#%%\n"
317 | }
318 | },
319 | "outputs": [],
320 | "source": [
321 | "plt.title('Tag depth ordinary papers')\n",
322 | "plt.hist(d_stat['ordinary'].values(), bins=range(5, 20))\n",
323 | "plt.show()"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": null,
329 | "metadata": {
330 | "pycharm": {
331 | "name": "#%%\n"
332 | }
333 | },
334 | "outputs": [],
335 | "source": [
336 | "tree = etree.parse(os.path.join(pmc_dataset_root, list(dict_articles.values())[0]))"
337 | ]
338 | },
339 | {
340 | "cell_type": "markdown",
341 | "metadata": {
342 | "pycharm": {
343 | "name": "#%% md\n"
344 | }
345 | },
346 | "source": [
347 | "### Collecting paragraphs statistics"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {
354 | "pycharm": {
355 | "name": "#%%\n"
356 | }
357 | },
358 | "outputs": [],
359 | "source": [
360 | "def get_paragraph_info(root):\n",
361 | " num = 0\n",
362 | " sum_pos = -1\n",
363 | " disc_pos = -1\n",
364 | " lens = Counter()\n",
365 | " for ch in root.find('body').getchildren():\n",
366 | " if ch.tag == 'sec':\n",
367 | " num += 1\n",
368 | " try:\n",
369 | " lens[num] = len(etree.tostring(ch))\n",
370 | " except Exception:\n",
371 | " lens[num] = 0\n",
372 | " print(\"\\n!\")\n",
373 | " str_title = str(etree.tostring(ch.find('title'))).lower()\n",
374 | " if 'summary' in str_title:\n",
375 | " sum_pos = num\n",
376 | " if 'discussion' in str_title:\n",
377 | " disc_pos = num\n",
378 | " return num, sum_pos, disc_pos, lens"
379 | ]
380 | },
381 | {
382 | "cell_type": "code",
383 | "execution_count": null,
384 | "metadata": {
385 | "pycharm": {
386 | "name": "#%%\n"
387 | }
388 | },
389 | "outputs": [],
390 | "source": [
391 | "review_filenames = set()\n",
392 | "\n",
393 | "with open('review_tag_depth.txt', 'r') as f:\n",
394 | " for line in f:\n",
395 | " filename, _ = line.split()\n",
396 | " review_filenames.add(filename)"
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": null,
402 | "metadata": {
403 | "scrolled": true,
404 | "pycharm": {
405 | "name": "#%%\n"
406 | }
407 | },
408 | "outputs": [],
409 | "source": [
410 | "para_stats = {}\n",
411 | "para_stats['review'] = {}\n",
412 | "para_stats['ordinary'] = {}\n",
413 | "\n",
414 | "for filename in tqdm(dict_articles.values()):\n",
415 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n",
416 | " try:\n",
417 | " cur_stat = get_paragraph_info(tree.getroot())\n",
418 | " except Exception:\n",
419 | " print(f\"\\n{filename}\")\n",
420 | " continue\n",
421 | " if filename in review_filenames:\n",
422 | " para_stats['review'][filename] = cur_stat\n",
423 | " else:\n",
424 | " para_stats['ordinary'][filename] = cur_stat"
425 | ]
426 | },
427 | {
428 | "cell_type": "code",
429 | "execution_count": null,
430 | "metadata": {
431 | "pycharm": {
432 | "name": "#%%\n"
433 | }
434 | },
435 | "outputs": [],
436 | "source": [
437 | "list(para_stats['review'].items())[:10]"
438 | ]
439 | },
440 | {
441 | "cell_type": "markdown",
442 | "metadata": {
443 | "pycharm": {
444 | "name": "#%% md\n"
445 | }
446 | },
447 | "source": [
448 | "#### Number of sections"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": null,
454 | "metadata": {
455 | "pycharm": {
456 | "name": "#%%\n"
457 | }
458 | },
459 | "outputs": [],
460 | "source": [
461 | "para_nums = list(map(lambda x: x[1][0], para_stats['review'].items()))\n",
462 | "print(para_nums[:10])\n",
463 | "plt.title('Number of sections in review papers')\n",
464 | "plt.hist(para_nums, bins=range(1, 20))\n",
465 | "plt.show()"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": null,
471 | "metadata": {
472 | "pycharm": {
473 | "name": "#%%\n"
474 | }
475 | },
476 | "outputs": [],
477 | "source": [
478 | "para_nums = list(map(lambda x: x[1][0], para_stats['ordinary'].items()))\n",
479 | "print(para_nums[:10])\n",
480 | "plt.title('Number of sections in ordinary papers')\n",
481 | "plt.hist(para_nums, bins=range(1, 20))\n",
482 | "plt.show()"
483 | ]
484 | },
485 | {
486 | "cell_type": "markdown",
487 | "metadata": {
488 | "pycharm": {
489 | "name": "#%% md\n"
490 | }
491 | },
492 | "source": [
493 | "### Discussion section position"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": null,
499 | "metadata": {
500 | "pycharm": {
501 | "name": "#%%\n"
502 | }
503 | },
504 | "outputs": [],
505 | "source": [
506 | "sum_stat = list(map(lambda x: x[1][1], para_stats['review'].items()))\n",
507 | "print(sum_stat[:10])\n",
508 | "plt.title('Position of discussion section in review papers')\n",
509 | "plt.hist(sum_stat, bins=range(1, 20))\n",
510 | "plt.show()"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": null,
516 | "metadata": {
517 | "pycharm": {
518 | "name": "#%%\n"
519 | }
520 | },
521 | "outputs": [],
522 | "source": [
523 | "sum_stat = list(map(lambda x: x[1][1], para_stats['ordinary'].items()))\n",
524 | "print(sum_stat[:10])\n",
525 | "plt.title('Position of discussion section in ordinary papers')\n",
526 | "plt.hist(sum_stat, bins=range(1, 20))\n",
527 | "plt.show()"
528 | ]
529 | },
530 | {
531 | "cell_type": "markdown",
532 | "metadata": {
533 | "pycharm": {
534 | "name": "#%% md\n"
535 | }
536 | },
537 | "source": [
538 | "### Position of discussion papers"
539 | ]
540 | },
541 | {
542 | "cell_type": "code",
543 | "execution_count": null,
544 | "metadata": {
545 | "pycharm": {
546 | "name": "#%%\n"
547 | }
548 | },
549 | "outputs": [],
550 | "source": [
551 | "sum_stat = list(map(lambda x: x[1][2], para_stats['review'].items()))\n",
552 | "print(sum_stat[:10])\n",
553 | "plt.title('Position of discussion section in review papers')\n",
554 | "plt.hist(sum_stat, bins=range(-1, 20))\n",
555 | "plt.show()"
556 | ]
557 | },
558 | {
559 | "cell_type": "code",
560 | "execution_count": null,
561 | "metadata": {
562 | "pycharm": {
563 | "name": "#%%\n"
564 | }
565 | },
566 | "outputs": [],
567 | "source": [
568 | "sum_stat = list(map(lambda x: x[1][2], para_stats['ordinary'].items()))\n",
569 | "print(sum_stat[:10])\n",
570 | "plt.title('Position of discussion section in ordinary papers')\n",
571 | "plt.hist(sum_stat, bins=range(-1, 20))\n",
572 | "plt.show()"
573 | ]
574 | },
575 | {
576 | "cell_type": "markdown",
577 | "metadata": {
578 | "pycharm": {
579 | "name": "#%% md\n"
580 | }
581 | },
582 | "source": [
583 | "### Average number of sections"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": null,
589 | "metadata": {
590 | "pycharm": {
591 | "name": "#%%\n"
592 | }
593 | },
594 | "outputs": [],
595 | "source": [
596 | "len_stat = functools.reduce(lambda x, y: x + y, map(lambda x: x[1][3], para_stats['review'].items()))\n",
597 | "plt.title('Average number of sections in review papers')\n",
598 | "plt.bar(len_stat.keys(), list(map(lambda x: x / len(para_stats['review'].items()), len_stat.values())))\n",
599 | "plt.show()"
600 | ]
601 | },
602 | {
603 | "cell_type": "code",
604 | "execution_count": null,
605 | "metadata": {
606 | "pycharm": {
607 | "name": "#%%\n"
608 | }
609 | },
610 | "outputs": [],
611 | "source": [
612 | "len_stat = functools.reduce(lambda x, y: x + y, map(lambda x: x[1][3], para_stats['ordinary'].items()))\n",
613 | "plt.title('Average number of sections in ordinary papers')\n",
614 | "plt.bar(list(map(lambda x: min(35, x), len_stat.keys())),\n",
615 | " list(map(lambda x: x / len(para_stats['ordinary'].items()), len_stat.values())))\n",
616 | "plt.show()"
617 | ]
618 | },
619 | {
620 | "cell_type": "markdown",
621 | "metadata": {
622 | "pycharm": {
623 | "name": "#%% md\n"
624 | }
625 | },
626 | "source": [
627 | "### Position of conclusion section"
628 | ]
629 | },
630 | {
631 | "cell_type": "code",
632 | "execution_count": null,
633 | "metadata": {
634 | "pycharm": {
635 | "name": "#%%\n"
636 | }
637 | },
638 | "outputs": [],
639 | "source": [
640 | "xml = 'Some example text'\n",
641 | "tree = etree.fromstring(xml)\n",
642 | "print(''.join(tree.itertext()))"
643 | ]
644 | },
645 | {
646 | "cell_type": "code",
647 | "execution_count": null,
648 | "metadata": {
649 | "pycharm": {
650 | "name": "#%%\n"
651 | }
652 | },
653 | "outputs": [],
654 | "source": [
655 | "list(filter(lambda x: x[1][0] == 1, para_stats['ordinary'].items()))[:10]"
656 | ]
657 | },
658 | {
659 | "cell_type": "code",
660 | "execution_count": null,
661 | "metadata": {
662 | "pycharm": {
663 | "name": "#%%\n"
664 | }
665 | },
666 | "outputs": [],
667 | "source": [
668 | "def get_conc_info(root):\n",
669 | " conc_pos = -1\n",
670 | " num = 0\n",
671 | " for ch in root.find('body').getchildren():\n",
672 | " if ch.tag == 'sec':\n",
673 | " num += 1\n",
674 | " str_title = str(etree.tostring(ch.find('title'))).lower()\n",
675 | " if 'conclusion' in str_title:\n",
676 | " conc_pos = num\n",
677 | " return conc_pos"
678 | ]
679 | },
680 | {
681 | "cell_type": "code",
682 | "execution_count": null,
683 | "metadata": {
684 | "pycharm": {
685 | "name": "#%%\n"
686 | }
687 | },
688 | "outputs": [],
689 | "source": [
690 | "conc_stats = {}\n",
691 | "conc_stats['review'] = {}\n",
692 | "conc_stats['ordinary'] = {}\n",
693 | "\n",
694 | "for filename in tqdm(dict_articles.values()):\n",
695 | " tree = etree.parse(os.path.join(pmc_dataset_root, filename))\n",
696 | " try:\n",
697 | " cur_stat = get_conc_info(tree.getroot())\n",
698 | " except Exception:\n",
699 | " print(f\"\\n{filename}\")\n",
700 | " continue\n",
701 | " if filename in review_filenames:\n",
702 | " conc_stats['review'][filename] = cur_stat\n",
703 | " else:\n",
704 | " conc_stats['ordinary'][filename] = cur_stat"
705 | ]
706 | },
707 | {
708 | "cell_type": "code",
709 | "execution_count": null,
710 | "metadata": {
711 | "pycharm": {
712 | "name": "#%%\n"
713 | }
714 | },
715 | "outputs": [],
716 | "source": [
717 | "conc_stat = list(map(lambda x: x[1], conc_stats['review'].items()))\n",
718 | "print(conc_stat[:10])\n",
719 | "plt.title('Position of conclusion section in review papers')\n",
720 | "plt.hist(conc_stat, bins=range(-1, 20))\n",
721 | "plt.show()"
722 | ]
723 | },
724 | {
725 | "cell_type": "code",
726 | "execution_count": null,
727 | "metadata": {
728 | "pycharm": {
729 | "name": "#%%\n"
730 | }
731 | },
732 | "outputs": [],
733 | "source": [
734 | "conc_stat = list(map(lambda x: x[1], conc_stats['ordinary'].items()))\n",
735 | "print(conc_stat[:10])\n",
736 | "plt.title('Position of conclusion section in ordinary papers')\n",
737 | "plt.hist(conc_stat, bins=range(-1, 20))\n",
738 | "plt.show()"
739 | ]
740 | }
741 | ],
742 | "metadata": {
743 | "kernelspec": {
744 | "display_name": "Python 3 (ipykernel)",
745 | "language": "python",
746 | "name": "python3"
747 | },
748 | "language_info": {
749 | "codemirror_mode": {
750 | "name": "ipython",
751 | "version": 3
752 | },
753 | "file_extension": ".py",
754 | "mimetype": "text/x-python",
755 | "name": "python",
756 | "nbconvert_exporter": "python",
757 | "pygments_lexer": "ipython3",
758 | "version": "3.10.4"
759 | }
760 | },
761 | "nbformat": 4,
762 | "nbformat_minor": 4
763 | }
--------------------------------------------------------------------------------
/review/train/test_new_algo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "source": [
6 | "# Training&Evaluation of a developed algorithm\n",
7 | "\n",
8 | "This notebook contains source code for several possible architectures to train & evaluate them."
9 | ],
10 | "metadata": {
11 | "collapsed": false,
12 | "pycharm": {
13 | "name": "#%% md\n"
14 | }
15 | }
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "outputs": [],
21 | "source": [
22 | "import logging\n",
23 | "import os\n",
24 | "import random\n",
25 | "import sys\n",
26 | "\n",
27 | "import numpy as np\n",
28 | "import pandas as pd\n",
29 | "import torch\n",
30 | "from tqdm.auto import tqdm\n",
31 | "\n",
32 | "from review import config as cfg\n",
33 | "\n",
34 | "logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')"
35 | ],
36 | "metadata": {
37 | "collapsed": false,
38 | "pycharm": {
39 | "name": "#%%\n"
40 | }
41 | }
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "outputs": [],
47 | "source": [
48 | "print('CUDA version', torch.version.cuda)\n",
49 | "if torch.cuda.is_available():\n",
50 | " print('GPU:', torch.cuda.get_device_name(0))\n",
51 | "else:\n",
52 | " print('CPU')\n",
53 | "\n",
54 | "\n",
55 | "def setup_cuda_device(model):\n",
56 | " logging.info('Setup single-device settings...')\n",
57 | " device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
58 | " model = model.to(device)\n",
59 | " return model, device"
60 | ],
61 | "metadata": {
62 | "collapsed": false,
63 | "pycharm": {
64 | "name": "#%%\n"
65 | }
66 | }
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "outputs": [],
72 | "source": [
73 | "logging.info('Fix seed')\n",
74 | "seed = 42\n",
75 | "random.seed(seed)\n",
76 | "np.random.seed(seed)\n",
77 | "torch.manual_seed(seed)\n",
78 | "torch.cuda.manual_seed_all(seed)"
79 | ],
80 | "metadata": {
81 | "collapsed": false,
82 | "pycharm": {
83 | "name": "#%%\n"
84 | }
85 | }
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "source": [
90 | "# Loading data and preparation of dataset\n",
91 | "\n",
92 | "`load_data` loads all the needed datafiles for building train dataset.\\\n",
93 | "The several next steps are only should be done if no train/test/val datasets are saved."
94 | ],
95 | "metadata": {
96 | "collapsed": false,
97 | "pycharm": {
98 | "name": "#%% md\n"
99 | }
100 | }
101 | },
102 | {
103 | "cell_type": "code",
104 | "execution_count": null,
105 | "outputs": [],
106 | "source": [
107 | "def sizeof_fmt(num, suffix='B'):\n",
108 | " for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:\n",
109 | " if abs(num) < 1024.0:\n",
110 | " return \"%3.1f %s%s\" % (num, unit, suffix)\n",
111 | " num /= 1024.0\n",
112 | " return \"%.1f %s%s\" % (num, 'Yi', suffix)\n",
113 | "\n",
114 | "\n",
115 | "def load_data(additional_features):\n",
116 | " dataset_root = os.path.expanduser(cfg.dataset_path)\n",
117 | " logging.info(f'Loading references dataset from {dataset_root}')\n",
118 | "\n",
119 | " logging.info('Loading citations_df')\n",
120 | " citations_df = pd.read_csv(os.path.join(dataset_root, \"citations.csv\"), sep='\\t')\n",
121 | " logging.info(sizeof_fmt(sys.getsizeof(citations_df)))\n",
122 | "\n",
123 | " logging.info('Loading sentences_df')\n",
124 | " sentences_df = pd.read_csv(os.path.join(dataset_root, \"sentences.csv\"), sep='\\t')\n",
125 | " logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))\n",
126 | "\n",
127 | " logging.info('Loading review_files_df')\n",
128 | " review_files_df = pd.read_csv(os.path.join(dataset_root, \"review_files.csv\"), sep='\\t')\n",
129 | " logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))\n",
130 | "\n",
131 | " logging.info('Loading reverse_ref_df')\n",
132 | " reverse_ref_df = pd.read_csv(os.path.join(dataset_root, \"reverse_ref.csv\"), sep='\\t')\n",
133 | " logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))\n",
134 | "\n",
135 | " if not additional_features:\n",
136 | " return citations_df, sentences_df, review_files_df, reverse_ref_df, None, None, None\n",
137 | "\n",
138 | " logging.info('Loading abstracts_df')\n",
139 | " abstracts_df = pd.read_csv(os.path.join(dataset_root, \"abstracts.csv\"), sep='\\t')\n",
140 | " logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))\n",
141 | "\n",
142 | " logging.info('Loading figures_df')\n",
143 | " figures_df = pd.read_csv(os.path.join(dataset_root, \"figures.csv\"), sep='\\t')\n",
144 | " logging.info(sizeof_fmt(sys.getsizeof(figures_df)))\n",
145 | "\n",
146 | " logging.info('Loading tables_df')\n",
147 | " tables_df = pd.read_csv(os.path.join(dataset_root, \"tables.csv\"), sep='\\t')\n",
148 | " logging.info(sizeof_fmt(sys.getsizeof(tables_df)))\n",
149 | "\n",
150 | " return citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df"
151 | ],
152 | "metadata": {
153 | "collapsed": false,
154 | "pycharm": {
155 | "name": "#%%\n"
156 | }
157 | }
158 | },
159 | {
160 | "cell_type": "code",
161 | "execution_count": null,
162 | "outputs": [],
163 | "source": [
164 | "ADDITIONAL_FEATURES = False\n",
165 | "\n",
166 | "citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df = load_data(\n",
167 | " additional_features=ADDITIONAL_FEATURES)\n",
168 | "\n",
169 | "logging.info('Done loading references dataset')"
170 | ],
171 | "metadata": {
172 | "collapsed": false,
173 | "pycharm": {
174 | "name": "#%%\n"
175 | }
176 | }
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "outputs": [],
182 | "source": [
183 | "import matplotlib.pyplot as plt\n",
184 | "from collections import Counter\n",
185 | "\n",
186 | "res = Counter(list(sentences_df['pmid'].values))\n",
187 | "plt.hist(res.values(), bins=range(-1, 400))\n",
188 | "plt.title('Length of papers')\n",
189 | "plt.show()\n",
190 | "del res"
191 | ],
192 | "metadata": {
193 | "collapsed": false,
194 | "pycharm": {
195 | "name": "#%%\n"
196 | }
197 | }
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "outputs": [],
203 | "source": [
204 | "REF_SENTS_DF_PATH = f\"{os.path.expanduser(cfg.base_path)}/ref_sents.csv\"\n",
205 | "! rm {REF_SENTS_DF_PATH}\n",
206 | "\n",
207 | "if os.path.exists(REF_SENTS_DF_PATH):\n",
208 | " ref_sents_df = pd.read_csv(REF_SENTS_DF_PATH, sep='\\t')\n",
209 | "else:\n",
210 | " logging.info('Creating reference sentences dataset')\n",
211 | " ref_sents_df = pd.merge(citations_df, reverse_ref_df, left_on=['pmid', 'ref_id'], right_on=['pmid', 'ref_id'])\n",
212 | " ref_sents_df = pd.merge(ref_sents_df, sentences_df, left_on=['pmid', 'sent_type', 'sent_id'],\n",
213 | " right_on=['pmid', 'type', 'sent_id'])\n",
214 | " ref_sents_df = ref_sents_df[ref_sents_df['pmid'].isin(review_files_df['pmid'].values)]\n",
215 | " ref_sents_df = ref_sents_df.drop_duplicates()\n",
216 | " logging.info(f'Len of unique ref_sents {len(set(ref_sents_df[\"ref_pmid\"]))}')\n",
217 | " ref_sents_df = ref_sents_df[['pmid', 'ref_id', 'pub_type', 'ref_pmid', 'sent_type', 'sent_id', 'sentence']]\n",
218 | " ref_sents_df.to_csv(REF_SENTS_DF_PATH, sep='\\t', index=False)\n",
219 | "\n",
220 | "logging.info('Cleanup memory')\n",
221 | "del citations_df\n",
222 | "del review_files_df\n",
223 | "\n",
224 | "display(ref_sents_df.head())"
225 | ],
226 | "metadata": {
227 | "collapsed": false,
228 | "pycharm": {
229 | "name": "#%%\n"
230 | }
231 | }
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "source": [
236 | "# Rouge\n",
237 | "`get_rouge` function allows to compute similarity between two sentences.\n",
238 | "`rouge-l` is another possible option."
239 | ],
240 | "metadata": {
241 | "collapsed": false,
242 | "pycharm": {
243 | "name": "#%% md\n"
244 | }
245 | }
246 | },
247 | {
248 | "cell_type": "code",
249 | "execution_count": null,
250 | "outputs": [],
251 | "source": [
252 | "from rouge import Rouge\n",
253 | "from transformers import BertTokenizer\n",
254 | "\n",
255 | "ROUGE_METER = Rouge()\n",
256 | "TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
257 | "\n",
258 | "\n",
259 | "def get_rouge(sent1, sent2):\n",
260 | " if sent1 is None or sent2 is None:\n",
261 | " return None\n",
262 | " sent_1 = TOKENIZER.tokenize(sent1)\n",
263 | " if len(sent_1) == 0:\n",
264 | " return None\n",
265 | " sent_1 = \" \".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_1)))\n",
266 | " sent_2 = TOKENIZER.tokenize(sent2)\n",
267 | " if len(sent_2) == 0:\n",
268 | " return None\n",
269 | " sent_2 = \" \".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_2)))\n",
270 | " if len(sent_1) == 0 or len(sent_2) == 0:\n",
271 | " return None\n",
272 | " rouges = ROUGE_METER.get_scores(sent_1, sent_2)[0]\n",
273 | " rouges = [rouges[f'rouge-{x}'][\"f\"] for x in ('1', '2')] # , 'l')]\n",
274 | " return np.mean(rouges) * 100\n",
275 | "\n",
276 | "\n",
277 | "def mean_rouge(sent, text):\n",
278 | " if len(text) == 0:\n",
279 | " return None\n",
280 | " try:\n",
281 | " return sum(get_rouge(sent, ref_sent) for ref_sent in text) / len(text)\n",
282 | " except Exception as e:\n",
283 | " logging.error(f'Exception at mean_rouge {e}')\n",
284 | " return None\n",
285 | "\n",
286 | "\n",
287 | "def min_rouge(sent, text):\n",
288 | " try:\n",
289 | " score = 100000000\n",
290 | " for ref_sent in text:\n",
291 | " score = min(get_rouge(sent, ref_sent), score)\n",
292 | " if score == 100000000:\n",
293 | " return None\n",
294 | " return score\n",
295 | " except Exception as e:\n",
296 | " logging.error(f'Exception at min_rouge {e}')\n",
297 | " return None\n",
298 | "\n",
299 | "\n",
300 | "def max_rouge(sent, text):\n",
301 | " try:\n",
302 | " score = -100000\n",
303 | " for ref_sent in text:\n",
304 | " score = max(get_rouge(sent, ref_sent), score)\n",
305 | " if score == -100000:\n",
306 | " return None\n",
307 | " return score\n",
308 | " except Exception as e:\n",
309 | " logging.error(f'Exception at max_rouge {e}')\n",
310 | " return None"
311 | ],
312 | "metadata": {
313 | "collapsed": false,
314 | "pycharm": {
315 | "name": "#%%\n"
316 | }
317 | }
318 | },
319 | {
320 | "cell_type": "code",
321 | "execution_count": null,
322 | "outputs": [],
323 | "source": [
324 | "get_rouge(\"I am scout. True!\", \"No you are not a scout.\")"
325 | ],
326 | "metadata": {
327 | "collapsed": false,
328 | "pycharm": {
329 | "name": "#%%\n"
330 | }
331 | }
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "source": [
336 | "# Build train / test/ validate datasets\n",
337 | "\n",
338 | "To create a dataset with features use `preprocess_paper_with_features`. Otherwise, use `preprocess_paper`.\n",
339 | "\n",
340 | "In case datasets are not yet created and `ref_sents_df` is also not yet created, let's create `ref_sents_df`.\\\n",
341 | "For each paper pmid there is a list of sentences from review papers in which the paper with this `pmid` is cited."
342 | ],
343 | "metadata": {
344 | "collapsed": false,
345 | "pycharm": {
346 | "name": "#%% md\n"
347 | }
348 | }
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "outputs": [],
354 | "source": [
355 | "from unidecode import unidecode\n",
356 | "from nltk.tokenize import sent_tokenize\n",
357 | "import re\n",
358 | "\n",
359 | "REPLACE_SYMBOLS = {\n",
360 | " '—': '-',\n",
361 | " '–': '-',\n",
362 | " '―': '-',\n",
363 | " '…': '...',\n",
364 | " '´´': \"´\",\n",
365 | " '´´´': \"´´\",\n",
366 | " \"''\": \"'\",\n",
367 | " \"'''\": \"'\",\n",
368 | " \"``\": \"`\",\n",
369 | " \"```\": \"`\",\n",
370 | " \":\": \" : \",\n",
371 | "}\n",
372 | "\n",
373 | "\n",
374 | "def parse_sents(data):\n",
375 | " sents = sum([sent_tokenize(text) for text in data], [])\n",
376 | " sents = list(filter(lambda x: len(x) > 3, sents))\n",
377 | " return sents\n",
378 | "\n",
379 | "\n",
380 | "def sent_standardize(sent):\n",
381 | " sent = unidecode(sent)\n",
382 | " sent = re.sub(r\"\\[(xref_\\w*_\\w\\d*]*)(, xref_\\w*_\\w\\d*)*\\]\", \" \", sent) # delete [xref,...]\n",
383 | " sent = re.sub(r\"\\( (xref_\\w*_\\w\\d*)(; xref_\\w*_\\w\\d*)* \\)\", \" \", sent) # delete (xref; ...)\n",
384 | " sent = re.sub(r\"\\[xref_\\w*_\\w\\d*\\]\", \" \", sent) # delete [xref]\n",
385 | " sent = re.sub(r\"xref_\\w*_\\w\\d*\", \" \", sent) # delete [[xref]]\n",
386 | " for k, v in REPLACE_SYMBOLS.items():\n",
387 | " sent = sent.replace(k, v)\n",
388 | " return sent.strip()\n",
389 | "\n",
390 | "\n",
391 | "def standardize(text):\n",
392 | " return [x for x in (sent_standardize(sent) for sent in text) if len(x) > 3]\n",
393 | "\n",
394 | "\n",
395 | "PAPER_MIN_SENTENCES = 50\n",
396 | "PAPER_MAX_SENTENCES = 100\n",
397 | "\n",
398 | "\n",
399 | "def preprocess_paper(\n",
400 | " paper_id, sentences_df, ref_sents_df,\n",
401 | " paper_min_sents=PAPER_MIN_SENTENCES, paper_max_sents=PAPER_MAX_SENTENCES):\n",
402 | " paper = sentences_df[sentences_df['pmid'] == paper_id]['sentence']\n",
403 | " paper = standardize(paper)\n",
404 | "\n",
405 | " ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']\n",
406 | " ref_sents = standardize(ref_sents)\n",
407 | "\n",
408 | " if len(paper) < paper_min_sents:\n",
409 | " return None\n",
410 | "\n",
411 | " if len(paper) > paper_max_sents:\n",
412 | " paper = list(paper[:paper_min_sents]) + list(paper[-paper_min_sents:])\n",
413 | "\n",
414 | " preprocessed_score = [\n",
415 | " sum(get_rouge(sent, ref_sent) for ref_sent in ref_sents) / len(ref_sents) for sent in paper\n",
416 | " ]\n",
417 | " return paper, preprocessed_score\n",
418 | "\n",
419 | "\n",
420 | "def preprocess_paper_with_features(paper_id, sentences_df, ref_sents_df,\n",
421 | " abstracts_df, figures_df, reverse_ref_df, tables_df,\n",
422 | " paper_min_sents=PAPER_MIN_SENTENCES, paper_max_sents=PAPER_MAX_SENTENCES):\n",
423 | " preprocessed_score = []\n",
424 | " features = []\n",
425 | "\n",
426 | " papers = sentences_df[sentences_df['pmid'] == paper_id]['sentence']\n",
427 | " papers = standardize(papers)\n",
428 | "\n",
429 | " sent_ids = sentences_df[sentences_df['pmid'] == paper_id]['sent_id']\n",
430 | " sent_types = sentences_df[sentences_df['pmid'] == paper_id]['type']\n",
431 | "\n",
432 | " ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']\n",
433 | " ref_sents = standardize(ref_sents)\n",
434 | "\n",
435 | " fig_captions = figures_df[figures_df['pmid'] == paper_id]['caption']\n",
436 | " fig_captions = standardize(fig_captions)\n",
437 | "\n",
438 | " tab_captions = tables_df[tables_df['pmid'] == paper_id]['caption']\n",
439 | " tab_captions = standardize(tab_captions)\n",
440 | "\n",
441 | " abstract = abstracts_df[abstracts_df['pmid'] == paper_id]['abstract']\n",
442 | " if len(abstract) != 0:\n",
443 | " abstract = standardize(abstract)\n",
444 | "\n",
445 | " tmp_df = reverse_ref_df[reverse_ref_df['pmid'] == paper_id]\n",
446 | "\n",
447 | " if len(papers) < paper_min_sents:\n",
448 | " return None\n",
449 | "\n",
450 | " if len(papers) > paper_max_sents:\n",
451 | " papers = list(papers[:paper_min_sents]) + list(papers[-paper_min_sents:])\n",
452 | " sent_ids = list(sent_ids[:paper_min_sents]) + list(sent_ids[-paper_min_sents:])\n",
453 | " sent_types = list(sent_types[:paper_min_sents]) + list(sent_types[-paper_min_sents:])\n",
454 | "\n",
455 | " for sent, sent_type, sent_id in zip(papers, sent_types, sent_ids):\n",
456 | " score = mean_rouge(sent, ref_sents)\n",
457 | " if score is None:\n",
458 | " return None\n",
459 | "\n",
460 | " r_abs = get_rouge(sent, abstract[0])\n",
461 | " num_refs = len(tmp_df[(tmp_df['sent_type'] == sent_type) & (tmp_df['sent_id'] == sent_id)])\n",
462 | " preprocessed_score.append(score)\n",
463 | " features.append((sent_id, int(sent_type == \"general\"), r_abs, num_refs,\n",
464 | " mean_rouge(sent, fig_captions), mean_rouge(sent, tab_captions),\n",
465 | " min_rouge(sent, fig_captions), min_rouge(sent, tab_captions),\n",
466 | " max_rouge(sent, fig_captions), max_rouge(sent, tab_captions)))\n",
467 | " return papers, preprocessed_score, features"
468 | ],
469 | "metadata": {
470 | "collapsed": false,
471 | "pycharm": {
472 | "name": "#%%\n"
473 | }
474 | }
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "outputs": [],
480 | "source": [
481 | "FEATURES_NUMBER = 10\n",
482 | "\n",
483 | "\n",
484 | "def process_reference_sentences_dataset(sentences_df, ref_sents_df, additional_features):\n",
485 | " res = {}\n",
486 | " inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)\n",
487 | " for pmid in tqdm(inter):\n",
488 | " try:\n",
489 | " if additional_features:\n",
490 | " temp = preprocess_paper_with_features(\n",
491 | " pmid, sentences_df, ref_sents_df, abstracts_df, figures_df, reverse_ref_df, tables_df\n",
492 | " )\n",
493 | " else:\n",
494 | " temp = preprocess_paper(pmid, sentences_df, ref_sents_df)\n",
495 | " except Exception as e:\n",
496 | " logging.warning(f'Error during processing {pmid} {e}')\n",
497 | " continue\n",
498 | " if temp is None:\n",
499 | " logging.warning(f'temp is None for {pmid}')\n",
500 | " continue\n",
501 | " res[pmid] = temp\n",
502 | " print(f\"\\r{pmid} {np.mean(res[pmid][1])}\", end=\"\")\n",
503 | "\n",
504 | " logging.info(f'Successfully preprocessed {len(res)} of {len(inter)} papers')\n",
505 | "\n",
506 | " logging.info(f'Creating train dataset')\n",
507 | " feature_names = [\n",
508 | " 'sent_id', 'sent_type', 'r_abs', 'num_refs',\n",
509 | " 'mean_r_fig', 'mean_r_tab', 'min_r_fig', 'min_r_tab', 'max_r_fig', 'max_r_tab'\n",
510 | " ]\n",
511 | " assert len(feature_names) == FEATURES_NUMBER\n",
512 | " train_dic = dict(\n",
513 | " pmid=[], sentence=[], score=[], sent_id=[], sent_type=[], r_abs=[], num_refs=[],\n",
514 | " mean_r_fig=[], mean_r_tab=[], min_r_fig=[], min_r_tab=[], max_r_fig=[], max_r_tab=[]\n",
515 | " )\n",
516 | "\n",
517 | " for pmid, stat in tqdm(res.items()):\n",
518 | " if len(stat) == 2:\n",
519 | " for sent, score in zip(*stat):\n",
520 | " train_dic['pmid'].append(pmid)\n",
521 | " train_dic['sentence'].append(sent)\n",
522 | " train_dic['score'].append(score)\n",
523 | " else:\n",
524 | " for sent, score, features in zip(*stat):\n",
525 | " train_dic['pmid'].append(pmid)\n",
526 | " train_dic['sentence'].append(sent)\n",
527 | " train_dic['score'].append(score)\n",
528 | " for name, val in zip(feature_names, features):\n",
529 | " train_dic[name].append(val)\n",
530 | "\n",
531 | " train_df = pd.DataFrame({k: v for k, v in train_dic.items() if v})\n",
532 | " logging.info(f'Full train dataset {len(train_df)}')\n",
533 | " return train_df"
534 | ],
535 | "metadata": {
536 | "collapsed": false,
537 | "pycharm": {
538 | "name": "#%%\n"
539 | }
540 | }
541 | },
542 | {
543 | "cell_type": "code",
544 | "execution_count": null,
545 | "outputs": [],
546 | "source": [
547 | "TRAIN_DATASET_PATH = f'{os.path.expanduser(cfg.base_path)}/dataset.csv'\n",
548 | "! rm {TRAIN_DATASET_PATH}\n",
549 | "\n",
550 | "if os.path.exists(TRAIN_DATASET_PATH):\n",
551 | " train_df = pd.read_csv(TRAIN_DATASET_PATH)\n",
552 | "else:\n",
553 | " train_df = process_reference_sentences_dataset(\n",
554 | " sentences_df, ref_sents_df, additional_features=ADDITIONAL_FEATURES\n",
555 | " )\n",
556 | " train_df.to_csv(TRAIN_DATASET_PATH, index=False)"
557 | ],
558 | "metadata": {
559 | "collapsed": false,
560 | "pycharm": {
561 | "name": "#%%\n"
562 | }
563 | }
564 | },
565 | {
566 | "cell_type": "code",
567 | "execution_count": null,
568 | "outputs": [],
569 | "source": [
570 | "display(train_df.head())"
571 | ],
572 | "metadata": {
573 | "collapsed": false,
574 | "pycharm": {
575 | "name": "#%%\n"
576 | }
577 | }
578 | },
579 | {
580 | "cell_type": "code",
581 | "execution_count": null,
582 | "outputs": [],
583 | "source": [
584 | "import matplotlib.pyplot as plt\n",
585 | "from collections import Counter\n",
586 | "\n",
587 | "res = Counter(list(train_df['score'].values))\n",
588 | "\n",
589 | "plt.hist(res.values(), bins=range(2, 20))\n",
590 | "plt.title('Train dataset scores')\n",
591 | "plt.show()"
592 | ],
593 | "metadata": {
594 | "collapsed": false,
595 | "pycharm": {
596 | "name": "#%%\n"
597 | }
598 | }
599 | },
600 | {
601 | "cell_type": "markdown",
602 | "source": [
603 | "# Splitting data into train/test/val"
604 | ],
605 | "metadata": {
606 | "collapsed": false,
607 | "pycharm": {
608 | "name": "#%% md\n"
609 | }
610 | }
611 | },
612 | {
613 | "cell_type": "code",
614 | "execution_count": null,
615 | "outputs": [],
616 | "source": [
617 | "from sklearn.model_selection import train_test_split\n",
618 | "\n",
619 | "train_ids, test_ids = train_test_split(list(set(train_df['pmid'].values)), test_size=0.2)\n",
620 | "test_ids, val_ids = train_test_split(test_ids, test_size=0.4)\n",
621 | "\n",
622 | "train = train_df[train_df['pmid'].isin(train_ids)]\n",
623 | "logging.info(f'Train {len(train)}')\n",
624 | "display(train.head(1))\n",
625 | "\n",
626 | "test = train_df[train_df['pmid'].isin(test_ids)]\n",
627 | "logging.info(f'Test {len(test)}')\n",
628 | "display(test.head(1))\n",
629 | "\n",
630 | "val = train_df[train_df['pmid'].isin(val_ids)]\n",
631 | "logging.info(f'Validate {len(val)}')\n",
632 | "display(val.head(1))\n"
633 | ],
634 | "metadata": {
635 | "collapsed": false,
636 | "pycharm": {
637 | "name": "#%%\n"
638 | }
639 | }
640 | },
641 | {
642 | "cell_type": "markdown",
643 | "source": [
644 | "# Inspect pretrained models: BERT and Roberta"
645 | ],
646 | "metadata": {
647 | "collapsed": false,
648 | "pycharm": {
649 | "name": "#%% md\n"
650 | }
651 | }
652 | },
653 | {
654 | "cell_type": "code",
655 | "execution_count": null,
656 | "outputs": [],
657 | "source": [
658 | "from transformers import BertModel\n",
659 | "\n",
660 | "backbone = BertModel.from_pretrained(\n",
661 | " \"bert-base-uncased\", output_hidden_states=False\n",
662 | ")\n",
663 | "print('BERT pretrained')\n",
664 | "print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')\n",
665 | "print(', '.join(n for n, p in backbone.named_parameters()))\n",
666 | "print(backbone)\n",
667 | "\n",
668 | "# backbone = RobertaModel.from_pretrained(\n",
669 | "# 'roberta-base', output_hidden_states=False\n",
670 | "# )\n",
671 | "# print('ROBERTA pretrained')\n",
672 | "# print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')\n",
673 | "# print(', '.join(n for n, p in backbone.named_parameters()))\n",
674 | "# print(backbone)"
675 | ],
676 | "metadata": {
677 | "collapsed": false,
678 | "pycharm": {
679 | "name": "#%%\n"
680 | }
681 | }
682 | },
683 | {
684 | "cell_type": "markdown",
685 | "source": [
686 | "# Main model classes\n",
687 | "\n",
688 | "The model in main pubtrends application is loaded using `load_model` function from `review.model` module.\n",
689 | "\n",
690 | "It has several options to set up:\n",
691 | "* with or without features (right now without features works better),\n",
692 | "* `BERT` or `roberta` as basis (no big difference),\n",
693 | "\n",
694 | "You can also choose `frozen_strategy`:\n",
695 | "* `froze_all` in case you don't want to improve bert layers but only the summarization layer,\n",
696 | "* `unfroze_last` -- modifies bert weights and still training not very slow,\n",
697 | "* `unfroze_all` -- the training is slow, the results may better though"
698 | ],
699 | "metadata": {
700 | "collapsed": false,
701 | "pycharm": {
702 | "name": "#%% md\n"
703 | }
704 | }
705 | },
706 | {
707 | "cell_type": "code",
708 | "execution_count": null,
709 | "outputs": [],
710 | "source": [
711 | "import os\n",
712 | "import torch.nn as nn\n",
713 | "from transformers import BertModel, RobertaModel\n",
714 | "from transformers import BertTokenizer, RobertaTokenizer\n",
715 | "\n",
716 | "\n",
717 | "def get_token_id(tokenizer, tkn):\n",
718 | " return tokenizer.convert_tokens_to_ids([tkn])[0]\n",
719 | "\n",
720 | "\n",
721 | "EMBEDDINGS_ADDITIONAL = 20\n",
722 | "\n",
723 | "FEATURES_NN_INTERMEDIATE = 100\n",
724 | "FEATURES_NN_OUT = 50\n",
725 | "FEATURES_DROPOUT = 0.1\n",
726 | "\n",
727 | "DECODER_DROPOUT = 0.1\n",
728 | "\n",
729 | "\n",
730 | "class Summarizer(nn.Module):\n",
731 | " \"\"\"\n",
732 | " This is the main summarization model.\n",
733 | " See https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_tf_bert.py\n",
734 | " It operates the same input format as original BERT used underneath.\n",
735 | " See forward, evaluate params description.\n",
736 | " \"\"\"\n",
737 | " enc_output: torch.Tensor\n",
738 | " dec_ids_mask: torch.Tensor\n",
739 | " encdec_ids_mask: torch.Tensor\n",
740 | "\n",
741 | " def __init__(self, model_type, article_len, additional_features, num_features=FEATURES_NUMBER):\n",
742 | " super(Summarizer, self).__init__()\n",
743 | "\n",
744 | " print(f'Initialize backbone and tokenizer for {model_type}')\n",
745 | " self.article_len = article_len\n",
746 | " if model_type == 'bert':\n",
747 | " self.backbone = self.initialize_bert()\n",
748 | " self.tokenizer = self.initialize_bert_tokenizer()\n",
749 | " elif model_type == 'roberta':\n",
750 | " self.backbone = self.initialize_roberta()\n",
751 | " self.tokenizer = self.initialize_roberta_tokenizer()\n",
752 | " else:\n",
753 | " raise Exception(f\"Wrong model_type argument: {model_type}\")\n",
754 | " self.backbone.resize_token_embeddings(EMBEDDINGS_ADDITIONAL + self.tokenizer.vocab_size)\n",
755 | "\n",
756 | " if additional_features:\n",
757 | " print('Adding additional features double fully connected nn')\n",
758 | " self.features = nn.Sequential(\n",
759 | " nn.Linear(num_features, FEATURES_NN_INTERMEDIATE),\n",
760 | " nn.LeakyReLU(),\n",
761 | " nn.Dropout(FEATURES_DROPOUT),\n",
762 | " nn.Linear(FEATURES_NN_INTERMEDIATE, FEATURES_NN_OUT)\n",
763 | " )\n",
764 | " else:\n",
765 | " self.features = None\n",
766 | "\n",
767 | " print('Initialize backbone embeddings pulling')\n",
768 | "\n",
769 | " def backbone_forward(input_ids, attention_mask, token_type_ids, position_ids):\n",
770 | " return self.backbone(\n",
771 | " input_ids=input_ids,\n",
772 | " attention_mask=attention_mask,\n",
773 | " token_type_ids=token_type_ids,\n",
774 | " position_ids=position_ids\n",
775 | " )\n",
776 | "\n",
777 | " self.encoder = lambda *args: backbone_forward(*args)[0]\n",
778 | "\n",
779 | " print('Initialize decoder')\n",
780 | " if additional_features:\n",
781 | " self.decoder = Classifier(768 + FEATURES_NN_OUT) # Default BERT output with additional features\n",
782 | " else:\n",
783 | " self.decoder = Classifier(768) # Default BERT output\n",
784 | "\n",
785 | " def expand_positional_embs_if_need(self):\n",
786 | " print('Expand positional embeddings if need')\n",
787 | " print('Positional embeddings', self.backbone.config.max_position_embeddings, self.article_len)\n",
788 | " if self.article_len > self.backbone.config.max_position_embeddings:\n",
789 | " old_maxlen = self.backbone.config.max_position_embeddings\n",
790 | " old_w = self.backbone.embeddings.position_embeddings.weight\n",
791 | " logging.info(f'Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}')\n",
792 | " self.backbone.embeddings.position_embeddings = nn.Embedding(\n",
793 | " self.article_len, self.backbone.config.hidden_size\n",
794 | " )\n",
795 | " self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w)\n",
796 | " self.backbone.config.max_position_embeddings = self.article_len\n",
797 | " print('New positional embeddings', self.backbone.config.max_position_embeddings)\n",
798 | "\n",
799 | " @staticmethod\n",
800 | " def initialize_bert():\n",
801 | " return BertModel.from_pretrained(\n",
802 | " \"bert-base-uncased\", output_hidden_states=False\n",
803 | " )\n",
804 | "\n",
805 | " @staticmethod\n",
806 | " def initialize_bert_tokenizer():\n",
807 | " tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
808 | " tokenizer.BOS = \"[CLS]\"\n",
809 | " tokenizer.EOS = \"[SEP]\"\n",
810 | " tokenizer.PAD = \"[PAD]\"\n",
811 | " return tokenizer\n",
812 | "\n",
813 | " @staticmethod\n",
814 | " def initialize_roberta():\n",
815 | " backbone = RobertaModel.from_pretrained(\n",
816 | " 'roberta-base', output_hidden_states=False\n",
817 | " )\n",
818 | " print('initialize token type emb, by default roberta doesnt have it')\n",
819 | " backbone.config.type_vocab_size = 2\n",
820 | " backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size)\n",
821 | " backbone.embeddings.token_type_embeddings.weight.data.normal_(\n",
822 | " mean=0.0, std=backbone.config.initializer_range\n",
823 | " )\n",
824 | " return backbone\n",
825 | "\n",
826 | " @staticmethod\n",
827 | " def initialize_roberta_tokenizer():\n",
828 | " tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)\n",
829 | " tokenizer.BOS = \"\"\n",
830 | " tokenizer.EOS = \"\"\n",
831 | " tokenizer.PAD = \"\"\n",
832 | " return tokenizer\n",
833 | "\n",
834 | " def save(self, save_filename):\n",
835 | " \"\"\" Save model in filename\n",
836 | "\n",
837 | " :param save_filename: str\n",
838 | " \"\"\"\n",
839 | " state = dict(\n",
840 | " encoder_dict=self.backbone.state_dict(),\n",
841 | " decoder_dict=self.decoder.state_dict()\n",
842 | " )\n",
843 | " if self.features:\n",
844 | " state['features_dict'] = self.features.state_dict()\n",
845 | " models_folder = os.path.expanduser(cfg.weights_path)\n",
846 | " if not os.path.exists(models_folder):\n",
847 | " os.makedirs(models_folder)\n",
848 | " torch.save(state, f\"{models_folder}/{save_filename}.pth\")\n",
849 | "\n",
850 | " def load(self, load_filename):\n",
851 | " path = f\"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth\"\n",
852 | " state = torch.load(path, map_location=lambda storage, location: storage)\n",
853 | " self.backbone.load_state_dict(state['encoder_dict'])\n",
854 | " self.decoder.load_state_dict(state['decoder_dict'])\n",
855 | " if self.features:\n",
856 | " self.features.load_state_dict(state['features_dict'])\n",
857 | "\n",
858 | " def froze_backbone(self, froze_strategy):\n",
859 | " if froze_strategy == 'froze_all':\n",
860 | " for param in self.backbone.parameters():\n",
861 | " param.requires_grad_(False)\n",
862 | "\n",
863 | " elif froze_strategy == 'unfroze_last':\n",
864 | " for name, param in self.backbone.named_parameters():\n",
865 | " param.requires_grad_(\n",
866 | " 'encoder.layer.11' in name or\n",
867 | " 'encoder.layer.10' in name or\n",
868 | " 'encoder.layer.9' in name\n",
869 | " )\n",
870 | "\n",
871 | " elif froze_strategy == 'unfroze_all':\n",
872 | " for param in self.backbone.parameters():\n",
873 | " param.requires_grad_(True)\n",
874 | "\n",
875 | " else:\n",
876 | " raise Exception(f'Unsupported froze strategy {froze_strategy}')\n",
877 | "\n",
878 | " def unfroze_head(self):\n",
879 | " for param in self.decoder.parameters():\n",
880 | " param.requires_grad_(True)\n",
881 | "\n",
882 | " def forward(self, input_ids, attention_mask, token_type_ids, input_features=None):\n",
883 | " \"\"\"\n",
884 | " :param input_ids: torch.Size([batch_size, article_len])\n",
885 | " Indices of input sequence tokens in the vocabulary.\n",
886 | " :param attention_mask: torch.Size([batch_size, article_len])\n",
887 | " Mask to avoid performing attention on padding token indices.\n",
888 | " Mask values selected in `[0, 1]`:\n",
889 | " - 1 for tokens that are **not masked**,\n",
890 | " - 0 for tokens that are **masked**.\n",
891 | " :param token_type_ids: torch.Size([batch_size, article_len])\n",
892 | " Segment token indices to indicate first and second portions of the inputs.\n",
893 | " Indices are selected in `[0, 1]`:\n",
894 | " - 0 corresponds to a *sentence A* token,\n",
895 | " - 1 corresponds to a *sentence B* token.\n",
896 | " :return: scores | torch.Size([batch_size, summary_len])\n",
897 | " \"\"\"\n",
898 | "\n",
899 | " # The output of [CLS] is inferred by all other words in this sentence.\n",
900 | " # This makes [CLS] a good representation for sentence-level classification.\n",
901 | " cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS))\n",
902 | " print(f'cls_mask {cls_mask.shape}')\n",
903 | "\n",
904 | " # Indices of positions of each input sequence tokens in the position embeddings.\n",
905 | " # position ids | torch.Size([batch_size, article_len])\n",
906 | " pos_ids = torch.arange(\n",
907 | " 0,\n",
908 | " self.article_len,\n",
909 | " dtype=torch.long,\n",
910 | " device=input_ids.device\n",
911 | " ).unsqueeze(0).repeat(len(input_ids), 1)\n",
912 | " print(f'pos_ids {pos_ids.shape}')\n",
913 | " # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])\n",
914 | " # for each word in the input, the BERT base internally creates a 768-dimensional output,\n",
915 | " # but for tasks like classification, we do not actually require the output for all the embeddings.\n",
916 | " # So by default, BERT considers only the output corresponding to the first token [CLS]\n",
917 | " # and drops the output vectors corresponding to all the other tokens.\n",
918 | " enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)\n",
919 | "\n",
920 | " if self.features:\n",
921 | " out_features = self.features(input_features)\n",
922 | " scores = self.decoder(torch.cat([enc_output[cls_mask], out_features], dim=-1))\n",
923 | " else:\n",
924 | " print('enc_output', enc_output.shape)\n",
925 | " print('enc_output[cls_mask]', enc_output[cls_mask].shape)\n",
926 | " scores = self.decoder(enc_output[cls_mask])\n",
927 | " print('scores', scores.shape)\n",
928 | "\n",
929 | " return scores\n",
930 | "\n",
931 | " def evaluate(self, input_ids, attention_mask, token_type_ids, input_features=None):\n",
932 | " \"\"\"See forward for parameters and output description\"\"\"\n",
933 | "\n",
934 | " # The output of [CLS] is inferred by all other words in this sentence.\n",
935 | " # This makes [CLS] a good representation for sentence-level classification.\n",
936 | " cls_mask = (input_ids == get_token_id(self.tokenizer, self.tokenizer.BOS))\n",
937 | "\n",
938 | " # position ids | torch.Size([batch_size, article_len])\n",
939 | " pos_ids = torch.arange(\n",
940 | " 0,\n",
941 | " self.article_len,\n",
942 | " dtype=torch.long,\n",
943 | " device=input_ids.device\n",
944 | " ).unsqueeze(0).repeat(len(input_ids), 1)\n",
945 | "\n",
946 | " # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])\n",
947 | " enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)\n",
948 | "\n",
949 | " scores = []\n",
950 | " for eo, cm in zip(enc_output, cls_mask):\n",
951 | " if self.features:\n",
952 | " out_features = self.features(input_features)\n",
953 | " score = self.decoder.evaluate(torch.cat([eo[cm], out_features], dim=-1))\n",
954 | " else:\n",
955 | " score = self.decoder.evaluate(eo[cm])\n",
956 | " scores.append(score)\n",
957 | " return scores\n",
958 | "\n",
959 | "\n",
960 | "class Classifier(nn.Module):\n",
961 | " def __init__(self, hidden_size):\n",
962 | " super(Classifier, self).__init__()\n",
963 | " self.dropout = nn.Dropout(DECODER_DROPOUT)\n",
964 | " self.linear = nn.Linear(hidden_size, 1)\n",
965 | " self.sigmoid = nn.Sigmoid()\n",
966 | "\n",
967 | " def forward(self, x):\n",
968 | " return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1))\n",
969 | "\n",
970 | " def evaluate(self, x):\n",
971 | " return self.sigmoid(self.linear(self.dropout(x)).squeeze(-1))\n",
972 | "\n",
973 | "\n",
974 | "def create_model(model_type, froze_strategy, article_len, additional_features):\n",
975 | " model = Summarizer(model_type, article_len, additional_features)\n",
976 | " model.expand_positional_embs_if_need()\n",
977 | " # Load intermediate model\n",
978 | " # model.load('temp')\n",
979 | " model.froze_backbone(froze_strategy)\n",
980 | " model.unfroze_head()\n",
981 | " if additional_features:\n",
982 | " print('Parameters for features NN', sum(p.numel() for p in model.features.parameters() if p.requires_grad))\n",
983 | " print('Parameters for backbone', sum(p.numel() for p in model.backbone.parameters() if p.requires_grad))\n",
984 | " print('Parameters for classifier', sum(p.numel() for p in model.decoder.parameters() if p.requires_grad))\n",
985 | " return model"
986 | ],
987 | "metadata": {
988 | "collapsed": false,
989 | "pycharm": {
990 | "name": "#%%\n"
991 | }
992 | }
993 | },
994 | {
995 | "cell_type": "markdown",
996 | "source": [
997 | "# Preprocessing text for BERT"
998 | ],
999 | "metadata": {
1000 | "collapsed": false,
1001 | "pycharm": {
1002 | "name": "#%% md\n"
1003 | }
1004 | }
1005 | },
1006 | {
1007 | "cell_type": "code",
1008 | "execution_count": null,
1009 | "outputs": [],
1010 | "source": [
1011 | "# TODO: likely incorrect training scheme\n",
1012 | "# BERT was pretrained to support only a single [CLS] token in the input.\n",
1013 | "# [SEP] and token_type_ids are used to process two connected sentences in a way:\n",
1014 | "# [CLS] ... sent1 ... [SEP] ... sent2 ... [SEP] [PAD] ... with token_type_ids\n",
1015 | "# 0 ... ... ... ... 0 1 ... ... ... ... ... 1 ... ...\n",
1016 | "# Current learning strategy may produce incorrect results.\n",
1017 | "# We can concat max N sentences into a single input and use average score for them.\n",
1018 | "\n",
1019 | "def preprocess_text(text, max_len, tokenizer):\n",
1020 | " \"\"\"\n",
1021 | " Preprocess text for BERT / ROBERTA model.\n",
1022 | " NOTE: not all the text can be processed because of max_len.\n",
1023 | " IMPORTANT: preprocessed text may contain more than one [CLS] token!!! and more than two sentences!!!\n",
1024 | " :param text: list(list(str))\n",
1025 | " :param max_len: maximum length of preprocessing\n",
1026 | " :param tokenizer: BERT or ROBERTA tokenizer\n",
1027 | " :return:\n",
1028 | " ids | tokenized ids of length max_len, 0 if padding\n",
1029 | " attention_mask | list(str) 1 if real token, not padding\n",
1030 | " token_type_ids | 0-1 for different sentences\n",
1031 | " n_sents | number of actual sentences encoded\n",
1032 | " \"\"\"\n",
1033 | " sents = [\n",
1034 | " [tokenizer.BOS] + tokenizer.tokenize(sent) + [tokenizer.EOS] for sent in text\n",
1035 | " ]\n",
1036 | " logging.debug(f'sents {sents}')\n",
1037 | " ids, token_type_ids, segment_signature = [], [], 0\n",
1038 | " n_sents = 0\n",
1039 | " for i, s in enumerate(sents):\n",
1040 | " logging.debug(f'sentence {i} {s}')\n",
1041 | " logging.debug(f'ids {len(ids)}')\n",
1042 | " logging.debug(f'segments {len(token_type_ids)}')\n",
1043 | " logging.debug(f'segment_signature {segment_signature}')\n",
1044 | " if len(ids) + len(s) <= max_len:\n",
1045 | " n_sents += 1\n",
1046 | " ids.extend(tokenizer.convert_tokens_to_ids(s))\n",
1047 | " token_type_ids.extend([segment_signature] * len(s))\n",
1048 | " segment_signature = (segment_signature + 1) % 2\n",
1049 | " else:\n",
1050 | " logging.debug(f'break, len(s)={len(s)}')\n",
1051 | " break\n",
1052 | " attention_mask = [1] * len(ids)\n",
1053 | "\n",
1054 | " logging.debug('Padding data')\n",
1055 | " pad_len = max(0, max_len - len(ids))\n",
1056 | " ids += [get_token_id(tokenizer, tokenizer.PAD)] * pad_len\n",
1057 | " attention_mask += [0] * pad_len\n",
1058 | " token_type_ids += [segment_signature] * pad_len\n",
1059 | " assert len(ids) == len(attention_mask)\n",
1060 | " assert len(ids) == len(token_type_ids)\n",
1061 | " return ids, attention_mask, token_type_ids, n_sents"
1062 | ],
1063 | "metadata": {
1064 | "collapsed": false,
1065 | "pycharm": {
1066 | "name": "#%%\n",
1067 | "is_executing": true
1068 | }
1069 | }
1070 | },
1071 | {
1072 | "cell_type": "code",
1073 | "execution_count": null,
1074 | "outputs": [],
1075 | "source": [
1076 | "print('Inspect preprocessing text for BERT')\n",
1077 | "\n",
1078 | "tokenizer = Summarizer.initialize_bert_tokenizer()\n",
1079 | "text = list(train_df.head(5)['sentence'])\n",
1080 | "input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(text, 512, tokenizer)\n",
1081 | "\n",
1082 | "print(f'input_ids {input_ids}')\n",
1083 | "print(f'attention_mask {attention_mask}')\n",
1084 | "print(f'token_type_ids {token_type_ids}')\n",
1085 | "print(f'n_sents {n_sents}')"
1086 | ],
1087 | "metadata": {
1088 | "collapsed": false,
1089 | "pycharm": {
1090 | "name": "#%%\n"
1091 | }
1092 | }
1093 | },
1094 | {
1095 | "cell_type": "markdown",
1096 | "source": [
1097 | "## Train and evaluate functions"
1098 | ],
1099 | "metadata": {
1100 | "collapsed": false,
1101 | "pycharm": {
1102 | "name": "#%% md\n"
1103 | }
1104 | }
1105 | },
1106 | {
1107 | "cell_type": "code",
1108 | "execution_count": null,
1109 | "outputs": [],
1110 | "source": [
1111 | "import torch.nn as nn\n",
1112 | "from torch.optim.optimizer import Optimizer\n",
1113 | "import math\n",
1114 | "\n",
1115 | "\n",
1116 | "def get_enc_lr(optimizer):\n",
1117 | " return optimizer.param_groups[0]['lr']\n",
1118 | "\n",
1119 | "\n",
1120 | "def get_dec_lr(optimizer):\n",
1121 | " return optimizer.param_groups[1]['lr']\n",
1122 | "\n",
1123 | "\n",
1124 | "def backward_step(loss: torch.Tensor, model: nn.Module, clip: float):\n",
1125 | " loss.backward()\n",
1126 | " total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
1127 | " return total_norm\n",
1128 | "\n",
1129 | "\n",
1130 | "def batch_to_device(batch, additional_features, device):\n",
1131 | " batch_on_device = [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]\n",
1132 | " if additional_features:\n",
1133 | " input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_on_device\n",
1134 | " input_features = torch.cat(input_features).to(device)\n",
1135 | " else:\n",
1136 | " input_ids, attention_mask, token_type_ids, target_scores = batch_on_device\n",
1137 | " input_features = None\n",
1138 | " return input_ids, attention_mask, token_type_ids, target_scores, input_features\n",
1139 | "\n",
1140 | "\n",
1141 | "def train_fun(\n",
1142 | " model,\n",
1143 | " dataloader,\n",
1144 | " optimizer,\n",
1145 | " scheduler,\n",
1146 | " criter,\n",
1147 | " device,\n",
1148 | " writer,\n",
1149 | " additional_features\n",
1150 | "):\n",
1151 | " # Put the model into training mode. Don't be mislead--the call to\n",
1152 | " # `train` just changes the *mode*, it doesn't *perform* the training.\n",
1153 | " # `dropout` and `batchnorm` layers behave differently during training\n",
1154 | " # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)\n",
1155 | " model.train()\n",
1156 | " loss_val = 0\n",
1157 | " target_val = 0\n",
1158 | " mean_sents = 0\n",
1159 | " szs = 0\n",
1160 | "\n",
1161 | " for idx_batch, batch in enumerate(dataloader):\n",
1162 | " input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_to_device(\n",
1163 | " batch, additional_features, device\n",
1164 | " )\n",
1165 | "\n",
1166 | " sizes = [dc.shape[0] for dc in target_scores]\n",
1167 | " mean_sents += sum(sizes)\n",
1168 | " szs += len(sizes)\n",
1169 | "\n",
1170 | " # forward pass\n",
1171 | " print(f'batch {idx_batch}')\n",
1172 | " print('forward')\n",
1173 | " print(f'input_ids {input_ids.shape}')\n",
1174 | " print(f'attention_mask {attention_mask.shape}')\n",
1175 | " print(f'token_type_ids {token_type_ids.shape}')\n",
1176 | " if additional_features:\n",
1177 | " print(f'input_features {input_features.shape}')\n",
1178 | "\n",
1179 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n",
1180 | " print(f'models_scores {model_scores.shape}')\n",
1181 | " print(f'target_scores {len(target_scores)}')\n",
1182 | "\n",
1183 | " target_scores = torch.cat(target_scores).to(device)\n",
1184 | " target_val += sum(target_scores) / len(target_scores)\n",
1185 | " try:\n",
1186 | " # loss\n",
1187 | " loss = criter(model_scores, target_scores, )\n",
1188 | " loss_val += loss.item()\n",
1189 | " print(f'loss {loss}')\n",
1190 | " except Exception:\n",
1191 | " print(idx_batch, model_scores.shape, target_scores.shape, token_type_ids)\n",
1192 | " return\n",
1193 | "\n",
1194 | " # backward\n",
1195 | " print('backward')\n",
1196 | " grad_norm = backward_step(loss, model, optimizer.clip_value)\n",
1197 | " grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm\n",
1198 | "\n",
1199 | " # record a loss value\n",
1200 | " print(f'{idx_batch} / {len(dataloader)} train loss {loss.item()}')\n",
1201 | " print(f'\\r{idx_batch} / {len(dataloader)} train loss {loss.item()}', end='')\n",
1202 | " writer.add_scalar(f\"Train/loss\", loss.item(), writer.train_step)\n",
1203 | " writer.add_scalar(\"Train/grad_norm\", grad_norm, writer.train_step)\n",
1204 | " writer.add_scalar(\"Train/lr_enc\", get_enc_lr(optimizer), writer.train_step)\n",
1205 | " writer.add_scalar(\"Train/lr_dec\", get_dec_lr(optimizer), writer.train_step)\n",
1206 | " writer.train_step += 1\n",
1207 | "\n",
1208 | " # make a gradient step\n",
1209 | " if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):\n",
1210 | " print('optimizer step')\n",
1211 | " optimizer.step()\n",
1212 | "\n",
1213 | " # Always clear any previously calculated gradients before performing a\n",
1214 | " # backward pass. PyTorch doesn't do this automatically because\n",
1215 | " # accumulating the gradients is \"convenient while training RNNs\".\n",
1216 | " # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)\n",
1217 | " optimizer.zero_grad()\n",
1218 | "\n",
1219 | " print('scheduler step')\n",
1220 | " scheduler.step()\n",
1221 | "\n",
1222 | " print(\"\\rTrain loss:\", loss_val / len(dataloader), f\"{100 * loss_val / target_val:.5f}%\")\n",
1223 | " # print(\"Train mean sent len:\", mean_sents / szs)\n",
1224 | "\n",
1225 | " # save model, just in case\n",
1226 | " model.save('temp')\n",
1227 | " print('save model to temp')\n",
1228 | "\n",
1229 | " return model, optimizer, scheduler, writer\n",
1230 | "\n",
1231 | "\n",
1232 | "def evaluate_fun(\n",
1233 | " model,\n",
1234 | " dataloader,\n",
1235 | " criter,\n",
1236 | " device,\n",
1237 | " writer,\n",
1238 | " additional_features\n",
1239 | "):\n",
1240 | " # Put the model in evaluation mode--the dropout layers behave differently\n",
1241 | " # during evaluation.\n",
1242 | " model.eval()\n",
1243 | " loss_val = 0\n",
1244 | " target_val = 0\n",
1245 | " mean_sents = 0\n",
1246 | " szs = 0\n",
1247 | "\n",
1248 | " for idx_batch, batch in enumerate(dataloader):\n",
1249 | " input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_to_device(\n",
1250 | " batch, additional_features, device\n",
1251 | " )\n",
1252 | " sizes = [dc.shape[0] for dc in target_scores]\n",
1253 | " mean_sents += sum(sizes)\n",
1254 | " szs += len(sizes)\n",
1255 | " target_scores = torch.cat(target_scores).to(device)\n",
1256 | "\n",
1257 | " # evaluate pass\n",
1258 | " print('evaluate')\n",
1259 | " print(f'input_ids {input_ids.shape}')\n",
1260 | " print(f'attention_mask {attention_mask.shape}')\n",
1261 | " print(f'token_type_ids {token_type_ids.shape}')\n",
1262 | " if additional_features:\n",
1263 | " print(f'input_features {input_features.shape}')\n",
1264 | " # Tell pytorch not to bother with constructing the compute graph during\n",
1265 | " # the forward pass, since this is only needed for backprop (training).\n",
1266 | " with torch.no_grad():\n",
1267 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n",
1268 | " print(f'model_scores {model_scores.shape}')\n",
1269 | " print(f'target_scores {len(target_scores)}')\n",
1270 | "\n",
1271 | " # loss\n",
1272 | " loss = criter(model_scores, target_scores, )\n",
1273 | " target_val += sum(target_scores) / len(target_scores)\n",
1274 | "\n",
1275 | " # record a loss value\n",
1276 | " print(f'{idx_batch} / {len(dataloader)} val loss {loss.item()}')\n",
1277 | " print(f'\\r{idx_batch} / {len(dataloader)} val loss {loss.item()}', end='')\n",
1278 | " loss_val += loss.item()\n",
1279 | " writer.add_scalar(f\"Eval/loss\", loss.item(), writer.train_step)\n",
1280 | " writer.train_step += 1\n",
1281 | "\n",
1282 | " print(\"\\rValidate loss:\", loss_val / len(dataloader), f\"{100 * loss_val / target_val:.5f}%\")\n",
1283 | " # print(\"Validate mean sent len:\", mean_sents / szs)\n",
1284 | "\n",
1285 | " # save model, just in case\n",
1286 | " model.save('validated_weights')\n",
1287 | " print('save model to validated_weights')\n",
1288 | "\n",
1289 | " return model"
1290 | ],
1291 | "metadata": {
1292 | "collapsed": false,
1293 | "pycharm": {
1294 | "name": "#%%\n"
1295 | }
1296 | }
1297 | },
1298 | {
1299 | "cell_type": "markdown",
1300 | "source": [
1301 | "# Dataset and dataloader classes"
1302 | ],
1303 | "metadata": {
1304 | "collapsed": false,
1305 | "pycharm": {
1306 | "name": "#%% md\n"
1307 | }
1308 | }
1309 | },
1310 | {
1311 | "cell_type": "code",
1312 | "execution_count": null,
1313 | "outputs": [],
1314 | "source": [
1315 | "from torch.utils.data import Dataset, DataLoader\n",
1316 | "\n",
1317 | "# Overlap helps to keep context and connect different parts of abstract\n",
1318 | "OVERLAP = 5\n",
1319 | "\n",
1320 | "\n",
1321 | "class TrainDataset(Dataset):\n",
1322 | " \"\"\" Custom Train Dataset for data with additional features.\n",
1323 | " First preprocess all the data and then give out the batches.\n",
1324 | " It implements overlapping between batches to keep context between train examples.\n",
1325 | " \"\"\"\n",
1326 | "\n",
1327 | " def __init__(self, dataframe, tokenizer, article_len, additional_features):\n",
1328 | " self.df = dataframe\n",
1329 | " self.data = []\n",
1330 | " self.pmids = list(set(dataframe['pmid'].values))\n",
1331 | " self.tokenizer = tokenizer\n",
1332 | " self.article_len = article_len\n",
1333 | " self.additional_features = additional_features\n",
1334 | " # Create a list of test inputs for each pmid\n",
1335 | " for pmid in tqdm(self.pmids):\n",
1336 | " ex = self.df[self.df['pmid'] == pmid]\n",
1337 | " text = ex['sentence'].values\n",
1338 | " features = np.nan_to_num(\n",
1339 | " ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',\n",
1340 | " 'mean_r_fig', 'mean_r_tab',\n",
1341 | " 'min_r_fig', 'min_r_tab',\n",
1342 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n",
1343 | " ) if additional_features else None\n",
1344 | " # Preprocessing BERT cannot encode all the text,\n",
1345 | " # only limited number of sentences per single model run is supported.\n",
1346 | " total_sents = 0\n",
1347 | " while total_sents < len(text):\n",
1348 | " offset = max(0, total_sents - OVERLAP)\n",
1349 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n",
1350 | " text[offset:], self.article_len, self.tokenizer\n",
1351 | " )\n",
1352 | " if n_sents <= OVERLAP:\n",
1353 | " total_sents += 1\n",
1354 | " continue\n",
1355 | " total_sents = offset + n_sents\n",
1356 | " target_scores = ex['score'].values[offset: offset + n_sents] / 100\n",
1357 | " input_features = features[offset: offset + n_sents] if additional_features else None\n",
1358 | " # print(f'Train dataset example {len(self.data)}\\n'\n",
1359 | " # f'input_ids {input_ids}\\n'\n",
1360 | " # f'attention_mask {attention_mask}\\n'\n",
1361 | " # f'token_type_ids {token_type_ids}\\n'\n",
1362 | " # f'target_scores {target_scores}\\n'\n",
1363 | " # f'features {input_features}')\n",
1364 | " if additional_features:\n",
1365 | " self.data.append((input_ids, attention_mask, token_type_ids, target_scores, input_features))\n",
1366 | " else:\n",
1367 | " self.data.append((input_ids, attention_mask, token_type_ids, target_scores))\n",
1368 | "\n",
1369 | " logging.info(f'Train dataset size {len(self.data)}')\n",
1370 | "\n",
1371 | " def __getitem__(self, idx):\n",
1372 | " return self.data[idx]\n",
1373 | "\n",
1374 | " def __len__(self):\n",
1375 | " return len(self.data)\n",
1376 | "\n",
1377 | "\n",
1378 | "class EvalDataset(Dataset):\n",
1379 | " \"\"\" Custom Valid/Test Dataset\n",
1380 | " \"\"\"\n",
1381 | "\n",
1382 | " def __init__(self, dataframe, tokenizer, article_len, additional_features):\n",
1383 | " self.df = dataframe\n",
1384 | " self.pmids = list(set(dataframe['pmid'].values))\n",
1385 | " logging.info(f'Eval dataset size {len(self.pmids)}')\n",
1386 | " self.tokenizer = tokenizer\n",
1387 | " self.article_len = article_len\n",
1388 | " self.additional_features = additional_features\n",
1389 | "\n",
1390 | " def __getitem__(self, idx):\n",
1391 | " pmid = self.pmids[idx]\n",
1392 | " ex = self.df[self.df['pmid'] == pmid]\n",
1393 | " paper = ex['sentence'].values\n",
1394 | " features = np.nan_to_num(\n",
1395 | " ex[['sent_id', 'sent_type', 'r_abs',\n",
1396 | " 'num_refs', 'mean_r_fig', 'mean_r_tab',\n",
1397 | " 'min_r_fig', 'min_r_tab',\n",
1398 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n",
1399 | " ) if self.additional_features else None\n",
1400 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n",
1401 | " paper, self.article_len, self.tokenizer\n",
1402 | " )\n",
1403 | "\n",
1404 | " # form target\n",
1405 | " target_scores = ex['score'].values[:n_sents] / 100\n",
1406 | " input_features = features[:n_sents] if self.additional_features else None\n",
1407 | " # print(f'Eval dataset example {idx}\\n'\n",
1408 | " # f'input_ids {input_ids}\\n'\n",
1409 | " # f'attention_mask {attention_mask}\\n'\n",
1410 | " # f'token_type_ids {token_type_ids}\\n'\n",
1411 | " # f'target_scores {target_scores}\\n'\n",
1412 | " # f'features {input_features}')\n",
1413 | " if self.additional_features:\n",
1414 | " return input_ids, attention_mask, token_type_ids, target_scores, input_features\n",
1415 | " else:\n",
1416 | " return input_ids, attention_mask, token_type_ids, target_scores\n",
1417 | "\n",
1418 | " def __len__(self):\n",
1419 | " return len(self.pmids)\n",
1420 | "\n",
1421 | "\n",
1422 | "def create_collate_fn(additional_features):\n",
1423 | " \"\"\"Create Function to pull batch for train / eval.\"\"\"\n",
1424 | "\n",
1425 | " def _collate_fn(batch_data):\n",
1426 | " \"\"\"\n",
1427 | " :param batch_data: list of `TrainDataset` or `EvalDataset` Examples\n",
1428 | " :return: one batch of data\n",
1429 | " \"\"\"\n",
1430 | " data = list(zip(*batch_data))\n",
1431 | " result = [\n",
1432 | " torch.tensor(data[0], dtype=torch.long),\n",
1433 | " torch.tensor(data[1], dtype=torch.long),\n",
1434 | " torch.tensor(data[2], dtype=torch.long),\n",
1435 | " [torch.tensor(e, dtype=torch.float) for e in data[3]]\n",
1436 | " ]\n",
1437 | " if additional_features:\n",
1438 | " result.append([torch.tensor(e, dtype=torch.float) for e in data[4]])\n",
1439 | " return result\n",
1440 | "\n",
1441 | " return _collate_fn\n",
1442 | "\n",
1443 | "\n",
1444 | "# The DataLoader needs to know our batch size for training, so we specify it\n",
1445 | "# here. For fine-tuning BERT on a specific task, the authors recommend a batch\n",
1446 | "# size of 16 or 32.\n",
1447 | "BATCH_SIZE = 16\n",
1448 | "\n",
1449 | "\n",
1450 | "def get_dataloaders(train, val, batch_size, article_len, tokenizer, additional_features):\n",
1451 | " logging.info('Creating train dataset...')\n",
1452 | " train_ds = TrainDataset(train, tokenizer, article_len, additional_features)\n",
1453 | "\n",
1454 | " logging.info('Applying loader functions to train...')\n",
1455 | " train_dl = DataLoader(\n",
1456 | " dataset=train_ds, batch_size=batch_size, shuffle=False,\n",
1457 | " pin_memory=True, collate_fn=create_collate_fn(additional_features), num_workers=1\n",
1458 | " )\n",
1459 | "\n",
1460 | " logging.info('Creating val dataset...')\n",
1461 | " val_ds = EvalDataset(val, tokenizer, article_len, additional_features)\n",
1462 | "\n",
1463 | " logging.info('Applying loader functions to val...')\n",
1464 | " val_dl = DataLoader(\n",
1465 | " dataset=val_ds, batch_size=batch_size, shuffle=False,\n",
1466 | " pin_memory=True, collate_fn=create_collate_fn(additional_features), num_workers=1\n",
1467 | " )\n",
1468 | "\n",
1469 | " return train_dl, val_dl"
1470 | ],
1471 | "metadata": {
1472 | "collapsed": false,
1473 | "pycharm": {
1474 | "name": "#%%\n"
1475 | }
1476 | }
1477 | },
1478 | {
1479 | "cell_type": "code",
1480 | "execution_count": null,
1481 | "outputs": [],
1482 | "source": [
1483 | "train_dl, _ = get_dataloaders(\n",
1484 | " train, val, 1, 512, create_model('bert', 'froze_all', 512, False).tokenizer, False\n",
1485 | ")\n",
1486 | "for batch in train_dl:\n",
1487 | " break\n",
1488 | "batch"
1489 | ],
1490 | "metadata": {
1491 | "collapsed": false,
1492 | "pycharm": {
1493 | "name": "#%%\n"
1494 | }
1495 | }
1496 | },
1497 | {
1498 | "cell_type": "markdown",
1499 | "source": [
1500 | "# Custom scheduler used in training"
1501 | ],
1502 | "metadata": {
1503 | "collapsed": false,
1504 | "pycharm": {
1505 | "name": "#%% md\n"
1506 | }
1507 | }
1508 | },
1509 | {
1510 | "cell_type": "code",
1511 | "execution_count": null,
1512 | "outputs": [],
1513 | "source": [
1514 | "from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR\n",
1515 | "\n",
1516 | "\n",
1517 | "class CustomScheduler(_LRScheduler):\n",
1518 | " timestep: int = 0\n",
1519 | "\n",
1520 | " def __init__(self, optimizer, gamma, warmup=None):\n",
1521 | " self.optimizer = optimizer\n",
1522 | " self.after_warmup = ExponentialLR(optimizer, gamma=gamma)\n",
1523 | " self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]\n",
1524 | " self.warmup = 0 if warmup is None else warmup\n",
1525 | " super(CustomScheduler, self).__init__(optimizer)\n",
1526 | "\n",
1527 | " def get_lr(self):\n",
1528 | " return [self.timestep * group_init_lr / self.warmup for group_init_lr in\n",
1529 | " self.initial_lrs] if self.timestep < self.warmup else self.after_warmup.get_lr()\n",
1530 | "\n",
1531 | " def step(self, epoch=None):\n",
1532 | " if self.timestep < self.warmup:\n",
1533 | " self.timestep += 1\n",
1534 | " super(CustomScheduler, self).step(epoch)\n",
1535 | " else:\n",
1536 | " self.after_warmup.step(epoch)\n",
1537 | "\n",
1538 | "\n",
1539 | "class NoamScheduler(_LRScheduler):\n",
1540 | " \"\"\"\n",
1541 | " Noam optimizer has a warm-up period and then an exponentially decaying learning.\n",
1542 | " This is the PyTorch implementation of optimizer introduced in the paper \"Attention is all you need\"\n",
1543 | " \"\"\"\n",
1544 | "\n",
1545 | " def __init__(self, optimizer, warmup):\n",
1546 | " assert warmup > 0\n",
1547 | " self.optimizer = optimizer\n",
1548 | " self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]\n",
1549 | " self.warmup = warmup\n",
1550 | " self.timestep = 0\n",
1551 | " super(NoamScheduler, self).__init__(optimizer)\n",
1552 | "\n",
1553 | " def get_lr(self):\n",
1554 | " noam_lr = self.get_noam_lr()\n",
1555 | " return [group_init_lr * noam_lr for group_init_lr in self.initial_lrs]\n",
1556 | "\n",
1557 | " def get_noam_lr(self):\n",
1558 | " return min(self.timestep ** -0.5, self.timestep * self.warmup ** -1.5)\n",
1559 | "\n",
1560 | " def step(self, epoch=None):\n",
1561 | " self.timestep += 1\n",
1562 | " super(NoamScheduler, self).step(epoch)"
1563 | ],
1564 | "metadata": {
1565 | "collapsed": false,
1566 | "pycharm": {
1567 | "name": "#%%\n"
1568 | }
1569 | }
1570 | },
1571 | {
1572 | "cell_type": "markdown",
1573 | "source": [
1574 | "# Prepare and configure training"
1575 | ],
1576 | "metadata": {
1577 | "collapsed": false,
1578 | "pycharm": {
1579 | "name": "#%% md\n"
1580 | }
1581 | }
1582 | },
1583 | {
1584 | "cell_type": "code",
1585 | "execution_count": null,
1586 | "outputs": [],
1587 | "source": [
1588 | "from torch.optim import AdamW\n",
1589 | "from torch.nn import MSELoss\n",
1590 | "from tensorboardX import SummaryWriter\n",
1591 | "\n",
1592 | "ENCODER_LEARNING_RATE = 0.0001\n",
1593 | "DECODER_LEARNING_RATE = 0.001\n",
1594 | "\n",
1595 | "WARMUP = 5\n",
1596 | "WEIGHT_DECAY = 0.01\n",
1597 | "CLIP_VALUE = 1.0\n",
1598 | "ACCUMULATION_INTERVAL = 1\n",
1599 | "\n",
1600 | "# Number of training epochs. The BERT authors recommend between 2 and 4.\n",
1601 | "# We chose to run for 4, but we'll see later that this may be over-fitting the\n",
1602 | "# training data.\n",
1603 | "EPOCHS_NUMBER = 10\n",
1604 | "\n",
1605 | "\n",
1606 | "def prepare_learning_tools(\n",
1607 | " model,\n",
1608 | " enc_lr=ENCODER_LEARNING_RATE,\n",
1609 | " dec_lr=DECODER_LEARNING_RATE,\n",
1610 | " warmup=WARMUP,\n",
1611 | " weight_decay=WEIGHT_DECAY,\n",
1612 | " clip_value=CLIP_VALUE,\n",
1613 | " accumulation_interval=ACCUMULATION_INTERVAL\n",
1614 | "):\n",
1615 | " # TODO fix for Roberta model\n",
1616 | " enc_parameters = [\n",
1617 | " param for name, param in model.named_parameters()\n",
1618 | " if param.requires_grad and name.startswith('bert.')\n",
1619 | " ]\n",
1620 | " dec_parameters = [\n",
1621 | " param for name, param in model.named_parameters() if param.requires_grad\n",
1622 | " ]\n",
1623 | " optimizer = AdamW([\n",
1624 | " dict(params=enc_parameters, lr=enc_lr),\n",
1625 | " dict(params=dec_parameters, lr=dec_lr),\n",
1626 | " ], weight_decay=weight_decay)\n",
1627 | " optimizer.clip_value = clip_value\n",
1628 | " optimizer.accumulation_interval = accumulation_interval\n",
1629 | "\n",
1630 | " scheduler = NoamScheduler(optimizer, warmup=warmup)\n",
1631 | " criter = MSELoss()\n",
1632 | "\n",
1633 | " return optimizer, scheduler, criter\n",
1634 | "\n",
1635 | "\n",
1636 | "def load_or_train_model(model, device, additional_features):\n",
1637 | " model_name = f'learn_simple_berta_{additional_features}.pth'.lower()\n",
1638 | " MODEL_PATH = f'{os.path.expanduser(cfg.weights_path)}/{model_name}'\n",
1639 | " ! rm {MODEL_PATH}\n",
1640 | "\n",
1641 | " if os.path.exists(MODEL_PATH):\n",
1642 | " logging.info(f'Loading model {MODEL_PATH}')\n",
1643 | " model.load(model_name)\n",
1644 | " model, device = setup_cuda_device(model)\n",
1645 | " return model\n",
1646 | " else:\n",
1647 | " logging.info('Create dataloaders...')\n",
1648 | " train_loader, valid_loader = get_dataloaders(\n",
1649 | " train, val, BATCH_SIZE, ARTICLE_LENGTH, model.tokenizer, additional_features\n",
1650 | " )\n",
1651 | "\n",
1652 | " writer = SummaryWriter(log_dir=os.path.expanduser(cfg.log_path))\n",
1653 | " writer.train_step, writer.eval_step = 0, 0\n",
1654 | "\n",
1655 | " optimizer, scheduler, criter = prepare_learning_tools(model)\n",
1656 | "\n",
1657 | " logging.info(f\"Start training {EPOCHS_NUMBER} epochs...\")\n",
1658 | " for epoch in tqdm(range(1, EPOCHS_NUMBER + 1)):\n",
1659 | " print(f'Epoch {epoch}')\n",
1660 | " model, optimizer, scheduler, writer = train_fun(\n",
1661 | " model, train_loader, optimizer, scheduler,\n",
1662 | " criter, device, writer, additional_features\n",
1663 | " )\n",
1664 | " model = evaluate_fun(\n",
1665 | " model, valid_loader, criter, device, writer, additional_features\n",
1666 | " )\n",
1667 | " logging.info(f\"Done training {EPOCHS_NUMBER} epochs...\")\n",
1668 | " logging.info(f'Save trained model to {model_name}')\n",
1669 | " model.save(model_name)\n",
1670 | "\n",
1671 | " return model"
1672 | ],
1673 | "metadata": {
1674 | "collapsed": false,
1675 | "pycharm": {
1676 | "name": "#%%\n"
1677 | }
1678 | }
1679 | },
1680 | {
1681 | "cell_type": "markdown",
1682 | "source": [
1683 | "# Create and train model"
1684 | ],
1685 | "metadata": {
1686 | "collapsed": false,
1687 | "pycharm": {
1688 | "name": "#%% md\n"
1689 | }
1690 | }
1691 | },
1692 | {
1693 | "cell_type": "code",
1694 | "execution_count": null,
1695 | "outputs": [],
1696 | "source": [
1697 | "ARTICLE_LENGTH = 512\n",
1698 | "\n",
1699 | "model = create_model(\"bert\", \"froze_all\", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)\n",
1700 | "model, device = setup_cuda_device(model)"
1701 | ],
1702 | "metadata": {
1703 | "collapsed": false,
1704 | "pycharm": {
1705 | "name": "#%%\n"
1706 | }
1707 | }
1708 | },
1709 | {
1710 | "cell_type": "code",
1711 | "execution_count": null,
1712 | "outputs": [],
1713 | "source": [
1714 | "model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)"
1715 | ],
1716 | "metadata": {
1717 | "collapsed": false,
1718 | "pycharm": {
1719 | "name": "#%%\n",
1720 | "is_executing": true
1721 | }
1722 | }
1723 | },
1724 | {
1725 | "cell_type": "markdown",
1726 | "source": [
1727 | "# Example of model predictions"
1728 | ],
1729 | "metadata": {
1730 | "collapsed": false,
1731 | "pycharm": {
1732 | "name": "#%% md\n"
1733 | }
1734 | }
1735 | },
1736 | {
1737 | "cell_type": "code",
1738 | "execution_count": null,
1739 | "outputs": [],
1740 | "source": [
1741 | "print('Prepare data for model')\n",
1742 | "ex = val[val['pmid'] == val['pmid'].values[0]]\n",
1743 | "print('ex', len(ex))\n",
1744 | "text = ex['sentence'].values\n",
1745 | "\n",
1746 | "input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n",
1747 | " text, ARTICLE_LENGTH, model.tokenizer\n",
1748 | ")\n",
1749 | "print('n_sents', n_sents)\n",
1750 | "res_sents = text[:n_sents]\n",
1751 | "scores = ex['score'].values[:n_sents]\n",
1752 | "\n",
1753 | "input_ids = torch.tensor([input_ids]).to(device)\n",
1754 | "attention_mask = torch.tensor([attention_mask]).to(device)\n",
1755 | "token_type_ids = torch.tensor([token_type_ids]).to(device)\n",
1756 | "\n",
1757 | "if ADDITIONAL_FEATURES:\n",
1758 | " features = np.nan_to_num(\n",
1759 | " ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',\n",
1760 | " 'mean_r_fig', 'mean_r_tab',\n",
1761 | " 'min_r_fig', 'min_r_tab',\n",
1762 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n",
1763 | " )\n",
1764 | " features = features[:n_sents]\n",
1765 | " features = [torch.tensor(e, dtype=torch.float) for e in features]\n",
1766 | " features = torch.stack(features).to(device)\n",
1767 | "else:\n",
1768 | " features = None\n",
1769 | "\n",
1770 | "print('Apply model')\n",
1771 | "model_scores = model(input_ids, attention_mask, token_type_ids, features)\n",
1772 | "\n",
1773 | "to_show_df = pd.DataFrame(\n",
1774 | " dict(sentence=res_sents, ideal_score=scores / 100, res_score=model_scores.cpu().detach().numpy())\n",
1775 | ")\n",
1776 | "display(to_show_df.head())\n",
1777 | "print(((to_show_df['ideal_score'].values - to_show_df['res_score'].values) ** 2).mean() ** 0.5)"
1778 | ],
1779 | "metadata": {
1780 | "collapsed": false,
1781 | "pycharm": {
1782 | "name": "#%%\n"
1783 | }
1784 | }
1785 | },
1786 | {
1787 | "cell_type": "markdown",
1788 | "source": [
1789 | "# Evaluate model performance"
1790 | ],
1791 | "metadata": {
1792 | "collapsed": false,
1793 | "pycharm": {
1794 | "name": "#%% md\n"
1795 | }
1796 | }
1797 | },
1798 | {
1799 | "cell_type": "code",
1800 | "execution_count": null,
1801 | "outputs": [],
1802 | "source": [
1803 | "if 'model' not in globals():\n",
1804 | " model = create_model(\"bert\", \"froze_all\", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)\n",
1805 | " model, device = setup_cuda_device(model)\n",
1806 | " model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)"
1807 | ],
1808 | "metadata": {
1809 | "collapsed": false,
1810 | "pycharm": {
1811 | "name": "#%%\n"
1812 | }
1813 | }
1814 | },
1815 | {
1816 | "cell_type": "code",
1817 | "execution_count": null,
1818 | "outputs": [],
1819 | "source": [
1820 | "print('Prepare refs_and_scores dataset')\n",
1821 | "REF_SCORES_PATH = os.path.expanduser(f\"{cfg.base_path}/refs_and_scores.csv\")\n",
1822 | "! rm {REF_SCORES_PATH}\n",
1823 | "\n",
1824 | "if os.path.exists(REF_SCORES_PATH):\n",
1825 | " final_ref_show_df = pd.read_csv(REF_SCORES_PATH)\n",
1826 | "else:\n",
1827 | " to_show_ref = pd.merge(train_df, ref_sents_df[['ref_pmid', 'sentence']],\n",
1828 | " left_on=['pmid'], right_on=['ref_pmid'])\n",
1829 | " to_show_ref = to_show_ref.rename(columns=dict(sentence_x='sentence', sentence_y='ref_sentence'))\n",
1830 | " to_show_ref = to_show_ref[['pmid', 'sentence', 'ref_sentence', 'score']]\n",
1831 | " final_ref_show_dic = dict(pmid=[], sentence=[], ref_sentence=[], score=[])\n",
1832 | " ite = [(pmid, sent) for pmid, sent in to_show_ref[['pmid', 'sentence']].values]\n",
1833 | "\n",
1834 | " for pmid, sent in tqdm(set(ite)):\n",
1835 | " refs_df = to_show_ref[(to_show_ref['pmid'] == pmid) & (to_show_ref['sentence'] == sent)]\n",
1836 | " final_ref_show_dic['pmid'].append(pmid)\n",
1837 | " final_ref_show_dic['sentence'].append(sent)\n",
1838 | " final_ref_show_dic['ref_sentence'].append(\" \".join(refs_df['ref_sentence'].values))\n",
1839 | " final_ref_show_dic['score'].append(refs_df['score'].values[0])\n",
1840 | " final_ref_show_df = pd.DataFrame(final_ref_show_dic)\n",
1841 | " final_ref_show_df.to_csv(REF_SCORES_PATH, index=False)\n",
1842 | "\n",
1843 | "display(final_ref_show_df)"
1844 | ],
1845 | "metadata": {
1846 | "collapsed": false,
1847 | "pycharm": {
1848 | "name": "#%%\n"
1849 | }
1850 | }
1851 | },
1852 | {
1853 | "cell_type": "code",
1854 | "execution_count": null,
1855 | "outputs": [],
1856 | "source": [
1857 | "print('Prepare dataset to estimate performance')\n",
1858 | "to_test = final_ref_show_df[final_ref_show_df['pmid'].isin(set(val['pmid'].values))]\n",
1859 | "if ADDITIONAL_FEATURES:\n",
1860 | " to_test = pd.merge(to_test, train_df[['pmid', 'sentence',\n",
1861 | " 'sent_id', 'sent_type', 'r_abs', 'num_refs',\n",
1862 | " 'mean_r_fig', 'mean_r_tab',\n",
1863 | " 'min_r_fig', 'min_r_tab',\n",
1864 | " 'max_r_fig', 'max_r_tab']],\n",
1865 | " left_on=['pmid', 'sentence'], right_on=['pmid', 'sentence'])\n",
1866 | "display(to_test)"
1867 | ],
1868 | "metadata": {
1869 | "collapsed": false,
1870 | "pycharm": {
1871 | "name": "#%%\n"
1872 | }
1873 | }
1874 | },
1875 | {
1876 | "cell_type": "code",
1877 | "execution_count": null,
1878 | "outputs": [],
1879 | "source": [
1880 | "print('Using model for predictions')\n",
1881 | "res = dict(pmid=[], sentence=[], ref_sentences=[], score=[], res_score=[])\n",
1882 | "\n",
1883 | "for pmid in tqdm(set(to_test['pmid'].values)):\n",
1884 | " ex = to_test[to_test['pmid'] == pmid]\n",
1885 | " text = ex['sentence'].values\n",
1886 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n",
1887 | " text, ARTICLE_LENGTH, model.tokenizer\n",
1888 | " )\n",
1889 | " res_sents = text[:n_sents]\n",
1890 | " scores = ex['score'].values[:n_sents] / 100\n",
1891 | " input_ids = torch.tensor([input_ids]).to(device)\n",
1892 | " attention_mask = torch.tensor([attention_mask]).to(device)\n",
1893 | " token_type_ids = torch.tensor([token_type_ids]).to(device)\n",
1894 | " if ADDITIONAL_FEATURES:\n",
1895 | " features = np.nan_to_num(\n",
1896 | " ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',\n",
1897 | " 'mean_r_fig', 'mean_r_tab',\n",
1898 | " 'min_r_fig', 'min_r_tab',\n",
1899 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n",
1900 | " )\n",
1901 | " features = features[:n_sents]\n",
1902 | " input_features = [torch.tensor(e, dtype=torch.float) for e in features]\n",
1903 | " input_features = torch.stack(input_features).to(device)\n",
1904 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n",
1905 | " else:\n",
1906 | " model_scores = model(input_ids, attention_mask, token_type_ids)\n",
1907 | " for sent, sc, res_sc in zip(res_sents, scores, model_scores.cpu().detach().numpy()):\n",
1908 | " res['pmid'].append(pmid)\n",
1909 | " res['sentence'].append(sent)\n",
1910 | " res['ref_sentences'].append(ex['ref_sentence'].values[0])\n",
1911 | " res['score'].append(sc)\n",
1912 | " res['res_score'].append(res_sc)\n",
1913 | "\n",
1914 | "res_df = pd.DataFrame(res)\n",
1915 | "display(res_df.head())\n",
1916 | "res_df.to_csv(f\"{cfg.base_path}/saved_example_refs.csv\")\n",
1917 | "\n",
1918 | "print('MSE score', ((res_df['score'].values - res_df['res_score'].values) ** 2).mean() ** 0.5)"
1919 | ],
1920 | "metadata": {
1921 | "collapsed": false,
1922 | "pycharm": {
1923 | "name": "#%%\n"
1924 | }
1925 | }
1926 | },
1927 | {
1928 | "cell_type": "markdown",
1929 | "source": [
1930 | "# Quality analysis"
1931 | ],
1932 | "metadata": {
1933 | "collapsed": false,
1934 | "pycharm": {
1935 | "name": "#%% md\n"
1936 | }
1937 | }
1938 | },
1939 | {
1940 | "cell_type": "code",
1941 | "execution_count": null,
1942 | "outputs": [],
1943 | "source": [
1944 | "if 'model' not in globals():\n",
1945 | " model = create_model(\"bert\", \"froze_all\", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)\n",
1946 | " model, device = setup_cuda_device(model)\n",
1947 | " model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)"
1948 | ],
1949 | "metadata": {
1950 | "collapsed": false,
1951 | "pycharm": {
1952 | "name": "#%%\n"
1953 | }
1954 | }
1955 | },
1956 | {
1957 | "cell_type": "code",
1958 | "execution_count": null,
1959 | "outputs": [],
1960 | "source": [
1961 | "print('Searching for review papers')\n",
1962 | "inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)\n",
1963 | "review_papers = list(set(ref_sents_df[ref_sents_df['ref_pmid'].isin(inter)]['pmid'].values))\n",
1964 | "print('Review papers', len(review_papers))"
1965 | ],
1966 | "metadata": {
1967 | "collapsed": false,
1968 | "pycharm": {
1969 | "name": "#%%\n"
1970 | }
1971 | }
1972 | },
1973 | {
1974 | "cell_type": "code",
1975 | "execution_count": null,
1976 | "outputs": [],
1977 | "source": [
1978 | "test_stat = dict(rev_pmid=[], sent_num=[], rouge=[], true_rouge=[], diff_papers=[])\n",
1979 | "\n",
1980 | "for rev_id in tqdm(review_papers):\n",
1981 | " paper_ref = sentences_df[sentences_df['pmid'] == rev_id]['sentence'].values\n",
1982 | " papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid'] == rev_id]['ref_pmid'].values))\n",
1983 | " result = {'pmid': [], 'sentence': [], 'score': []}\n",
1984 | " for paper_id in papers_to_check:\n",
1985 | " ex = test[test['pmid'] == paper_id]\n",
1986 | " text = ex['sentence'].values\n",
1987 | "\n",
1988 | " features = np.nan_to_num(\n",
1989 | " ex[['sent_id', 'sent_type', 'r_abs',\n",
1990 | " 'num_refs', 'mean_r_fig', 'mean_r_tab',\n",
1991 | " 'min_r_fig', 'min_r_tab',\n",
1992 | " 'max_r_fig', 'max_r_tab']].values.astype(float)\n",
1993 | " ) if ADDITIONAL_FEATURES else None\n",
1994 | " total_sents = 0\n",
1995 | " while total_sents < len(text):\n",
1996 | " offset = max(0, total_sents - OVERLAP)\n",
1997 | " input_ids, attention_mask, token_type_ids, n_sents = preprocess_text(\n",
1998 | " text[offset:], ARTICLE_LENGTH, model.tokenizer\n",
1999 | " )\n",
2000 | " if n_sents <= OVERLAP:\n",
2001 | " total_sents += 1\n",
2002 | " continue\n",
2003 | " old_total = total_sents\n",
2004 | " total_sents = offset + n_sents\n",
2005 | " input_ids = torch.tensor([input_ids]).to(device)\n",
2006 | " attention_mask = torch.tensor([attention_mask]).to(device)\n",
2007 | " token_type_ids = torch.tensor([token_type_ids]).to(device)\n",
2008 | " if ADDITIONAL_FEATURES:\n",
2009 | " input_features = [torch.tensor(e, dtype=torch.float) for e in features[offset:total_sents]]\n",
2010 | " input_features = torch.stack(input_features).to(device)\n",
2011 | " model_scores = model(input_ids, attention_mask, token_type_ids, input_features)\n",
2012 | " else:\n",
2013 | " model_scores = model(input_ids, attention_mask, token_type_ids)\n",
2014 | "\n",
2015 | " result['pmid'].extend([paper_id] * (total_sents - old_total))\n",
2016 | " result['sentence'].extend(list(text[old_total:total_sents]))\n",
2017 | " result['score'].extend(list(model_scores.cpu().detach().numpy())[old_total - offset:])\n",
2018 | "\n",
2019 | " res_df = pd.DataFrame(result)\n",
2020 | " sorted_arr = sorted(list(res_df['score'].values))\n",
2021 | " for i in range(5, 103, 5):\n",
2022 | " if len(sorted_arr) < i:\n",
2023 | " break\n",
2024 | " threshold = sorted_arr[-i]\n",
2025 | " final_text = res_df[res_df['score'] >= threshold][['pmid', 'sentence']]\n",
2026 | " mean_score = 0\n",
2027 | " num = 0\n",
2028 | " for sent in final_text['sentence'].values:\n",
2029 | " for ref_sent in paper_ref:\n",
2030 | " try:\n",
2031 | " mean_score += get_rouge(sent, ref_sent)\n",
2032 | " num += 1\n",
2033 | " except Exception:\n",
2034 | " continue\n",
2035 | " mean_score /= num\n",
2036 | " real_score = get_rouge(\" \".join(final_text['sentence'].values), \" \".join(paper_ref))\n",
2037 | " test_stat['rev_pmid'].append(rev_id)\n",
2038 | " test_stat['sent_num'].append(i)\n",
2039 | " print(len(\" \".join(final_text['sentence'].values)), len(\" \".join(paper_ref)))\n",
2040 | "\n",
2041 | " test_stat['rouge'].append(mean_score)\n",
2042 | " test_stat['true_rouge'].append(real_score)\n",
2043 | " test_stat['diff_papers'].append(len(set(final_text['pmid'])))\n",
2044 | "\n",
2045 | "test_stat_df = pd.DataFrame(test_stat)\n",
2046 | "test_stat_df"
2047 | ],
2048 | "metadata": {
2049 | "collapsed": false,
2050 | "pycharm": {
2051 | "name": "#%%\n"
2052 | }
2053 | }
2054 | },
2055 | {
2056 | "cell_type": "code",
2057 | "execution_count": null,
2058 | "outputs": [],
2059 | "source": [
2060 | "print(*[len(arr) for key, arr in test_stat.items()])"
2061 | ],
2062 | "metadata": {
2063 | "collapsed": false,
2064 | "pycharm": {
2065 | "name": "#%%\n"
2066 | }
2067 | }
2068 | },
2069 | {
2070 | "cell_type": "code",
2071 | "execution_count": null,
2072 | "outputs": [],
2073 | "source": [
2074 | "test_stat_df.to_csv(f\"{cfg.base_path}/simple_right_test_on_review_{ADDITIONAL_FEATURES}.csv\", index=False)"
2075 | ],
2076 | "metadata": {
2077 | "collapsed": false,
2078 | "pycharm": {
2079 | "name": "#%%\n"
2080 | }
2081 | }
2082 | },
2083 | {
2084 | "cell_type": "code",
2085 | "execution_count": null,
2086 | "outputs": [],
2087 | "source": [
2088 | "rouge_means = []\n",
2089 | "rouge_err = []\n",
2090 | "papers_means = []\n",
2091 | "papers_err = []\n",
2092 | "\n",
2093 | "for i in range(5, 103, 5):\n",
2094 | " tmp = test_stat_df.groupby(['sent_num']).get_group(i)\n",
2095 | " rouge_means.append(tmp['rouge'].mean())\n",
2096 | " rouge_err.append(tmp['rouge'].std())\n",
2097 | " papers_means.append(tmp['diff_papers'].mean())\n",
2098 | " papers_err.append(tmp['diff_papers'].std())"
2099 | ],
2100 | "metadata": {
2101 | "collapsed": false,
2102 | "pycharm": {
2103 | "name": "#%%\n"
2104 | }
2105 | }
2106 | },
2107 | {
2108 | "cell_type": "code",
2109 | "execution_count": null,
2110 | "outputs": [],
2111 | "source": [
2112 | "plt.errorbar(list(range(5, 103, 5)), rouge_means, yerr=rouge_err, fmt='-o')\n",
2113 | "plt.title('Mean rouge value')\n",
2114 | "plt.show()"
2115 | ],
2116 | "metadata": {
2117 | "collapsed": false,
2118 | "pycharm": {
2119 | "name": "#%%\n"
2120 | }
2121 | }
2122 | },
2123 | {
2124 | "cell_type": "code",
2125 | "execution_count": null,
2126 | "outputs": [],
2127 | "source": [
2128 | "plt.errorbar(list(range(5, 103, 5)), papers_means, yerr=papers_err, fmt='-o')\n",
2129 | "plt.title('Mean number of papers')\n",
2130 | "plt.show()"
2131 | ],
2132 | "metadata": {
2133 | "collapsed": false,
2134 | "pycharm": {
2135 | "name": "#%%\n"
2136 | }
2137 | }
2138 | },
2139 | {
2140 | "cell_type": "markdown",
2141 | "source": [
2142 | "# Compare models with and without features"
2143 | ],
2144 | "metadata": {
2145 | "collapsed": false,
2146 | "pycharm": {
2147 | "name": "#%% md\n"
2148 | }
2149 | }
2150 | },
2151 | {
2152 | "cell_type": "code",
2153 | "execution_count": null,
2154 | "outputs": [],
2155 | "source": [
2156 | "import pandas as pd\n",
2157 | "\n",
2158 | "df1 = pd.read_csv(f\"{cfg.base_path}/simple_right_test_on_review_{False}.csv\", index=False)\n",
2159 | "df1 = df1.assign(model=['BERTSUM'] * len(df1))\n",
2160 | "df2 = pd.read_csv(f\"{cfg.base_path}/simple_right_test_on_review_{True}.csv\", index=False)\n",
2161 | "df2 = df1.assign(model=['BERTSUM with features'] * len(df1))\n",
2162 | "draw_df = pd.concat([df1, df2])"
2163 | ],
2164 | "metadata": {
2165 | "collapsed": false,
2166 | "pycharm": {
2167 | "name": "#%%\n"
2168 | }
2169 | }
2170 | },
2171 | {
2172 | "cell_type": "code",
2173 | "execution_count": null,
2174 | "outputs": [],
2175 | "source": [
2176 | "import seaborn as sns\n",
2177 | "\n",
2178 | "sns.catplot(x=\"sent_num\", y=\"rouge\", kind=\"box\", hue='model', aspect=1.7, color='lightblue',\n",
2179 | " data=draw_df).set_axis_labels(\"Number of sentences\", \"ROUGE, %\")\n",
2180 | "plt.show()"
2181 | ],
2182 | "metadata": {
2183 | "collapsed": false,
2184 | "pycharm": {
2185 | "name": "#%%\n"
2186 | }
2187 | }
2188 | },
2189 | {
2190 | "cell_type": "code",
2191 | "execution_count": null,
2192 | "outputs": [],
2193 | "source": [],
2194 | "metadata": {
2195 | "collapsed": false,
2196 | "pycharm": {
2197 | "name": "#%%\n"
2198 | }
2199 | }
2200 | }
2201 | ],
2202 | "metadata": {
2203 | "kernelspec": {
2204 | "display_name": "Python 3 (ipykernel)",
2205 | "language": "python",
2206 | "name": "python3"
2207 | },
2208 | "language_info": {
2209 | "codemirror_mode": {
2210 | "name": "ipython",
2211 | "version": 3
2212 | },
2213 | "file_extension": ".py",
2214 | "mimetype": "text/x-python",
2215 | "name": "python",
2216 | "nbconvert_exporter": "python",
2217 | "pygments_lexer": "ipython3",
2218 | "version": "3.10.4"
2219 | }
2220 | },
2221 | "nbformat": 4,
2222 | "nbformat_minor": 4
2223 | }
--------------------------------------------------------------------------------