├── .gitignore ├── README.md ├── SECURITY.md ├── __init__.py ├── data ├── __init__.py ├── data_utils.py ├── datasets.py ├── datasets │ ├── video_games │ │ └── gt │ └── wines │ │ └── gt └── images │ ├── Results.png │ ├── inference.png │ ├── training_intuition.png │ └── training_intuition2.png ├── instructions ├── CODE_OF_CONDUCT.md ├── LICENSE ├── SECURITY.md ├── SUPPORT.md └── installation.sh ├── models ├── SDR │ ├── SDR.py │ ├── SDR_utils.py │ └── similarity_modeling.py ├── doc_similarity_pl_template.py ├── reco │ ├── __init__.py │ ├── hierarchical_reco.py │ ├── recos_utils.py │ └── wiki_recos_eval │ │ ├── __init__.py │ │ └── eval_metrics.py ├── transformer_utils.py └── transformers_base.py ├── sdr_main.py └── utils ├── __init__.py ├── argparse_init.py ├── logging_utils.py ├── metrics_utils.py ├── model_utils.py ├── pytorch_lightning_utils ├── __init__.py ├── callbacks.py └── pytorch_lightning_utils.py ├── switch_functions.py └── torch_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | **/.vscode/* 3 | **/.history/* 4 | 5 | **/datasets/* 6 | !**/datasets/wines 7 | **/datasets/wines/* 8 | !**/datasets/wines/gt 9 | !**/datasets/video_games 10 | **/datasets/video_games/* 11 | !**/datasets/video_games/gt 12 | **/output/* 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference 2 | 3 | This repo is the implementation for [**SDR**](https://arxiv.org/abs/2106.01186). 4 | 5 | 6 |   7 |

8 | 9 |

10 | 11 | 12 | ## Tested environment 13 | - Python 3.7 14 | - PyTorch 1.7 15 | - CUDA 11.0 16 | 17 | Lower CUDA and PyTorch versions should work as well. 18 | 19 |   20 | ## Contents 21 | - [Installation](#installation) 22 | - [Datasets](#datasets) 23 | - [Train with our datasets](#training) 24 | - [Hierarchical Inference](#inference) 25 | - [Cite](#cite) 26 | 27 | License, Security, support and code of conduct specifications are under the `Instructions` directory. 28 |   29 | ## Installation 30 | Run 31 | ``` 32 | bash instructions/installation.sh 33 | ``` 34 |   35 | 36 | ## Datasets 37 | The published datasets are: 38 | * Video games 39 | * 21,935 articles 40 | * Expert annotated test set. 90 articles with 12 ground-truth recommendations. 41 | * Examples: 42 | * Grand Theft Auto - Mafia 43 | * Burnout Paradise - Forza Horizon 3 44 | * Wines 45 | * 1635 articles 46 | * Crafted by a human sommelier, 92 articles with ~10 ground-truth recommendations. 47 | * Examples: 48 | * Pinot Meunier - Chardonnay 49 | * Dom Pérignon - Moët & Chandon 50 | 51 | For more details and direct download see [Wines](https://zenodo.org/record/4812960#.YK8zqagzaUk) and [Video Games](https://zenodo.org/record/4812962#.YK8zqqgzaUk). 52 | 53 |   54 | 55 | # Training 56 | **The training process downloads the datasets automatically.** 57 | 58 | ``` 59 | python sdr_main.py --dataset_name video_games 60 | ``` 61 | The code is based on [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), all PL [hyperparameters](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html) are supported. (`limit_train/val/test_batches, check_val_every_n_epoch` etc.) 62 | 63 | ## Tensorboard support 64 | All metrics are being logged automatically and stored in 65 | ``` 66 | SDR/output/document_similarity/SDR/arch_SDR/dataset_name_/ 67 | ``` 68 | Run `tesnroboard --logdir=` to see the the logs. 69 | 70 |   71 | 72 | # Inference 73 | The hierarchical inference described in the paper is implemented as a stand-alone service and can be used with any backbone algorithm (`models/reco/hierarchical_reco.py`). 74 |

75 | 76 |

77 | 78 |   79 | 80 | 81 | ``` 82 | python sdr_main.py --dataset_name --resume_from_checkpoint --test_only 83 | ``` 84 | 85 | # Results 86 |

87 | 88 |

89 | 90 | # Citing & Authors 91 | If you find this repository or the annotated datasets helpful, feel free to cite our publication - 92 | 93 | SDR: Self-Supervised Document-to-Document Similarity Ranking viaContextualized Language Models and Hierarchical Inference 94 | ``` 95 | @misc{ginzburg2021selfsupervised, 96 | title={Self-Supervised Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference}, 97 | author={Dvir Ginzburg and Itzik Malkiel and Oren Barkan and Avi Caciularu and Noam Koenigstein}, 98 | year={2021}, 99 | eprint={2106.01186}, 100 | archivePrefix={arXiv}, 101 | primaryClass={cs.CL} 102 | } 103 | ``` 104 | 105 | Contact: [Dvir Ginzburg](mailto:dvirginz@gmail.com), [Itzik Malkiel](mailto:itzik.malkiel@microsoft.com). 106 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/__init__.py -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/__init__.py -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import List 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | 7 | def get_gt_seeds_titles(titles=None, dataset_name="wines"): 8 | idxs = None 9 | gt_path = f"data/datasets/{dataset_name}/gt" 10 | popular_titles = list(pickle.load(open(gt_path, "rb")).keys()) 11 | if titles != None: 12 | idxs = [titles.index(pop_title) for pop_title in popular_titles if pop_title in titles] 13 | return popular_titles, idxs, gt_path 14 | 15 | 16 | def reco_sentence_test_collate(examples: List[torch.Tensor], tokenizer): 17 | examples_ = [] 18 | for example in examples: 19 | sections = [] 20 | for section in example: 21 | if section == []: 22 | continue 23 | sections.append( 24 | ( 25 | pad_sequence([i[0] for i in section], batch_first=True, padding_value=tokenizer.pad_token_id), 26 | [i[2] for i in section], 27 | [i[3] for i in section], 28 | [i[4] for i in section], 29 | [i[5] for i in section], 30 | [i[6] for i in section], 31 | [i[7] for i in section], 32 | torch.tensor([i[8] for i in section]), 33 | ) 34 | ) 35 | examples_.append(sections) 36 | return examples_ 37 | 38 | 39 | def reco_sentence_collate(examples: List[torch.Tensor], tokenizer): 40 | return ( 41 | pad_sequence([i[0] for i in examples], batch_first=True, padding_value=tokenizer.pad_token_id), 42 | [i[2] for i in examples], 43 | [i[3] for i in examples], 44 | [i[4] for i in examples], 45 | [i[5] for i in examples], 46 | [i[6] for i in examples], 47 | [i[7] for i in examples], 48 | torch.tensor([i[8] for i in examples]), 49 | ) 50 | 51 | 52 | def raw_data_link(dataset_name): 53 | if dataset_name == "wines": 54 | return "https://zenodo.org/record/4812960/files/wines.txt?download=1" 55 | if dataset_name == "video_games": 56 | return "https://zenodo.org/record/4812962/files/video_games.txt?download=1" 57 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import ast 2 | from data.data_utils import get_gt_seeds_titles, raw_data_link 3 | import nltk 4 | from torch.utils.data import Dataset 5 | from transformers import PreTrainedTokenizer 6 | import os 7 | import pickle 8 | import numpy as np 9 | from tqdm import tqdm 10 | import torch 11 | import json 12 | import csv 13 | import sys 14 | from models.reco.recos_utils import index_amp 15 | 16 | 17 | nltk.download("punkt") 18 | 19 | 20 | class WikipediaTextDatasetParagraphsSentences(Dataset): 21 | def __init__(self, tokenizer: PreTrainedTokenizer, hparams, dataset_name, block_size, mode="train"): 22 | self.hparams = hparams 23 | cached_features_file = os.path.join( 24 | f"data/datasets/cached_proccessed/{dataset_name}", 25 | f"bs_{block_size}_{dataset_name}_{type(self).__name__}_tokenizer_{str(type(tokenizer)).split('.')[-1][:-2]}_mode_{mode}", 26 | ) 27 | self.cached_features_file = cached_features_file 28 | os.makedirs(os.path.dirname(cached_features_file), exist_ok=True) 29 | 30 | raw_data_path = self.download_raw(dataset_name) 31 | 32 | all_articles = self.save_load_splitted_dataset(mode, cached_features_file, raw_data_path) 33 | 34 | self.hparams = hparams 35 | 36 | max_article_len,max_sentences, max_sent_len = int(1e6), 16, 10000 37 | block_size = min(block_size, tokenizer.max_len_sentences_pair) if tokenizer is not None else block_size 38 | self.block_size = block_size 39 | self.tokenizer = tokenizer 40 | 41 | if os.path.exists(cached_features_file) and (self.hparams is None or not self.hparams.overwrite_data_cache): 42 | print("\nLoading features from cached file %s", cached_features_file) 43 | with open(cached_features_file, "rb") as handle: 44 | self.examples, self.indices_map = pickle.load(handle) 45 | else: 46 | print("\nCreating features from dataset file at ", cached_features_file) 47 | 48 | self.examples = [] 49 | self.indices_map = [] 50 | 51 | for idx_article, article in enumerate(tqdm(all_articles)): 52 | this_sample_sections = [] 53 | title, sections = article[0], ast.literal_eval(article[1]) 54 | valid_sections_count = 0 55 | for section_idx, section in enumerate(sections): 56 | this_sections_sentences = [] 57 | if section[1] == "": 58 | continue 59 | valid_sentences_count = 0 60 | title_with_base_title = "{}:{}".format(title, section[0]) 61 | for sent_idx, sent in enumerate(nltk.sent_tokenize(section[1][:max_article_len])[:max_sentences]): 62 | tokenized_desc = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(json.dumps(sent[:max_sent_len])))[ 63 | :block_size 64 | ] 65 | this_sections_sentences.append( 66 | ( 67 | tokenized_desc, 68 | len(tokenized_desc), 69 | idx_article, 70 | valid_sections_count, 71 | valid_sentences_count, 72 | sent[:max_sent_len], 73 | ), 74 | ) 75 | self.indices_map.append((idx_article, valid_sections_count, valid_sentences_count)) 76 | valid_sentences_count += 1 77 | this_sample_sections.append((this_sections_sentences, title_with_base_title)) 78 | valid_sections_count += 1 79 | self.examples.append((this_sample_sections, title)) 80 | 81 | print("\nSaving features into cached file %s", cached_features_file) 82 | with open(cached_features_file, "wb") as handle: 83 | pickle.dump((self.examples, self.indices_map), handle, protocol=pickle.HIGHEST_PROTOCOL) 84 | 85 | self.labels = [idx_article for idx_article, _, _ in self.indices_map] 86 | 87 | def save_load_splitted_dataset(self, mode, cached_features_file, raw_data_path): 88 | proccessed_path = f"{cached_features_file}_EXAMPLES" 89 | if not os.path.exists(proccessed_path): 90 | all_articles = self.read_all_articles(raw_data_path) 91 | indices = list(range(len(all_articles))) 92 | if mode != "test": 93 | train_indices = sorted( 94 | np.random.choice(indices, replace=False, size=int(len(all_articles) * self.hparams.train_val_ratio)) 95 | ) 96 | val_indices = np.setdiff1d(list(range(len(all_articles))), train_indices) 97 | indices = train_indices if mode == "train" else val_indices 98 | 99 | articles = [] 100 | for i in indices: 101 | articles.append(all_articles[i]) 102 | all_articles = articles 103 | pickle.dump(all_articles, open(proccessed_path, "wb")) 104 | print(f"\nsaved dataset at {proccessed_path}") 105 | else: 106 | all_articles = pickle.load(open(proccessed_path, "rb")) 107 | setattr(self.hparams, f"{mode}_data_file", proccessed_path) 108 | return all_articles 109 | 110 | def read_all_articles(self, raw_data_path): 111 | csv.field_size_limit(sys.maxsize) 112 | with open(raw_data_path, newline="") as f: 113 | reader = csv.reader(f) 114 | all_articles = list(reader) 115 | return all_articles[1:] 116 | 117 | def download_raw(self, dataset_name): 118 | raw_data_path = f"data/datasets/{dataset_name}/raw_data" 119 | os.makedirs(os.path.dirname(raw_data_path), exist_ok=True) 120 | if not os.path.exists(raw_data_path): 121 | os.system(f"wget -O {raw_data_path} {raw_data_link(dataset_name)}") 122 | return raw_data_path 123 | 124 | def __len__(self): 125 | return len(self.indices_map) 126 | 127 | def __getitem__(self, item): 128 | idx_article, idx_section, idx_sentence = self.indices_map[item] 129 | sent = self.examples[idx_article][0][idx_section][0][idx_sentence] 130 | 131 | return ( 132 | torch.tensor(self.tokenizer.build_inputs_with_special_tokens(sent[0]), dtype=torch.long,)[ 133 | : self.hparams.limit_tokens 134 | ], 135 | self.examples[idx_article][1], 136 | self.examples[idx_article][0][idx_section][1], 137 | sent[1], 138 | idx_article, 139 | idx_section, 140 | idx_sentence, 141 | item, 142 | self.labels[item], 143 | ) 144 | 145 | class WikipediaTextDatasetParagraphsSentencesTest(WikipediaTextDatasetParagraphsSentences): 146 | def __init__(self, tokenizer: PreTrainedTokenizer, hparams, dataset_name, block_size, mode="test"): 147 | super().__init__(tokenizer, hparams, dataset_name, block_size, mode=mode) 148 | 149 | def __len__(self): 150 | return len(self.examples) 151 | 152 | def __getitem__(self, item): 153 | sections = [] 154 | for idx_section, section in enumerate(self.examples[item][0]): 155 | sentences = [] 156 | for idx_sentence, sentence in enumerate(section[0]): 157 | sentences.append( 158 | ( 159 | torch.tensor(self.tokenizer.build_inputs_with_special_tokens(sentence[0]), dtype=torch.long,), 160 | self.examples[item][1], 161 | section[1], 162 | sentence[1], 163 | item, 164 | idx_section, 165 | idx_sentence, 166 | item, 167 | self.labels[item], 168 | ) 169 | ) 170 | sections.append(sentences) 171 | return sections 172 | 173 | -------------------------------------------------------------------------------- /data/datasets/video_games/gt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/datasets/video_games/gt -------------------------------------------------------------------------------- /data/datasets/wines/gt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/datasets/wines/gt -------------------------------------------------------------------------------- /data/images/Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/Results.png -------------------------------------------------------------------------------- /data/images/inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/inference.png -------------------------------------------------------------------------------- /data/images/training_intuition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/training_intuition.png -------------------------------------------------------------------------------- /data/images/training_intuition2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/data/images/training_intuition2.png -------------------------------------------------------------------------------- /instructions/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /instructions/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /instructions/SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /instructions/SUPPORT.md: -------------------------------------------------------------------------------- 1 | 2 | # Support 3 | 4 | ## How to file issues and get help 5 | 6 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 7 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 8 | feature request as a new Issue. 9 | 10 | For help and questions about using this project, please use the issues section as well, we will respond to everyone in time. 11 | 12 | ## Microsoft Support Policy 13 | 14 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 15 | -------------------------------------------------------------------------------- /instructions/installation.sh: -------------------------------------------------------------------------------- 1 | # The installations script can be executed as a bash script. 2 | conda create -n SDR python=3.7 --yes 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | conda activate SDR 5 | 6 | conda install -c pytorch pytorch==1.7.0 torchvision cudatoolkit=11.0 --yes 7 | pip install -U cython transformers==3.1.0 nltk pytorch-metric-learning joblib pytorch-lightning==1.1.8 pandas 8 | -------------------------------------------------------------------------------- /models/SDR/SDR.py: -------------------------------------------------------------------------------- 1 | from data.datasets import ( 2 | WikipediaTextDatasetParagraphsSentences, 3 | WikipediaTextDatasetParagraphsSentencesTest, 4 | ) 5 | from utils.argparse_init import str2bool 6 | from models.SDR.SDR_utils import MPerClassSamplerDeter 7 | from data.data_utils import get_gt_seeds_titles, reco_sentence_collate, reco_sentence_test_collate 8 | from functools import partial 9 | import os 10 | from models.reco.hierarchical_reco import vectorize_reco_hierarchical 11 | from utils.torch_utils import to_numpy 12 | from models.transformers_base import TransformersBase 13 | from models.doc_similarity_pl_template import DocEmbeddingTemplate 14 | from utils import switch_functions 15 | from models import transformer_utils 16 | import numpy as np 17 | import torch 18 | from utils import metrics_utils 19 | from pytorch_metric_learning.samplers import MPerClassSampler 20 | from torch.utils.data.dataloader import DataLoader 21 | import json 22 | 23 | 24 | class SDR(TransformersBase): 25 | 26 | """ 27 | Author: Dvir Ginzburg. 28 | 29 | SDR model (ACL IJCNLP 2021) 30 | """ 31 | 32 | def __init__( 33 | self, hparams, 34 | ): 35 | """Stub.""" 36 | super(SDR, self).__init__(hparams) 37 | 38 | 39 | def forward_train(self, batch): 40 | inputs, labels = transformer_utils.mask_tokens(batch[0].clone().detach(), self.tokenizer, self.hparams) 41 | 42 | outputs = self.model( 43 | inputs, 44 | masked_lm_labels=labels, 45 | non_masked_input_ids=batch[0], 46 | sample_labels=batch[-1], 47 | run_similarity=True, 48 | run_mlm=True, 49 | ) 50 | 51 | self.losses["mlm_loss"] = outputs[0] 52 | self.losses["d2v_loss"] = (outputs[1] or 0) * self.hparams.sim_loss_lambda # If no similarity loss we ignore 53 | 54 | tracked = self.track_metrics(input_ids=inputs, outputs=outputs, is_train=self.hparams.mode == "train", labels=labels,) 55 | self.tracks.update(tracked) 56 | 57 | return 58 | 59 | def forward_val(self, batch): 60 | self.forward_train(batch) 61 | 62 | def test_step(self, batch, batch_idx): 63 | section_out = [] 64 | for section in batch[0]: # batch=1 for test 65 | sentences=[] 66 | sentences_embed_per_token = [ 67 | self.model( 68 | sentence.unsqueeze(0), masked_lm_labels=None, run_similarity=False 69 | )[5].squeeze(0) 70 | for sentence in section[0][:8] 71 | ] 72 | for idx, sentence in enumerate(sentences_embed_per_token): 73 | sentences.append(sentence[: section[2][idx]].mean(0)) # We take the non-padded tokens and mean them 74 | section_out.append(torch.stack(sentences)) 75 | return (section_out, batch[0][0][1][0]) # title name 76 | 77 | def forward(self, batch): 78 | eval(f"self.forward_{self.hparams.mode}")(batch) 79 | 80 | @staticmethod 81 | def track_metrics( 82 | outputs=None, input_ids=None, labels=None, is_train=True, batch_idx=0, 83 | ): 84 | mode = "train" if is_train else "val" 85 | 86 | trackes = {} 87 | lm_pred = np.argmax(outputs[3].cpu().detach().numpy(), axis=2) 88 | labels_numpy = labels.cpu().numpy() 89 | labels_non_zero = labels_numpy[np.array(labels_numpy != -100)] if np.any(labels_numpy != -100) else np.zeros(1) 90 | lm_pred_non_zero = lm_pred[np.array(labels_numpy != -100)] if np.any(labels_numpy != -100) else np.ones(1) 91 | lm_acc = torch.tensor( 92 | metrics_utils.simple_accuracy(lm_pred_non_zero, labels_non_zero), device=outputs[3].device, 93 | ).reshape((1, -1)) 94 | 95 | trackes["lm_acc_{}".format(mode)] = lm_acc.detach().cpu() 96 | 97 | return trackes 98 | 99 | def test_epoch_end(self, outputs, recos_path=None): 100 | if self.trainer.checkpoint_callback.last_model_path == "" and self.hparams.resume_from_checkpoint is None: 101 | self.trainer.checkpoint_callback.last_model_path = f"{self.hparams.hparams_dir}/no_train" 102 | elif(self.hparams.resume_from_checkpoint is not None): 103 | self.trainer.checkpoint_callback.last_model_path = self.hparams.resume_from_checkpoint 104 | if recos_path is None: 105 | save_outputs_path = f"{self.trainer.checkpoint_callback.last_model_path}_FEATURES_NumSamples_{len(outputs)}" 106 | 107 | if isinstance(outputs[0][0][0], torch.Tensor): 108 | outputs = [([to_numpy(section) for section in sample[0]], sample[1]) for sample in outputs] 109 | torch.save(outputs, save_outputs_path) 110 | print(f"\nSaved to {save_outputs_path}\n") 111 | 112 | titles = popular_titles = [out[1][:-1] for out in outputs] 113 | idxs, gt_path = list(range(len(titles))), "" 114 | 115 | section_sentences_features = [out[0] for out in outputs] 116 | popular_titles, idxs, gt_path = get_gt_seeds_titles(titles, self.hparams.dataset_name) 117 | 118 | self.hparams.test_sample_size = ( 119 | self.hparams.test_sample_size if self.hparams.test_sample_size > 0 else len(popular_titles) 120 | ) 121 | idxs = idxs[: self.hparams.test_sample_size] 122 | 123 | recos, metrics = vectorize_reco_hierarchical( 124 | all_features=section_sentences_features, 125 | titles=titles, 126 | gt_path=gt_path, 127 | output_path=self.trainer.checkpoint_callback.last_model_path, 128 | ) 129 | metrics = { 130 | "mrr": float(metrics["mrr"]), 131 | "mpr": float(metrics["mpr"]), 132 | **{f"hit_rate_{rate[0]}": float(rate[1]) for rate in metrics["hit_rates"]}, 133 | } 134 | print(json.dumps(metrics, indent=2)) 135 | for k, v in metrics.items(): 136 | self.logger.experiment.add_scalar(k, v, global_step=self.global_step) 137 | 138 | recos_path = os.path.join( 139 | os.path.dirname(self.trainer.checkpoint_callback.last_model_path), 140 | f"{os.path.basename(self.trainer.checkpoint_callback.last_model_path)[:-5]}" 141 | f"_numSamples_{self.hparams.test_sample_size}", 142 | ) 143 | torch.save(recos, recos_path) 144 | print("Saving recos in {}".format(recos_path)) 145 | 146 | setattr(self.hparams, "recos_path", recos_path) 147 | return 148 | 149 | def dataloader(self, mode=None): 150 | if mode == "train": 151 | sampler = MPerClassSampler( 152 | self.train_dataset.labels, 153 | 2, 154 | batch_size=self.hparams.train_batch_size, 155 | length_before_new_iter=(self.hparams.limit_train_batches) * self.hparams.train_batch_size, 156 | ) 157 | 158 | loader = DataLoader( 159 | self.train_dataset, 160 | num_workers=self.hparams.num_data_workers, 161 | sampler=sampler, 162 | batch_size=self.hparams.train_batch_size, 163 | collate_fn=partial(reco_sentence_collate, tokenizer=self.tokenizer,), 164 | ) 165 | 166 | elif mode == "val": 167 | sampler = MPerClassSamplerDeter( 168 | self.val_dataset.labels, 169 | 2, 170 | length_before_new_iter=self.hparams.limit_val_indices_batches, 171 | batch_size=self.hparams.val_batch_size, 172 | ) 173 | 174 | loader = DataLoader( 175 | self.val_dataset, 176 | num_workers=self.hparams.num_data_workers, 177 | sampler=sampler, 178 | batch_size=self.hparams.val_batch_size, 179 | collate_fn=partial(reco_sentence_collate, tokenizer=self.tokenizer,), 180 | ) 181 | 182 | else: 183 | loader = DataLoader( 184 | self.test_dataset, 185 | num_workers=self.hparams.num_data_workers, 186 | batch_size=self.hparams.test_batch_size, 187 | collate_fn=partial(reco_sentence_test_collate, tokenizer=self.tokenizer,), 188 | ) 189 | return loader 190 | 191 | @staticmethod 192 | def add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False): 193 | parser = TransformersBase.add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False) 194 | parser.add_argument("--hard_mine", type=str2bool, nargs="?", const=True, default=True) 195 | parser.add_argument("--metric_loss_func", type=str, default="ContrastiveLoss") # TripletMarginLoss #CosineLoss 196 | parser.add_argument("--sim_loss_lambda", type=float, default=0.1) 197 | parser.add_argument("--limit_tokens", type=int, default=64) 198 | parser.add_argument("--limit_val_indices_batches", type=int, default=500) 199 | parser.add_argument("--metric_for_similarity", type=str, choices=["cosine", "norm_euc"], default="cosine") 200 | 201 | parser.set_defaults( 202 | check_val_every_n_epoch=1, 203 | batch_size=32, 204 | accumulate_grad_batches=2, 205 | metric_to_track="train_mlm_loss_epoch", 206 | ) 207 | 208 | return parser 209 | 210 | def prepare_data(self): 211 | block_size = ( 212 | self.hparams.block_size 213 | if hasattr(self.hparams, "block_size") 214 | and self.hparams.block_size > 0 215 | and self.hparams.block_size < self.tokenizer.max_len 216 | else self.tokenizer.max_len 217 | ) 218 | self.train_dataset = WikipediaTextDatasetParagraphsSentences( 219 | tokenizer=self.tokenizer, 220 | hparams=self.hparams, 221 | dataset_name=self.hparams.dataset_name, 222 | block_size=block_size, 223 | mode="train", 224 | ) 225 | self.val_dataset = WikipediaTextDatasetParagraphsSentences( 226 | tokenizer=self.tokenizer, 227 | hparams=self.hparams, 228 | dataset_name=self.hparams.dataset_name, 229 | block_size=block_size, 230 | mode="val", 231 | ) 232 | self.val_dataset.indices_map = self.val_dataset.indices_map[: self.hparams.limit_val_indices_batches] 233 | self.val_dataset.labels = self.val_dataset.labels[: self.hparams.limit_val_indices_batches] 234 | 235 | self.test_dataset = WikipediaTextDatasetParagraphsSentencesTest( 236 | tokenizer=self.tokenizer, 237 | hparams=self.hparams, 238 | dataset_name=self.hparams.dataset_name, 239 | block_size=block_size, 240 | mode="test", 241 | ) 242 | 243 | -------------------------------------------------------------------------------- /models/SDR/SDR_utils.py: -------------------------------------------------------------------------------- 1 | from pytorch_metric_learning.samplers.m_per_class_sampler import MPerClassSampler 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | from pytorch_metric_learning.utils import common_functions as c_f 5 | 6 | # modified from 7 | # https://raw.githubusercontent.com/bnulihaixia/Deep_metric/master/utils/sampler.py 8 | class MPerClassSamplerDeter(MPerClassSampler): 9 | """ 10 | At every iteration, this will return m samples per class. For example, 11 | if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples 12 | each will be returned 13 | """ 14 | 15 | def __init__(self, labels, m, batch_size=None, length_before_new_iter=100000): 16 | super(MPerClassSamplerDeter, self).__init__(labels, m, batch_size, int(length_before_new_iter)) 17 | self.shuffled_idx_list = None 18 | 19 | def __iter__(self): 20 | idx_list = [0] * self.list_size 21 | i = 0 22 | num_iters = self.calculate_num_iters() 23 | if self.shuffled_idx_list is None: 24 | for _ in range(num_iters): 25 | 26 | c_f.NUMPY_RANDOM.shuffle(self.labels) 27 | if self.batch_size is None: 28 | curr_label_set = self.labels 29 | else: 30 | curr_label_set = self.labels[: self.batch_size // self.m_per_class] 31 | for label in curr_label_set: 32 | t = self.labels_to_indices[label] 33 | idx_list[i : i + self.m_per_class] = c_f.safe_random_choice(t, size=self.m_per_class) 34 | i += self.m_per_class 35 | self.shuffled_idx_list = idx_list 36 | return iter(self.shuffled_idx_list) 37 | -------------------------------------------------------------------------------- /models/SDR/similarity_modeling.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import logging 3 | import math 4 | from pytorch_metric_learning.distances.lp_distance import LpDistance 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import CrossEntropyLoss 9 | from torch.nn.functional import gelu 10 | from pytorch_metric_learning import miners, losses, reducers 11 | 12 | from transformers.configuration_roberta import RobertaConfig 13 | from transformers.modeling_bert import BertLayerNorm, BertPreTrainedModel 14 | from transformers.modeling_roberta import RobertaModel, RobertaLMHead 15 | from pytorch_metric_learning.distances import CosineSimilarity 16 | 17 | 18 | class SimilarityModeling(BertPreTrainedModel): 19 | config_class = RobertaConfig 20 | base_model_prefix = "roberta" 21 | 22 | def __init__(self, config, hparams): 23 | super().__init__(config) 24 | self.hparams = hparams 25 | config.output_hidden_states = True 26 | 27 | self.roberta = RobertaModel(config) 28 | self.lm_head = RobertaLMHead(config) 29 | self.init_weights() 30 | 31 | if self.hparams.metric_for_similarity == "cosine": 32 | self.metric = CosineSimilarity() 33 | pos_margin, neg_margin = 1, 0 34 | neg_margin = -1 if (getattr(self.hparams, "metric_loss_func", "ContrastiveLoss") == "CosineLoss") else 0 35 | elif self.hparams.metric_for_similarity == "norm_euc": 36 | self.metric = LpDistance(normalize_embeddings=True, p=2) 37 | pos_margin, neg_margin = 0, 1 38 | 39 | self.reducer = reducers.DoNothingReducer() 40 | if self.hparams.hard_mine: 41 | self.miner_func = miners.MultiSimilarityMiner() 42 | else: 43 | self.miner_func = miners.BatchEasyHardMiner( 44 | pos_strategy=miners.BatchEasyHardMiner.ALL, 45 | neg_strategy=miners.BatchEasyHardMiner.ALL, 46 | distance=CosineSimilarity(), 47 | ) 48 | 49 | if getattr(self.hparams, "metric_loss_func", "ContrastiveLoss") in ["ContrastiveLoss", "CosineLoss"]: 50 | self.similarity_loss_func = losses.ContrastiveLoss( 51 | pos_margin=pos_margin, neg_margin=neg_margin, distance=self.metric 52 | ) # |np-sp|_+ + |sn-mn|_+ so for cossim we do pos_m=1 and neg_m=0 53 | else: 54 | self.similarity_loss_func = losses.TripletMarginLoss(margin=1, distance=self.metric) 55 | 56 | def get_output_embeddings(self): 57 | return self.lm_head.decoder 58 | 59 | @staticmethod 60 | def mean_mask(features, mask): 61 | return (features * mask.unsqueeze(-1)).sum(1) / mask.sum(-1, keepdim=True) 62 | 63 | def forward( 64 | self, 65 | input_ids=None, 66 | sample_labels=None, 67 | samples_idxs=None, 68 | track_sim_dict=None, 69 | non_masked_input_ids=None, 70 | attention_mask=None, 71 | token_type_ids=None, 72 | position_ids=None, 73 | head_mask=None, 74 | inputs_embeds=None, 75 | masked_lm_labels=None, 76 | labels=None, 77 | output_hidden_states=False, 78 | return_dict=False, 79 | run_similarity=False, 80 | run_mlm=True, 81 | ): 82 | if run_mlm: 83 | outputs = list( 84 | self.roberta( 85 | input_ids, 86 | attention_mask=attention_mask, 87 | token_type_ids=token_type_ids, 88 | position_ids=position_ids, 89 | head_mask=head_mask, 90 | inputs_embeds=inputs_embeds, 91 | output_hidden_states=output_hidden_states, 92 | return_dict=return_dict, 93 | ) 94 | ) 95 | sequence_output = outputs[0] 96 | prediction_scores = self.lm_head(sequence_output) 97 | outputs = (prediction_scores, None, sequence_output) # Add hidden states and attention if they are here 98 | 99 | ####### 100 | # MLM 101 | ####### 102 | masked_lm_loss = torch.zeros(1, device=prediction_scores.device).float() 103 | 104 | if (masked_lm_labels is not None and (not (masked_lm_labels == -100).all())) and self.hparams.mlm: 105 | loss_fct = CrossEntropyLoss() 106 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 107 | else: 108 | masked_lm_loss = 0 109 | else: 110 | outputs = ( 111 | torch.zeros([*input_ids.shape, 50265]).to(input_ids.device).float(), 112 | None, 113 | torch.zeros([*input_ids.shape, 1024]).to(input_ids.device).float(), 114 | ) 115 | masked_lm_loss = torch.zeros(1)[0].to(input_ids.device).float() 116 | 117 | ####### 118 | # Similarity 119 | ####### 120 | if run_similarity: 121 | non_masked_outputs = self.roberta( 122 | non_masked_input_ids, 123 | attention_mask=attention_mask, 124 | token_type_ids=token_type_ids, 125 | position_ids=position_ids, 126 | head_mask=head_mask, 127 | inputs_embeds=inputs_embeds, 128 | output_hidden_states=output_hidden_states, 129 | return_dict=return_dict, 130 | ) 131 | non_masked_seq_out = non_masked_outputs[0] 132 | 133 | meaned_sentences = non_masked_seq_out.mean(1) 134 | miner_output = list(self.miner_func(meaned_sentences, sample_labels)) 135 | 136 | sim_loss = self.similarity_loss_func(meaned_sentences, sample_labels, miner_output) 137 | outputs = (masked_lm_loss, sim_loss, torch.zeros(1)) + outputs 138 | else: 139 | outputs = ( 140 | masked_lm_loss, 141 | torch.zeros(1)[0].to(input_ids.device).float(), 142 | torch.zeros(1)[0].to(input_ids.device).float(), 143 | ) + outputs 144 | 145 | return outputs 146 | -------------------------------------------------------------------------------- /models/doc_similarity_pl_template.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from utils.torch_utils import to_numpy 3 | from pytorch_lightning import _logger as log 4 | from pytorch_lightning.core import LightningModule 5 | import torch 6 | from utils import argparse_init 7 | from utils import switch_functions 8 | from utils.model_utils import extract_model_path_for_hyperparams 9 | from subprocess import Popen 10 | import pandas as pd 11 | 12 | 13 | class DocEmbeddingTemplate(LightningModule): 14 | 15 | """ 16 | Author: Dvir Ginzburg. 17 | 18 | This is a template for future document templates using pytorch lightning. 19 | """ 20 | 21 | def __init__( 22 | self, hparams, 23 | ): 24 | super(DocEmbeddingTemplate, self).__init__() 25 | self.hparams = hparams 26 | self.hparams.hparams_dir = extract_model_path_for_hyperparams(self.hparams.default_root_dir, self) 27 | self.losses = {} 28 | self.tracks = {} 29 | self.hparams.mode = "val" 30 | 31 | 32 | def forward(self, data): 33 | """ 34 | forward function for the doc similarity network 35 | """ 36 | raise NotImplementedError() 37 | 38 | def training_step(self, batch, batch_idx, mode="train"): 39 | """ 40 | Lightning calls this inside the training loop with the 41 | data from the training dataloader passed in as `batch`. 42 | """ 43 | self.losses = {} 44 | self.tracks = {} 45 | self.hparams.batch_idx = batch_idx 46 | self.hparams.mode = mode 47 | self.batch = batch 48 | 49 | batch = self(batch) 50 | 51 | self.tracks[f"tot_loss"] = sum(self.losses.values()).mean() 52 | 53 | all = {k: to_numpy(v) for k, v in {**self.tracks, **self.losses}.items()} 54 | getattr(self, f"{mode}_logs", None).append(all) 55 | self.log_step(all) 56 | 57 | output = collections.OrderedDict({"loss": self.tracks[f"tot_loss"]}) 58 | return output 59 | 60 | def validation_step(self, batch, batch_idx, mode="val"): 61 | """Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`.""" 62 | 63 | return self.training_step(batch, batch_idx, mode=mode) 64 | 65 | def log_step(self, all): 66 | if not ( 67 | getattr(self.hparams, f"{self.hparams.mode}_batch_size") 68 | % (getattr(self.hparams, f"{self.hparams.mode}_log_every_n_steps")) 69 | == 0 70 | ): 71 | return 72 | for k, v in all.items(): 73 | if v.shape != (): 74 | v = v.sum() 75 | self.logger.experiment.add_scalar(f"{self.hparams.mode}_{k}_step", v, global_step=self.global_step) 76 | 77 | def test_step(self, batch, batch_idx): 78 | return self.validation_step(batch, batch_idx, mode="test") 79 | 80 | def on_validation_epoch_start(self): 81 | self.val_logs = [] 82 | self.hparams.mode = "val" 83 | 84 | def on_train_epoch_start(self): 85 | self.hparams.current_epoch = self.current_epoch 86 | self.train_logs = [] 87 | self.hparams.mode = "train" 88 | 89 | def on_test_epoch_start(self): 90 | self.test_logs = [] 91 | self.hparams.mode = "test" 92 | 93 | def on_epoch_end_generic(self): 94 | if self.trainer.running_sanity_check: 95 | return 96 | logs = getattr(self, f"{self.hparams.mode}_logs", None) 97 | 98 | self.log_dict(logs, prefix=self.hparams.mode) 99 | 100 | def log_dict(self, logs, prefix): 101 | dict_of_lists = pd.DataFrame(logs).to_dict("list") 102 | for lst in dict_of_lists: 103 | dict_of_lists[lst] = list(filter(lambda x: not pd.isnull(x), dict_of_lists[lst])) 104 | for key, lst in dict_of_lists.items(): 105 | s = 0 106 | for item in lst: 107 | s += item.sum() 108 | name = f"{prefix}_{key}_epoch" 109 | val = s / len(lst) 110 | self.logger.experiment.add_scalar(name, val, global_step=self.global_step) 111 | if self.hparams.metric_to_track == name: 112 | self.log(name, torch.tensor(val)) 113 | 114 | def on_train_epoch_end(self, outputs) -> None: 115 | self.on_epoch_end_generic() 116 | 117 | def on_validation_epoch_end(self) -> None: 118 | if self.trainer.running_sanity_check: 119 | return 120 | if self.current_epoch % 10 == 0: 121 | self.logger.experiment.add_text("Profiler", self.trainer.profiler.summary(), global_step=self.global_step) 122 | 123 | 124 | def on_test_epoch_end(self) -> None: 125 | self.on_epoch_end_generic() 126 | 127 | def validation_epoch_end(self, outputs): 128 | self.on_epoch_end_generic() 129 | 130 | # --------------------- 131 | # TRAINING SETUP 132 | # --------------------- 133 | def configure_optimizers(self): 134 | """ 135 | Return whatever optimizers and learning rate schedulers you want here. 136 | 137 | At least one optimizer is required. 138 | """ 139 | no_decay = ["bias", "LayerNorm.weight"] 140 | optimizer_grouped_parameters = [ 141 | { 142 | "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 143 | "weight_decay": self.hparams.weight_decay, 144 | }, 145 | {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 146 | ] 147 | 148 | optimizer = switch_functions.choose_optimizer(self.hparams, optimizer_grouped_parameters) 149 | scheduler = switch_functions.choose_scheduler( 150 | self.hparams.scheduler, optimizer, warmup_steps=0, params=self.hparams 151 | ) 152 | 153 | return [optimizer], [scheduler] 154 | 155 | def dataloader(self): 156 | """ 157 | Returns the relevant dataloader (called once per training). 158 | 159 | Args: 160 | train_val_test (str, optional): Define which dataset to choose from. Defaults to 'train'. 161 | 162 | Returns: 163 | Dataset 164 | """ 165 | raise NotImplementedError() 166 | 167 | def prepare_data(self): 168 | """ 169 | Here we upload the data, called once, all the mask and train, eval split. 170 | 171 | Returns: 172 | Tuple of datasets: train,val and test dataset splits 173 | """ 174 | raise NotImplementedError() 175 | 176 | def train_dataloader(self): 177 | log.info("Training data loader called.") 178 | return self.dataloader(mode="train") 179 | 180 | def val_dataloader(self): 181 | log.info("Validation data loader called.") 182 | return self.dataloader(mode="val") 183 | 184 | def test_dataloader(self): 185 | log.info("Test data loader called.") 186 | return self.dataloader(mode="test") 187 | 188 | @staticmethod 189 | def add_model_specific_args(parser, task_name, dataset_name, is_lowest_leaf=False): 190 | """ 191 | Static function to add all arguments that are relevant only for this module 192 | 193 | Args: 194 | parent_parser (ArgparseManager): The general argparser 195 | 196 | Returns: 197 | ArgparseManager : The new argparser 198 | """ 199 | parser.add_argument( 200 | "--test_sample_size", default=-1, type=int, help="The number of samples to eval recos on. (-1 is all)" 201 | ) 202 | parser.add_argument("--top_k_size", default=-1, type=int, help="The number of top k correspondences. (-1 is all)") 203 | 204 | parser.add_argument("--with_same_series", type=argparse_init.str2bool, nargs="?", const=True, default=True) 205 | 206 | return parser 207 | 208 | -------------------------------------------------------------------------------- /models/reco/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/models/reco/__init__.py -------------------------------------------------------------------------------- /models/reco/hierarchical_reco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from data.data_utils import get_gt_seeds_titles 3 | from models.reco.wiki_recos_eval.eval_metrics import evaluate_wiki_recos 4 | from utils.torch_utils import mean_non_pad_value, to_numpy 5 | from models.reco.recos_utils import index_amp, sim_matrix 6 | import numpy as np 7 | import torch 8 | from tqdm import tqdm 9 | import pickle 10 | from sklearn.preprocessing import normalize 11 | 12 | 13 | def vectorize_reco_hierarchical(all_features, titles,gt_path, output_path=""): 14 | gt = pickle.load(open(gt_path, "rb")) 15 | to_reco_indices = [index_amp(titles, title) for title in gt.keys()] 16 | to_reco_indices = list(filter(lambda title: title != None, to_reco_indices)) 17 | sections_per_article = np.array([len(article) for article in all_features]) 18 | sections_per_article_cumsum = np.array([0,] + [len(article) for article in all_features]).cumsum() 19 | features_per_section = [sec for article in all_features for sec in article] 20 | features_per_section_torch = [torch.from_numpy(feat) for feat in features_per_section] 21 | features_per_section_padded = torch.nn.utils.rnn.pad_sequence( 22 | features_per_section_torch, batch_first=True, padding_value=torch.tensor(float("nan")) 23 | ).cuda() 24 | 25 | num_samples, max_after_pad = features_per_section_padded.shape[:2] 26 | 27 | flattened = features_per_section_padded.reshape(-1, features_per_section_padded.shape[-1]) 28 | 29 | recos = [] 30 | for i in tqdm(to_reco_indices): 31 | if i > len(all_features): 32 | print(f"GT title {titles[i]} was not evaluated") 33 | continue 34 | 35 | to_reco_flattened = features_per_section_padded[ 36 | sections_per_article_cumsum[i] : sections_per_article_cumsum[i + 1] 37 | ].reshape(-1, features_per_section_padded.shape[-1]) 38 | 39 | sim = sim_matrix(to_reco_flattened, flattened) 40 | reshaped_sim = sim.reshape( 41 | sections_per_article_cumsum[i + 1] - sections_per_article_cumsum[i], max_after_pad, num_samples, max_after_pad 42 | ) 43 | sim = reshaped_sim.permute(0, 2, 1, 3) 44 | sim[sim.isnan()] = float("-Inf") 45 | score_mat = sim.max(-1)[0] 46 | score = mean_non_pad_value(score_mat, axis=-1, pad_value=torch.tensor(float("-Inf")).cuda()) 47 | 48 | score_per_article = torch.split(score.t(), sections_per_article.tolist(), dim=0) 49 | score_per_article_padded = torch.nn.utils.rnn.pad_sequence( 50 | score_per_article, batch_first=True, padding_value=float("-Inf") 51 | ).permute(0, 2, 1) 52 | score_per_article_padded[torch.isnan(score_per_article_padded)] = float("-Inf") 53 | par_score_mat = score_per_article_padded.max(-1)[0] 54 | par_score = mean_non_pad_value(par_score_mat, axis=-1, pad_value=float("-Inf")) 55 | 56 | recos.append((i, to_numpy(par_score.argsort(descending=True)[1:]))) 57 | 58 | examples = [[None, title] for title in titles] # reco_utils compatibale 59 | _, mpr, _, mrr, _, hit_rate = evaluate_wiki_recos(recos, output_path, gt_path, examples=examples) 60 | metrics = {"mrr": mrr, "mpr": mpr, "hit_rates": hit_rate} 61 | return recos, metrics 62 | 63 | -------------------------------------------------------------------------------- /models/reco/recos_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import pickle 6 | import numpy as np 7 | 8 | from sklearn.preprocessing import normalize 9 | 10 | 11 | def index_amp(lst, k): 12 | try: 13 | return lst.index(k) if k in lst else lst.index(k.replace("&", "&")) 14 | except: 15 | return 16 | 17 | 18 | def sim_matrix(a, b, eps=1e-8): 19 | """ 20 | Similarity matrix 21 | """ 22 | a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] 23 | a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) 24 | b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) 25 | sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) 26 | return sim_mt 27 | 28 | -------------------------------------------------------------------------------- /models/reco/wiki_recos_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/models/reco/wiki_recos_eval/__init__.py -------------------------------------------------------------------------------- /models/reco/wiki_recos_eval/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from tqdm import tqdm 4 | from pathlib import Path 5 | import sys 6 | from utils.logging_utils import Unbuffered 7 | import json 8 | 9 | 10 | def evaluate_wiki_recos(recos, output_path, gt_path, examples): 11 | original_stdout = sys.stdout 12 | sys.stdout = Unbuffered(open(f"{output_path}_reco_scores", "w")) 13 | dataset_as_dict = {sample[1]: sample for sample in examples} 14 | recos_as_dict = {reco[0]: reco for reco in recos} 15 | names_to_id = {sample[1]: idx for idx, sample in enumerate(examples)} 16 | titles = [ex[1] for ex in examples] 17 | 18 | article_recos_per_article = pickle.load(open(gt_path, "rb")) 19 | 20 | percentiles, mpr = calculate_mpr(recos_as_dict, article_recos_per_article, dataset_as_dict, names_to_id, titles=titles) 21 | recepricals, mrr = calculate_mrr(recos_as_dict, article_recos_per_article, dataset_as_dict, names_to_id, titles=titles) 22 | hit_rates, hit_rate = calculate_mean_hit_rate( 23 | recos_as_dict, 24 | article_recos_per_article, 25 | dataset_as_dict, 26 | names_to_id, 27 | rate_thresulds=[5, 10, 50, 100, 1000], 28 | titles=titles, 29 | ) 30 | 31 | metrics = {"mrr": float(mrr), "mpr": float(mpr), **{f"hit_rate_{rate[0]}": float(rate[1]) for rate in hit_rate}} 32 | print(json.dumps(metrics, indent=2)) 33 | 34 | sys.stdout = original_stdout 35 | return percentiles, mpr, recepricals, mrr, hit_rates, hit_rate 36 | 37 | 38 | def calculate_mpr( 39 | input_recommandations, article_article_gt, dataset, names_to_id, sample_size=-1, popular_titles=None, titles=[] 40 | ): 41 | percentiles = [] 42 | for reco_idx in tqdm(input_recommandations): 43 | wiki_title = titles[reco_idx] 44 | curr_gts, text = [], [] 45 | recommandations = input_recommandations[reco_idx][1] 46 | if wiki_title not in article_article_gt: 47 | continue 48 | for gt_title in article_article_gt[wiki_title].keys(): 49 | lookup = gt_title.replace("&", "&") if "amp;" not in gt_title and gt_title not in names_to_id else gt_title 50 | if lookup not in names_to_id: 51 | print(f"{lookup} not in names_to_id") 52 | continue 53 | recommended_idx_ls = np.where(recommandations == names_to_id[lookup])[0] 54 | if recommended_idx_ls.shape[0] == 0: 55 | continue 56 | curr_gts.append(recommended_idx_ls[0]) 57 | percentiles.extend((recommended_idx_ls[0] / len(recommandations),) * article_article_gt[wiki_title][gt_title]) 58 | text.append("gt: {} gt place: {}".format(gt_title, recommended_idx_ls[0])) 59 | 60 | if len(curr_gts) > 0: 61 | print( 62 | "title: {}\n".format(wiki_title) 63 | + "\n".join(text) 64 | + "\ntopk: {}\n\n\n".format([titles[reco_i] for reco_i in recommandations[:10]]) 65 | ) 66 | 67 | percentiles = percentiles if percentiles != [] else [0] 68 | print("percentiles_mean:{}\n\n\n\n".format(sum(percentiles) / len(percentiles))) 69 | return percentiles, sum(percentiles) / len(percentiles) 70 | 71 | 72 | def calculate_mrr( 73 | input_recommandations, article_article_gt, dataset, names_to_id, sample_size=-1, popular_titles=None, titles=[] 74 | ): 75 | """ 76 | input_recommandations - list of [] per title the order of all titles recommended with it 77 | article_article_gt - dict of dicts, each element is a sample, and all the gt samples goes with it and the count each sample 78 | sample_size - the amount of candidates to calculate the MPR on 79 | """ 80 | recepricals = [] 81 | for reco_idx in tqdm(input_recommandations): 82 | wiki_title = titles[reco_idx] 83 | text = [] 84 | recommandations = input_recommandations[reco_idx][1] 85 | top = len(input_recommandations) 86 | for gt_title in article_article_gt[wiki_title].keys(): 87 | lookup = gt_title.replace("&", "&") if "amp;" not in gt_title and gt_title not in names_to_id else gt_title 88 | if lookup not in names_to_id: 89 | print(f"{lookup} not in names_to_id") 90 | continue 91 | recommended_idx_ls = np.where(recommandations == names_to_id[lookup])[0] 92 | if recommended_idx_ls.shape[0] > 0 and recommended_idx_ls[0] < top: 93 | top = recommended_idx_ls[0] 94 | if recommended_idx_ls.shape[0] == 0: 95 | continue 96 | text.append("gt: {} gt place: {} ".format(gt_title, recommended_idx_ls[0])) 97 | 98 | if top == 0: 99 | top = 1 100 | 101 | if len(text) > 0: 102 | recepricals.append(1 / (top)) 103 | text.append(f"\n receprical: {recepricals[-1]}") 104 | print( 105 | "title: {}\n".format(wiki_title) 106 | + "\n".join(text) 107 | + "\ntopk: {}\n\n\n".format([titles[reco_i] for reco_i in recommandations[:10]]) 108 | ) 109 | 110 | recepricals = recepricals if recepricals != [] else [0] 111 | print(f"Recepricle mean:{sum(recepricals) / len(recepricals)}") 112 | print(f"Recepricals \n {recepricals}") 113 | return recepricals, sum(recepricals) / len(recepricals) 114 | 115 | 116 | def calculate_mean_hit_rate( 117 | input_recommandations, 118 | article_article_gt, 119 | dataset, 120 | names_to_id, 121 | sample_size=-1, 122 | popular_titles=None, 123 | rate_thresulds=[100], 124 | titles=[], 125 | ): 126 | mean_hits = [[] for i in rate_thresulds] 127 | for reco_idx in tqdm(input_recommandations): 128 | wiki_title = titles[reco_idx] 129 | curr_gts, text = [], [] 130 | hit_by_rate = [0 for i in rate_thresulds] 131 | recommandations = input_recommandations[reco_idx][1] 132 | for gt_title in article_article_gt[wiki_title].keys(): 133 | 134 | lookup = gt_title.replace("&", "&") if "amp;" not in gt_title and gt_title not in names_to_id else gt_title 135 | if lookup not in names_to_id: 136 | print(f"{lookup} not in names_to_id") 137 | continue 138 | recommended_idx_ls = np.where(recommandations == names_to_id[lookup])[0] 139 | for thr_idx, thresuld in enumerate(rate_thresulds): 140 | if recommended_idx_ls.shape[0] != 0 and recommended_idx_ls[0] < thresuld: 141 | hit_by_rate[thr_idx] += 1 142 | text.append(f"gt: {gt_title} gt place: {recommended_idx_ls}") 143 | 144 | if len(text) > 0: 145 | for thr_idx, thresuld in enumerate(rate_thresulds): 146 | print( 147 | f"title: {wiki_title} Hit rate at {thresuld}: {hit_by_rate[thr_idx]} \n \n {''.join(text)} \n topk: {[titles[reco_i] for reco_i in recommandations[:10]]}\n\n\n" 148 | ) 149 | hit_mean = hit_by_rate[thr_idx] / len(article_article_gt[wiki_title].keys()) if hit_by_rate[thr_idx] > 0 else 0 150 | mean_hits[thr_idx].append(hit_mean) 151 | 152 | mean_hits = mean_hits if mean_hits != [[] for rate_thresuld in rate_thresulds] else [[0] for rate_thresuld in rate_thresulds] 153 | mean_hit = [sum(mean_hit) / len(mean_hit) for mean_hit in mean_hits] 154 | mean_hits_with_thresuld = [(thresuld, mean) for (thresuld, mean) in zip(*[rate_thresulds, mean_hit])] 155 | print(f"Hit rate mean:{mean_hits_with_thresuld}") 156 | return mean_hits, mean_hits_with_thresuld 157 | -------------------------------------------------------------------------------- /models/transformer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | from transformers import PreTrainedTokenizer 4 | 5 | 6 | def mask_tokens(inputs: torch.Tensor, tokenizer: PreTrainedTokenizer, args) -> Tuple[torch.Tensor, torch.Tensor]: 7 | """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ 8 | 9 | if tokenizer.mask_token is None: 10 | raise ValueError( 11 | "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer." 12 | ) 13 | 14 | labels = inputs.clone() 15 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 16 | probability_matrix = torch.full(labels.shape, args.mlm_probability, device=labels.device) 17 | special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] 18 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool, device=labels.device), value=0.0) 19 | if tokenizer._pad_token is not None: 20 | padding_mask = labels.eq(tokenizer.pad_token_id) 21 | probability_matrix.masked_fill_(padding_mask, value=0.0) 22 | masked_indices = torch.bernoulli(probability_matrix).bool() 23 | if (~masked_indices).all(): 24 | masked_indices = ~masked_indices # If we choose to not learn anything - learn everything 25 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 26 | 27 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 28 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8, device=labels.device)).bool() & masked_indices 29 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 30 | 31 | # 10% of the time, we replace masked input tokens with random word 32 | indices_random = ( 33 | torch.bernoulli(torch.full(labels.shape, 0.5, device=labels.device)).bool() & masked_indices & ~indices_replaced 34 | ) 35 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long, device=labels.device) 36 | inputs[indices_random] = random_words[indices_random] 37 | 38 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 39 | return inputs, labels 40 | 41 | -------------------------------------------------------------------------------- /models/transformers_base.py: -------------------------------------------------------------------------------- 1 | """Stub.""" 2 | from models.doc_similarity_pl_template import DocEmbeddingTemplate 3 | from utils import argparse_init 4 | from utils import switch_functions 5 | 6 | 7 | class TransformersBase(DocEmbeddingTemplate): 8 | 9 | """ 10 | Author: Dvir Ginzburg. 11 | 12 | This is a template for future document templates using transformers. 13 | """ 14 | 15 | def __init__( 16 | self, hparams, 17 | ): 18 | super(TransformersBase, self).__init__(hparams) 19 | (self.config_class, self.model_class, self.tokenizer_class,) = switch_functions.choose_model_class_configuration( 20 | self.hparams.arch, self.hparams.base_model_name 21 | ) 22 | if self.hparams.config_name: 23 | self.config = self.config_class.from_pretrained(self.hparams.config_name, cache_dir=None) 24 | elif self.hparams.arch_or_path: 25 | self.config = self.config_class.from_pretrained(self.hparams.arch_or_path) 26 | else: 27 | self.config = self.config_class() 28 | if self.hparams.tokenizer_name: 29 | self.tokenizer = self.tokenizer_class.from_pretrained(self.hparams.tokenizer_name) 30 | elif self.hparams.arch_or_path: 31 | self.tokenizer = self.tokenizer_class.from_pretrained(self.hparams.arch_or_path) 32 | else: 33 | raise ValueError( 34 | "You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it," 35 | "and load it from here, using --tokenizer_name".format(self.tokenizer_class.__name__) 36 | ) 37 | self.hparams.tokenizer_pad_id = self.tokenizer.pad_token_id 38 | self.model = self.model_class.from_pretrained( 39 | self.hparams.config_name, from_tf=bool(".ckpt" in self.hparams.config_name), config=self.config, hparams=self.hparams 40 | ) 41 | 42 | @staticmethod 43 | def add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False): 44 | parser = DocEmbeddingTemplate.add_model_specific_args(parent_parser, task_name, dataset_name, is_lowest_leaf=False) 45 | parser.add_argument( 46 | "--mlm", 47 | type=argparse_init.str2bool, 48 | nargs="?", 49 | const=True, 50 | default=True, 51 | help="Train with masked-language modeling loss instead of language modeling.", 52 | ) 53 | 54 | parser.add_argument( 55 | "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss", 56 | ) 57 | parser.add_argument( 58 | "--base_model_name", type=str, default="roberta", help="The underliying BERT-like model this arc.", 59 | ) 60 | base_model_name = parser.parse_known_args()[0].base_model_name 61 | if base_model_name in ["roberta", "tnlr"]: 62 | default_config, default_tokenizer = "roberta-large", "roberta-large" 63 | elif base_model_name in ["bert", "tnlr3"]: 64 | default_config, default_tokenizer = "bert-large-uncased", "bert-large-uncased" 65 | elif base_model_name == "longformer": 66 | default_config, default_tokenizer = "allenai/longformer-base-4096", "allenai/longformer-base-4096" 67 | parser.set_defaults(batch_size=2) 68 | parser.add_argument( 69 | "--config_name", 70 | type=str, 71 | default=default_config, 72 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.", 73 | ) 74 | parser.add_argument( 75 | "--tokenizer_name", 76 | default=default_tokenizer, 77 | type=str, 78 | help="Optional pretrained tokenizer name or path if not the same as model_name_or_path. If both are None, initialize a new tokenizer.", 79 | ) 80 | 81 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 82 | 83 | parser.set_defaults(lr=2e-5, weight_decay=0) 84 | 85 | arch, mlm = parser.parse_known_args()[0].arch, parser.parse_known_args()[0].mlm 86 | if arch in ["bert", "roberta", "distilbert", "camembert", "recoberta", "recoberta_cosine"] and not mlm: 87 | raise ValueError( 88 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm " 89 | "flag (masked language modeling)." 90 | ) 91 | 92 | return parser 93 | -------------------------------------------------------------------------------- /sdr_main.py: -------------------------------------------------------------------------------- 1 | """Top level file, parse flags and call trining loop.""" 2 | import os 3 | from utils.pytorch_lightning_utils.pytorch_lightning_utils import load_params_from_checkpoint 4 | import torch 5 | from pytorch_lightning.profiler.profilers import SimpleProfiler 6 | from utils.pytorch_lightning_utils.callbacks import RunValidationOnStart 7 | from utils import switch_functions 8 | import pytorch_lightning 9 | from pytorch_lightning import Trainer 10 | from pytorch_lightning.callbacks import ModelCheckpoint 11 | from utils.argparse_init import default_arg_parser, init_parse_argparse_default_params 12 | import logging 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | from pytorch_lightning.loggers import TensorBoardLogger 16 | 17 | def main(): 18 | """Initialize all the parsers, before training init.""" 19 | parser = default_arg_parser() 20 | parser = Trainer.add_argparse_args(parser) # Bug in PL 21 | parser = default_arg_parser(description="docBert", parents=[parser]) 22 | 23 | eager_flags = init_parse_argparse_default_params(parser) 24 | model_class_pointer = switch_functions.model_class_pointer(eager_flags["task_name"], eager_flags["architecture"]) 25 | parser = model_class_pointer.add_model_specific_args(parser, eager_flags["task_name"], eager_flags["dataset_name"]) 26 | 27 | hyperparams = parser.parse_args() 28 | main_train(model_class_pointer, hyperparams,parser) 29 | 30 | 31 | def main_train(model_class_pointer, hparams,parser): 32 | """Initialize the model, call training loop.""" 33 | pytorch_lightning.utilities.seed.seed_everything(seed=hparams.seed) 34 | 35 | if(hparams.resume_from_checkpoint not in [None,'']): 36 | hparams = load_params_from_checkpoint(hparams, parser) 37 | 38 | model = model_class_pointer(hparams) 39 | 40 | 41 | logger = TensorBoardLogger(save_dir=model.hparams.hparams_dir,name='',default_hp_metric=False) 42 | logger.log_hyperparams(model.hparams, metrics={model.hparams.metric_to_track: 0}) 43 | print(f"\nLog directory:\n{model.hparams.hparams_dir}\n") 44 | 45 | trainer = pytorch_lightning.Trainer( 46 | num_sanity_val_steps=2, 47 | gradient_clip_val=hparams.max_grad_norm, 48 | callbacks=[RunValidationOnStart()], 49 | checkpoint_callback=ModelCheckpoint( 50 | save_top_k=3, 51 | save_last=True, 52 | mode="min" if "acc" not in hparams.metric_to_track else "max", 53 | monitor=hparams.metric_to_track, 54 | filepath=os.path.join(model.hparams.hparams_dir, "{epoch}"), 55 | verbose=True, 56 | ), 57 | logger=logger, 58 | max_epochs=hparams.max_epochs, 59 | gpus=hparams.gpus, 60 | distributed_backend="dp", 61 | limit_val_batches=hparams.limit_val_batches, 62 | limit_train_batches=hparams.limit_train_batches, 63 | limit_test_batches=hparams.limit_test_batches, 64 | check_val_every_n_epoch=hparams.check_val_every_n_epoch, 65 | profiler=SimpleProfiler(), 66 | accumulate_grad_batches=hparams.accumulate_grad_batches, 67 | reload_dataloaders_every_epoch=True, 68 | # load 69 | resume_from_checkpoint=hparams.resume_from_checkpoint, 70 | ) 71 | if(not hparams.test_only): 72 | trainer.fit(model) 73 | else: 74 | if(hparams.resume_from_checkpoint is not None): 75 | model = model.load_from_checkpoint(hparams.resume_from_checkpoint,hparams=hparams, map_location=torch.device(f"cpu")) 76 | trainer.test(model) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | 82 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/utils/__init__.py -------------------------------------------------------------------------------- /utils/argparse_init.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | from argparse import ArgumentDefaultsHelpFormatter 5 | 6 | 7 | def str2intlist(v): 8 | if v.isdigit(): 9 | return [int(v)] 10 | try: 11 | return [int(dig) for dig in v.split("_")] 12 | except Exception as e: 13 | raise argparse.ArgumentTypeError('Excpected int or "4_4"') 14 | 15 | 16 | def str2bool(v): 17 | if isinstance(v, bool): 18 | return v 19 | if v.lower() in ("yes", "true", "t", "y", "1"): 20 | return True 21 | elif v.lower() in ("no", "false", "f", "n", "0"): 22 | return False 23 | else: 24 | raise argparse.ArgumentTypeError("Boolean value expected.") 25 | 26 | 27 | def default_arg_parser(description="", conflict_handler="resolve", parents=[], is_lowest_leaf=False): 28 | """ 29 | Generate the default parser - Helper for readability 30 | 31 | Args: 32 | description (str, optional): name of the parser - usually project name. Defaults to ''. 33 | conflict_handler (str, optional): whether to raise error on conflict or resolve(take last). Defaults to 'resolve'. 34 | parents (list, optional): [the name of parent argument managers]. Defaults to []. 35 | 36 | Returns: 37 | [type]: [description] 38 | """ 39 | description = ( 40 | parents[0].description + description 41 | if len(parents) != 0 and parents[0] is not None and parents[0].description is not None 42 | else description 43 | ) 44 | parser = argparse.ArgumentParser( 45 | description=description, 46 | add_help=is_lowest_leaf, 47 | formatter_class=ArgumentDefaultsHelpFormatter, 48 | conflict_handler=conflict_handler, 49 | parents=parents, 50 | ) 51 | 52 | return parser 53 | 54 | def get_non_default(parsed,parser): 55 | non_default = { 56 | opt.dest: getattr(parsed, opt.dest) 57 | for opt in parser._option_string_actions.values() 58 | if hasattr(parsed, opt.dest) and opt.default != getattr(parsed, opt.dest) 59 | } 60 | return non_default 61 | 62 | 63 | def init_parse_argparse_default_params(parser, dataset_name=None, arch=None): 64 | TASK_OPTIONS = ["document_similarity"] 65 | 66 | parser.add_argument( 67 | "--task_name", type=str, default="document_similarity", choices=TASK_OPTIONS, help="The task to solve", 68 | ) 69 | task_name = parser.parse_known_args()[0].task_name 70 | 71 | DATASET_OPTIONS = { 72 | "document_similarity": ["video_games", "wines",], 73 | } 74 | parser.add_argument( 75 | "--dataset_name", 76 | type=str, 77 | default=DATASET_OPTIONS[task_name][0], 78 | choices=DATASET_OPTIONS[task_name], 79 | help="The dataset to evalute on", 80 | ) 81 | dataset_name = dataset_name or parser.parse_known_args()[0].dataset_name 82 | 83 | ## General learning parameters 84 | parser.add_argument( 85 | "--train_batch_size", default={"document_similarity": 32}[task_name], type=int, help="Number of samples in batch", 86 | ) 87 | parser.add_argument( 88 | "--max_epochs", default={"document_similarity": 50}[task_name], type=int, help="Number of epochs to train", 89 | ) 90 | parser.add_argument( 91 | "-lr", default={"document_similarity": 2e-5}[task_name], type=float, help="Learning rate", 92 | ) 93 | 94 | parser.add_argument("--optimizer", default="adamW", help="Optimizer to use") 95 | parser.add_argument( 96 | "--scheduler", 97 | default="linear_with_warmup", 98 | choices=["linear_with_warmup", "cosine_annealing_lr"], 99 | help="Scheduler to use", 100 | ) 101 | parser.add_argument("--weight_decay", default=5e-3, help="weight decay") 102 | 103 | ## Input Output parameters 104 | parser.add_argument( 105 | "--default_root_dir", default=os.path.join(os.getcwd(), "output", task_name), help="The path to store this run output", 106 | ) 107 | output_dir = parser.parse_known_args()[0].default_root_dir 108 | os.makedirs(output_dir, exist_ok=True) 109 | 110 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 111 | 112 | ### Model Parameters 113 | parser.add_argument( 114 | "--arch", "--architecture", default={"document_similarity": "SDR"}[task_name], help="Architecture", 115 | ) 116 | 117 | architecture = arch or parser.parse_known_args()[0].arch 118 | 119 | parser.add_argument("--accumulate_grad_batches", default=1, type=int) 120 | 121 | ### Auxiliary parameters 122 | parser.add_argument("--gpus", default=1, type=str, help="gpu count") 123 | parser.add_argument("--num_data_workers", default=0, type=int, help="for parallel data load") 124 | parser.add_argument("--overwrite_data_cache", type=str2bool, nargs="?", const=True, default=False) 125 | 126 | parser.add_argument("--train_val_ratio", default=0.90, type=float, help="The split ratio of the data") 127 | parser.add_argument( 128 | "--limit_train_batches", default=10000, type=int, 129 | ) 130 | 131 | parser.add_argument( 132 | "--train_log_every_n_steps", default=50, type=int, 133 | ) 134 | parser.add_argument( 135 | "--val_log_every_n_steps", default=1, type=int, 136 | ) 137 | parser.add_argument( 138 | "--test_log_every_n_steps", default=1, type=int, 139 | ) 140 | 141 | 142 | parser.add_argument("--resume_from_checkpoint", default=None, type=str, help="Path to reload pretrained weights") 143 | parser.add_argument( 144 | "--metric_to_track", default=None, help="which parameter to track on saving", 145 | ) 146 | parser.add_argument("--val_batch_size", default=8, type=int) 147 | parser.add_argument("--test_batch_size", default=1, type=int) 148 | parser.add_argument("--test_only", type=str2bool, nargs="?", const=True, default=False) 149 | 150 | return { 151 | "dataset_name": dataset_name, 152 | "task_name": task_name, 153 | "architecture": architecture, 154 | } 155 | 156 | -------------------------------------------------------------------------------- /utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | class Unbuffered(object): 2 | """ 3 | Create buffer that dumps stdout to file. 4 | 5 | Example: sys.stdout = Unbuffered(open(path + '_output','w')) 6 | """ 7 | 8 | def __init__(self, stream): 9 | self.stream = stream 10 | 11 | def write(self, data): 12 | self.stream.write(data) 13 | self.stream.flush() 14 | 15 | def writelines(self, datas): 16 | self.stream.writelines(datas) 17 | self.stream.flush() 18 | 19 | def __getattr__(self, attr): 20 | return getattr(self.stream, attr) 21 | -------------------------------------------------------------------------------- /utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | def simple_accuracy(preds, labels): 2 | return (preds == labels).mean() 3 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | 5 | def extract_model_path_for_hyperparams(start_path, model): 6 | relevant_hparams = {} 7 | for key in [ 8 | "arch", 9 | "dataset_name", 10 | "test_only" 11 | ]: 12 | if hasattr(model.hparams, key): 13 | relevant_hparams[key] = eval(f"model.hparams.{key}") 14 | 15 | path = os.path.join(start_path, *["{}_{}".format(key, val) for key, val in relevant_hparams.items()]) 16 | 17 | dt_string = datetime.now().strftime("%d_%m_%Y-%H_%M_%S") 18 | path = os.path.join(path, dt_string) 19 | 20 | os.makedirs(path, exist_ok=True) 21 | 22 | return path -------------------------------------------------------------------------------- /utils/pytorch_lightning_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SDR/c913e591e81aa92f0d44629bbc8a8f8cfa4aaf79/utils/pytorch_lightning_utils/__init__.py -------------------------------------------------------------------------------- /utils/pytorch_lightning_utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback 2 | from pytorch_lightning.trainer.trainer import Trainer 3 | 4 | 5 | class RunValidationOnStart(Callback): 6 | def __init__(self): 7 | pass 8 | 9 | def on_train_start(self, trainer: Trainer, pl_module): 10 | return trainer.run_evaluation() 11 | -------------------------------------------------------------------------------- /utils/pytorch_lightning_utils/pytorch_lightning_utils.py: -------------------------------------------------------------------------------- 1 | """Diagnose your system and show basic information 2 | 3 | This server mainly to get detail info for better bug reporting. 4 | 5 | """ 6 | 7 | import os 8 | import platform 9 | import re 10 | import sys 11 | from argparse import Namespace 12 | 13 | import numpy 14 | import tensorboard 15 | import torch 16 | import tqdm 17 | from utils.argparse_init import get_non_default 18 | 19 | sys.path += [os.path.abspath(".."), os.path.abspath(".")] 20 | import pytorch_lightning # noqa: E402 21 | 22 | LEVEL_OFFSET = "\t" 23 | KEY_PADDING = 20 24 | 25 | 26 | def run_and_parse_first_match(run_lambda, command, regex): 27 | """Runs command using run_lambda, returns the first regex match if it exists""" 28 | rc, out, _ = run_lambda(command) 29 | if rc != 0: 30 | return None 31 | match = re.search(regex, out) 32 | if match is None: 33 | return None 34 | return match.group(1) 35 | 36 | 37 | def get_running_cuda_version(run_lambda): 38 | return run_and_parse_first_match(run_lambda, "nvcc --version", r"V(.*)$") 39 | 40 | 41 | def info_system(): 42 | return { 43 | "OS": platform.system(), 44 | "architecture": platform.architecture(), 45 | "version": platform.version(), 46 | "processor": platform.processor(), 47 | "python": platform.python_version(), 48 | } 49 | 50 | 51 | def info_cuda(): 52 | return { 53 | "GPU": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())], 54 | # 'nvidia_driver': get_nvidia_driver_version(run_lambda), 55 | "available": torch.cuda.is_available(), 56 | "version": torch.version.cuda, 57 | } 58 | 59 | 60 | def info_packages(): 61 | return { 62 | "numpy": numpy.__version__, 63 | "pyTorch_version": torch.__version__, 64 | "pyTorch_debug": torch.version.debug, 65 | "pytorch-lightning": pytorch_lightning.__version__, 66 | "tensorboard": tensorboard.__version__, 67 | "tqdm": tqdm.__version__, 68 | } 69 | 70 | 71 | def nice_print(details, level=0): 72 | lines = [] 73 | for k in sorted(details): 74 | key = f"* {k}:" if level == 0 else f"- {k}:" 75 | if isinstance(details[k], dict): 76 | lines += [level * LEVEL_OFFSET + key] 77 | lines += nice_print(details[k], level + 1) 78 | elif isinstance(details[k], (set, list, tuple)): 79 | lines += [level * LEVEL_OFFSET + key] 80 | lines += [(level + 1) * LEVEL_OFFSET + "- " + v for v in details[k]] 81 | else: 82 | template = "{:%is} {}" % KEY_PADDING 83 | key_val = template.format(key, details[k]) 84 | lines += [(level * LEVEL_OFFSET) + key_val] 85 | return lines 86 | 87 | 88 | def main(): 89 | details = { 90 | "System": info_system(), 91 | "CUDA": info_cuda(), 92 | "Packages": info_packages(), 93 | } 94 | lines = nice_print(details) 95 | text = os.linesep.join(lines) 96 | print(text) 97 | 98 | def load_params_from_checkpoint(hparams, parser): 99 | path = hparams.resume_from_checkpoint 100 | hparams_model = Namespace(**torch.load(path,map_location=torch.device('cpu'))['hyper_parameters']) 101 | hparams_model.max_epochs = hparams_model.max_epochs + 30 102 | for k,v in get_non_default(hparams,parser).items(): 103 | setattr(hparams_model,k,v) 104 | hparams_model.gpus = hparams.gpus 105 | for key in vars(hparams): 106 | if(key not in vars(hparams_model)): 107 | setattr(hparams_model,key,getattr(hparams,key,None)) 108 | hparams = hparams_model 109 | return hparams 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /utils/switch_functions.py: -------------------------------------------------------------------------------- 1 | from models.SDR.similarity_modeling import SimilarityModeling 2 | import torch 3 | import transformers 4 | from transformers import get_linear_schedule_with_warmup 5 | from torch import optim 6 | 7 | 8 | def model_class_pointer(task_name, arch): 9 | """Get pointer to class base on flags. 10 | Arguments: 11 | task_name {str} -- the task name, node clasification etc 12 | arch {str} -- recobert, etc 13 | Raises: 14 | Exception: If unknown task,dataset 15 | 16 | Returns: 17 | torch.nn.Module -- The module to train on 18 | 19 | """ 20 | 21 | if task_name == "document_similarity": 22 | if arch == "SDR": 23 | from models.SDR.SDR import SDR 24 | 25 | return SDR 26 | raise Exception("Unkown task") 27 | 28 | 29 | def choose_optimizer(params, network_parameters): 30 | """ 31 | Choose the optimizer from params.optimizer flag 32 | 33 | Args: 34 | params (dict): The input flags 35 | network_parameters (dict): from net.parameters() 36 | 37 | Raises: 38 | Exception: If not matched optimizer 39 | """ 40 | if params.optimizer == "adamW": 41 | optimizer = transformers.AdamW(network_parameters, lr=params.lr,) 42 | elif params.optimizer == "sgd": 43 | optimizer = torch.optim.SGD(network_parameters, lr=params.lr, weight_decay=params.weight_decay, momentum=0.9,) 44 | else: 45 | raise Exception("No valid optimizer provided") 46 | return optimizer 47 | 48 | 49 | def choose_scheduler(scheduler_name, optimizer, warmup_steps, params): 50 | 51 | if scheduler_name == "linear_with_warmup": 52 | return get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=params.max_epochs) 53 | elif scheduler_name == "cosine_annealing_lr": 54 | return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0,) 55 | 56 | else: 57 | raise Exception("No valid optimizer provided") 58 | 59 | 60 | from transformers import ( 61 | RobertaConfig, 62 | RobertaTokenizer, 63 | get_linear_schedule_with_warmup, 64 | ) 65 | 66 | 67 | def choose_model_class_configuration(arch, base_model_name): 68 | MODEL_CLASSES = { 69 | "SDR_roberta": (RobertaConfig, SimilarityModeling, RobertaTokenizer), 70 | } 71 | return MODEL_CLASSES[f"{arch}_{base_model_name}"] 72 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import numbers 4 | 5 | def mean_non_pad_value(tensor, axis, pad_value=0): 6 | mask = tensor != pad_value 7 | tensor[~mask] = 0 8 | tensor_mean = (tensor * mask).sum(dim=axis) / (mask.sum(dim=axis)) 9 | 10 | ignore_mask = (mask.sum(dim=axis)) == 0 11 | tensor_mean[ignore_mask] = pad_value 12 | return tensor_mean 13 | 14 | 15 | def to_numpy(tensor): 16 | """Wrapper around .detach().cpu().numpy() """ 17 | if isinstance(tensor, torch.Tensor): 18 | return tensor.detach().cpu().numpy() 19 | elif isinstance(tensor, np.ndarray): 20 | return tensor 21 | elif isinstance(tensor, numbers.Number): 22 | return np.array(tensor) 23 | else: 24 | raise NotImplementedError 25 | --------------------------------------------------------------------------------