├── .gitignore
├── README.md
├── SECURITY.md
├── __init__.py
├── data
├── __init__.py
├── data_utils.py
├── datasets.py
├── datasets
│ ├── video_games
│ │ └── gt
│ └── wines
│ │ └── gt
└── images
│ ├── Results.png
│ ├── inference.png
│ ├── training_intuition.png
│ └── training_intuition2.png
├── instructions
├── CODE_OF_CONDUCT.md
├── LICENSE
├── SECURITY.md
├── SUPPORT.md
└── installation.sh
├── models
├── SDR
│ ├── SDR.py
│ ├── SDR_utils.py
│ └── similarity_modeling.py
├── doc_similarity_pl_template.py
├── reco
│ ├── __init__.py
│ ├── hierarchical_reco.py
│ ├── recos_utils.py
│ └── wiki_recos_eval
│ │ ├── __init__.py
│ │ └── eval_metrics.py
├── transformer_utils.py
└── transformers_base.py
├── sdr_main.py
└── utils
├── __init__.py
├── argparse_init.py
├── logging_utils.py
├── metrics_utils.py
├── model_utils.py
├── pytorch_lightning_utils
├── __init__.py
├── callbacks.py
└── pytorch_lightning_utils.py
├── switch_functions.py
└── torch_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | **/.vscode/*
3 | **/.history/*
4 |
5 | **/datasets/*
6 | !**/datasets/wines
7 | **/datasets/wines/*
8 | !**/datasets/wines/gt
9 | !**/datasets/video_games
10 | **/datasets/video_games/*
11 | !**/datasets/video_games/gt
12 | **/output/*
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference
2 |
3 | This repo is the implementation for [**SDR**](https://arxiv.org/abs/2106.01186).
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## Tested environment
13 | - Python 3.7
14 | - PyTorch 1.7
15 | - CUDA 11.0
16 |
17 | Lower CUDA and PyTorch versions should work as well.
18 |
19 |
20 | ## Contents
21 | - [Installation](#installation)
22 | - [Datasets](#datasets)
23 | - [Train with our datasets](#training)
24 | - [Hierarchical Inference](#inference)
25 | - [Cite](#cite)
26 |
27 | License, Security, support and code of conduct specifications are under the `Instructions` directory.
28 |
29 | ## Installation
30 | Run
31 | ```
32 | bash instructions/installation.sh
33 | ```
34 |
35 |
36 | ## Datasets
37 | The published datasets are:
38 | * Video games
39 | * 21,935 articles
40 | * Expert annotated test set. 90 articles with 12 ground-truth recommendations.
41 | * Examples:
42 | * Grand Theft Auto - Mafia
43 | * Burnout Paradise - Forza Horizon 3
44 | * Wines
45 | * 1635 articles
46 | * Crafted by a human sommelier, 92 articles with ~10 ground-truth recommendations.
47 | * Examples:
48 | * Pinot Meunier - Chardonnay
49 | * Dom Pérignon - Moët & Chandon
50 |
51 | For more details and direct download see [Wines](https://zenodo.org/record/4812960#.YK8zqagzaUk) and [Video Games](https://zenodo.org/record/4812962#.YK8zqqgzaUk).
52 |
53 |
54 |
55 | # Training
56 | **The training process downloads the datasets automatically.**
57 |
58 | ```
59 | python sdr_main.py --dataset_name video_games
60 | ```
61 | The code is based on [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), all PL [hyperparameters](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html) are supported. (`limit_train/val/test_batches, check_val_every_n_epoch` etc.)
62 |
63 | ## Tensorboard support
64 | All metrics are being logged automatically and stored in
65 | ```
66 | SDR/output/document_similarity/SDR/arch_SDR/dataset_name_/
67 | ```
68 | Run `tesnroboard --logdir=` to see the the logs.
69 |
70 |
71 |
72 | # Inference
73 | The hierarchical inference described in the paper is implemented as a stand-alone service and can be used with any backbone algorithm (`models/reco/hierarchical_reco.py`).
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 | ```
82 | python sdr_main.py --dataset_name --resume_from_checkpoint --test_only
83 | ```
84 |
85 | # Results
86 |
87 |
88 |
89 |
90 | # Citing & Authors
91 | If you find this repository or the annotated datasets helpful, feel free to cite our publication -
92 |
93 | SDR: Self-Supervised Document-to-Document Similarity Ranking viaContextualized Language Models and Hierarchical Inference
94 | ```
95 | @misc{ginzburg2021selfsupervised,
96 | title={Self-Supervised Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference},
97 | author={Dvir Ginzburg and Itzik Malkiel and Oren Barkan and Avi Caciularu and Noam Koenigstein},
98 | year={2021},
99 | eprint={2106.01186},
100 | archivePrefix={arXiv},
101 | primaryClass={cs.CL}
102 | }
103 | ```
104 |
105 | Contact: [Dvir Ginzburg](mailto:dvirginz@gmail.com), [Itzik Malkiel](mailto:itzik.malkiel@microsoft.com).
106 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/__init__.py
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/__init__.py
--------------------------------------------------------------------------------
/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from typing import List
3 | import torch
4 | from torch.nn.utils.rnn import pad_sequence
5 |
6 |
7 | def get_gt_seeds_titles(titles=None, dataset_name="wines"):
8 | idxs = None
9 | gt_path = f"data/datasets/{dataset_name}/gt"
10 | popular_titles = list(pickle.load(open(gt_path, "rb")).keys())
11 | if titles != None:
12 | idxs = [titles.index(pop_title) for pop_title in popular_titles if pop_title in titles]
13 | return popular_titles, idxs, gt_path
14 |
15 |
16 | def reco_sentence_test_collate(examples: List[torch.Tensor], tokenizer):
17 | examples_ = []
18 | for example in examples:
19 | sections = []
20 | for section in example:
21 | if section == []:
22 | continue
23 | sections.append(
24 | (
25 | pad_sequence([i[0] for i in section], batch_first=True, padding_value=tokenizer.pad_token_id),
26 | [i[2] for i in section],
27 | [i[3] for i in section],
28 | [i[4] for i in section],
29 | [i[5] for i in section],
30 | [i[6] for i in section],
31 | [i[7] for i in section],
32 | torch.tensor([i[8] for i in section]),
33 | )
34 | )
35 | examples_.append(sections)
36 | return examples_
37 |
38 |
39 | def reco_sentence_collate(examples: List[torch.Tensor], tokenizer):
40 | return (
41 | pad_sequence([i[0] for i in examples], batch_first=True, padding_value=tokenizer.pad_token_id),
42 | [i[2] for i in examples],
43 | [i[3] for i in examples],
44 | [i[4] for i in examples],
45 | [i[5] for i in examples],
46 | [i[6] for i in examples],
47 | [i[7] for i in examples],
48 | torch.tensor([i[8] for i in examples]),
49 | )
50 |
51 |
52 | def raw_data_link(dataset_name):
53 | if dataset_name == "wines":
54 | return "https://zenodo.org/record/4812960/files/wines.txt?download=1"
55 | if dataset_name == "video_games":
56 | return "https://zenodo.org/record/4812962/files/video_games.txt?download=1"
57 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | import ast
2 | from data.data_utils import get_gt_seeds_titles, raw_data_link
3 | import nltk
4 | from torch.utils.data import Dataset
5 | from transformers import PreTrainedTokenizer
6 | import os
7 | import pickle
8 | import numpy as np
9 | from tqdm import tqdm
10 | import torch
11 | import json
12 | import csv
13 | import sys
14 | from models.reco.recos_utils import index_amp
15 |
16 |
17 | nltk.download("punkt")
18 |
19 |
20 | class WikipediaTextDatasetParagraphsSentences(Dataset):
21 | def __init__(self, tokenizer: PreTrainedTokenizer, hparams, dataset_name, block_size, mode="train"):
22 | self.hparams = hparams
23 | cached_features_file = os.path.join(
24 | f"data/datasets/cached_proccessed/{dataset_name}",
25 | f"bs_{block_size}_{dataset_name}_{type(self).__name__}_tokenizer_{str(type(tokenizer)).split('.')[-1][:-2]}_mode_{mode}",
26 | )
27 | self.cached_features_file = cached_features_file
28 | os.makedirs(os.path.dirname(cached_features_file), exist_ok=True)
29 |
30 | raw_data_path = self.download_raw(dataset_name)
31 |
32 | all_articles = self.save_load_splitted_dataset(mode, cached_features_file, raw_data_path)
33 |
34 | self.hparams = hparams
35 |
36 | max_article_len,max_sentences, max_sent_len = int(1e6), 16, 10000
37 | block_size = min(block_size, tokenizer.max_len_sentences_pair) if tokenizer is not None else block_size
38 | self.block_size = block_size
39 | self.tokenizer = tokenizer
40 |
41 | if os.path.exists(cached_features_file) and (self.hparams is None or not self.hparams.overwrite_data_cache):
42 | print("\nLoading features from cached file %s", cached_features_file)
43 | with open(cached_features_file, "rb") as handle:
44 | self.examples, self.indices_map = pickle.load(handle)
45 | else:
46 | print("\nCreating features from dataset file at ", cached_features_file)
47 |
48 | self.examples = []
49 | self.indices_map = []
50 |
51 | for idx_article, article in enumerate(tqdm(all_articles)):
52 | this_sample_sections = []
53 | title, sections = article[0], ast.literal_eval(article[1])
54 | valid_sections_count = 0
55 | for section_idx, section in enumerate(sections):
56 | this_sections_sentences = []
57 | if section[1] == "":
58 | continue
59 | valid_sentences_count = 0
60 | title_with_base_title = "{}:{}".format(title, section[0])
61 | for sent_idx, sent in enumerate(nltk.sent_tokenize(section[1][:max_article_len])[:max_sentences]):
62 | tokenized_desc = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(json.dumps(sent[:max_sent_len])))[
63 | :block_size
64 | ]
65 | this_sections_sentences.append(
66 | (
67 | tokenized_desc,
68 | len(tokenized_desc),
69 | idx_article,
70 | valid_sections_count,
71 | valid_sentences_count,
72 | sent[:max_sent_len],
73 | ),
74 | )
75 | self.indices_map.append((idx_article, valid_sections_count, valid_sentences_count))
76 | valid_sentences_count += 1
77 | this_sample_sections.append((this_sections_sentences, title_with_base_title))
78 | valid_sections_count += 1
79 | self.examples.append((this_sample_sections, title))
80 |
81 | print("\nSaving features into cached file %s", cached_features_file)
82 | with open(cached_features_file, "wb") as handle:
83 | pickle.dump((self.examples, self.indices_map), handle, protocol=pickle.HIGHEST_PROTOCOL)
84 |
85 | self.labels = [idx_article for idx_article, _, _ in self.indices_map]
86 |
87 | def save_load_splitted_dataset(self, mode, cached_features_file, raw_data_path):
88 | proccessed_path = f"{cached_features_file}_EXAMPLES"
89 | if not os.path.exists(proccessed_path):
90 | all_articles = self.read_all_articles(raw_data_path)
91 | indices = list(range(len(all_articles)))
92 | if mode != "test":
93 | train_indices = sorted(
94 | np.random.choice(indices, replace=False, size=int(len(all_articles) * self.hparams.train_val_ratio))
95 | )
96 | val_indices = np.setdiff1d(list(range(len(all_articles))), train_indices)
97 | indices = train_indices if mode == "train" else val_indices
98 |
99 | articles = []
100 | for i in indices:
101 | articles.append(all_articles[i])
102 | all_articles = articles
103 | pickle.dump(all_articles, open(proccessed_path, "wb"))
104 | print(f"\nsaved dataset at {proccessed_path}")
105 | else:
106 | all_articles = pickle.load(open(proccessed_path, "rb"))
107 | setattr(self.hparams, f"{mode}_data_file", proccessed_path)
108 | return all_articles
109 |
110 | def read_all_articles(self, raw_data_path):
111 | csv.field_size_limit(sys.maxsize)
112 | with open(raw_data_path, newline="") as f:
113 | reader = csv.reader(f)
114 | all_articles = list(reader)
115 | return all_articles[1:]
116 |
117 | def download_raw(self, dataset_name):
118 | raw_data_path = f"data/datasets/{dataset_name}/raw_data"
119 | os.makedirs(os.path.dirname(raw_data_path), exist_ok=True)
120 | if not os.path.exists(raw_data_path):
121 | os.system(f"wget -O {raw_data_path} {raw_data_link(dataset_name)}")
122 | return raw_data_path
123 |
124 | def __len__(self):
125 | return len(self.indices_map)
126 |
127 | def __getitem__(self, item):
128 | idx_article, idx_section, idx_sentence = self.indices_map[item]
129 | sent = self.examples[idx_article][0][idx_section][0][idx_sentence]
130 |
131 | return (
132 | torch.tensor(self.tokenizer.build_inputs_with_special_tokens(sent[0]), dtype=torch.long,)[
133 | : self.hparams.limit_tokens
134 | ],
135 | self.examples[idx_article][1],
136 | self.examples[idx_article][0][idx_section][1],
137 | sent[1],
138 | idx_article,
139 | idx_section,
140 | idx_sentence,
141 | item,
142 | self.labels[item],
143 | )
144 |
145 | class WikipediaTextDatasetParagraphsSentencesTest(WikipediaTextDatasetParagraphsSentences):
146 | def __init__(self, tokenizer: PreTrainedTokenizer, hparams, dataset_name, block_size, mode="test"):
147 | super().__init__(tokenizer, hparams, dataset_name, block_size, mode=mode)
148 |
149 | def __len__(self):
150 | return len(self.examples)
151 |
152 | def __getitem__(self, item):
153 | sections = []
154 | for idx_section, section in enumerate(self.examples[item][0]):
155 | sentences = []
156 | for idx_sentence, sentence in enumerate(section[0]):
157 | sentences.append(
158 | (
159 | torch.tensor(self.tokenizer.build_inputs_with_special_tokens(sentence[0]), dtype=torch.long,),
160 | self.examples[item][1],
161 | section[1],
162 | sentence[1],
163 | item,
164 | idx_section,
165 | idx_sentence,
166 | item,
167 | self.labels[item],
168 | )
169 | )
170 | sections.append(sentences)
171 | return sections
172 |
173 |
--------------------------------------------------------------------------------
/data/datasets/video_games/gt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/datasets/video_games/gt
--------------------------------------------------------------------------------
/data/datasets/wines/gt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/datasets/wines/gt
--------------------------------------------------------------------------------
/data/images/Results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/Results.png
--------------------------------------------------------------------------------
/data/images/inference.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/inference.png
--------------------------------------------------------------------------------
/data/images/training_intuition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/training_intuition.png
--------------------------------------------------------------------------------
/data/images/training_intuition2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/training_intuition2.png
--------------------------------------------------------------------------------
/instructions/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/instructions/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/instructions/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/instructions/SUPPORT.md:
--------------------------------------------------------------------------------
1 |
2 | # Support
3 |
4 | ## How to file issues and get help
5 |
6 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
7 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
8 | feature request as a new Issue.
9 |
10 | For help and questions about using this project, please use the issues section as well, we will respond to everyone in time.
11 |
12 | ## Microsoft Support Policy
13 |
14 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
15 |
--------------------------------------------------------------------------------
/instructions/installation.sh:
--------------------------------------------------------------------------------
1 | # The installations script can be executed as a bash script.
2 | conda create -n SDR python=3.7 --yes
3 | source ~/anaconda3/etc/profile.d/conda.sh
4 | conda activate SDR
5 |
6 | conda install -c pytorch pytorch==1.7.0 torchvision cudatoolkit=11.0 --yes
7 | pip install -U cython transformers==3.1.0 nltk pytorch-metric-learning joblib pytorch-lightning==1.1.8 pandas
8 |
--------------------------------------------------------------------------------
/models/SDR/SDR.py:
--------------------------------------------------------------------------------
1 | from data.datasets import (
2 | WikipediaTextDatasetParagraphsSentences,
3 | WikipediaTextDatasetParagraphsSentencesTest,
4 | )
5 | from utils.argparse_init import str2bool
6 | from models.SDR.SDR_utils import MPerClassSamplerDeter
7 | from data.data_utils import get_gt_seeds_titles, reco_sentence_collate, reco_sentence_test_collate
8 | from functools import partial
9 | import os
10 | from models.reco.hierarchical_reco import vectorize_reco_hierarchical
11 | from utils.torch_utils import to_numpy
12 | from models.transformers_base import TransformersBase
13 | from models.doc_similarity_pl_template import DocEmbeddingTemplate
14 | from utils import switch_functions
15 | from models import transformer_utils
16 | import numpy as np
17 | import torch
18 | from utils import metrics_utils
19 | from pytorch_metric_learning.samplers import MPerClassSampler
20 | from torch.utils.data.dataloader import DataLoader
21 | import json
22 |
23 |
24 | class SDR(TransformersBase):
25 |
26 | """
27 | Author: Dvir Ginzburg.
28 |
29 | SDR model (ACL IJCNLP 2021)
30 | """
31 |
32 | def __init__(
33 | self, hparams,
34 | ):
35 | """Stub."""
36 | super(SDR, self).__init__(hparams)
37 |
38 |
39 | def forward_train(self, batch):
40 | inputs, labels = transformer_utils.mask_tokens(batch[0].clone().detach(), self.tokenizer, self.hparams)
41 |
42 | outputs = self.model(
43 | inputs,
44 | masked_lm_labels=labels,
45 | non_masked_input_ids=batch[0],
46 | sample_labels=batch[-1],
47 | run_similarity=True,
48 | run_mlm=True,
49 | )
50 |
51 | self.losses["mlm_loss"] = outputs[0]
52 | self.losses["d2v_loss"] = (outputs[1] or 0) * self.hparams.sim_loss_lambda # If no similarity loss we ignore
53 |
54 | tracked = self.track_metrics(input_ids=inputs, outputs=outputs, is_train=self.hparams.mode == "train", labels=labels,)
55 | self.tracks.update(tracked)
56 |
57 | return
58 |
59 | def forward_val(self, batch):
60 | self.forward_train(batch)
61 |
62 | def test_step(self, batch, batch_idx):
63 | section_out = []
64 | for section in batch[0]: # batch=1 for test
65 | sentences=[]
66 | sentences_embed_per_token = [
67 | self.model(
68 | sentence.unsqueeze(0), masked_lm_labels=None, run_similarity=False
69 | )[5].squeeze(0)
70 | for sentence in section[0][:8]
71 | ]
72 | for idx, sentence in enumerate(sentences_embed_per_token):
73 | sentences.append(sentence[: section[2][idx]].mean(0)) # We take the non-padded tokens and mean them
74 | section_out.append(torch.stack(sentences))
75 | return (section_out, batch[0][0][1][0]) # title name
76 |
77 | def forward(self, batch):
78 | eval(f"self.forward_{self.hparams.mode}")(batch)
79 |
80 | @staticmethod
81 | def track_metrics(
82 | outputs=None, input_ids=None, labels=None, is_train=True, batch_idx=0,
83 | ):
84 | mode = "train" if is_train else "val"
85 |
86 | trackes = {}
87 | lm_pred = np.argmax(outputs[3].cpu().detach().numpy(), axis=2)
88 | labels_numpy = labels.cpu().numpy()
89 | labels_non_zero = labels_numpy[np.array(labels_numpy != -100)] if np.any(labels_numpy != -100) else np.zeros(1)
90 | lm_pred_non_zero = lm_pred[np.array(labels_numpy != -100)] if np.any(labels_numpy != -100) else np.ones(1)
91 | lm_acc = torch.tensor(
92 | metrics_utils.simple_accuracy(lm_pred_non_zero, labels_non_zero), device=outputs[3].device,
93 | ).reshape((1, -1))
94 |
95 | trackes["lm_acc_{}".format(mode)] = lm_acc.detach().cpu()
96 |
97 | return trackes
98 |
99 | def test_epoch_end(self, outputs, recos_path=None):
100 | if self.trainer.checkpoint_callback.last_model_path == "" and self.hparams.resume_from_checkpoint is None:
101 | self.trainer.checkpoint_callback.last_model_path = f"{self.hparams.hparams_dir}/no_train"
102 | elif(self.hparams.resume_from_checkpoint is not None):
103 | self.trainer.checkpoint_callback.last_model_path = self.hparams.resume_from_checkpoint
104 | if recos_path is None:
105 | save_outputs_path = f"{self.trainer.checkpoint_callback.last_model_path}_FEATURES_NumSamples_{len(outputs)}"
106 |
107 | if isinstance(outputs[0][0][0], torch.Tensor):
108 | outputs = [([to_numpy(section) for section in sample[0]], sample[1]) for sample in outputs]
109 | torch.save(outputs, save_outputs_path)
110 | print(f"\nSaved to {save_outputs_path}\n")
111 |
112 | titles = popular_titles = [out[1][:-1] for out in outputs]
113 | idxs, gt_path = list(range(len(titles))), ""
114 |
115 | section_sentences_features = [out[0] for out in outputs]
116 | popular_titles, idxs, gt_path = get_gt_seeds_titles(titles, self.hparams.dataset_name)
117 |
118 | self.hparams.test_sample_size = (
119 | self.hparams.test_sample_size if self.hparams.test_sample_size > 0 else len(popular_titles)
120 | )
121 | idxs = idxs[: self.hparams.test_sample_size]
122 |
123 | recos, metrics = vectorize_reco_hierarchical(
124 | all_features=section_sentences_features,
125 | titles=titles,
126 | gt_path=gt_path,
127 | output_path=self.trainer.checkpoint_callback.last_model_path,
128 | )
129 | metrics = {
130 | "mrr": float(metrics["mrr"]),
131 | "mpr": float(metrics["mpr"]),
132 | **{f"hit_rate_{rate[0]}": float(rate[1]) for rate in metrics["hit_rates"]},
133 | }
134 | print(json.dumps(metrics, indent=2))
135 | for k, v in metrics.items():
136 | self.logger.experiment.add_scalar(k, v, global_step=self.global_step)
137 |
138 | recos_path = os.path.join(
139 | os.path.dirname(self.trainer.checkpoint_callback.last_model_path),
140 | f"{os.path.basename(self.trainer.checkpoint_callback.last_model_path)[:-5]}"
141 | f"_numSamples_{self.hparams.test_sample_size}",
142 | )
143 | torch.save(recos, recos_path)
144 | print("Saving recos in {}".format(recos_path))
145 |
146 | setattr(self.hparams, "recos_path", recos_path)
147 | return
148 |
149 | def dataloader(self, mode=None):
150 | if mode == "train":
151 | sampler = MPerClassSampler(
152 | self.train_dataset.labels,
153 | 2,
154 | batch_size=self.hparams.train_batch_size,
155 | length_before_new_iter=(self.hparams.limit_train_batches) * self.hparams.train_batch_size,
156 | )
157 |
158 | loader = DataLoader(
159 | self.train_dataset,
160 | num_workers=self.hparams.num_data_workers,
161 | sampler=sampler,
162 | batch_size=self.hparams.train_batch_size,
163 | collate_fn=partial(reco_sentence_collate, tokenizer=self.tokenizer,),
164 | )
165 |
166 | elif mode == "val":
167 | sampler = MPerClassSamplerDeter(
168 | self.val_dataset.labels,
169 | 2,
170 | length_before_new_iter=self.hparams.limit_val_indices_batches,
171 | batch_size=self.hparams.val_batch_size,
172 | )
173 |
174 | loader = DataLoader(
175 | self.val_dataset,
176 | num_workers=self.hparams.num_data_workers,
177 | sampler=sampler,
178 | batch_size=self.hparams.val_batch_size,
179 | collate_fn=partial(reco_sentence_collate, tokenizer=self.tokenizer,),
180 | )
181 |
182 | else:
183 | loader = DataLoader(
184 | self.test_dataset,
185 | num_workers=self.hparams.num_data_workers,
186 | batch_size=self.hparams.test_batch_size,
187 | collate_fn=partial(reco_sentence_test_collate, tokenizer=self.tokenizer,),
188 | )
189 | return loader
190 |
191 | @staticmethod
192 | def add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False):
193 | parser = TransformersBase.add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False)
194 | parser.add_argument("--hard_mine", type=str2bool, nargs="?", const=True, default=True)
195 | parser.add_argument("--metric_loss_func", type=str, default="ContrastiveLoss") # TripletMarginLoss #CosineLoss
196 | parser.add_argument("--sim_loss_lambda", type=float, default=0.1)
197 | parser.add_argument("--limit_tokens", type=int, default=64)
198 | parser.add_argument("--limit_val_indices_batches", type=int, default=500)
199 | parser.add_argument("--metric_for_similarity", type=str, choices=["cosine", "norm_euc"], default="cosine")
200 |
201 | parser.set_defaults(
202 | check_val_every_n_epoch=1,
203 | batch_size=32,
204 | accumulate_grad_batches=2,
205 | metric_to_track="train_mlm_loss_epoch",
206 | )
207 |
208 | return parser
209 |
210 | def prepare_data(self):
211 | block_size = (
212 | self.hparams.block_size
213 | if hasattr(self.hparams, "block_size")
214 | and self.hparams.block_size > 0
215 | and self.hparams.block_size < self.tokenizer.max_len
216 | else self.tokenizer.max_len
217 | )
218 | self.train_dataset = WikipediaTextDatasetParagraphsSentences(
219 | tokenizer=self.tokenizer,
220 | hparams=self.hparams,
221 | dataset_name=self.hparams.dataset_name,
222 | block_size=block_size,
223 | mode="train",
224 | )
225 | self.val_dataset = WikipediaTextDatasetParagraphsSentences(
226 | tokenizer=self.tokenizer,
227 | hparams=self.hparams,
228 | dataset_name=self.hparams.dataset_name,
229 | block_size=block_size,
230 | mode="val",
231 | )
232 | self.val_dataset.indices_map = self.val_dataset.indices_map[: self.hparams.limit_val_indices_batches]
233 | self.val_dataset.labels = self.val_dataset.labels[: self.hparams.limit_val_indices_batches]
234 |
235 | self.test_dataset = WikipediaTextDatasetParagraphsSentencesTest(
236 | tokenizer=self.tokenizer,
237 | hparams=self.hparams,
238 | dataset_name=self.hparams.dataset_name,
239 | block_size=block_size,
240 | mode="test",
241 | )
242 |
243 |
--------------------------------------------------------------------------------
/models/SDR/SDR_utils.py:
--------------------------------------------------------------------------------
1 | from pytorch_metric_learning.samplers.m_per_class_sampler import MPerClassSampler
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 | from pytorch_metric_learning.utils import common_functions as c_f
5 |
6 | # modified from
7 | # https://raw.githubusercontent.com/bnulihaixia/Deep_metric/master/utils/sampler.py
8 | class MPerClassSamplerDeter(MPerClassSampler):
9 | """
10 | At every iteration, this will return m samples per class. For example,
11 | if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples
12 | each will be returned
13 | """
14 |
15 | def __init__(self, labels, m, batch_size=None, length_before_new_iter=100000):
16 | super(MPerClassSamplerDeter, self).__init__(labels, m, batch_size, int(length_before_new_iter))
17 | self.shuffled_idx_list = None
18 |
19 | def __iter__(self):
20 | idx_list = [0] * self.list_size
21 | i = 0
22 | num_iters = self.calculate_num_iters()
23 | if self.shuffled_idx_list is None:
24 | for _ in range(num_iters):
25 |
26 | c_f.NUMPY_RANDOM.shuffle(self.labels)
27 | if self.batch_size is None:
28 | curr_label_set = self.labels
29 | else:
30 | curr_label_set = self.labels[: self.batch_size // self.m_per_class]
31 | for label in curr_label_set:
32 | t = self.labels_to_indices[label]
33 | idx_list[i : i + self.m_per_class] = c_f.safe_random_choice(t, size=self.m_per_class)
34 | i += self.m_per_class
35 | self.shuffled_idx_list = idx_list
36 | return iter(self.shuffled_idx_list)
37 |
--------------------------------------------------------------------------------
/models/SDR/similarity_modeling.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import logging
3 | import math
4 | from pytorch_metric_learning.distances.lp_distance import LpDistance
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import CrossEntropyLoss
9 | from torch.nn.functional import gelu
10 | from pytorch_metric_learning import miners, losses, reducers
11 |
12 | from transformers.configuration_roberta import RobertaConfig
13 | from transformers.modeling_bert import BertLayerNorm, BertPreTrainedModel
14 | from transformers.modeling_roberta import RobertaModel, RobertaLMHead
15 | from pytorch_metric_learning.distances import CosineSimilarity
16 |
17 |
18 | class SimilarityModeling(BertPreTrainedModel):
19 | config_class = RobertaConfig
20 | base_model_prefix = "roberta"
21 |
22 | def __init__(self, config, hparams):
23 | super().__init__(config)
24 | self.hparams = hparams
25 | config.output_hidden_states = True
26 |
27 | self.roberta = RobertaModel(config)
28 | self.lm_head = RobertaLMHead(config)
29 | self.init_weights()
30 |
31 | if self.hparams.metric_for_similarity == "cosine":
32 | self.metric = CosineSimilarity()
33 | pos_margin, neg_margin = 1, 0
34 | neg_margin = -1 if (getattr(self.hparams, "metric_loss_func", "ContrastiveLoss") == "CosineLoss") else 0
35 | elif self.hparams.metric_for_similarity == "norm_euc":
36 | self.metric = LpDistance(normalize_embeddings=True, p=2)
37 | pos_margin, neg_margin = 0, 1
38 |
39 | self.reducer = reducers.DoNothingReducer()
40 | if self.hparams.hard_mine:
41 | self.miner_func = miners.MultiSimilarityMiner()
42 | else:
43 | self.miner_func = miners.BatchEasyHardMiner(
44 | pos_strategy=miners.BatchEasyHardMiner.ALL,
45 | neg_strategy=miners.BatchEasyHardMiner.ALL,
46 | distance=CosineSimilarity(),
47 | )
48 |
49 | if getattr(self.hparams, "metric_loss_func", "ContrastiveLoss") in ["ContrastiveLoss", "CosineLoss"]:
50 | self.similarity_loss_func = losses.ContrastiveLoss(
51 | pos_margin=pos_margin, neg_margin=neg_margin, distance=self.metric
52 | ) # |np-sp|_+ + |sn-mn|_+ so for cossim we do pos_m=1 and neg_m=0
53 | else:
54 | self.similarity_loss_func = losses.TripletMarginLoss(margin=1, distance=self.metric)
55 |
56 | def get_output_embeddings(self):
57 | return self.lm_head.decoder
58 |
59 | @staticmethod
60 | def mean_mask(features, mask):
61 | return (features * mask.unsqueeze(-1)).sum(1) / mask.sum(-1, keepdim=True)
62 |
63 | def forward(
64 | self,
65 | input_ids=None,
66 | sample_labels=None,
67 | samples_idxs=None,
68 | track_sim_dict=None,
69 | non_masked_input_ids=None,
70 | attention_mask=None,
71 | token_type_ids=None,
72 | position_ids=None,
73 | head_mask=None,
74 | inputs_embeds=None,
75 | masked_lm_labels=None,
76 | labels=None,
77 | output_hidden_states=False,
78 | return_dict=False,
79 | run_similarity=False,
80 | run_mlm=True,
81 | ):
82 | if run_mlm:
83 | outputs = list(
84 | self.roberta(
85 | input_ids,
86 | attention_mask=attention_mask,
87 | token_type_ids=token_type_ids,
88 | position_ids=position_ids,
89 | head_mask=head_mask,
90 | inputs_embeds=inputs_embeds,
91 | output_hidden_states=output_hidden_states,
92 | return_dict=return_dict,
93 | )
94 | )
95 | sequence_output = outputs[0]
96 | prediction_scores = self.lm_head(sequence_output)
97 | outputs = (prediction_scores, None, sequence_output) # Add hidden states and attention if they are here
98 |
99 | #######
100 | # MLM
101 | #######
102 | masked_lm_loss = torch.zeros(1, device=prediction_scores.device).float()
103 |
104 | if (masked_lm_labels is not None and (not (masked_lm_labels == -100).all())) and self.hparams.mlm:
105 | loss_fct = CrossEntropyLoss()
106 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
107 | else:
108 | masked_lm_loss = 0
109 | else:
110 | outputs = (
111 | torch.zeros([*input_ids.shape, 50265]).to(input_ids.device).float(),
112 | None,
113 | torch.zeros([*input_ids.shape, 1024]).to(input_ids.device).float(),
114 | )
115 | masked_lm_loss = torch.zeros(1)[0].to(input_ids.device).float()
116 |
117 | #######
118 | # Similarity
119 | #######
120 | if run_similarity:
121 | non_masked_outputs = self.roberta(
122 | non_masked_input_ids,
123 | attention_mask=attention_mask,
124 | token_type_ids=token_type_ids,
125 | position_ids=position_ids,
126 | head_mask=head_mask,
127 | inputs_embeds=inputs_embeds,
128 | output_hidden_states=output_hidden_states,
129 | return_dict=return_dict,
130 | )
131 | non_masked_seq_out = non_masked_outputs[0]
132 |
133 | meaned_sentences = non_masked_seq_out.mean(1)
134 | miner_output = list(self.miner_func(meaned_sentences, sample_labels))
135 |
136 | sim_loss = self.similarity_loss_func(meaned_sentences, sample_labels, miner_output)
137 | outputs = (masked_lm_loss, sim_loss, torch.zeros(1)) + outputs
138 | else:
139 | outputs = (
140 | masked_lm_loss,
141 | torch.zeros(1)[0].to(input_ids.device).float(),
142 | torch.zeros(1)[0].to(input_ids.device).float(),
143 | ) + outputs
144 |
145 | return outputs
146 |
--------------------------------------------------------------------------------
/models/doc_similarity_pl_template.py:
--------------------------------------------------------------------------------
1 | import collections
2 | from utils.torch_utils import to_numpy
3 | from pytorch_lightning import _logger as log
4 | from pytorch_lightning.core import LightningModule
5 | import torch
6 | from utils import argparse_init
7 | from utils import switch_functions
8 | from utils.model_utils import extract_model_path_for_hyperparams
9 | from subprocess import Popen
10 | import pandas as pd
11 |
12 |
13 | class DocEmbeddingTemplate(LightningModule):
14 |
15 | """
16 | Author: Dvir Ginzburg.
17 |
18 | This is a template for future document templates using pytorch lightning.
19 | """
20 |
21 | def __init__(
22 | self, hparams,
23 | ):
24 | super(DocEmbeddingTemplate, self).__init__()
25 | self.hparams = hparams
26 | self.hparams.hparams_dir = extract_model_path_for_hyperparams(self.hparams.default_root_dir, self)
27 | self.losses = {}
28 | self.tracks = {}
29 | self.hparams.mode = "val"
30 |
31 |
32 | def forward(self, data):
33 | """
34 | forward function for the doc similarity network
35 | """
36 | raise NotImplementedError()
37 |
38 | def training_step(self, batch, batch_idx, mode="train"):
39 | """
40 | Lightning calls this inside the training loop with the
41 | data from the training dataloader passed in as `batch`.
42 | """
43 | self.losses = {}
44 | self.tracks = {}
45 | self.hparams.batch_idx = batch_idx
46 | self.hparams.mode = mode
47 | self.batch = batch
48 |
49 | batch = self(batch)
50 |
51 | self.tracks[f"tot_loss"] = sum(self.losses.values()).mean()
52 |
53 | all = {k: to_numpy(v) for k, v in {**self.tracks, **self.losses}.items()}
54 | getattr(self, f"{mode}_logs", None).append(all)
55 | self.log_step(all)
56 |
57 | output = collections.OrderedDict({"loss": self.tracks[f"tot_loss"]})
58 | return output
59 |
60 | def validation_step(self, batch, batch_idx, mode="val"):
61 | """Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`."""
62 |
63 | return self.training_step(batch, batch_idx, mode=mode)
64 |
65 | def log_step(self, all):
66 | if not (
67 | getattr(self.hparams, f"{self.hparams.mode}_batch_size")
68 | % (getattr(self.hparams, f"{self.hparams.mode}_log_every_n_steps"))
69 | == 0
70 | ):
71 | return
72 | for k, v in all.items():
73 | if v.shape != ():
74 | v = v.sum()
75 | self.logger.experiment.add_scalar(f"{self.hparams.mode}_{k}_step", v, global_step=self.global_step)
76 |
77 | def test_step(self, batch, batch_idx):
78 | return self.validation_step(batch, batch_idx, mode="test")
79 |
80 | def on_validation_epoch_start(self):
81 | self.val_logs = []
82 | self.hparams.mode = "val"
83 |
84 | def on_train_epoch_start(self):
85 | self.hparams.current_epoch = self.current_epoch
86 | self.train_logs = []
87 | self.hparams.mode = "train"
88 |
89 | def on_test_epoch_start(self):
90 | self.test_logs = []
91 | self.hparams.mode = "test"
92 |
93 | def on_epoch_end_generic(self):
94 | if self.trainer.running_sanity_check:
95 | return
96 | logs = getattr(self, f"{self.hparams.mode}_logs", None)
97 |
98 | self.log_dict(logs, prefix=self.hparams.mode)
99 |
100 | def log_dict(self, logs, prefix):
101 | dict_of_lists = pd.DataFrame(logs).to_dict("list")
102 | for lst in dict_of_lists:
103 | dict_of_lists[lst] = list(filter(lambda x: not pd.isnull(x), dict_of_lists[lst]))
104 | for key, lst in dict_of_lists.items():
105 | s = 0
106 | for item in lst:
107 | s += item.sum()
108 | name = f"{prefix}_{key}_epoch"
109 | val = s / len(lst)
110 | self.logger.experiment.add_scalar(name, val, global_step=self.global_step)
111 | if self.hparams.metric_to_track == name:
112 | self.log(name, torch.tensor(val))
113 |
114 | def on_train_epoch_end(self, outputs) -> None:
115 | self.on_epoch_end_generic()
116 |
117 | def on_validation_epoch_end(self) -> None:
118 | if self.trainer.running_sanity_check:
119 | return
120 | if self.current_epoch % 10 == 0:
121 | self.logger.experiment.add_text("Profiler", self.trainer.profiler.summary(), global_step=self.global_step)
122 |
123 |
124 | def on_test_epoch_end(self) -> None:
125 | self.on_epoch_end_generic()
126 |
127 | def validation_epoch_end(self, outputs):
128 | self.on_epoch_end_generic()
129 |
130 | # ---------------------
131 | # TRAINING SETUP
132 | # ---------------------
133 | def configure_optimizers(self):
134 | """
135 | Return whatever optimizers and learning rate schedulers you want here.
136 |
137 | At least one optimizer is required.
138 | """
139 | no_decay = ["bias", "LayerNorm.weight"]
140 | optimizer_grouped_parameters = [
141 | {
142 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
143 | "weight_decay": self.hparams.weight_decay,
144 | },
145 | {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
146 | ]
147 |
148 | optimizer = switch_functions.choose_optimizer(self.hparams, optimizer_grouped_parameters)
149 | scheduler = switch_functions.choose_scheduler(
150 | self.hparams.scheduler, optimizer, warmup_steps=0, params=self.hparams
151 | )
152 |
153 | return [optimizer], [scheduler]
154 |
155 | def dataloader(self):
156 | """
157 | Returns the relevant dataloader (called once per training).
158 |
159 | Args:
160 | train_val_test (str, optional): Define which dataset to choose from. Defaults to 'train'.
161 |
162 | Returns:
163 | Dataset
164 | """
165 | raise NotImplementedError()
166 |
167 | def prepare_data(self):
168 | """
169 | Here we upload the data, called once, all the mask and train, eval split.
170 |
171 | Returns:
172 | Tuple of datasets: train,val and test dataset splits
173 | """
174 | raise NotImplementedError()
175 |
176 | def train_dataloader(self):
177 | log.info("Training data loader called.")
178 | return self.dataloader(mode="train")
179 |
180 | def val_dataloader(self):
181 | log.info("Validation data loader called.")
182 | return self.dataloader(mode="val")
183 |
184 | def test_dataloader(self):
185 | log.info("Test data loader called.")
186 | return self.dataloader(mode="test")
187 |
188 | @staticmethod
189 | def add_model_specific_args(parser, task_name, dataset_name, is_lowest_leaf=False):
190 | """
191 | Static function to add all arguments that are relevant only for this module
192 |
193 | Args:
194 | parent_parser (ArgparseManager): The general argparser
195 |
196 | Returns:
197 | ArgparseManager : The new argparser
198 | """
199 | parser.add_argument(
200 | "--test_sample_size", default=-1, type=int, help="The number of samples to eval recos on. (-1 is all)"
201 | )
202 | parser.add_argument("--top_k_size", default=-1, type=int, help="The number of top k correspondences. (-1 is all)")
203 |
204 | parser.add_argument("--with_same_series", type=argparse_init.str2bool, nargs="?", const=True, default=True)
205 |
206 | return parser
207 |
208 |
--------------------------------------------------------------------------------
/models/reco/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/models/reco/__init__.py
--------------------------------------------------------------------------------
/models/reco/hierarchical_reco.py:
--------------------------------------------------------------------------------
1 | import json
2 | from data.data_utils import get_gt_seeds_titles
3 | from models.reco.wiki_recos_eval.eval_metrics import evaluate_wiki_recos
4 | from utils.torch_utils import mean_non_pad_value, to_numpy
5 | from models.reco.recos_utils import index_amp, sim_matrix
6 | import numpy as np
7 | import torch
8 | from tqdm import tqdm
9 | import pickle
10 | from sklearn.preprocessing import normalize
11 |
12 |
13 | def vectorize_reco_hierarchical(all_features, titles,gt_path, output_path=""):
14 | gt = pickle.load(open(gt_path, "rb"))
15 | to_reco_indices = [index_amp(titles, title) for title in gt.keys()]
16 | to_reco_indices = list(filter(lambda title: title != None, to_reco_indices))
17 | sections_per_article = np.array([len(article) for article in all_features])
18 | sections_per_article_cumsum = np.array([0,] + [len(article) for article in all_features]).cumsum()
19 | features_per_section = [sec for article in all_features for sec in article]
20 | features_per_section_torch = [torch.from_numpy(feat) for feat in features_per_section]
21 | features_per_section_padded = torch.nn.utils.rnn.pad_sequence(
22 | features_per_section_torch, batch_first=True, padding_value=torch.tensor(float("nan"))
23 | ).cuda()
24 |
25 | num_samples, max_after_pad = features_per_section_padded.shape[:2]
26 |
27 | flattened = features_per_section_padded.reshape(-1, features_per_section_padded.shape[-1])
28 |
29 | recos = []
30 | for i in tqdm(to_reco_indices):
31 | if i > len(all_features):
32 | print(f"GT title {titles[i]} was not evaluated")
33 | continue
34 |
35 | to_reco_flattened = features_per_section_padded[
36 | sections_per_article_cumsum[i] : sections_per_article_cumsum[i + 1]
37 | ].reshape(-1, features_per_section_padded.shape[-1])
38 |
39 | sim = sim_matrix(to_reco_flattened, flattened)
40 | reshaped_sim = sim.reshape(
41 | sections_per_article_cumsum[i + 1] - sections_per_article_cumsum[i], max_after_pad, num_samples, max_after_pad
42 | )
43 | sim = reshaped_sim.permute(0, 2, 1, 3)
44 | sim[sim.isnan()] = float("-Inf")
45 | score_mat = sim.max(-1)[0]
46 | score = mean_non_pad_value(score_mat, axis=-1, pad_value=torch.tensor(float("-Inf")).cuda())
47 |
48 | score_per_article = torch.split(score.t(), sections_per_article.tolist(), dim=0)
49 | score_per_article_padded = torch.nn.utils.rnn.pad_sequence(
50 | score_per_article, batch_first=True, padding_value=float("-Inf")
51 | ).permute(0, 2, 1)
52 | score_per_article_padded[torch.isnan(score_per_article_padded)] = float("-Inf")
53 | par_score_mat = score_per_article_padded.max(-1)[0]
54 | par_score = mean_non_pad_value(par_score_mat, axis=-1, pad_value=float("-Inf"))
55 |
56 | recos.append((i, to_numpy(par_score.argsort(descending=True)[1:])))
57 |
58 | examples = [[None, title] for title in titles] # reco_utils compatibale
59 | _, mpr, _, mrr, _, hit_rate = evaluate_wiki_recos(recos, output_path, gt_path, examples=examples)
60 | metrics = {"mrr": mrr, "mpr": mpr, "hit_rates": hit_rate}
61 | return recos, metrics
62 |
63 |
--------------------------------------------------------------------------------
/models/reco/recos_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import torch
5 | import pickle
6 | import numpy as np
7 |
8 | from sklearn.preprocessing import normalize
9 |
10 |
11 | def index_amp(lst, k):
12 | try:
13 | return lst.index(k) if k in lst else lst.index(k.replace("&", "&"))
14 | except:
15 | return
16 |
17 |
18 | def sim_matrix(a, b, eps=1e-8):
19 | """
20 | Similarity matrix
21 | """
22 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
23 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
24 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
25 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
26 | return sim_mt
27 |
28 |
--------------------------------------------------------------------------------
/models/reco/wiki_recos_eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/models/reco/wiki_recos_eval/__init__.py
--------------------------------------------------------------------------------
/models/reco/wiki_recos_eval/eval_metrics.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy as np
3 | from tqdm import tqdm
4 | from pathlib import Path
5 | import sys
6 | from utils.logging_utils import Unbuffered
7 | import json
8 |
9 |
10 | def evaluate_wiki_recos(recos, output_path, gt_path, examples):
11 | original_stdout = sys.stdout
12 | sys.stdout = Unbuffered(open(f"{output_path}_reco_scores", "w"))
13 | dataset_as_dict = {sample[1]: sample for sample in examples}
14 | recos_as_dict = {reco[0]: reco for reco in recos}
15 | names_to_id = {sample[1]: idx for idx, sample in enumerate(examples)}
16 | titles = [ex[1] for ex in examples]
17 |
18 | article_recos_per_article = pickle.load(open(gt_path, "rb"))
19 |
20 | percentiles, mpr = calculate_mpr(recos_as_dict, article_recos_per_article, dataset_as_dict, names_to_id, titles=titles)
21 | recepricals, mrr = calculate_mrr(recos_as_dict, article_recos_per_article, dataset_as_dict, names_to_id, titles=titles)
22 | hit_rates, hit_rate = calculate_mean_hit_rate(
23 | recos_as_dict,
24 | article_recos_per_article,
25 | dataset_as_dict,
26 | names_to_id,
27 | rate_thresulds=[5, 10, 50, 100, 1000],
28 | titles=titles,
29 | )
30 |
31 | metrics = {"mrr": float(mrr), "mpr": float(mpr), **{f"hit_rate_{rate[0]}": float(rate[1]) for rate in hit_rate}}
32 | print(json.dumps(metrics, indent=2))
33 |
34 | sys.stdout = original_stdout
35 | return percentiles, mpr, recepricals, mrr, hit_rates, hit_rate
36 |
37 |
38 | def calculate_mpr(
39 | input_recommandations, article_article_gt, dataset, names_to_id, sample_size=-1, popular_titles=None, titles=[]
40 | ):
41 | percentiles = []
42 | for reco_idx in tqdm(input_recommandations):
43 | wiki_title = titles[reco_idx]
44 | curr_gts, text = [], []
45 | recommandations = input_recommandations[reco_idx][1]
46 | if wiki_title not in article_article_gt:
47 | continue
48 | for gt_title in article_article_gt[wiki_title].keys():
49 | lookup = gt_title.replace("&", "&") if "amp;" not in gt_title and gt_title not in names_to_id else gt_title
50 | if lookup not in names_to_id:
51 | print(f"{lookup} not in names_to_id")
52 | continue
53 | recommended_idx_ls = np.where(recommandations == names_to_id[lookup])[0]
54 | if recommended_idx_ls.shape[0] == 0:
55 | continue
56 | curr_gts.append(recommended_idx_ls[0])
57 | percentiles.extend((recommended_idx_ls[0] / len(recommandations),) * article_article_gt[wiki_title][gt_title])
58 | text.append("gt: {} gt place: {}".format(gt_title, recommended_idx_ls[0]))
59 |
60 | if len(curr_gts) > 0:
61 | print(
62 | "title: {}\n".format(wiki_title)
63 | + "\n".join(text)
64 | + "\ntopk: {}\n\n\n".format([titles[reco_i] for reco_i in recommandations[:10]])
65 | )
66 |
67 | percentiles = percentiles if percentiles != [] else [0]
68 | print("percentiles_mean:{}\n\n\n\n".format(sum(percentiles) / len(percentiles)))
69 | return percentiles, sum(percentiles) / len(percentiles)
70 |
71 |
72 | def calculate_mrr(
73 | input_recommandations, article_article_gt, dataset, names_to_id, sample_size=-1, popular_titles=None, titles=[]
74 | ):
75 | """
76 | input_recommandations - list of [] per title the order of all titles recommended with it
77 | article_article_gt - dict of dicts, each element is a sample, and all the gt samples goes with it and the count each sample
78 | sample_size - the amount of candidates to calculate the MPR on
79 | """
80 | recepricals = []
81 | for reco_idx in tqdm(input_recommandations):
82 | wiki_title = titles[reco_idx]
83 | text = []
84 | recommandations = input_recommandations[reco_idx][1]
85 | top = len(input_recommandations)
86 | for gt_title in article_article_gt[wiki_title].keys():
87 | lookup = gt_title.replace("&", "&") if "amp;" not in gt_title and gt_title not in names_to_id else gt_title
88 | if lookup not in names_to_id:
89 | print(f"{lookup} not in names_to_id")
90 | continue
91 | recommended_idx_ls = np.where(recommandations == names_to_id[lookup])[0]
92 | if recommended_idx_ls.shape[0] > 0 and recommended_idx_ls[0] < top:
93 | top = recommended_idx_ls[0]
94 | if recommended_idx_ls.shape[0] == 0:
95 | continue
96 | text.append("gt: {} gt place: {} ".format(gt_title, recommended_idx_ls[0]))
97 |
98 | if top == 0:
99 | top = 1
100 |
101 | if len(text) > 0:
102 | recepricals.append(1 / (top))
103 | text.append(f"\n receprical: {recepricals[-1]}")
104 | print(
105 | "title: {}\n".format(wiki_title)
106 | + "\n".join(text)
107 | + "\ntopk: {}\n\n\n".format([titles[reco_i] for reco_i in recommandations[:10]])
108 | )
109 |
110 | recepricals = recepricals if recepricals != [] else [0]
111 | print(f"Recepricle mean:{sum(recepricals) / len(recepricals)}")
112 | print(f"Recepricals \n {recepricals}")
113 | return recepricals, sum(recepricals) / len(recepricals)
114 |
115 |
116 | def calculate_mean_hit_rate(
117 | input_recommandations,
118 | article_article_gt,
119 | dataset,
120 | names_to_id,
121 | sample_size=-1,
122 | popular_titles=None,
123 | rate_thresulds=[100],
124 | titles=[],
125 | ):
126 | mean_hits = [[] for i in rate_thresulds]
127 | for reco_idx in tqdm(input_recommandations):
128 | wiki_title = titles[reco_idx]
129 | curr_gts, text = [], []
130 | hit_by_rate = [0 for i in rate_thresulds]
131 | recommandations = input_recommandations[reco_idx][1]
132 | for gt_title in article_article_gt[wiki_title].keys():
133 |
134 | lookup = gt_title.replace("&", "&") if "amp;" not in gt_title and gt_title not in names_to_id else gt_title
135 | if lookup not in names_to_id:
136 | print(f"{lookup} not in names_to_id")
137 | continue
138 | recommended_idx_ls = np.where(recommandations == names_to_id[lookup])[0]
139 | for thr_idx, thresuld in enumerate(rate_thresulds):
140 | if recommended_idx_ls.shape[0] != 0 and recommended_idx_ls[0] < thresuld:
141 | hit_by_rate[thr_idx] += 1
142 | text.append(f"gt: {gt_title} gt place: {recommended_idx_ls}")
143 |
144 | if len(text) > 0:
145 | for thr_idx, thresuld in enumerate(rate_thresulds):
146 | print(
147 | f"title: {wiki_title} Hit rate at {thresuld}: {hit_by_rate[thr_idx]} \n \n {''.join(text)} \n topk: {[titles[reco_i] for reco_i in recommandations[:10]]}\n\n\n"
148 | )
149 | hit_mean = hit_by_rate[thr_idx] / len(article_article_gt[wiki_title].keys()) if hit_by_rate[thr_idx] > 0 else 0
150 | mean_hits[thr_idx].append(hit_mean)
151 |
152 | mean_hits = mean_hits if mean_hits != [[] for rate_thresuld in rate_thresulds] else [[0] for rate_thresuld in rate_thresulds]
153 | mean_hit = [sum(mean_hit) / len(mean_hit) for mean_hit in mean_hits]
154 | mean_hits_with_thresuld = [(thresuld, mean) for (thresuld, mean) in zip(*[rate_thresulds, mean_hit])]
155 | print(f"Hit rate mean:{mean_hits_with_thresuld}")
156 | return mean_hits, mean_hits_with_thresuld
157 |
--------------------------------------------------------------------------------
/models/transformer_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | from transformers import PreTrainedTokenizer
4 |
5 |
6 | def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]:
7 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
8 |
9 | if tokenizer.mask_token is None:
10 | raise ValueError(
11 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
12 | )
13 |
14 | labels = inputs.clone()
15 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
16 | probability_matrix = torch.full(labels.shape, args.mlm_probability, device=labels.device)
17 | special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
18 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool, device=labels.device), value=0.0)
19 | if tokenizer._pad_token is not None:
20 | padding_mask = labels.eq(tokenizer.pad_token_id)
21 | probability_matrix.masked_fill_(padding_mask, value=0.0)
22 | masked_indices = torch.bernoulli(probability_matrix).bool()
23 | if (~masked_indices).all():
24 | masked_indices = ~masked_indices # If we choose to not learn anything - learn everything
25 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
26 |
27 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
28 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8, device=labels.device)).bool() & masked_indices
29 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
30 |
31 | # 10% of the time, we replace masked input tokens with random word
32 | indices_random = (
33 | torch.bernoulli(torch.full(labels.shape, 0.5, device=labels.device)).bool() & masked_indices & ~indices_replaced
34 | )
35 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long, device=labels.device)
36 | inputs[indices_random] = random_words[indices_random]
37 |
38 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
39 | return inputs, labels
40 |
41 |
--------------------------------------------------------------------------------
/models/transformers_base.py:
--------------------------------------------------------------------------------
1 | """Stub."""
2 | from models.doc_similarity_pl_template import DocEmbeddingTemplate
3 | from utils import argparse_init
4 | from utils import switch_functions
5 |
6 |
7 | class TransformersBase(DocEmbeddingTemplate):
8 |
9 | """
10 | Author: Dvir Ginzburg.
11 |
12 | This is a template for future document templates using transformers.
13 | """
14 |
15 | def __init__(
16 | self, hparams,
17 | ):
18 | super(TransformersBase, self).__init__(hparams)
19 | (self.config_class, self.model_class, self.tokenizer_class,) = switch_functions.choose_model_class_configuration(
20 | self.hparams.arch, self.hparams.base_model_name
21 | )
22 | if self.hparams.config_name:
23 | self.config = self.config_class.from_pretrained(self.hparams.config_name, cache_dir=None)
24 | elif self.hparams.arch_or_path:
25 | self.config = self.config_class.from_pretrained(self.hparams.arch_or_path)
26 | else:
27 | self.config = self.config_class()
28 | if self.hparams.tokenizer_name:
29 | self.tokenizer = self.tokenizer_class.from_pretrained(self.hparams.tokenizer_name)
30 | elif self.hparams.arch_or_path:
31 | self.tokenizer = self.tokenizer_class.from_pretrained(self.hparams.arch_or_path)
32 | else:
33 | raise ValueError(
34 | "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
35 | "and load it from here, using --tokenizer_name".format(self.tokenizer_class.__name__)
36 | )
37 | self.hparams.tokenizer_pad_id = self.tokenizer.pad_token_id
38 | self.model = self.model_class.from_pretrained(
39 | self.hparams.config_name, from_tf=bool(".ckpt" in self.hparams.config_name), config=self.config, hparams=self.hparams
40 | )
41 |
42 | @staticmethod
43 | def add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False):
44 | parser = DocEmbeddingTemplate.add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False)
45 | parser.add_argument(
46 | "--mlm",
47 | type=argparse_init.str2bool,
48 | nargs="?",
49 | const=True,
50 | default=True,
51 | help="Train with masked-language modeling loss instead of language modeling.",
52 | )
53 |
54 | parser.add_argument(
55 | "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss",
56 | )
57 | parser.add_argument(
58 | "--base_model_name", type=str, default="roberta", help="The underliying BERT-like model this arc.",
59 | )
60 | base_model_name = parser.parse_known_args()[0].base_model_name
61 | if base_model_name in ["roberta", "tnlr"]:
62 | default_config, default_tokenizer = "roberta-large", "roberta-large"
63 | elif base_model_name in ["bert", "tnlr3"]:
64 | default_config, default_tokenizer = "bert-large-uncased", "bert-large-uncased"
65 | elif base_model_name == "longformer":
66 | default_config, default_tokenizer = "allenai/longformer-base-4096", "allenai/longformer-base-4096"
67 | parser.set_defaults(batch_size=2)
68 | parser.add_argument(
69 | "--config_name",
70 | type=str,
71 | default=default_config,
72 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
73 | )
74 | parser.add_argument(
75 | "--tokenizer_name",
76 | default=default_tokenizer,
77 | type=str,
78 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.",
79 | )
80 |
81 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
82 |
83 | parser.set_defaults(lr=2e-5, weight_decay=0)
84 |
85 | arch, mlm = parser.parse_known_args()[0].arch, parser.parse_known_args()[0].mlm
86 | if arch in ["bert", "roberta", "distilbert", "camembert", "recoberta", "recoberta_cosine"] and not mlm:
87 | raise ValueError(
88 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
89 | "flag (masked language modeling)."
90 | )
91 |
92 | return parser
93 |
--------------------------------------------------------------------------------
/sdr_main.py:
--------------------------------------------------------------------------------
1 | """Top level file, parse flags and call trining loop."""
2 | import os
3 | from utils.pytorch_lightning_utils.pytorch_lightning_utils import load_params_from_checkpoint
4 | import torch
5 | from pytorch_lightning.profiler.profilers import SimpleProfiler
6 | from utils.pytorch_lightning_utils.callbacks import RunValidationOnStart
7 | from utils import switch_functions
8 | import pytorch_lightning
9 | from pytorch_lightning import Trainer
10 | from pytorch_lightning.callbacks import ModelCheckpoint
11 | from utils.argparse_init import default_arg_parser, init_parse_argparse_default_params
12 | import logging
13 |
14 | logging.basicConfig(level=logging.INFO)
15 | from pytorch_lightning.loggers import TensorBoardLogger
16 |
17 | def main():
18 | """Initialize all the parsers, before training init."""
19 | parser = default_arg_parser()
20 | parser = Trainer.add_argparse_args(parser) # Bug in PL
21 | parser = default_arg_parser(description="docBert", parents=[parser])
22 |
23 | eager_flags = init_parse_argparse_default_params(parser)
24 | model_class_pointer = switch_functions.model_class_pointer(eager_flags["task_name"], eager_flags["architecture"])
25 | parser = model_class_pointer.add_model_specific_args(parser, eager_flags["task_name"], eager_flags["dataset_name"])
26 |
27 | hyperparams = parser.parse_args()
28 | main_train(model_class_pointer, hyperparams,parser)
29 |
30 |
31 | def main_train(model_class_pointer, hparams,parser):
32 | """Initialize the model, call training loop."""
33 | pytorch_lightning.utilities.seed.seed_everything(seed=hparams.seed)
34 |
35 | if(hparams.resume_from_checkpoint not in [None,'']):
36 | hparams = load_params_from_checkpoint(hparams, parser)
37 |
38 | model = model_class_pointer(hparams)
39 |
40 |
41 | logger = TensorBoardLogger(save_dir=model.hparams.hparams_dir,name='',default_hp_metric=False)
42 | logger.log_hyperparams(model.hparams, metrics={model.hparams.metric_to_track: 0})
43 | print(f"\nLog directory:\n{model.hparams.hparams_dir}\n")
44 |
45 | trainer = pytorch_lightning.Trainer(
46 | num_sanity_val_steps=2,
47 | gradient_clip_val=hparams.max_grad_norm,
48 | callbacks=[RunValidationOnStart()],
49 | checkpoint_callback=ModelCheckpoint(
50 | save_top_k=3,
51 | save_last=True,
52 | mode="min" if "acc" not in hparams.metric_to_track else "max",
53 | monitor=hparams.metric_to_track,
54 | filepath=os.path.join(model.hparams.hparams_dir, "{epoch}"),
55 | verbose=True,
56 | ),
57 | logger=logger,
58 | max_epochs=hparams.max_epochs,
59 | gpus=hparams.gpus,
60 | distributed_backend="dp",
61 | limit_val_batches=hparams.limit_val_batches,
62 | limit_train_batches=hparams.limit_train_batches,
63 | limit_test_batches=hparams.limit_test_batches,
64 | check_val_every_n_epoch=hparams.check_val_every_n_epoch,
65 | profiler=SimpleProfiler(),
66 | accumulate_grad_batches=hparams.accumulate_grad_batches,
67 | reload_dataloaders_every_epoch=True,
68 | # load
69 | resume_from_checkpoint=hparams.resume_from_checkpoint,
70 | )
71 | if(not hparams.test_only):
72 | trainer.fit(model)
73 | else:
74 | if(hparams.resume_from_checkpoint is not None):
75 | model = model.load_from_checkpoint(hparams.resume_from_checkpoint,hparams=hparams, map_location=torch.device(f"cpu"))
76 | trainer.test(model)
77 |
78 |
79 | if __name__ == "__main__":
80 | main()
81 |
82 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/utils/__init__.py
--------------------------------------------------------------------------------
/utils/argparse_init.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | from argparse import ArgumentDefaultsHelpFormatter
5 |
6 |
7 | def str2intlist(v):
8 | if v.isdigit():
9 | return [int(v)]
10 | try:
11 | return [int(dig) for dig in v.split("_")]
12 | except Exception as e:
13 | raise argparse.ArgumentTypeError('Excpected int or "4_4"')
14 |
15 |
16 | def str2bool(v):
17 | if isinstance(v, bool):
18 | return v
19 | if v.lower() in ("yes", "true", "t", "y", "1"):
20 | return True
21 | elif v.lower() in ("no", "false", "f", "n", "0"):
22 | return False
23 | else:
24 | raise argparse.ArgumentTypeError("Boolean value expected.")
25 |
26 |
27 | def default_arg_parser(description="", conflict_handler="resolve", parents=[], is_lowest_leaf=False):
28 | """
29 | Generate the default parser - Helper for readability
30 |
31 | Args:
32 | description (str, optional): name of the parser - usually project name. Defaults to ''.
33 | conflict_handler (str, optional): whether to raise error on conflict or resolve(take last). Defaults to 'resolve'.
34 | parents (list, optional): [the name of parent argument managers]. Defaults to [].
35 |
36 | Returns:
37 | [type]: [description]
38 | """
39 | description = (
40 | parents[0].description + description
41 | if len(parents) != 0 and parents[0] is not None and parents[0].description is not None
42 | else description
43 | )
44 | parser = argparse.ArgumentParser(
45 | description=description,
46 | add_help=is_lowest_leaf,
47 | formatter_class=ArgumentDefaultsHelpFormatter,
48 | conflict_handler=conflict_handler,
49 | parents=parents,
50 | )
51 |
52 | return parser
53 |
54 | def get_non_default(parsed,parser):
55 | non_default = {
56 | opt.dest: getattr(parsed, opt.dest)
57 | for opt in parser._option_string_actions.values()
58 | if hasattr(parsed, opt.dest) and opt.default != getattr(parsed, opt.dest)
59 | }
60 | return non_default
61 |
62 |
63 | def init_parse_argparse_default_params(parser, dataset_name=None, arch=None):
64 | TASK_OPTIONS = ["document_similarity"]
65 |
66 | parser.add_argument(
67 | "--task_name", type=str, default="document_similarity", choices=TASK_OPTIONS, help="The task to solve",
68 | )
69 | task_name = parser.parse_known_args()[0].task_name
70 |
71 | DATASET_OPTIONS = {
72 | "document_similarity": ["video_games", "wines",],
73 | }
74 | parser.add_argument(
75 | "--dataset_name",
76 | type=str,
77 | default=DATASET_OPTIONS[task_name][0],
78 | choices=DATASET_OPTIONS[task_name],
79 | help="The dataset to evalute on",
80 | )
81 | dataset_name = dataset_name or parser.parse_known_args()[0].dataset_name
82 |
83 | ## General learning parameters
84 | parser.add_argument(
85 | "--train_batch_size", default={"document_similarity": 32}[task_name], type=int, help="Number of samples in batch",
86 | )
87 | parser.add_argument(
88 | "--max_epochs", default={"document_similarity": 50}[task_name], type=int, help="Number of epochs to train",
89 | )
90 | parser.add_argument(
91 | "-lr", default={"document_similarity": 2e-5}[task_name], type=float, help="Learning rate",
92 | )
93 |
94 | parser.add_argument("--optimizer", default="adamW", help="Optimizer to use")
95 | parser.add_argument(
96 | "--scheduler",
97 | default="linear_with_warmup",
98 | choices=["linear_with_warmup", "cosine_annealing_lr"],
99 | help="Scheduler to use",
100 | )
101 | parser.add_argument("--weight_decay", default=5e-3, help="weight decay")
102 |
103 | ## Input Output parameters
104 | parser.add_argument(
105 | "--default_root_dir", default=os.path.join(os.getcwd(), "output", task_name), help="The path to store this run output",
106 | )
107 | output_dir = parser.parse_known_args()[0].default_root_dir
108 | os.makedirs(output_dir, exist_ok=True)
109 |
110 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
111 |
112 | ### Model Parameters
113 | parser.add_argument(
114 | "--arch", "--architecture", default={"document_similarity": "SDR"}[task_name], help="Architecture",
115 | )
116 |
117 | architecture = arch or parser.parse_known_args()[0].arch
118 |
119 | parser.add_argument("--accumulate_grad_batches", default=1, type=int)
120 |
121 | ### Auxiliary parameters
122 | parser.add_argument("--gpus", default=1, type=str, help="gpu count")
123 | parser.add_argument("--num_data_workers", default=0, type=int, help="for parallel data load")
124 | parser.add_argument("--overwrite_data_cache", type=str2bool, nargs="?", const=True, default=False)
125 |
126 | parser.add_argument("--train_val_ratio", default=0.90, type=float, help="The split ratio of the data")
127 | parser.add_argument(
128 | "--limit_train_batches", default=10000, type=int,
129 | )
130 |
131 | parser.add_argument(
132 | "--train_log_every_n_steps", default=50, type=int,
133 | )
134 | parser.add_argument(
135 | "--val_log_every_n_steps", default=1, type=int,
136 | )
137 | parser.add_argument(
138 | "--test_log_every_n_steps", default=1, type=int,
139 | )
140 |
141 |
142 | parser.add_argument("--resume_from_checkpoint", default=None, type=str, help="Path to reload pretrained weights")
143 | parser.add_argument(
144 | "--metric_to_track", default=None, help="which parameter to track on saving",
145 | )
146 | parser.add_argument("--val_batch_size", default=8, type=int)
147 | parser.add_argument("--test_batch_size", default=1, type=int)
148 | parser.add_argument("--test_only", type=str2bool, nargs="?", const=True, default=False)
149 |
150 | return {
151 | "dataset_name": dataset_name,
152 | "task_name": task_name,
153 | "architecture": architecture,
154 | }
155 |
156 |
--------------------------------------------------------------------------------
/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | class Unbuffered(object):
2 | """
3 | Create buffer that dumps stdout to file.
4 |
5 | Example: sys.stdout = Unbuffered(open(path + '_output','w'))
6 | """
7 |
8 | def __init__(self, stream):
9 | self.stream = stream
10 |
11 | def write(self, data):
12 | self.stream.write(data)
13 | self.stream.flush()
14 |
15 | def writelines(self, datas):
16 | self.stream.writelines(datas)
17 | self.stream.flush()
18 |
19 | def __getattr__(self, attr):
20 | return getattr(self.stream, attr)
21 |
--------------------------------------------------------------------------------
/utils/metrics_utils.py:
--------------------------------------------------------------------------------
1 | def simple_accuracy(preds, labels):
2 | return (preds == labels).mean()
3 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datetime import datetime
3 |
4 |
5 | def extract_model_path_for_hyperparams(start_path, model):
6 | relevant_hparams = {}
7 | for key in [
8 | "arch",
9 | "dataset_name",
10 | "test_only"
11 | ]:
12 | if hasattr(model.hparams, key):
13 | relevant_hparams[key] = eval(f"model.hparams.{key}")
14 |
15 | path = os.path.join(start_path, *["{}_{}".format(key, val) for key, val in relevant_hparams.items()])
16 |
17 | dt_string = datetime.now().strftime("%d_%m_%Y-%H_%M_%S")
18 | path = os.path.join(path, dt_string)
19 |
20 | os.makedirs(path, exist_ok=True)
21 |
22 | return path
--------------------------------------------------------------------------------
/utils/pytorch_lightning_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/utils/pytorch_lightning_utils/__init__.py
--------------------------------------------------------------------------------
/utils/pytorch_lightning_utils/callbacks.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import Callback
2 | from pytorch_lightning.trainer.trainer import Trainer
3 |
4 |
5 | class RunValidationOnStart(Callback):
6 | def __init__(self):
7 | pass
8 |
9 | def on_train_start(self, trainer: Trainer, pl_module):
10 | return trainer.run_evaluation()
11 |
--------------------------------------------------------------------------------
/utils/pytorch_lightning_utils/pytorch_lightning_utils.py:
--------------------------------------------------------------------------------
1 | """Diagnose your system and show basic information
2 |
3 | This server mainly to get detail info for better bug reporting.
4 |
5 | """
6 |
7 | import os
8 | import platform
9 | import re
10 | import sys
11 | from argparse import Namespace
12 |
13 | import numpy
14 | import tensorboard
15 | import torch
16 | import tqdm
17 | from utils.argparse_init import get_non_default
18 |
19 | sys.path += [os.path.abspath(".."), os.path.abspath(".")]
20 | import pytorch_lightning # noqa: E402
21 |
22 | LEVEL_OFFSET = "\t"
23 | KEY_PADDING = 20
24 |
25 |
26 | def run_and_parse_first_match(run_lambda, command, regex):
27 | """Runs command using run_lambda, returns the first regex match if it exists"""
28 | rc, out, _ = run_lambda(command)
29 | if rc != 0:
30 | return None
31 | match = re.search(regex, out)
32 | if match is None:
33 | return None
34 | return match.group(1)
35 |
36 |
37 | def get_running_cuda_version(run_lambda):
38 | return run_and_parse_first_match(run_lambda, "nvcc --version", r"V(.*)$")
39 |
40 |
41 | def info_system():
42 | return {
43 | "OS": platform.system(),
44 | "architecture": platform.architecture(),
45 | "version": platform.version(),
46 | "processor": platform.processor(),
47 | "python": platform.python_version(),
48 | }
49 |
50 |
51 | def info_cuda():
52 | return {
53 | "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
54 | # 'nvidia_driver': get_nvidia_driver_version(run_lambda),
55 | "available": torch.cuda.is_available(),
56 | "version": torch.version.cuda,
57 | }
58 |
59 |
60 | def info_packages():
61 | return {
62 | "numpy": numpy.__version__,
63 | "pyTorch_version": torch.__version__,
64 | "pyTorch_debug": torch.version.debug,
65 | "pytorch-lightning": pytorch_lightning.__version__,
66 | "tensorboard": tensorboard.__version__,
67 | "tqdm": tqdm.__version__,
68 | }
69 |
70 |
71 | def nice_print(details, level=0):
72 | lines = []
73 | for k in sorted(details):
74 | key = f"* {k}:" if level == 0 else f"- {k}:"
75 | if isinstance(details[k], dict):
76 | lines += [level * LEVEL_OFFSET + key]
77 | lines += nice_print(details[k], level + 1)
78 | elif isinstance(details[k], (set, list, tuple)):
79 | lines += [level * LEVEL_OFFSET + key]
80 | lines += [(level + 1) * LEVEL_OFFSET + "- " + v for v in details[k]]
81 | else:
82 | template = "{:%is} {}" % KEY_PADDING
83 | key_val = template.format(key, details[k])
84 | lines += [(level * LEVEL_OFFSET) + key_val]
85 | return lines
86 |
87 |
88 | def main():
89 | details = {
90 | "System": info_system(),
91 | "CUDA": info_cuda(),
92 | "Packages": info_packages(),
93 | }
94 | lines = nice_print(details)
95 | text = os.linesep.join(lines)
96 | print(text)
97 |
98 | def load_params_from_checkpoint(hparams, parser):
99 | path = hparams.resume_from_checkpoint
100 | hparams_model = Namespace(**torch.load(path,map_location=torch.device('cpu'))['hyper_parameters'])
101 | hparams_model.max_epochs = hparams_model.max_epochs + 30
102 | for k,v in get_non_default(hparams,parser).items():
103 | setattr(hparams_model,k,v)
104 | hparams_model.gpus = hparams.gpus
105 | for key in vars(hparams):
106 | if(key not in vars(hparams_model)):
107 | setattr(hparams_model,key,getattr(hparams,key,None))
108 | hparams = hparams_model
109 | return hparams
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/utils/switch_functions.py:
--------------------------------------------------------------------------------
1 | from models.SDR.similarity_modeling import SimilarityModeling
2 | import torch
3 | import transformers
4 | from transformers import get_linear_schedule_with_warmup
5 | from torch import optim
6 |
7 |
8 | def model_class_pointer(task_name, arch):
9 | """Get pointer to class base on flags.
10 | Arguments:
11 | task_name {str} -- the task name, node clasification etc
12 | arch {str} -- recobert, etc
13 | Raises:
14 | Exception: If unknown task,dataset
15 |
16 | Returns:
17 | torch.nn.Module -- The module to train on
18 |
19 | """
20 |
21 | if task_name == "document_similarity":
22 | if arch == "SDR":
23 | from models.SDR.SDR import SDR
24 |
25 | return SDR
26 | raise Exception("Unkown task")
27 |
28 |
29 | def choose_optimizer(params, network_parameters):
30 | """
31 | Choose the optimizer from params.optimizer flag
32 |
33 | Args:
34 | params (dict): The input flags
35 | network_parameters (dict): from net.parameters()
36 |
37 | Raises:
38 | Exception: If not matched optimizer
39 | """
40 | if params.optimizer == "adamW":
41 | optimizer = transformers.AdamW(network_parameters, lr=params.lr,)
42 | elif params.optimizer == "sgd":
43 | optimizer = torch.optim.SGD(network_parameters, lr=params.lr, weight_decay=params.weight_decay, momentum=0.9,)
44 | else:
45 | raise Exception("No valid optimizer provided")
46 | return optimizer
47 |
48 |
49 | def choose_scheduler(scheduler_name, optimizer, warmup_steps, params):
50 |
51 | if scheduler_name == "linear_with_warmup":
52 | return get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=params.max_epochs)
53 | elif scheduler_name == "cosine_annealing_lr":
54 | return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0,)
55 |
56 | else:
57 | raise Exception("No valid optimizer provided")
58 |
59 |
60 | from transformers import (
61 | RobertaConfig,
62 | RobertaTokenizer,
63 | get_linear_schedule_with_warmup,
64 | )
65 |
66 |
67 | def choose_model_class_configuration(arch, base_model_name):
68 | MODEL_CLASSES = {
69 | "SDR_roberta": (RobertaConfig, SimilarityModeling, RobertaTokenizer),
70 | }
71 | return MODEL_CLASSES[f"{arch}_{base_model_name}"]
72 |
--------------------------------------------------------------------------------
/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import numbers
4 |
5 | def mean_non_pad_value(tensor, axis, pad_value=0):
6 | mask = tensor != pad_value
7 | tensor[~mask] = 0
8 | tensor_mean = (tensor * mask).sum(dim=axis) / (mask.sum(dim=axis))
9 |
10 | ignore_mask = (mask.sum(dim=axis)) == 0
11 | tensor_mean[ignore_mask] = pad_value
12 | return tensor_mean
13 |
14 |
15 | def to_numpy(tensor):
16 | """Wrapper around .detach().cpu().numpy() """
17 | if isinstance(tensor, torch.Tensor):
18 | return tensor.detach().cpu().numpy()
19 | elif isinstance(tensor, np.ndarray):
20 | return tensor
21 | elif isinstance(tensor, numbers.Number):
22 | return np.array(tensor)
23 | else:
24 | raise NotImplementedError
25 |
--------------------------------------------------------------------------------