├── .gitignore ├── README.md ├── config ├── mcpredictor-sent.json ├── mcpredictor.json ├── scpredictor-sent.json └── scpredictor.json ├── data └── english_stopwords.txt ├── experiments ├── constituent_analysis.py ├── output_dev_results.py ├── output_dev_set.py ├── preprocess.py ├── scan_events.py ├── scan_multichain.py ├── scan_pos.py ├── scan_sents.py ├── test.py └── train.py ├── mcpredictor ├── models │ ├── base │ │ ├── attention.py │ │ ├── constraint.py │ │ ├── embedding.py │ │ ├── event_encoder.py │ │ ├── score.py │ │ ├── sentence_encoder.py │ │ └── sequence_model.py │ ├── basic_model.py │ ├── multi_chain_sent │ │ ├── model.py │ │ └── network.py │ ├── single_chain │ │ ├── model.py │ │ └── network.py │ └── single_chain_sent │ │ ├── model.py │ │ └── network.py ├── preprocess │ ├── multi_chain.py │ ├── negative_pool.py │ ├── single_chain.py │ ├── stop_event.py │ ├── word_dict.py │ └── word_embedding.py └── utils │ ├── common.py │ ├── config.py │ ├── document.py │ ├── entity.py │ ├── event.py │ └── mention.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Pycharm 132 | .idea/ 133 | # VScode 134 | .vscode/ 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MCPredictor 2 | Experiment code for: 3 | 4 | Long Bai, Saiping Guan, Jiafeng Guo, Zixuan Li, Xiaolong Jin, and Xueqi Cheng. 5 | "*Integrating Deep Event-Level and Script-Level Information for Script Event Prediction*", EMNLP 2021 6 | 7 | 8 | ## 1. Corpus 9 | Corpus can be found in LDC: 10 | https://catalog.ldc.upenn.edu/LDC2011T07 , 11 | since this dataset use documents from year 1994 to 2004, please use at least the second edition. 12 | 13 | ## 2. MCNC dataset 14 | MCNC dataset processing code can be found here: 15 | https://mark.granroth-wilding.co.uk/papers/what_happens_next/ . 16 | 17 | Please use python2.7 environment to run this code. 18 | 19 | Please follow ```README.md``` and ```bin/event_pipeline/rich_docs/gigaword.txt``` to construct the dataset. ```bin/entity_narrative/eval/experiments/generate_sample.sh``` is used to generate dev/test dataset. 20 | 21 | ### 2.1 Modification of Granroth-Wilding's code 22 | Please let me know if I forget any changes. 23 | 24 | #### 2.1.1 modify ```bin/run``` 25 | Since some computer run in other languages, which may raise error 26 | when using JMNL, please set system language to english: 27 | ``` 28 | java -classpath $BUILD_DIR:$DIR/../src/main/java:$DIR/../lib/* \ 29 | -DWNSEARCHDIR=$DIR/../models/wordnet-dict \ 30 | -Duser.language=en \ 31 | $* 32 | ``` 33 | 34 | #### 2.1.2 modify ```bin/event_pipeline/1-parse/preprocess/gigaword/gigaword_split.py``` 35 | It is recommended to use absolute directory ```#!/bin/run_py``` instead of ```#!../run_py``` 36 | 37 | It is recommended to use lxml engine in BeautifulSoup: 38 | ```soup = BeautifulSoup(xml_data, "lxml")``` 39 | 40 | #### 2.1.3 modify directories 41 | Data directories in following files should be changed to user's data directory: 42 | - ```bin/event_pipeline/config/gigaword-nyt``` 43 | - ```bin/event_pipeline/rich_docs/gigaword.txt``` 44 | - ```bin/entity_narrative/eval/experiments/generate_sample.sh``` 45 | 46 | #### 2.1.4 modify ```bin/event_pipeline/1-parse/candc/parse_dir.sh``` 47 | Change to : 48 | ```../../../run_py ../../../../lib/python/whim_common/candc/parsedir.py $*``` 49 | 50 | #### 2.1.5 modify unavailable URLs in ```lib/``` 51 | C&C tool: https://github.com/chbrown/candc 52 | 53 | OpenNLP: http://archive.apache.org/dist/opennlp/opennlp-1.5.3/apache-opennlp-1.5.3-bin.tar.gz 54 | 55 | Stanford-postagger: https://nlp.stanford.edu/software/stanford-postagger-full-2014-01-04.zip 56 | 57 | #### 2.1.6 extract tokenized documents 58 | Since original texts are needed, 59 | ```/gigaword-nyt/tokenized.tar.gz``` should be decompressed 60 | into the same directory. 61 | Replace `````` with the place you want to store the extracted data. 62 | 63 | #### 2.1.7 build java files 64 | Change directory to the root of this code, then: 65 | ```bash 66 | mkdir build 67 | javac -classpath /build:/src/main/java:/lib/* -d build/ src/main/java/cam/whim/opennlp/Tokenize.java 68 | javac -classpath /build:/src/main/java:/lib/* -d build/ src/main/java/cam/whim/opennlp/Parse.java 69 | javac -classpath /build:/src/main/java:/lib/* -d build/ src/main/java/cam/whim/opennlp/StreamEntitiesExtractor.java 70 | javac -classpath /build:/src/main/java:/lib/* -d build/ src/main/java/cam/whim/opennlp/Coreference.java 71 | javac -classpath /build:/src/main/java:/lib/* -d build/ src/main/java/cam/whim/opennlp/Tokenize.java 72 | ``` 73 | Replace `````` with the absolute path of the root of this code. 74 | 75 | Notice: if you want to run PMI (i.e., Chambers and Jurafsky, 2008), please also build java files in ```src/main/java/cam/whim/narrative/chambersJurafsky```. 76 | 77 | ## 3. Installation 78 | Use command ```pip install -e .``` in 79 | project root directory. 80 | 81 | Use command ```pip install -r requirements.txt``` to 82 | install dependencies. 83 | 84 | Environment: python>=3.6. 85 | 86 | ## 4. Preprocess 87 | Use command ```python experiments/preprocess.py --data_dir --work_dir ``` to preprocess data. 88 | Following arguments should be specified: 89 | - ```--data_dir```: the directory of MCNC dataset 90 | - ```--work_dir```: the directory of temp data and results 91 | 92 | On my working platform, It takes about 7 hours to generate the single chain train set, 93 | and takes about 10 hours to generate the multi chain train set. 94 | Please make sure that the process will not be interrupted. 95 | 96 | ## 5. Training 97 | ### train mcpredictor: 98 | ```python experiments/train.py --work_dir --model_config config/mcpredictor-sent.json --device cuda:0 --multi``` 99 | 100 | ### train scpredictor: 101 | ```python experiments/train.py --work_dir --model_config config/scpredictor-sent.json --device cuda:0``` 102 | 103 | ## 6. Testing 104 | ### test mcpredictor: 105 | ```python experiments/test.py --work_dir --model_config config/mcpredictor-sent.json --device cuda:0 --multi``` 106 | 107 | ### test scpredictor: 108 | ```python experiments/test.py --work_dir --model_config config/scpredictor-sent.json --device cuda:0``` 109 | 110 | 111 | ## 7. Citation 112 | 113 | If you find the resource in this repository helpful, please cite 114 | 115 | ``` 116 | @inproceedings{bai-etal-2021-integrating, 117 | title = "Integrating Deep Event-Level and Script-Level Information for Script Event Prediction", 118 | author = "Bai, Long and Guan, Saiping and Guo, Jiafeng and Li, Zixuan and Jin, Xiaolong and Cheng, Xueqi", 119 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 120 | month = nov, 121 | year = "2021", 122 | address = "Online and Punta Cana, Dominican Republic", 123 | publisher = "Association for Computational Linguistics", 124 | url = "https://aclanthology.org/2021.emnlp-main.777", 125 | pages = "9869--9878", 126 | } 127 | ``` 128 | -------------------------------------------------------------------------------- /config/mcpredictor-sent.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mcpredictor-sent-full", 3 | "sequence_model": "transformer", 4 | "score": "euclidean", 5 | "attention": "scaled-dot", 6 | "vocab_size": 31037, 7 | "embedding_size": 300, 8 | "seq_len": 8, 9 | "mention_len": 30, 10 | "event_repr_size": 128, 11 | "dim_feedforward": 1024, 12 | "num_layers": 2, 13 | "num_heads": 16, 14 | "dropout": 0.1, 15 | "lr": 1e-4, 16 | "batch_size": 100, 17 | "npoch": 30, 18 | "interval": 500, 19 | "use_sent": true 20 | } 21 | -------------------------------------------------------------------------------- /config/mcpredictor.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "mcpredictor-full", 3 | "sequence_model": "transformer", 4 | "score": "euclidean", 5 | "attention": "scaled-dot", 6 | "vocab_size": 31037, 7 | "embedding_size": 300, 8 | "seq_len": 8, 9 | "mention_len": 30, 10 | "event_repr_size": 128, 11 | "dim_feedforward": 1024, 12 | "num_layers": 2, 13 | "num_heads": 16, 14 | "dropout": 0.1, 15 | "lr": 1e-4, 16 | "batch_size": 1000, 17 | "npoch": 30, 18 | "interval": 50, 19 | "use_sent": false 20 | } 21 | -------------------------------------------------------------------------------- /config/scpredictor-sent.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "scpredictor-sent-full", 3 | "sequence_model": "transformer", 4 | "score": "euclidean", 5 | "attention": "scaled-dot", 6 | "vocab_size": 31037, 7 | "embedding_size": 300, 8 | "seq_len": 8, 9 | "mention_len": 30, 10 | "event_repr_size": 128, 11 | "dim_feedforward": 1024, 12 | "num_layers": 2, 13 | "num_heads": 16, 14 | "dropout": 0.1, 15 | "lr": 1e-4, 16 | "batch_size": 200, 17 | "npoch": 30, 18 | "interval": 250, 19 | "use_sent": true 20 | } 21 | -------------------------------------------------------------------------------- /config/scpredictor.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "scpredictor-full", 3 | "sequence_model": "transformer", 4 | "score": "euclidean", 5 | "attention": "scaled-dot", 6 | "vocab_size": 31037, 7 | "embedding_size": 300, 8 | "seq_len": 8, 9 | "mention_len": 30, 10 | "event_repr_size": 128, 11 | "dim_feedforward": 1024, 12 | "num_layers": 2, 13 | "num_heads": 16, 14 | "dropout": 0.1, 15 | "lr": 1e-4, 16 | "batch_size": 1000, 17 | "npoch": 30, 18 | "interval": 50, 19 | "use_sent": false 20 | } 21 | -------------------------------------------------------------------------------- /data/english_stopwords.txt: -------------------------------------------------------------------------------- 1 | a 2 | about 3 | above 4 | after 5 | again 6 | against 7 | all 8 | am 9 | an 10 | and 11 | any 12 | are 13 | aren't 14 | as 15 | at 16 | be 17 | because 18 | been 19 | before 20 | being 21 | below 22 | between 23 | both 24 | but 25 | by 26 | can't 27 | cannot 28 | could 29 | couldn't 30 | did 31 | didn't 32 | do 33 | does 34 | doesn't 35 | doing 36 | don't 37 | down 38 | during 39 | each 40 | few 41 | for 42 | from 43 | further 44 | had 45 | hadn't 46 | has 47 | hasn't 48 | have 49 | haven't 50 | having 51 | he 52 | he'd 53 | he'll 54 | he's 55 | her 56 | here 57 | here's 58 | hers 59 | herself 60 | him 61 | himself 62 | his 63 | how 64 | how's 65 | i 66 | i'd 67 | i'll 68 | i'm 69 | i've 70 | if 71 | in 72 | into 73 | is 74 | isn't 75 | it 76 | it's 77 | its 78 | itself 79 | let's 80 | me 81 | more 82 | most 83 | mustn't 84 | my 85 | myself 86 | no 87 | nor 88 | not 89 | of 90 | off 91 | on 92 | once 93 | only 94 | or 95 | other 96 | ought 97 | our 98 | ours 99 | ourselves 100 | out 101 | over 102 | own 103 | same 104 | shan't 105 | she 106 | she'd 107 | she'll 108 | she's 109 | should 110 | shouldn't 111 | so 112 | some 113 | such 114 | than 115 | that 116 | that's 117 | the 118 | their 119 | theirs 120 | them 121 | themselves 122 | then 123 | there 124 | there's 125 | these 126 | they 127 | they'd 128 | they'll 129 | they're 130 | they've 131 | this 132 | those 133 | through 134 | to 135 | too 136 | under 137 | until 138 | up 139 | very 140 | was 141 | wasn't 142 | we 143 | we'd 144 | we'll 145 | we're 146 | we've 147 | were 148 | weren't 149 | what 150 | what's 151 | when 152 | when's 153 | where 154 | where's 155 | which 156 | while 157 | who 158 | who's 159 | whom 160 | why 161 | why's 162 | with 163 | won't 164 | would 165 | wouldn't 166 | you 167 | you'd 168 | you'll 169 | you're 170 | you've 171 | your 172 | yours 173 | yourself 174 | yourselves 175 | -------------------------------------------------------------------------------- /experiments/constituent_analysis.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import nltk 6 | import torch 7 | from torch.utils import data 8 | from tqdm import tqdm 9 | from transformers import BertTokenizerFast 10 | 11 | from mcpredictor.models.multi_chain_sent.model import MultiChainSentModel 12 | from mcpredictor.models.single_chain.model import SingleChainSentModel 13 | from mcpredictor.preprocess.multi_chain import generate_mask_list 14 | from mcpredictor.preprocess.stop_event import load_stop_event 15 | from mcpredictor.utils.config import CONFIG 16 | from mcpredictor.utils.document import document_iterator 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def align_tags_to_ids(words, tags, tokenizer): 23 | """This function aligns tags to bert tokenized results.""" 24 | inputs = tokenizer(words, 25 | is_split_into_words=True, 26 | return_offsets_mapping=True, 27 | padding="max_length", 28 | truncation=True, 29 | max_length=50) 30 | input_ids = inputs.pop("input_ids") 31 | offset_mapping = inputs.pop("offset_mapping") 32 | label_index = 0 33 | cur_label = "O" 34 | labels = [] 35 | for offset in offset_mapping: 36 | if offset[0] == 0 and offset[1] != 0 and label_index < len(tags): 37 | # Begin of a new word 38 | cur_label = tags[label_index] 39 | label_index += 1 40 | labels.append(cur_label) 41 | elif offset[0] == 0 and offset[1] == 0 or label_index >= len(tags): 42 | # Control tokens 43 | labels.append("O") 44 | else: 45 | # Subword 46 | labels.append(cur_label) 47 | assert len(labels) == 50 48 | return labels 49 | 50 | 51 | class MaskedDataSet(data.Dataset): 52 | """This dataset masks some constituents in sentence.""" 53 | 54 | def __init__(self, data, tags): 55 | self.data = data 56 | self.tags = tags 57 | 58 | def __len__(self): 59 | return len(self.data) 60 | 61 | def __getitem__(self, item): 62 | # Sample 63 | events, sents, masks, target = self.data[item] 64 | events = torch.tensor(events) 65 | sents = torch.tensor(sents) 66 | masks = torch.tensor(masks) 67 | target = torch.tensor(target) 68 | # Tag mask 69 | tags = self.tags[item] 70 | tags = torch.tensor(tags).to(torch.int) 71 | masks = torch.logical_and(masks, tags).to(torch.int) 72 | return events, sents, masks, target 73 | 74 | 75 | def tag_dev(data_dir, work_dir): 76 | """POS tagging results.""" 77 | dev_corp_dir = os.path.join(data_dir, "gigaword-nyt", "eval", "multiple_choice", "dev_10k") 78 | tokenized_dir = os.path.join(data_dir, "gigaword-nyt", "tokenized") 79 | pos_dir = os.path.join(data_dir, "gigaword-nyt", "candc", "tags") 80 | # Load stop event list 81 | stoplist = load_stop_event(work_dir) 82 | # Build tokenizer 83 | special_tokens = ["[subj]", "[obj]", "[iobj]"] 84 | tokenizer = BertTokenizerFast.from_pretrained("prajjwal1/bert-tiny", 85 | additional_special_tokens=special_tokens) 86 | control_tokens = special_tokens + ["[UNK]"] 87 | context_size = 8 88 | tags = [] 89 | with tqdm() as pbar: 90 | for doc in document_iterator(corp_dir=dev_corp_dir, 91 | tokenized_dir=tokenized_dir, 92 | pos_dir=pos_dir, 93 | file_type="txt", 94 | doc_type="eval"): 95 | target = doc.target 96 | choices = doc.choices 97 | context = doc.context 98 | verb_position = context[-1]["verb_position"] 99 | # Make tags 100 | sample_tags = [] 101 | for choice in choices: 102 | choice_tags = [] 103 | for choice_role in ["subject", "object", "iobject"]: 104 | protagonist = choice[choice_role] 105 | # Get chain by protagonist 106 | chain = doc.get_chain_for_entity(protagonist, end_pos=verb_position, stoplist=stoplist) 107 | mask_list = generate_mask_list(chain) 108 | # Truncate 109 | if len(chain) > context_size: 110 | chain = chain[-context_size:] 111 | if len(chain) < context_size: 112 | chain = [None] * (context_size - len(chain)) + chain 113 | chain_tags = [] 114 | for event in chain: 115 | if event is not None: 116 | verb, subj, obj, iobj, role = event.tuple(protagonist) 117 | tmp_mask_list = mask_list.difference(event.get_words()) 118 | sent_words, sent_tags = event.tagged_sent(role, mask_list=tmp_mask_list) 119 | vi = event["verb_position"][1] + 1 120 | sent_tags[vi] = "VBSelf" 121 | # Align 122 | # Align 123 | sent_tags = align_tags_to_ids(sent_words, sent_tags, tokenizer) 124 | else: 125 | sent_tags = ["O"] * 50 126 | chain_tags.append(sent_tags) 127 | choice_tags.append(chain_tags) 128 | sample_tags.append(choice_tags) 129 | tags.append(sample_tags) 130 | pbar.update() 131 | tag_path = os.path.join(work_dir, "dev_tags") 132 | with open(tag_path, "wb") as f: 133 | pickle.dump(tags, f) 134 | logger.info("Dev tags save to {}".format(tag_path)) 135 | 136 | 137 | def replace_verb_self(tags): 138 | result = [] 139 | verb_flag = False 140 | for t in tags: 141 | if t in ["[subj]", "[obj]", "[iobj]"]: 142 | result.append(t) 143 | verb_flag = not verb_flag 144 | elif verb_flag: 145 | result.append("VBSelf") 146 | return result 147 | 148 | 149 | if __name__ == "__main__": 150 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 151 | level=logging.INFO) 152 | data_dir = CONFIG.data_dir 153 | work_dir = CONFIG.work_dir 154 | # Load tags 155 | dev_tag_path = os.path.join(work_dir, "dev_tags") 156 | if not os.path.exists(dev_tag_path): 157 | tag_dev(data_dir, work_dir) 158 | with open(dev_tag_path, "rb") as f: 159 | raw_tags = pickle.load(f) 160 | masked_tags = [ 161 | # "CC", # Conjunctions 162 | # "JJ", "JJR", "JJS", "PDT", # Adjectives 163 | # "NN", "NNS", "NNP", "NNPS", # Nouns 164 | # "RB", "RBR", "RBS", "RP", # Adverbs 165 | # "VB", "VBD", "VBG", "VBN", "VBP", "VBZ", # Verbs(Other) 166 | # "VBSelf", # Verbs(Self) 167 | ] 168 | tags = [] 169 | verb_flag = False 170 | for sample in raw_tags: 171 | sample_tags = [] 172 | for choice in sample: 173 | choice_tags = [] 174 | for chain in choice: 175 | chain_tags = [] 176 | for event in chain: 177 | event_tags = [t not in masked_tags for t in event] 178 | chain_tags.append(event_tags) 179 | choice_tags.append(chain_tags) 180 | sample_tags.append(choice_tags) 181 | tags.append(sample_tags) 182 | # Load original dataset` 183 | dev_data_path = os.path.join(work_dir, "multi_dev") 184 | with open(dev_data_path, "rb") as f: 185 | dev_data = pickle.load(f) 186 | dev_set = MaskedDataSet(dev_data, tags) 187 | # Build model 188 | model = MultiChainSentModel(CONFIG.model_config) 189 | model.build_model() 190 | model.print_model_info() 191 | model.load_model() 192 | model.evaluate(dev_set) 193 | -------------------------------------------------------------------------------- /experiments/output_dev_results.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import numpy 6 | 7 | from mcpredictor.models.multi_chain_sent.model import MCSDataset, MultiChainSentModel 8 | from mcpredictor.utils.config import CONFIG 9 | 10 | if __name__ == "__main__": 11 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 12 | level=logging.INFO) 13 | data_dir = CONFIG.data_dir 14 | work_dir = CONFIG.work_dir 15 | # Load original dataset` 16 | dev_data_path = os.path.join(work_dir, "multi_dev") 17 | with open(dev_data_path, "rb") as f: 18 | dev_data = pickle.load(f) 19 | dev_set = MCSDataset(dev_data) 20 | # Build model 21 | model = MultiChainSentModel(CONFIG.model_config) 22 | model.build_model() 23 | model.print_model_info() 24 | model.load_model() 25 | prec, result = model.evaluate(dev_set) 26 | numpy.save("dev_result.npy", result) 27 | -------------------------------------------------------------------------------- /experiments/output_dev_set.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | from transformers import BertTokenizerFast 6 | 7 | from mcpredictor.models.multi_chain_sent.model import MCSDataset 8 | from mcpredictor.preprocess.word_dict import load_word_dict 9 | from mcpredictor.utils.config import CONFIG 10 | 11 | 12 | def idx2word(event, word_dict): 13 | return [word_dict[i.item()] for i in event] 14 | 15 | 16 | if __name__ == "__main__": 17 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 18 | level=logging.INFO) 19 | data_dir = CONFIG.data_dir 20 | work_dir = CONFIG.work_dir 21 | # Load original dataset` 22 | dev_data_path = os.path.join(work_dir, "multi_dev") 23 | with open(dev_data_path, "rb") as f: 24 | dev_data = pickle.load(f) 25 | dev_set = MCSDataset(dev_data) 26 | special_tokens = ["[subj]", "[obj]", "[iobj]"] 27 | tokenizer = BertTokenizerFast.from_pretrained("prajjwal1/bert-tiny", 28 | additional_special_tokens=special_tokens) 29 | # return map 30 | word_dict = load_word_dict(work_dir) 31 | rev_word_dict = dict([t[::-1] for t in word_dict.items()]) 32 | os.makedirs("dev_docs", exist_ok=True) 33 | for idx, (events, sents, _, _) in enumerate(dev_set): 34 | if idx == 100: 35 | break 36 | new_events = [ 37 | [ 38 | [ 39 | idx2word(event, rev_word_dict) 40 | for event in chain 41 | ] 42 | for chain in choice 43 | ] 44 | for choice in events 45 | ] 46 | new_sents = [ 47 | [ 48 | [ 49 | tokenizer.decode(event).replace("[PAD]", "").strip() 50 | for event in chain 51 | ] 52 | for chain in choice 53 | ] 54 | for choice in sents 55 | ] 56 | with open("dev_docs/{}.txt".format(idx), "w") as f: 57 | for choice_id in range(5): 58 | for chain_id in range(3): 59 | for event_id in range(9): 60 | f.write(" ".join(new_events[choice_id][chain_id][event_id])) 61 | f.write("\t") 62 | f.write("\n") 63 | for event_id in range(8): 64 | f.write(new_sents[choice_id][chain_id][event_id]) 65 | f.write("\t") 66 | f.write("\n") 67 | f.write("\n\n") 68 | -------------------------------------------------------------------------------- /experiments/preprocess.py: -------------------------------------------------------------------------------- 1 | """Preprocess.""" 2 | import logging 3 | import os 4 | 5 | from mcpredictor.preprocess.multi_chain import generate_multi_train, generate_multi_eval 6 | from mcpredictor.preprocess.negative_pool import generate_negative_pool 7 | from mcpredictor.preprocess.single_chain import generate_single_train, generate_single_eval 8 | from mcpredictor.preprocess.stop_event import count_stop_event 9 | from mcpredictor.preprocess.word_embedding import generate_word_embedding 10 | from mcpredictor.utils.config import CONFIG 11 | 12 | 13 | if __name__ == "__main__": 14 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 15 | level=logging.INFO) 16 | data_dir = CONFIG.data_dir 17 | work_dir = CONFIG.work_dir 18 | if not os.path.exists(work_dir): 19 | os.makedirs(work_dir) 20 | train_doc_dir = os.path.join(data_dir, "gigaword-nyt", "rich_docs", "training") 21 | dev_doc_dir = os.path.join(data_dir, "gigaword-nyt", "rich_docs", "dev") 22 | test_doc_dir = os.path.join(data_dir, "gigaword-nyt", "rich_docs", "test") 23 | tokenize_dir = os.path.join(data_dir, "gigaword-nyt", "tokenized") 24 | pos_dir = os.path.join(data_dir, "gigaword-nyt", "candc", "tags") 25 | count_stop_event(train_doc_dir, work_dir) 26 | generate_negative_pool(corp_dir=train_doc_dir, 27 | tokenized_dir=None, 28 | work_dir=work_dir, 29 | num_events=None, 30 | suffix="train", 31 | file_type="tar") 32 | generate_word_embedding(train_corp_dir=train_doc_dir, 33 | work_dir=work_dir) 34 | # Single chain 35 | generate_single_train(corp_dir=train_doc_dir, 36 | work_dir=work_dir, 37 | tokenized_dir=tokenize_dir, 38 | # pos_dir=pos_dir, 39 | overwrite=False) 40 | dev_corp_dir = os.path.join(data_dir, "gigaword-nyt", "eval", "multiple_choice", "dev_10k") 41 | generate_single_eval(corp_dir=dev_corp_dir, 42 | work_dir=work_dir, 43 | tokenized_dir=tokenize_dir, 44 | # pos_dir=pos_dir, 45 | mode="dev", 46 | overwrite=False) 47 | test_corp_dir = os.path.join(data_dir, "gigaword-nyt", "eval", "multiple_choice", "test_10k") 48 | generate_single_eval(corp_dir=test_corp_dir, 49 | work_dir=work_dir, 50 | tokenized_dir=tokenize_dir, 51 | # pos_dir=pos_dir, 52 | mode="test", 53 | overwrite=False) 54 | # Multi chain 55 | generate_multi_train(corp_dir=train_doc_dir, 56 | work_dir=work_dir, 57 | tokenized_dir=tokenize_dir, 58 | # pos_dir=pos_dir, 59 | overwrite=False) 60 | generate_multi_eval(corp_dir=dev_corp_dir, 61 | work_dir=work_dir, 62 | tokenized_dir=tokenize_dir, 63 | # pos_dir=pos_dir, 64 | mode="dev", 65 | overwrite=False) 66 | generate_multi_eval(corp_dir=test_corp_dir, 67 | work_dir=work_dir, 68 | tokenized_dir=tokenize_dir, 69 | # pos_dir=pos_dir, 70 | mode="test", 71 | overwrite=False) 72 | -------------------------------------------------------------------------------- /experiments/scan_events.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pprint import pprint 3 | 4 | from tqdm import tqdm 5 | 6 | from mcpredictor.utils.document import document_iterator 7 | 8 | if __name__ == "__main__": 9 | # corp_dir = "/home/jinxiaolong/bl/data/gandc16/gigaword-nyt/rich_docs/training" 10 | tokenize_dir = "/home/jinxiaolong/bl/data/gandc16/gigaword-nyt/tokenized" 11 | # for doc in document_iterator(corp_dir, tokenize_dir): 12 | # for entity, chain in doc.get_chains(): 13 | # print(entity) 14 | # for event in chain: 15 | # pprint(event.filter) 16 | # input() 17 | data_dir = "/home/jinxiaolong/bl/data/gandc16" 18 | dev_corp_dir = os.path.join(data_dir, "gigaword-nyt", "eval", "multiple_choice", "dev_10k") 19 | tot, hit = 0, 0 20 | with tqdm() as pbar: 21 | for doc in document_iterator(corp_dir=dev_corp_dir, 22 | tokenized_dir=tokenize_dir, 23 | file_type="txt", doc_type="eval"): 24 | context = doc.context 25 | answer = doc.choices[doc.target] 26 | final_event = context[-1] 27 | if final_event["verb_position"][0] == answer["verb_position"][0]: 28 | hit += 1 29 | tot += 1 30 | pbar.update() 31 | print("{} / {}".format(hit, tot)) 32 | -------------------------------------------------------------------------------- /experiments/scan_multichain.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | from mcpredictor.preprocess.stop_event import load_stop_event 6 | from mcpredictor.utils.document import document_iterator 7 | 8 | if __name__ == "__main__": 9 | tokenize_dir = "/home/jinxiaolong/bl/data/gandc16/gigaword-nyt/tokenized" 10 | data_dir = "/home/jinxiaolong/bl/data/gandc16" 11 | dev_corp_dir = os.path.join(data_dir, "gigaword-nyt", "eval", "multiple_choice", "dev_10k") 12 | stoplist = load_stop_event("/home/jinxiaolong/bl/data/sent_event_data") 13 | more = 0 14 | less = 0 15 | with tqdm() as pbar: 16 | for doc in document_iterator(corp_dir=dev_corp_dir, 17 | tokenized_dir=tokenize_dir, 18 | file_type="txt", doc_type="eval"): 19 | # Count single chain sents 20 | entity = doc.entity 21 | verb_position = doc.context[-1]["verb_position"] 22 | context = doc.get_chain_for_entity(entity, end_pos=verb_position, stoplist=stoplist) 23 | if len(context) > 8: 24 | context = context[-8:] 25 | sent_set = set() 26 | for e in context: 27 | sent_id = e["verb_position"][0] 28 | sent_set.add(sent_id) 29 | # Count multi chain sents 30 | choices = doc.choices 31 | # verb_position = context[-1]["verb_position"] 32 | target = doc.target 33 | tmp_sent_set = set() 34 | for role in ["subject", "object", "iobject"]: 35 | protagonist = choices[target][role] 36 | chain = doc.get_chain_for_entity(protagonist, end_pos=verb_position, stoplist=stoplist) 37 | if len(chain) > 8: 38 | chain = chain[-8:] 39 | for e in chain: 40 | sent_id = e["verb_position"][0] 41 | tmp_sent_set.add(sent_id) 42 | if len(tmp_sent_set) > len(sent_set): 43 | more += 1 44 | elif len(tmp_sent_set) < len(sent_set): 45 | less += 1 46 | pbar.update() 47 | print(more, less) 48 | -------------------------------------------------------------------------------- /experiments/scan_pos.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from tqdm import tqdm 4 | 5 | from mcpredictor.preprocess.stop_event import load_stop_event 6 | from mcpredictor.utils.document import document_iterator 7 | 8 | 9 | if __name__ == "__main__": 10 | corp_dir = "/home/jinxiaolong/bl/data/gandc16/gigaword-nyt/rich_docs/training" 11 | tokenize_dir = "/home/jinxiaolong/bl/data/gandc16/gigaword-nyt/tokenized" 12 | pos_dir = "/home/jinxiaolong/bl/data/gandc16/gigaword-nyt/candc/tags" 13 | data_dir = "/home/jinxiaolong/bl/data/gandc16" 14 | stoplist = load_stop_event("/home/jinxiaolong/bl/data/sent_event_data") 15 | for doc in document_iterator(corp_dir=corp_dir, 16 | tokenized_dir=tokenize_dir, 17 | pos_dir=pos_dir): 18 | for event in doc.events: 19 | print(event["sent"]) 20 | print(event["pos"]) 21 | input() 22 | -------------------------------------------------------------------------------- /experiments/scan_sents.py: -------------------------------------------------------------------------------- 1 | """The longest sentence is more than 500 words, thus we need to extract a span. 2 | Following the distribution of train set, there are: 3 | 4.4% sentences within 10 words, 4 | 23.9% sentences between 10~20 words, 5 | 33.3% sentences between 20~30 words, 6 | 23.6% sentences between 30~40 words, 7 | 10.9% sentences between 40~50 words, 8 | 0.4% sentences more than 50 words. 9 | Thus, we extract a span that contains 50 words (25 words before verb, 25 words after verb). 10 | 11 | Notice: after bert tokenizer, the length will be longer than 50. 12 | """ 13 | 14 | import os 15 | import pickle 16 | 17 | from tqdm import tqdm 18 | from transformers import AutoTokenizer 19 | 20 | from mcpredictor.utils.config import CONFIG 21 | 22 | if __name__ == "__main__": 23 | work_dir = CONFIG.work_dir 24 | train_dir = os.path.join(work_dir, "single_train") 25 | sent_len = 0 26 | dist = [0] * 11 27 | tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny") 28 | for fn in os.listdir(train_dir): 29 | with open(os.path.join(train_dir, fn), "rb") as f: 30 | s = pickle.load(f) 31 | for sample in tqdm(s): 32 | protagonist, context, choices, target = sample 33 | for e in context + choices: 34 | sent = e["sent"] 35 | sent = sent.split() 36 | old_len = len(sent) 37 | if old_len > 50: 38 | verb_position = e["verb_position"] 39 | token_idx = verb_position[1] 40 | sent = sent[max(0, token_idx-25):token_idx+25] 41 | sent = tokenizer(sent, is_split_into_words=True)["input_ids"] 42 | new_len = len(sent) 43 | sent_len = max(sent_len, new_len) 44 | if new_len // 10 < 10: 45 | dist[new_len // 10] += 1 46 | else: 47 | dist[10] += 1 48 | print(sent_len) 49 | print([i / sum(dist) for i in dist]) 50 | -------------------------------------------------------------------------------- /experiments/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mcpredictor.models.single_chain_sent.model import SingleChainSentModel 4 | from mcpredictor.models.multi_chain_sent.model import MultiChainSentModel 5 | from mcpredictor.utils.config import CONFIG 6 | 7 | 8 | if __name__ == "__main__": 9 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 10 | level=logging.INFO) 11 | if CONFIG.multi: 12 | model = MultiChainSentModel(CONFIG.model_config) 13 | else: 14 | model = SingleChainSentModel(CONFIG.model_config) 15 | model.build_model() 16 | model.print_model_info() 17 | # model.train() 18 | model.load_model() 19 | model.evaluate() 20 | -------------------------------------------------------------------------------- /experiments/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mcpredictor.models.single_chain_sent.model import SingleChainSentModel 4 | from mcpredictor.models.multi_chain_sent.model import MultiChainSentModel 5 | from mcpredictor.utils.config import CONFIG 6 | 7 | 8 | if __name__ == "__main__": 9 | logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 10 | level=logging.INFO) 11 | if CONFIG.multi: 12 | model = MultiChainSentModel(CONFIG.model_config) 13 | else: 14 | model = SingleChainSentModel(CONFIG.model_config) 15 | model.build_model() 16 | model.print_model_info() 17 | model.load_model() 18 | model.train() 19 | -------------------------------------------------------------------------------- /mcpredictor/models/base/attention.py: -------------------------------------------------------------------------------- 1 | """Attention functions.""" 2 | import logging 3 | import math 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class AdditiveAttention(nn.Module): 13 | """Additive attention function.""" 14 | 15 | def __init__(self, event_repr_size, directions=1): 16 | super(AdditiveAttention, self).__init__() 17 | self.ffn = nn.Linear(event_repr_size * directions * 2, 1) 18 | 19 | def forward(self, context, choice, mask=None): 20 | """Forward. 21 | 22 | :param context: size(batch_size, *, seq_len, event_repr_size) 23 | :param choice: size(batch_size, *, 1 or seq_len, event_repr_size) 24 | :param mask: size(batch_size, *, seq_len) 25 | :return: size(batch_size, *, seq_len) 26 | """ 27 | if choice.size(-2) == 1: 28 | choice = choice.expand(context.size()) 29 | __input = torch.cat([context, choice], dim=-1) 30 | weight = self.ffn(__input).sequeeze() 31 | if mask is not None: 32 | weight = weight.masked_fill(mask, -1e9) 33 | attn = torch.softmax(weight, dim=-1) 34 | return attn 35 | 36 | 37 | class DotAttention(nn.Module): 38 | """Dot attention function.""" 39 | 40 | def __init__(self): 41 | super(DotAttention, self).__init__() 42 | 43 | def forward(self, context, choice, mask=None): 44 | """Forward. 45 | 46 | :param context: size(batch_size, *, seq_len, event_repr_size) 47 | :param choice: size(batch_size, *, 1 or seq_len, event_repr_size) 48 | :param mask: size(batch_size, *, seq_len) 49 | :return: size(batch_size, *, seq_len) 50 | """ 51 | weight = (context * choice).sum(-1) 52 | if mask is not None: 53 | weight = weight.masked_fill(mask, -1e9) 54 | attn = torch.softmax(weight, dim=-1) 55 | return attn 56 | 57 | 58 | class ScaledDotAttention(nn.Module): 59 | """Scaled dot attention function.""" 60 | 61 | def __init__(self): 62 | super(ScaledDotAttention, self).__init__() 63 | 64 | def forward(self, context, choice, mask=None): 65 | """Forward. 66 | 67 | :param context: size(batch_size, *, seq_len, event_repr_size) 68 | :param choice: size(batch_size, *, 1 or seq_len, event_repr_size) 69 | :param mask: size(batch_size, *, seq_len) 70 | :return: size(batch_size, *, seq_len) 71 | """ 72 | event_repr_size = context.size(-1) 73 | weight = (context * choice).sum(-1) / math.sqrt(event_repr_size) 74 | if mask is not None: 75 | weight = weight.masked_fill(mask, -1e9) 76 | attn = torch.softmax(weight, dim=-1) 77 | return attn 78 | 79 | 80 | class AverageAttention(nn.Module): 81 | """Average attention function.""" 82 | 83 | def __init__(self): 84 | super(AverageAttention, self).__init__() 85 | 86 | def forward(self, context, choice, mask=None): 87 | """Forward. 88 | 89 | :param context: size(batch_size, *, seq_len, event_repr_size) 90 | :param choice: size(batch_size, *, 1 or seq_len, event_repr_size) 91 | :param mask: size(batch_size, *, seq_len) 92 | :return: size(batch_size, *, seq_len) 93 | """ 94 | weight = context.new_ones(context.size()[:-1], dtype=torch.float) 95 | if mask is not None: 96 | weight = weight.masked_fill(mask, -1e9) 97 | attn = torch.softmax(weight, dim=-1) 98 | return attn 99 | 100 | 101 | def build_attention(config): 102 | """Build attention function.""" 103 | layer_name = config["attention"] 104 | if layer_name not in ["average", "additive", "dot", "scaled-dot"]: 105 | logger.info("Unknown attention function '{}', " 106 | "default to use scaled-dot.".format(layer_name)) 107 | layer_name = "scaled-dot" 108 | if layer_name == "average": 109 | layer = AverageAttention() 110 | elif layer_name == "additive": 111 | event_repr_size = config["event_repr_size"] 112 | directions = config["directions"] 113 | layer = AdditiveAttention(event_repr_size, directions) 114 | elif layer_name == "dot": 115 | layer = DotAttention() 116 | else: # layer_name == "scaled-dot" 117 | layer = ScaledDotAttention() 118 | return layer 119 | 120 | 121 | __all__ = ["build_attention"] 122 | -------------------------------------------------------------------------------- /mcpredictor/models/base/constraint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/waltbai/MCPredictor/a3245516bd14127950697185316b0470d32e87c7/mcpredictor/models/base/constraint.py -------------------------------------------------------------------------------- /mcpredictor/models/base/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class EventEmbedding(nn.Module): 6 | """Word embedding""" 7 | def __init__(self, vocab_size, embedding_size, 8 | dropout=0., pretrain_embedding=None): 9 | super(EventEmbedding, self).__init__() 10 | # Define word embedding 11 | if pretrain_embedding is not None: 12 | # Fix embedding works, otherwise corrupts. 13 | # I don't know why, but it works. 14 | self.embedding = nn.Embedding.from_pretrained( 15 | torch.tensor(pretrain_embedding), 16 | padding_idx=0) 17 | else: 18 | self.embedding = nn.Embedding( 19 | vocab_size, embedding_size, padding_idx=0) 20 | # Define dropout 21 | self.dropout = nn.Dropout(dropout) 22 | 23 | def forward(self, events): 24 | """Forward""" 25 | output = self.dropout(self.embedding(events)) 26 | return output 27 | 28 | 29 | def build_embedding(config, pretrain_embedding=None): 30 | """Build embedding layer.""" 31 | vocab_size = config["vocab_size"] 32 | embedding_size = config["embedding_size"] 33 | dropout = config["dropout"] 34 | return EventEmbedding(vocab_size=vocab_size, 35 | embedding_size=embedding_size, 36 | dropout=dropout, 37 | pretrain_embedding=pretrain_embedding) 38 | 39 | 40 | __all__ = ["build_embedding"] 41 | -------------------------------------------------------------------------------- /mcpredictor/models/base/event_encoder.py: -------------------------------------------------------------------------------- 1 | """Event encoder that encodes verb and arguments.""" 2 | from torch import nn 3 | 4 | 5 | class EventFusionEncoder(nn.Module): 6 | """Event encoder layer.""" 7 | 8 | def __init__(self, embedding_size, event_repr_size, dropout=0.): 9 | """Event encoder layer.""" 10 | super(EventFusionEncoder, self).__init__() 11 | self.linear = nn.Linear(embedding_size * 4, event_repr_size) 12 | self.activation = nn.Tanh() 13 | self.dropout = nn.Dropout(dropout) 14 | 15 | def forward(self, events): 16 | """Forward. 17 | 18 | input_dim: (*, 4, embedding_size) 19 | 20 | output_dim: (*, event_repr_size) 21 | """ 22 | shape = events.size() 23 | input_shape = shape[:-2] + (shape[-1] * 4, ) 24 | projections = self.activation(self.linear(events.view(input_shape))) 25 | projections = self.dropout(projections) 26 | return projections 27 | 28 | 29 | def build_event_encoder(config): 30 | """Build event encoder layer.""" 31 | embedding_size = config["embedding_size"] 32 | event_repr_size = config["event_repr_size"] 33 | dropout = config["dropout"] 34 | return EventFusionEncoder(embedding_size=embedding_size, 35 | event_repr_size=event_repr_size, 36 | dropout=dropout) 37 | 38 | 39 | __all__ = ["build_event_encoder"] 40 | -------------------------------------------------------------------------------- /mcpredictor/models/base/score.py: -------------------------------------------------------------------------------- 1 | """Score functions.""" 2 | import logging 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class FusionScore(nn.Module): 12 | """Fusion score.""" 13 | 14 | def __init__(self, event_repr_size, directions=1): 15 | super(FusionScore, self).__init__() 16 | self.context_ffn = nn.Linear(event_repr_size * directions, 1) 17 | self.choice_ffn = nn.Linear(event_repr_size * directions, 1) 18 | 19 | def forward(self, context, choice): 20 | """Forward. 21 | 22 | :param context: size(batch_size, *, seq_len, event_repr_size) 23 | :param choice: size(batch_size, *, 1, event_repr_size) 24 | :return: size(batch_size, *, seq_len) 25 | """ 26 | context_score = self.context_ffn(context).squeeze(-1) 27 | choice_score = self.choice_ffn(choice).squeeze(-1) 28 | return context_score + choice_score 29 | 30 | 31 | class EuclideanScore(nn.Module): 32 | """Euclidean score.""" 33 | 34 | def __init__(self): 35 | super(EuclideanScore, self).__init__() 36 | 37 | def forward(self, context, choice): 38 | """Forward. 39 | 40 | :param context: size(batch_size, *, seq_len, event_repr_size) 41 | :param choice: size(batch_size, *, 1, event_repr_size) 42 | :return: size(batch_size, *, seq_len) 43 | """ 44 | return -torch.sqrt(torch.pow(context-choice, 2.).sum(-1)) 45 | 46 | 47 | class ManhattanScore(nn.Module): 48 | """Manhattan score.""" 49 | 50 | def __init__(self): 51 | super(ManhattanScore, self).__init__() 52 | 53 | def forward(self, context, choice): 54 | """Forward. 55 | 56 | :param context: size(batch_size, *, seq_len, event_repr_size) 57 | :param choice: size(batch_size, *, 1, event_repr_size) 58 | :return: size(batch_size, *, seq_len) 59 | """ 60 | return -torch.abs(context - choice).sum(-1) 61 | 62 | 63 | class CosineScore(nn.Module): 64 | """Cosine score.""" 65 | 66 | def __init__(self): 67 | super(CosineScore, self).__init__() 68 | 69 | def forward(self, context, choice): 70 | """Forward. 71 | 72 | :param context: size(batch_size, *, seq_len, event_repr_size) 73 | :param choice: size(batch_size, *, 1, event_repr_size) 74 | :return: size(batch_size, *, seq_len) 75 | """ 76 | inner_prod = (context * choice).sum(-1) 77 | context_length = torch.sqrt(torch.pow(context, 2.).sum(-1)) 78 | choice_length = torch.sqrt(torch.pow(choice, 2.).sum(-1)) 79 | score = inner_prod / context_length / choice_length 80 | return score 81 | 82 | 83 | class ConveScore(nn.Module): 84 | """ConvE score.""" 85 | 86 | def __init__(self, event_repr_size, 87 | num_out_channels=32, 88 | kernel_size=3, 89 | emb_2d_d1=8, 90 | emb_2d_d2=None, 91 | directions=1, dropout=0.1): 92 | super(ConveScore, self).__init__() 93 | # args 94 | event_repr_size = event_repr_size * directions 95 | if emb_2d_d2 is None: 96 | emb_2d_d2 = event_repr_size * directions // emb_2d_d1 97 | else: 98 | emb_2d_d2 = emb_2d_d2 * directions 99 | assert emb_2d_d1 * emb_2d_d2 == event_repr_size 100 | self.event_repr_size = event_repr_size 101 | self.emb_2d_d1 = emb_2d_d1 102 | self.emb_2d_d2 = emb_2d_d2 103 | self.num_out_channels = num_out_channels 104 | self.w_d = kernel_size 105 | h_out = 2 * self.emb_2d_d1 - self.w_d + 1 106 | w_out = self.emb_2d_d2 - self.w_d + 1 107 | self.feat_dim = self.num_out_channels * h_out * w_out 108 | # layers 109 | self.hidden_dropout = nn.Dropout(dropout) 110 | self.feature_dropout = nn.Dropout(dropout) 111 | self.conv1 = nn.Conv2d(1, self.num_out_channels, (self.w_d, self.w_d), 1, 0) 112 | self.bn0 = nn.BatchNorm2d(1) 113 | self.bn1 = nn.BatchNorm2d(self.num_out_channels) 114 | self.bn2 = nn.BatchNorm1d(self.event_repr_size) 115 | self.fc1 = nn.Linear(self.feat_dim, self.event_repr_size) 116 | self.fc2 = nn.Linear(self.event_repr_size, 1) 117 | self.relu = nn.ReLU() 118 | 119 | def forward(self, context, choice): 120 | """Forward. 121 | 122 | :param context: size(batch_size, *, seq_len, event_repr_size) 123 | :param choice: size(batch_size, *, 1, event_repr_size) 124 | :return: size(batch_size, *, seq_len) 125 | """ 126 | ori_context_size = context.size() 127 | context = context.contiguous().view(-1, self.event_repr_size) 128 | choice = choice.expand(ori_context_size).contiguous().view(-1, self.event_repr_size) 129 | # context: size(batch_size, event_repr_size) 130 | # choice: size(batch_size, event_repr_size) 131 | context = context.view(-1, 1, self.emb_2d_d1, self.emb_2d_d2) 132 | choice = choice.view(-1, 1, self.emb_2d_d1, self.emb_2d_d2) 133 | stacked_inputs = self.bn0(torch.cat([context, choice], 2)) 134 | x = self.conv1(stacked_inputs) 135 | x = self.relu(x) 136 | x = self.feature_dropout(x) 137 | x = x.view(-1, self.feat_dim) 138 | x = self.fc1(x) 139 | x = self.hidden_dropout(x) 140 | x = self.bn2(x) 141 | x = self.relu(x) 142 | x = self.fc2(x) 143 | # x: size(batch_size, 1) 144 | x = x.view(ori_context_size[:-1]) 145 | return x 146 | 147 | 148 | def build_score(config): 149 | """Build score function.""" 150 | layer_name = config["score"] 151 | if layer_name not in ["fusion", "manhattan", "euclidean", "cosine", "conve"]: 152 | logger.info("Unknown score function '{}', " 153 | "default to use euclidean.".format(layer_name)) 154 | layer_name = "euclidean" 155 | if layer_name == "fusion": 156 | # Get layer specific hyper-parameters 157 | event_repr_size = config["event_repr_size"] 158 | directions = config["directions"] 159 | # Initialize layer 160 | layer = FusionScore(event_repr_size, directions) 161 | elif layer_name == "manhattan": 162 | layer = ManhattanScore() 163 | elif layer_name == "euclidean": 164 | layer = EuclideanScore() 165 | elif layer_name == "conve": 166 | event_repr_size = config["event_repr_size"] 167 | directions = config["directions"] 168 | # num_out_channels = config["num_out_channels"] 169 | # kernel_size = config["kernel_size"] 170 | # emb_2d_d1 = config["emb_2d_d1"] 171 | # emb_2d_d2 = config["emb_2d_d2"] 172 | dropout = config["dropout"] 173 | layer = ConveScore(event_repr_size=event_repr_size) 174 | else: # layer_name == "cosine" 175 | layer = CosineScore() 176 | return layer 177 | 178 | 179 | __all__ = ["build_score"] 180 | -------------------------------------------------------------------------------- /mcpredictor/models/base/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | """Sentence encoder using bert.""" 2 | from torch import nn 3 | from transformers import AutoModel 4 | 5 | 6 | class BertEncoder(nn.Module): 7 | """Bert sentence encoder.""" 8 | 9 | def __init__(self, sent_repr_size=None, vocab_size=None): 10 | super(BertEncoder, self).__init__() 11 | self.bert = AutoModel.from_pretrained("prajjwal1/bert-tiny") 12 | if vocab_size is not None: 13 | self.bert.resize_token_embeddings(vocab_size) 14 | bert_repr_size = 128 15 | if sent_repr_size != bert_repr_size: 16 | self.linear = nn.Linear(self.bert_repr_size, sent_repr_size) 17 | else: 18 | self.linear = None 19 | 20 | def forward(self, sents, mask=None): 21 | """Forward. 22 | 23 | :param sents: size(*, sent_len) 24 | :param mask: size(*, sent_len) 25 | :return: size(*, sent_repr_size) 26 | """ 27 | original_size = sents.size() 28 | sent_len = original_size[-1] 29 | sents = sents.view(-1, sent_len) 30 | if mask is not None: 31 | mask = mask.view(-1, sent_len) 32 | result = self.bert(input_ids=sents, attention_mask=mask, return_dict=True) 33 | sent_embeddings = result.last_hidden_state 34 | sent_embeddings = sent_embeddings[:, 0, :] 35 | if self.linear is not None: 36 | sent_embeddings = self.linear(sent_embeddings) 37 | sent_embeddings = sent_embeddings.view(original_size[:-1] + (-1, )) 38 | return sent_embeddings 39 | 40 | 41 | def build_sent_encoder(config, vocab_size=None): 42 | """Build sentence encoder.""" 43 | event_repr_size = config["event_repr_size"] 44 | return BertEncoder(sent_repr_size=event_repr_size, vocab_size=vocab_size) 45 | 46 | 47 | __all__ = ["build_sent_encoder"] 48 | -------------------------------------------------------------------------------- /mcpredictor/models/base/sequence_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class PositionEncoder(nn.Module): 13 | """Position encoder used by transformer.""" 14 | 15 | def __init__(self, d_model, seq_len, dropout=0.1): 16 | super(PositionEncoder, self).__init__() 17 | self.dropout = nn.Dropout(dropout) 18 | pe = torch.zeros(seq_len, d_model) 19 | position = torch.arange(0., seq_len).unsqueeze(1) 20 | div_term = torch.exp(torch.arange(0., d_model, 2) * 21 | -(math.log(10000.) / d_model)) 22 | pe[:, 0::2] = torch.sin(position * div_term) 23 | pe[:, 1::2] = torch.cos(position * div_term) 24 | pe = pe.unsqueeze(0) 25 | self.register_buffer("pe", pe) 26 | 27 | def forward(self, x): 28 | """Add position embedding to input tensor. 29 | 30 | :param x: size(batch_size, seq_len, d_model) 31 | """ 32 | x = x + Variable(self.pe[:, :x.size(1)], 33 | requires_grad=False) 34 | return self.dropout(x) 35 | 36 | 37 | class Transformer(nn.Module): 38 | """Transformer layer with position encoding.""" 39 | 40 | def __init__(self, d_model, seq_len=9, nhead=16, 41 | dim_feedforward=256, num_layers=1, dropout=0.1): 42 | super(Transformer, self).__init__() 43 | self.position_encoder = PositionEncoder(d_model, seq_len, dropout) 44 | encoder_layer = nn.TransformerEncoderLayer( 45 | d_model=d_model, 46 | nhead=nhead, 47 | dim_feedforward=dim_feedforward, 48 | dropout=dropout) 49 | self.transformer = nn.TransformerEncoder( 50 | encoder_layer=encoder_layer, 51 | num_layers=num_layers) 52 | 53 | def forward(self, x, mask=None): 54 | """Forward transformer layer. 55 | 56 | :param x: size(batch_size, seq_len, d_model) 57 | :param mask: size(batch_size, seq_len) 58 | """ 59 | x = self.position_encoder(x) 60 | __input = x.transpose(0, 1) 61 | __output = self.transformer(__input, src_key_padding_mask=mask) 62 | return __output.transpose(0, 1) 63 | 64 | 65 | def build_sequence_model(config): 66 | """Build sequence modeling layer.""" 67 | # Get hyper-parameters 68 | event_repr_size = config["event_repr_size"] 69 | seq_len = config["seq_len"] + 1 70 | num_layers = config["num_layers"] 71 | dropout = config["dropout"] 72 | layer_name = config["sequence_model"] 73 | # Select and initialize layer 74 | if layer_name not in ["transformer", "lstm", "bilstm"]: 75 | logger.info("Unknown sequence model '{}', " 76 | "default to use transformer.".format(layer_name)) 77 | layer_name = "transformer" 78 | layer = None 79 | if layer_name == "transformer": 80 | # Get layer specific hyper-parameters 81 | num_heads = config["num_heads"] 82 | dim_feedforward = config["dim_feedforward"] 83 | config["directions"] = 1 84 | # Initialize transformer 85 | layer = Transformer(d_model=event_repr_size, 86 | seq_len=seq_len, 87 | nhead=num_heads, 88 | num_layers=num_layers, 89 | dim_feedforward=dim_feedforward, 90 | dropout=dropout) 91 | else: # layer_name in ["lstm", "bilstm"] 92 | if layer_name == "lstm": 93 | config["directions"] = 1 94 | else: 95 | config["directions"] = 2 96 | return layer 97 | 98 | 99 | __all__ = ["build_sequence_model"] 100 | -------------------------------------------------------------------------------- /mcpredictor/models/basic_model.py: -------------------------------------------------------------------------------- 1 | """Basic model.""" 2 | import json 3 | import os 4 | from abc import ABC, abstractmethod 5 | 6 | import torch 7 | 8 | from mcpredictor.utils.config import CONFIG 9 | 10 | 11 | class BasicModel(ABC): 12 | """Basic model.""" 13 | 14 | def __init__(self, config_path): 15 | with open(config_path, "r") as f: 16 | config = json.load(f) 17 | self._config = config 18 | self._model = None 19 | self._model_name = config["model_name"] 20 | self._data_dir = CONFIG.data_dir 21 | self._work_dir = CONFIG.work_dir 22 | self._device = CONFIG.device 23 | self._logger = None 24 | 25 | @abstractmethod 26 | def train(self, train_data=None, dev_data=None): 27 | """Train.""" 28 | 29 | @abstractmethod 30 | def evaluate(self, eval_data=None, verbose=True): 31 | """Evaluate.""" 32 | 33 | @abstractmethod 34 | def build_model(self): 35 | """Build model.""" 36 | 37 | def print_model_info(self): 38 | """Print model information in logger.""" 39 | model_info = "\n".join(["{0:<15} = {1:}".format(*t) for t in self._config.items()]) 40 | self._logger.info("\n===== Model hyper-parameters =====\n{}".format(model_info)) 41 | self._logger.info("\n===== Model architecture =====\n{}".format(repr(self._model))) 42 | 43 | def load_model(self, suffix="best"): 44 | """Load model.""" 45 | model_path = os.path.join( 46 | self._work_dir, "model", 47 | "{}.{}.pt".format(self._model_name, suffix)) 48 | if os.path.exists(model_path): 49 | self._logger.info("Load model from {}".format(model_path)) 50 | self._model.load_state_dict( 51 | torch.load(model_path, map_location=self._device), strict=False) 52 | else: 53 | self._logger.info("Fail to load model from {}".format(model_path)) 54 | 55 | def save_model(self, suffix="best", verbose=False): 56 | """Save model.""" 57 | model_dir = os.path.join(self._work_dir, "model") 58 | if not os.path.exists(model_dir): 59 | os.makedirs(model_dir) 60 | model_path = os.path.join( 61 | model_dir, "{}.{}.pt".format(self._model_name, suffix)) 62 | if verbose: 63 | self._logger.info("Save model to {}".format(model_path)) 64 | torch.save(self._model.state_dict(), model_path) 65 | -------------------------------------------------------------------------------- /mcpredictor/models/multi_chain_sent/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | 6 | import numpy 7 | import torch 8 | from torch.nn import CrossEntropyLoss 9 | from torch.optim import Adam 10 | from torch.utils import data 11 | from tqdm import tqdm 12 | from transformers import AdamW 13 | 14 | from mcpredictor.models.basic_model import BasicModel 15 | from mcpredictor.models.multi_chain_sent.network import MCPredictorSent 16 | 17 | 18 | class MCSDataset(data.Dataset): 19 | """Single Chain Dataset for mention.""" 20 | 21 | def __init__(self, __data): 22 | super(MCSDataset, self).__init__() 23 | self.data = __data 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, item): 29 | events, sents, masks, target = self.data[item] 30 | events = torch.tensor(events) 31 | sents = torch.tensor(sents) 32 | masks = torch.tensor(masks) 33 | target = torch.tensor(target) 34 | return events, sents, masks, target 35 | 36 | 37 | class MultiChainSentModel(BasicModel): 38 | """Multi chain model combines event summary information.""" 39 | 40 | def __init__(self, config_path): 41 | super(MultiChainSentModel, self).__init__(config_path) 42 | self._logger = logging.getLogger(__name__) 43 | 44 | def build_model(self): 45 | """Build model.""" 46 | work_dir = self._work_dir 47 | device = self._device 48 | pretrain_embedding = numpy.load(os.path.join(work_dir, "pretrain_embedding.npy")) 49 | self._model = MCPredictorSent(self._config, pretrain_embedding).to(device) 50 | 51 | def train(self, train_data=None, dev_data=None): 52 | """Train.""" 53 | # Get hyper-parameters 54 | work_dir = self._work_dir 55 | device = self._device 56 | npoch = self._config["npoch"] 57 | batch_size = self._config["batch_size"] 58 | lr = self._config["lr"] 59 | interval = self._config["interval"] 60 | use_sent = self._config["use_sent"] 61 | # Use default datasets 62 | dev_path = os.path.join(work_dir, "multi_dev") 63 | with open(dev_path, "rb") as f: 64 | dev_set = MCSDataset(pickle.load(f)) 65 | # Model 66 | model = self._model.to(device) 67 | # model.sent_encoder.requires_grad_(False) 68 | # model.sent_sequence_model.requires_grad_(False) 69 | # Optimizer and loss function 70 | param_group = [ 71 | { 72 | "params": [p for n, p in model.named_parameters() if "bert" in n], 73 | "lr": 1e-5, 74 | }, 75 | { 76 | "params": [p for n, p in model.named_parameters() if "bert" not in n] 77 | } 78 | ] 79 | # optimizer = AdamW(param_group, lr=lr, weight_decay=1e-6) 80 | optimizer = Adam(param_group, lr=lr, weight_decay=1e-6) 81 | # Train 82 | tmp_dir = os.path.join(work_dir, "multi_train") 83 | # with open(os.path.join(tmp_dir, "train.0"), "rb") as f: 84 | # train_set = MCSDataset(pickle.load(f)) 85 | best_performance = 0. 86 | for epoch in range(1, npoch + 1): 87 | self._logger.info("===== Epoch {} =====".format(epoch)) 88 | batch_loss = [] 89 | batch_event_loss = [] 90 | fn_list = os.listdir(tmp_dir) 91 | random.shuffle(fn_list) 92 | for fn in fn_list: 93 | # if True: 94 | self._logger.info("Processing slice {} ...".format(fn)) 95 | train_fp = os.path.join(tmp_dir, fn) 96 | with open(train_fp, "rb") as f: 97 | train_set = MCSDataset(pickle.load(f)) 98 | train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8) 99 | with tqdm(total=len(train_set)) as pbar: 100 | for iteration, (events, sents, masks, target) in enumerate(train_loader): 101 | events = events.to(device) 102 | sents = sents.to(device) 103 | masks = masks.to(device) 104 | target = target.to(device) 105 | model.train() 106 | if use_sent: 107 | event_loss = model(events=events, 108 | sents=sents, 109 | sent_mask=masks, 110 | target=target) 111 | else: 112 | event_loss = model(events=events, target=target) 113 | loss = event_loss 114 | # Get loss 115 | batch_loss.append(loss.item()) 116 | batch_event_loss.append(event_loss.item()) 117 | # Update gradient 118 | optimizer.zero_grad() 119 | loss.backward() 120 | optimizer.step() 121 | # Evaluate on dev set 122 | if (iteration + 1) % interval == 0: 123 | result = self.evaluate(eval_set=dev_set, verbose=False) 124 | if result > best_performance: 125 | best_performance = result 126 | self.save_model("best") 127 | # Update progress bar 128 | # pbar.set_description("Loss: {:.4f}".format(loss.item())) 129 | pbar.set_description("event loss: {:.4f}, best_performance: {:.2%}".format( 130 | sum(batch_event_loss) / len(batch_event_loss), 131 | best_performance 132 | )) 133 | pbar.update(len(events)) 134 | result = self.evaluate(eval_set=dev_set, verbose=False) 135 | if result > best_performance: 136 | best_performance = result 137 | self.save_model("best") 138 | self._logger.info("Average loss: {:.4f}".format( 139 | sum(batch_loss) / len(batch_loss))) 140 | self._logger.info("Best evaluation accuracy: {:.2%}".format(best_performance)) 141 | 142 | def evaluate(self, eval_set=None, verbose=True, return_result=False): 143 | """Evaluate.""" 144 | # Get hyper-parameters 145 | work_dir = self._work_dir 146 | device = self._device 147 | batch_size = self._config["batch_size"] 148 | use_sent = self._config["use_sent"] 149 | # Use default test data 150 | if eval_set is None: 151 | eval_path = os.path.join(work_dir, "multi_test") 152 | with open(eval_path, "rb") as f: 153 | eval_set = MCSDataset(pickle.load(f)) 154 | eval_loader = data.DataLoader(eval_set, batch_size, num_workers=8) 155 | # Evaluate 156 | model = self._model 157 | model.eval() 158 | tot, acc = 0, 0 159 | result = [] 160 | with torch.no_grad(): 161 | for events, sents, masks, target in eval_loader: 162 | events = events.to(device) 163 | sents = sents.to(device) 164 | masks = masks.to(device) 165 | target = target.to(device) 166 | if use_sent: 167 | pred = model(events=events, 168 | sents=sents, 169 | sent_mask=masks) 170 | else: 171 | pred = model(events=events) 172 | result.append(pred.argmax(1).cpu()) 173 | acc += pred.argmax(1).eq(target).sum().item() 174 | tot += len(events) 175 | result = torch.cat(result, dim=0).numpy() 176 | accuracy = acc / tot 177 | if verbose: 178 | self._logger.info("Evaluation accuracy: {:.2%}".format(accuracy)) 179 | if return_result: 180 | return accuracy, result 181 | else: 182 | return accuracy 183 | -------------------------------------------------------------------------------- /mcpredictor/models/multi_chain_sent/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mcpredictor.models.base.attention import build_attention 5 | from mcpredictor.models.base.embedding import build_embedding 6 | from mcpredictor.models.base.event_encoder import build_event_encoder 7 | from mcpredictor.models.base.score import build_score 8 | from mcpredictor.models.base.sentence_encoder import build_sent_encoder 9 | from mcpredictor.models.base.sequence_model import build_sequence_model 10 | 11 | 12 | class MCPredictorSent(nn.Module): 13 | """This model contains two different parts: event part and sentence part. 14 | 15 | Event part encodes events and predict next event, 16 | while sentence part encodes sentences and predict next sentence. 17 | A constraint is applied between events and sentences to force their representation to be similar. 18 | """ 19 | 20 | def __init__(self, config, pretrain_embedding=None, tokenizer=None): 21 | super(MCPredictorSent, self).__init__() 22 | self.config = config 23 | # Event part 24 | self.embedding = build_embedding(config, pretrain_embedding) 25 | self.event_encoder = build_event_encoder(config) 26 | self.event_sequence_model = build_sequence_model(config) 27 | self.event_score = build_score(config) 28 | self.event_attention = build_attention(config) 29 | # Sentence part 30 | vocab_size = len(tokenizer) if tokenizer is not None else 30525 31 | self.sent_encoder = build_sent_encoder(config, vocab_size=vocab_size) 32 | # Criterion 33 | self.criterion = nn.CrossEntropyLoss() 34 | 35 | def forward(self, events, sents=None, sent_mask=None, target=None): 36 | """Forward function. 37 | 38 | If "sents" and "target" is not None, return the loss, 39 | otherwise, only return the scores of 5 choices. 40 | 41 | :param events: size(batch_size, choice_num, chain_num, seq_len + 1, 4) 42 | :param sents: size(batch_size, choice_num, chain_num, seq_len, sent_len) 43 | :param sent_mask: size(batch_size, choice_num, chain_num, seq_len, sent_len) 44 | :param target: 45 | """ 46 | batch_size = events.size(0) 47 | choice_num = events.size(1) 48 | chain_num = events.size(2) 49 | seq_len = events.size(3) - 1 50 | # Event mask 51 | event_mask = events.sum(-1)[:, :, :, :-1].to(torch.bool) 52 | # Event encoding 53 | # event_repr: size(batch_size, choice_num, chain_num, seq_len + 1, event_repr_size) 54 | event_repr = self.event_encoding(events) 55 | # Sentence encoding 56 | if sents is not None: 57 | sent_repr = self.sent_encoding(sents, sent_mask) 58 | event_repr_size = sent_repr.size(-1) 59 | zeros = sent_repr.new_zeros(batch_size, choice_num, chain_num, 1, event_repr_size) 60 | sent_repr = torch.cat([sent_repr, zeros], dim=-2) 61 | else: 62 | sent_repr = None 63 | # Event sequence modeling 64 | # updated_event_repr: size(batch_size, choice_num, chain_num, seq_len+1, event_repr_size) 65 | if sents is not None: 66 | event_repr = event_repr + sent_repr 67 | event_repr = event_repr.view(batch_size * choice_num * chain_num, seq_len + 1, -1) 68 | updated_event_repr = self.event_sequence_model(event_repr) 69 | updated_event_repr = updated_event_repr.view(batch_size, choice_num, chain_num, seq_len + 1, -1) 70 | # Event scoring and attention 71 | # event_context: size(batch_size, choice_num, chain_num, seq_len, event_repr_size) 72 | # event_choice: size(batch_size, choice_num, chain_num, 1, event_repr_size) 73 | event_context = updated_event_repr[:, :, :, :-1, :] 74 | event_choice = updated_event_repr[:, :, :, -1:, :] 75 | # Event loss 76 | # event_score: size(batch_size, choice_num) 77 | event_score = self.event_score(event_context, event_choice) 78 | event_attention = self.event_attention(event_context, event_choice, event_mask) 79 | event_score = (event_score * event_attention).sum(-1).sum(-1) 80 | if target is not None: 81 | return self.criterion(event_score, target) 82 | else: 83 | return event_score 84 | 85 | def event_encoding(self, events): 86 | """Encode events.""" 87 | # Embedding 88 | # event_repr: size(batch_size, choice_num, chain_num, seq_len + 1, 4, embedding_size) 89 | event_repr = self.embedding(events) 90 | # Encoding 91 | # event_repr: size(batch_size, choice_num, chain_num, seq_len + 1, event_repr_size) 92 | event_repr = self.event_encoder(event_repr) 93 | return event_repr 94 | 95 | def sent_encoding(self, sents, sent_mask): 96 | """Encode sentences.""" 97 | # size(batch_size, choice_num, chain_num, seq_len, event_repr_size) 98 | return self.sent_encoder(sents, sent_mask) 99 | -------------------------------------------------------------------------------- /mcpredictor/models/single_chain/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import numpy 6 | import torch 7 | from torch.nn import CrossEntropyLoss 8 | from torch.utils import data 9 | from tqdm import tqdm 10 | from transformers import AdamW 11 | 12 | from mcpredictor.models.basic_model import BasicModel 13 | from mcpredictor.models.single_chain.network import SCPredictorSent 14 | 15 | 16 | class SCSDataset(data.Dataset): 17 | """Single Chain Dataset for mention.""" 18 | 19 | def __init__(self, __data): 20 | super(SCSDataset, self).__init__() 21 | self.data = __data 22 | 23 | def __len__(self): 24 | return len(self.data) 25 | 26 | def __getitem__(self, item): 27 | events, sents, masks, target = self.data[item] 28 | events = torch.tensor(events) 29 | sents = torch.tensor(sents) 30 | masks = torch.tensor(masks) 31 | target = torch.tensor(target) 32 | return events, sents, masks, target 33 | 34 | 35 | class SingleChainSentModel(BasicModel): 36 | """Single chain model combines next sentence prediction.""" 37 | 38 | def __init__(self, config_path): 39 | super(SingleChainSentModel, self).__init__(config_path) 40 | self._logger = logging.getLogger(__name__) 41 | 42 | def build_model(self): 43 | """Build model.""" 44 | work_dir = self._work_dir 45 | device = self._device 46 | pretrain_embedding = numpy.load(os.path.join(work_dir, "pretrain_embedding.npy")) 47 | self._model = SCPredictorSent(self._config, pretrain_embedding).to(device) 48 | 49 | def train(self, train_data=None, dev_data=None): 50 | """Train.""" 51 | # Get hyper-parameters 52 | work_dir = self._work_dir 53 | device = self._device 54 | npoch = self._config["npoch"] 55 | batch_size = self._config["batch_size"] 56 | lr = self._config["lr"] 57 | interval = self._config["interval"] 58 | # Use default datasets 59 | dev_path = os.path.join(work_dir, "single_dev") 60 | with open(dev_path, "rb") as f: 61 | dev_set = SCSDataset(pickle.load(f)) 62 | # Model 63 | model = self._model 64 | model.sent_encoder.requires_grad_(False) 65 | model.sent_sequence_model.requires_grad_(False) 66 | # Optimizer and loss function 67 | optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-6) 68 | lambda_sent = 1. 69 | lambda_dist = 0.1 70 | # Train 71 | tmp_dir = os.path.join(work_dir, "single_train") 72 | with open(os.path.join(tmp_dir, "train.0"), "rb") as f: 73 | train_set = SCSDataset(pickle.load(f)) 74 | best_performance = 0. 75 | for epoch in range(1, npoch + 1): 76 | self._logger.info("===== Epoch {} =====".format(epoch)) 77 | batch_loss = [] 78 | # for fn in sorted(os.listdir(tmp_dir)): 79 | if True: 80 | # self._logger.info("Processing slice {} ...".format(fn)) 81 | # train_fp = os.path.join(tmp_dir, fn) 82 | # with open(train_fp, "rb") as f: 83 | # train_set = SCSDataset(pickle.load(f)) 84 | train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8) 85 | with tqdm(total=len(train_set)) as pbar: 86 | for iteration, (events, sents, masks, target) in enumerate(train_loader): 87 | events = events.to(device) 88 | target = target.to(device) 89 | model.train() 90 | event_loss = model.forward_event(events, target) 91 | loss = event_loss 92 | # Get loss 93 | batch_loss.append(loss.item()) 94 | # Update gradient 95 | optimizer.zero_grad() 96 | loss.backward() 97 | optimizer.step() 98 | # Evaluate on dev set 99 | if (iteration + 1) % interval == 0: 100 | result = self.evaluate(eval_set=dev_set, verbose=False) 101 | if result > best_performance: 102 | best_performance = result 103 | self.save_model("best") 104 | # Update progress bar 105 | # pbar.set_description("Loss: {:.4f}".format(loss.item())) 106 | pbar.set_description("event loss: {:.4f}".format(event_loss.item())) 107 | pbar.update(len(events)) 108 | result = self.evaluate(eval_set=dev_set, verbose=False) 109 | if result > best_performance: 110 | best_performance = result 111 | self.save_model("best") 112 | self._logger.info("Average loss: {:.4f}".format( 113 | sum(batch_loss) / len(batch_loss))) 114 | self._logger.info("Best evaluation accuracy: {:.2%}".format(best_performance)) 115 | 116 | def evaluate(self, eval_set=None, verbose=True): 117 | """Evaluate.""" 118 | # Get hyper-parameters 119 | work_dir = self._work_dir 120 | device = self._device 121 | batch_size = self._config["batch_size"] 122 | # Use default test data 123 | if eval_set is None: 124 | eval_path = os.path.join(work_dir, "single_test") 125 | with open(eval_path, "rb") as f: 126 | eval_set = SCSDataset(pickle.load(f)) 127 | eval_loader = data.DataLoader(eval_set, batch_size, num_workers=8) 128 | # Evaluate 129 | model = self._model 130 | model.eval() 131 | tot, acc = 0, 0 132 | with torch.no_grad(): 133 | for events, sents, masks, target in eval_loader: 134 | events = events.to(device) 135 | target = target.to(device) 136 | pred = model.forward_event(events) 137 | acc += pred.argmax(1).eq(target).sum().item() 138 | tot += len(events) 139 | accuracy = acc / tot 140 | if verbose: 141 | self._logger.info("Evaluation accuracy: {:.2%}".format(accuracy)) 142 | return accuracy 143 | -------------------------------------------------------------------------------- /mcpredictor/models/single_chain/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mcpredictor.models.base.attention import build_attention 5 | from mcpredictor.models.base.embedding import build_embedding 6 | from mcpredictor.models.base.event_encoder import build_event_encoder 7 | from mcpredictor.models.base.score import build_score 8 | from mcpredictor.models.base.sentence_encoder import build_sent_encoder 9 | from mcpredictor.models.base.sequence_model import build_sequence_model 10 | 11 | 12 | class SCPredictorSent(nn.Module): 13 | """This model contains two different parts: event part and sentence part. 14 | 15 | Event part encodes events and predict next event, 16 | while sentence part encodes sentences and predict next sentence. 17 | A constraint is applied between events and sentences to force their representation to be similar. 18 | """ 19 | 20 | def __init__(self, config, pretrain_embedding=None, tokenizer=None): 21 | super(SCPredictorSent, self).__init__() 22 | self.config = config 23 | # Event part 24 | self.embedding = build_embedding(config, pretrain_embedding) 25 | self.event_encoder = build_event_encoder(config) 26 | self.event_sequence_model = build_sequence_model(config) 27 | self.event_score = build_score(config) 28 | self.event_attention = build_attention(config) 29 | # Sentence part 30 | self.sent_encoder = build_sent_encoder(config, vocab_size=30525) 31 | self.sent_sequence_model = build_sequence_model(config) 32 | # self.sent_score = nn.Linear(128, 1) 33 | self.sent_score = build_score(config) 34 | self.sent_attention = build_attention(config) 35 | # Criterion 36 | self.criterion = nn.CrossEntropyLoss() 37 | 38 | def forward(self, events, sents=None, sent_mask=None, target=None): 39 | """Forward function. 40 | 41 | If "sents" and "target" is not None, return the sum of three losses, 42 | otherwise, only return the scores of 5 choices. 43 | 44 | :param events: size(batch_size, choice_num, seq_len + 1, 4) 45 | :param sents: size(batch_size, choice_num, seq_len, sent_len) 46 | :param sent_mask: size(batch_size, choice_num, seq_len, sent_len) 47 | :param target: 48 | """ 49 | batch_size = events.size(0) 50 | choice_num = events.size(1) 51 | seq_len = events.size(2) - 1 52 | # Event encoding 53 | # event_repr: size(batch_size, choice_num, seq_len + 1, event_repr_size) 54 | event_repr = self.event_encoding(events) 55 | # Event sequence modeling 56 | # updated_event_repr: size(batch_size, choice_num, seq_len+1, event_repr_size) 57 | event_repr = event_repr.view(batch_size * choice_num, seq_len + 1, -1) 58 | # Sentence encoding 59 | sent_repr = self.sent_encoding(sents, sent_mask) 60 | sent_repr = sent_repr.view(batch_size * choice_num, seq_len + 1, -1) 61 | # Sentence sequence modeling 62 | event_repr = torch.cat([event_repr, sent_repr], dim=-1) 63 | updated_event_repr = self.event_sequence_model(event_repr) 64 | updated_event_repr = updated_event_repr.view(batch_size, choice_num, seq_len + 1, -1) 65 | # Event scoring and attention 66 | # event_context: size(batch_size, choice_num, seq_len, event_repr_size) 67 | # event_choice: size(batch_size, choice_num, seq_len, event_repr_size) 68 | event_context = updated_event_repr[:, :, :-1, :] 69 | event_choice = updated_event_repr[:, :, -1:, :] 70 | # Event loss 71 | # event_score: size(batch_size, choice_num) 72 | event_score = self.event_score(event_context, event_choice) 73 | event_attention = self.event_attention(event_context, event_choice) 74 | event_score = (event_score * event_attention).sum(-1) 75 | if target is not None: 76 | event_loss = self.criterion(event_score, target) 77 | else: 78 | event_loss = event_score 79 | return event_loss 80 | 81 | def forward_event(self, events, target=None, return_hidden=False): 82 | """Only forward event part.""" 83 | batch_size = events.size(0) 84 | choice_num = events.size(1) 85 | seq_len = events.size(2) - 1 86 | event_repr = self.event_encoding(events) 87 | event_repr = event_repr.view(batch_size * choice_num, seq_len + 1, -1) 88 | updated_event_repr = self.event_sequence_model(event_repr) 89 | updated_event_repr = updated_event_repr.view(batch_size, choice_num, seq_len + 1, -1) 90 | event_context = updated_event_repr[:, :, :-1, :] 91 | event_choice = updated_event_repr[:, :, -1:, :] 92 | event_score = self.event_score(event_context, event_choice) 93 | event_attention = self.event_attention(event_context, event_choice) 94 | event_score = (event_score * event_attention).sum(-1) 95 | if target is None: 96 | event_loss = event_score 97 | else: 98 | event_loss = self.criterion(event_score, target) 99 | if return_hidden: 100 | return event_loss, updated_event_repr 101 | else: 102 | return event_loss 103 | 104 | def forward_sent(self, sents, sent_mask, target=None, return_hidden=False): 105 | """Only forward sentence part.""" 106 | batch_size = sents.size(0) 107 | choice_num = sents.size(1) 108 | seq_len = sents.size(2) - 1 109 | sent_repr = self.sent_encoding(sents, sent_mask) 110 | sent_repr = sent_repr.view(batch_size * choice_num, seq_len + 1, -1) 111 | updated_sent_repr = self.sent_sequence_model(sent_repr) 112 | updated_sent_repr = updated_sent_repr.view(batch_size, choice_num, seq_len + 1, -1) 113 | sent_context = updated_sent_repr[:, :, :-1, :] 114 | sent_choice = updated_sent_repr[:, :, -1:, :] 115 | sent_score = self.sent_score(sent_context, sent_choice) 116 | sent_attention = self.sent_attention(sent_context, sent_choice) 117 | sent_score = (sent_score * sent_attention).sum(-1) 118 | if target is None: 119 | sent_loss = sent_score 120 | else: 121 | sent_loss = self.criterion(sent_score, target) 122 | if return_hidden: 123 | return sent_loss, updated_sent_repr 124 | else: 125 | return sent_loss 126 | 127 | def forward_all(self, events, sents=None, sent_mask=None, target=None): 128 | """Forward function. 129 | 130 | If "sents" and "target" is not None, return the sum of three losses, 131 | otherwise, only return the scores of 5 choices. 132 | 133 | :param events: size(batch_size, choice_num, seq_len + 1, 5) 134 | :param sents: size(batch_size, choice_num, seq_len, sent_len) 135 | :param sent_mask: size(batch_size, choice_num, seq_len, sent_len) 136 | :param target: 137 | """ 138 | batch_size = events.size(0) 139 | choice_num = events.size(1) 140 | seq_len = events.size(2) - 1 141 | # Event encoding 142 | # event_repr: size(batch_size, choice_num, seq_len + 1, event_repr_size) 143 | event_repr = self.event_encoding(events) 144 | # Event sequence modeling 145 | # updated_event_repr: size(batch_size, choice_num, seq_len+1, event_repr_size) 146 | event_repr = event_repr.view(batch_size * choice_num, seq_len + 1, -1) 147 | updated_event_repr = self.event_sequence_model(event_repr) 148 | updated_event_repr = updated_event_repr.view(batch_size, choice_num, seq_len + 1, -1) 149 | # Sentence encoding 150 | if sents is not None: 151 | sent_repr = self.sent_encoding(sents, sent_mask) 152 | else: 153 | sent_repr = None 154 | # Sentence sequence modeling 155 | if sent_repr is not None: 156 | sent_repr = sent_repr.view(batch_size * choice_num, seq_len + 1, -1) 157 | updated_sent_repr = self.sent_sequence_model(sent_repr) 158 | updated_sent_repr = updated_sent_repr.view(batch_size, choice_num, seq_len + 1, -1) 159 | else: 160 | updated_sent_repr = None 161 | # Event scoring and attention 162 | # event_context: size(batch_size, choice_num, seq_len, event_repr_size) 163 | # event_choice: size(batch_size, choice_num, seq_len, event_repr_size) 164 | event_context = updated_event_repr[:, :, :-1, :] 165 | event_choice = updated_event_repr[:, :, -1:, :] 166 | # Event loss 167 | # event_score: size(batch_size, choice_num) 168 | event_score = self.event_score(event_context, event_choice) 169 | event_attention = self.event_attention(event_context, event_choice) 170 | event_score = (event_score * event_attention).sum(-1) 171 | if target is not None: 172 | event_loss = self.criterion(event_score, target) 173 | else: 174 | event_loss = event_score 175 | # Sentence scoring, only use last hidden state 176 | if updated_sent_repr is not None: 177 | sent_context = updated_sent_repr[:, :, :-1, :] 178 | sent_choice = updated_sent_repr[:, :, -1:, :] 179 | sent_score = self.sent_score(sent_context, sent_choice) 180 | sent_attention = self.sent_attention(sent_context, sent_choice) 181 | sent_score = (sent_score * sent_attention).sum(-1) 182 | else: 183 | sent_score = None 184 | # Sentence loss 185 | if target is not None and sent_score is not None: 186 | sent_loss = self.criterion(sent_score, target) 187 | else: 188 | sent_loss = None 189 | # Event-Sentence constraint 190 | if updated_sent_repr is not None: 191 | # event_repr: size(batch_size, choice_num, seq_len, event_repr_size) 192 | # sent_repr: size(batch_size, choice_num, seq_len, event_repr_size) 193 | dist = torch.sqrt(torch.pow(updated_event_repr - updated_sent_repr, 2.).sum(-1)).mean() 194 | else: 195 | dist = None 196 | # Return 197 | if sent_loss is None: 198 | if event_loss is None: 199 | return event_score 200 | else: 201 | return event_loss 202 | else: 203 | return event_loss, sent_loss, dist 204 | 205 | def event_encoding(self, events): 206 | """Encode events.""" 207 | # Embedding 208 | # event_repr: size(batch_size, choice_num, seq_len + 1, 4, embedding_size) 209 | event_repr = self.embedding(events) 210 | # Encoding 211 | # event_repr: size(batch_size, choice_num, seq_len + 1, event_repr_size) 212 | event_repr = self.event_encoder(event_repr) 213 | return event_repr 214 | 215 | def sent_encoding(self, sents, sent_mask): 216 | """Encode sentences.""" 217 | return self.sent_encoder(sents, sent_mask) 218 | -------------------------------------------------------------------------------- /mcpredictor/models/single_chain_sent/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | 6 | import numpy 7 | import torch 8 | from torch.nn import CrossEntropyLoss 9 | from torch.optim import Adam 10 | from torch.utils import data 11 | from tqdm import tqdm 12 | from transformers import AdamW 13 | 14 | from mcpredictor.models.basic_model import BasicModel 15 | from mcpredictor.models.single_chain_sent.network import SCPredictorSent 16 | 17 | 18 | class SCSDataset(data.Dataset): 19 | """Single Chain Dataset for mention.""" 20 | 21 | def __init__(self, __data): 22 | super(SCSDataset, self).__init__() 23 | self.data = __data 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, item): 29 | events, sents, masks, target = self.data[item] 30 | events = torch.tensor(events) 31 | sents = torch.tensor(sents) 32 | masks = torch.tensor(masks) 33 | target = torch.tensor(target) 34 | return events, sents, masks, target 35 | 36 | 37 | class SingleChainSentModel(BasicModel): 38 | """Single chain model combines next sentence prediction.""" 39 | 40 | def __init__(self, config_path): 41 | super(SingleChainSentModel, self).__init__(config_path) 42 | self._logger = logging.getLogger(__name__) 43 | 44 | def build_model(self): 45 | """Build model.""" 46 | work_dir = self._work_dir 47 | device = self._device 48 | pretrain_embedding = numpy.load(os.path.join(work_dir, "pretrain_embedding.npy")) 49 | self._model = SCPredictorSent(self._config, pretrain_embedding).to(device) 50 | 51 | def train(self, train_data=None, dev_data=None): 52 | """Train.""" 53 | # Get hyper-parameters 54 | work_dir = self._work_dir 55 | device = self._device 56 | npoch = self._config["npoch"] 57 | batch_size = self._config["batch_size"] 58 | lr = self._config["lr"] 59 | interval = self._config["interval"] 60 | use_sent = self._config["use_sent"] 61 | # Use default datasets 62 | dev_path = os.path.join(work_dir, "single_dev") 63 | with open(dev_path, "rb") as f: 64 | dev_set = SCSDataset(pickle.load(f)) 65 | # Model 66 | model = self._model.to(device) 67 | # model.sent_encoder.requires_grad_(False) 68 | # model.sent_sequence_model.requires_grad_(False) 69 | # Optimizer and loss function 70 | param_group = [ 71 | { 72 | "params": [p for n, p in model.named_parameters() if "bert" in n], 73 | "lr": 1e-5, 74 | }, 75 | { 76 | "params": [p for n, p in model.named_parameters() if "bert" not in n] 77 | } 78 | ] 79 | # optimizer = AdamW(param_group, lr=lr, weight_decay=1e-6) 80 | optimizer = Adam(param_group, lr=lr, weight_decay=1e-6) 81 | # Train 82 | tmp_dir = os.path.join(work_dir, "single_train") 83 | # with open(os.path.join(tmp_dir, "train.0"), "rb") as f: 84 | # train_set = SCSDataset(pickle.load(f)[:100000]) 85 | best_performance = 0. 86 | for epoch in range(1, npoch + 1): 87 | self._logger.info("===== Epoch {} =====".format(epoch)) 88 | batch_loss = [] 89 | batch_event_loss = [] 90 | fn_list = os.listdir(tmp_dir) 91 | random.shuffle(fn_list) 92 | for fn in fn_list: 93 | # if True: 94 | self._logger.info("Processing slice {} ...".format(fn)) 95 | train_fp = os.path.join(tmp_dir, fn) 96 | with open(train_fp, "rb") as f: 97 | train_set = SCSDataset(pickle.load(f)) 98 | train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=8) 99 | with tqdm(total=len(train_set)) as pbar: 100 | for iteration, (events, sents, masks, target) in enumerate(train_loader): 101 | events = events.to(device) 102 | sents = sents.to(device) 103 | masks = masks.to(device) 104 | target = target.to(device) 105 | model.train() 106 | if use_sent: 107 | event_loss = model(events=events, 108 | sents=sents, 109 | sent_mask=masks, 110 | target=target) 111 | else: 112 | event_loss = model(events=events, target=target) 113 | loss = event_loss 114 | # Get loss 115 | batch_loss.append(loss.item()) 116 | batch_event_loss.append(event_loss.item()) 117 | # Update gradient 118 | optimizer.zero_grad() 119 | loss.backward() 120 | optimizer.step() 121 | # Evaluate on dev set 122 | if (iteration + 1) % interval == 0: 123 | result = self.evaluate(eval_set=dev_set, verbose=False) 124 | if result > best_performance: 125 | best_performance = result 126 | self.save_model("best") 127 | # Update progress bar 128 | # pbar.set_description("Loss: {:.4f}".format(loss.item())) 129 | pbar.set_description("event loss: {:.4f}, best_performance: {:.2%}".format( 130 | sum(batch_event_loss) / len(batch_event_loss), 131 | best_performance 132 | )) 133 | pbar.update(len(events)) 134 | result = self.evaluate(eval_set=dev_set, verbose=False) 135 | if result > best_performance: 136 | best_performance = result 137 | self.save_model("best") 138 | self._logger.info("Average loss: {:.4f}".format( 139 | sum(batch_loss) / len(batch_loss))) 140 | self._logger.info("Best evaluation accuracy: {:.2%}".format(best_performance)) 141 | 142 | def evaluate(self, eval_set=None, verbose=True): 143 | """Evaluate.""" 144 | # Get hyper-parameters 145 | work_dir = self._work_dir 146 | device = self._device 147 | batch_size = self._config["batch_size"] 148 | use_sent = self._config["use_sent"] 149 | # Use default test data 150 | if eval_set is None: 151 | eval_path = os.path.join(work_dir, "single_test") 152 | with open(eval_path, "rb") as f: 153 | eval_set = SCSDataset(pickle.load(f)) 154 | eval_loader = data.DataLoader(eval_set, batch_size, num_workers=8) 155 | # Evaluate 156 | model = self._model 157 | model.eval() 158 | tot, acc = 0, 0 159 | with torch.no_grad(): 160 | for events, sents, masks, target in eval_loader: 161 | events = events.to(device) 162 | sents = sents.to(device) 163 | masks = masks.to(device) 164 | target = target.to(device) 165 | if use_sent: 166 | pred = model(events=events, 167 | sents=sents, 168 | sent_mask=masks) 169 | else: 170 | pred = model(events=events) 171 | acc += pred.argmax(1).eq(target).sum().item() 172 | tot += len(events) 173 | accuracy = acc / tot 174 | if verbose: 175 | self._logger.info("Evaluation accuracy: {:.2%}".format(accuracy)) 176 | return accuracy 177 | -------------------------------------------------------------------------------- /mcpredictor/models/single_chain_sent/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mcpredictor.models.base.attention import build_attention 5 | from mcpredictor.models.base.embedding import build_embedding 6 | from mcpredictor.models.base.event_encoder import build_event_encoder 7 | from mcpredictor.models.base.score import build_score 8 | from mcpredictor.models.base.sentence_encoder import build_sent_encoder 9 | from mcpredictor.models.base.sequence_model import build_sequence_model 10 | 11 | 12 | class SCPredictorSent(nn.Module): 13 | """This model contains two different parts: event part and sentence part. 14 | 15 | Event part encodes events and predict next event, 16 | while sentence part encodes sentences and predict next sentence. 17 | A constraint is applied between events and sentences to force their representation to be similar. 18 | """ 19 | 20 | def __init__(self, config, pretrain_embedding=None, tokenizer=None): 21 | super(SCPredictorSent, self).__init__() 22 | self.config = config 23 | # Event part 24 | self.embedding = build_embedding(config, pretrain_embedding) 25 | self.event_encoder = build_event_encoder(config) 26 | self.event_sequence_model = build_sequence_model(config) 27 | self.event_score = build_score(config) 28 | self.event_attention = build_attention(config) 29 | # Sentence part 30 | vocab_size = len(tokenizer) if tokenizer is not None else 30525 31 | self.sent_encoder = build_sent_encoder(config, vocab_size=vocab_size) 32 | # Criterion 33 | self.criterion = nn.CrossEntropyLoss() 34 | 35 | def forward(self, events, sents=None, sent_mask=None, target=None): 36 | """Forward function. 37 | 38 | If "sents" and "target" is not None, return the sum of three losses, 39 | otherwise, only return the scores of 5 choices. 40 | 41 | :param events: size(batch_size, choice_num, seq_len + 1, 4) 42 | :param sents: size(batch_size, choice_num, seq_len, sent_len) 43 | :param sent_mask: size(batch_size, choice_num, seq_len, sent_len) 44 | :param target: 45 | """ 46 | batch_size = events.size(0) 47 | choice_num = events.size(1) 48 | seq_len = events.size(2) - 1 49 | # Event encoding 50 | # event_repr: size(batch_size, choice_num, seq_len + 1, event_repr_size) 51 | event_repr = self.event_encoding(events) 52 | # Sentence encoding 53 | if sents is not None: 54 | sent_repr = self.sent_encoding(sents, sent_mask) 55 | event_repr_size = sent_repr.size(-1) 56 | sent_repr = torch.cat([sent_repr, sent_repr.new_zeros(batch_size, choice_num, 1, event_repr_size)], dim=-2) 57 | else: 58 | sent_repr = None 59 | # Event sequence modeling 60 | # updated_event_repr: size(batch_size, choice_num, seq_len+1, event_repr_size) 61 | if sents is not None: 62 | event_repr = event_repr + sent_repr 63 | event_repr = event_repr.view(batch_size * choice_num, seq_len + 1, -1) 64 | updated_event_repr = self.event_sequence_model(event_repr) 65 | updated_event_repr = updated_event_repr.view(batch_size, choice_num, seq_len + 1, -1) 66 | # Event scoring and attention 67 | # event_context: size(batch_size, choice_num, seq_len, event_repr_size) 68 | # event_choice: size(batch_size, choice_num, seq_len, event_repr_size) 69 | event_context = updated_event_repr[:, :, :-1, :] 70 | event_choice = updated_event_repr[:, :, -1:, :] 71 | # Event loss 72 | # event_score: size(batch_size, choice_num) 73 | event_score = self.event_score(event_context, event_choice) 74 | event_attention = self.event_attention(event_context, event_choice) 75 | event_score = (event_score * event_attention).sum(-1) 76 | if target is not None: 77 | return self.criterion(event_score, target) 78 | else: 79 | return event_score 80 | 81 | def event_encoding(self, events): 82 | """Encode events.""" 83 | # Embedding 84 | # event_repr: size(batch_size, choice_num, seq_len + 1, 4, embedding_size) 85 | event_repr = self.embedding(events) 86 | # Encoding 87 | # event_repr: size(batch_size, choice_num, seq_len + 1, event_repr_size) 88 | event_repr = self.event_encoder(event_repr) 89 | return event_repr 90 | 91 | def sent_encoding(self, sents, sent_mask): 92 | """Encode sentences.""" 93 | return self.sent_encoder(sents, sent_mask) 94 | -------------------------------------------------------------------------------- /mcpredictor/preprocess/multi_chain.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | import random 5 | 6 | from tqdm import tqdm 7 | from transformers import BertTokenizerFast 8 | 9 | from mcpredictor.preprocess.negative_pool import load_negative_pool 10 | from mcpredictor.preprocess.single_chain import negative_sampling, align_pos_to_token 11 | from mcpredictor.preprocess.stop_event import load_stop_event 12 | from mcpredictor.preprocess.word_dict import load_word_dict 13 | from mcpredictor.utils.document import document_iterator 14 | from mcpredictor.utils.entity import Entity 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def generate_mask_list(chain): 20 | """Generate masked words in chain.""" 21 | masked_list = set() 22 | for event in chain: 23 | masked_list.update(event.get_words()) 24 | return masked_list 25 | 26 | 27 | def make_sample(doc, 28 | choices, 29 | target, 30 | context_size, 31 | verb_position, 32 | word_dict, 33 | stoplist, 34 | tokenizer): 35 | """Make sample.""" 36 | sample_event = [] 37 | sample_sent = [] 38 | sample_mask = [] 39 | sample_pos = [] 40 | for choice in choices: 41 | choice_event = [] 42 | choice_sent = [] 43 | choice_mask = [] 44 | # choice_pos = [] 45 | for choice_role in ["subject", "object", "iobject"]: 46 | protagonist = choice[choice_role] 47 | # Get chain by protagonist 48 | chain = doc.get_chain_for_entity(protagonist, end_pos=verb_position, stoplist=stoplist) 49 | mask_list = generate_mask_list(chain) 50 | # Truncate 51 | if len(chain) > context_size: 52 | chain = chain[-context_size:] 53 | if len(chain) < context_size: 54 | chain = [None] * (context_size - len(chain)) + chain 55 | # Make sample 56 | chain_event = [] 57 | chain_sent = [] 58 | chain_mask = [] 59 | # chain_pos = [] 60 | if isinstance(protagonist, Entity): 61 | p_head = protagonist["head"] 62 | else: 63 | p_head = protagonist 64 | for event in chain: 65 | if event is not None: 66 | verb, subj, obj, iobj, role = event.tuple(protagonist) 67 | predicate_gr = "{}:{}".format(verb, role) if protagonist != "None" else "None" 68 | tmp_mask_list = mask_list.difference(event.get_words()) 69 | # sent, pos = event.tagged_sent(role, mask_list=tmp_mask_list) 70 | sent = event.tagged_sent(role, mask_list=tmp_mask_list) 71 | else: 72 | predicate_gr = subj = obj = iobj = "None" 73 | # sent, pos = [], [] 74 | sent = [] 75 | tmp = [predicate_gr, subj, obj, iobj] 76 | event_input = [word_dict[w] if w in word_dict else word_dict["None"] for w in tmp] 77 | # input_ids, attention_mask, aligned_pos = align_pos_to_token(sent, pos, tokenizer) 78 | input_ids, attention_mask = align_pos_to_token(sent, tokenizer) 79 | chain_event.append(event_input) 80 | chain_sent.append(input_ids) 81 | chain_mask.append(attention_mask) 82 | # chain_pos.append(aligned_pos) 83 | # Choice event 84 | verb, subj, obj, iobj, role = choice.tuple(protagonist) 85 | predicate_gr = "{}:{}".format(verb, role) if protagonist != "None" else "None" 86 | tmp = [predicate_gr, subj, obj, iobj] 87 | event_input = [word_dict[w] if w in word_dict else word_dict["None"] for w in tmp] 88 | chain_event.append(event_input) 89 | # Add to list 90 | choice_event.append(chain_event) 91 | choice_sent.append(chain_sent) 92 | choice_mask.append(chain_mask) 93 | # choice_pos.append(chain_pos) 94 | sample_event.append(choice_event) 95 | sample_sent.append(choice_sent) 96 | sample_mask.append(choice_mask) 97 | # sample_pos.append(choice_pos) 98 | # Adding pos makes a lot of changes, ignore it. 99 | # return sample_event, sample_sent, sample_mask, sample_pos, target 100 | return sample_event, sample_sent, sample_mask, target 101 | 102 | 103 | def generate_multi_train(corp_dir, 104 | work_dir, 105 | tokenized_dir, 106 | # pos_dir, 107 | part_size=100000, 108 | file_type="tar", 109 | context_size=8, 110 | overwrite=False): 111 | """Generate multichain train data. 112 | 113 | :param corp_dir: train corpus directory 114 | :param work_dir: workspace directory 115 | :param tokenized_dir: tokenized raw text directory 116 | :param pos_dir: pos tagging directory 117 | :param part_size: size of each partition 118 | :param file_type: "tar" or "txt" 119 | :param context_size: length of the context chain 120 | :param overwrite: whether to overwrite old data 121 | """ 122 | # All parts of the dataset will be store in a sub directory. 123 | data_dir = os.path.join(work_dir, "multi_train") 124 | if os.path.exists(data_dir) and not overwrite: 125 | logger.info("{} already exists.".format(data_dir)) 126 | else: 127 | # Load stop list 128 | stoplist = load_stop_event(work_dir) 129 | # Load negative pool 130 | neg_pool = load_negative_pool(work_dir, "train") 131 | # Load word dictionary 132 | word_dict = load_word_dict(work_dir) 133 | # Load tokenizer 134 | special_tokens = ["[subj]", "[obj]", "[iobj]"] 135 | tokenizer = BertTokenizerFast.from_pretrained("prajjwal1/bert-tiny", 136 | additional_special_tokens=special_tokens) 137 | # Make sub directory 138 | os.makedirs(data_dir, exist_ok=True) 139 | partition = [] 140 | partition_id = 0 141 | total_num = 0 142 | with tqdm() as pbar: 143 | for doc in document_iterator(corp_dir=corp_dir, 144 | tokenized_dir=tokenized_dir, 145 | # pos_dir=pos_dir, 146 | file_type=file_type, 147 | doc_type="train"): 148 | for protagonist, chain in doc.get_chains(stoplist): 149 | # Context + Answer 150 | if len(chain) < context_size + 1: 151 | continue 152 | # Get non protagonist entities 153 | non_protagonist_entities = doc.non_protagonist_entities(protagonist) 154 | # Make sample 155 | n = len(chain) 156 | for begin, end in zip(range(n), range(8, n)): 157 | context = chain[begin:end] 158 | answer = chain[end] 159 | # Negative sampling 160 | neg_choices = negative_sampling(positive_event=answer, 161 | negative_pool=neg_pool, 162 | protagonist=protagonist, 163 | non_protagonist_entities=non_protagonist_entities) 164 | # Make choices 165 | choices = [answer] + neg_choices 166 | random.shuffle(choices) 167 | target = choices.index(answer) 168 | # Make sample 169 | sample = make_sample(doc=doc, 170 | choices=choices, 171 | target=target, 172 | context_size=context_size, 173 | verb_position=context[-1]["verb_position"], 174 | word_dict=word_dict, 175 | stoplist=stoplist, 176 | tokenizer=tokenizer) 177 | partition.append(sample) 178 | if len(partition) == part_size: 179 | partition_path = os.path.join(data_dir, "train.{}".format(partition_id)) 180 | with open(partition_path, "wb") as f: 181 | pickle.dump(partition, f) 182 | total_num += len(partition) 183 | partition_id += 1 184 | partition = [] 185 | pbar.update(1) 186 | if len(partition) > 0: 187 | partition_path = os.path.join(data_dir, "train.{}".format(partition_id)) 188 | with open(partition_path, "wb") as f: 189 | pickle.dump(partition, f) 190 | total_num += len(partition) 191 | logger.info("Totally {} samples generated.".format(total_num)) 192 | 193 | 194 | def generate_multi_eval(corp_dir, 195 | work_dir, 196 | tokenized_dir, 197 | # pos_dir, 198 | mode="dev", 199 | file_type="txt", 200 | context_size=8, 201 | overwrite=False): 202 | """Generate multi chain evaluate data.""" 203 | data_path = os.path.join(work_dir, "multi_{}".format(mode)) 204 | if os.path.exists(data_path) and not overwrite: 205 | logger.info("{} already exists.".format(data_path)) 206 | else: 207 | # Load stop list 208 | stoplist = load_stop_event(work_dir) 209 | # Load word dictionary 210 | word_dict = load_word_dict(work_dir) 211 | # Load tokenizer 212 | special_tokens = ["[subj]", "[obj]", "[iobj]"] 213 | tokenizer = BertTokenizerFast.from_pretrained("prajjwal1/bert-tiny", 214 | additional_special_tokens=special_tokens) 215 | # Make sample 216 | eval_data = [] 217 | with tqdm() as pbar: 218 | for doc in document_iterator(corp_dir=corp_dir, 219 | tokenized_dir=tokenized_dir, 220 | # pos_dir=pos_dir, 221 | file_type=file_type, 222 | doc_type="eval"): 223 | # protagonist = doc.entity 224 | context = doc.context 225 | choices = doc.choices 226 | target = doc.target 227 | # Make sample 228 | sample = make_sample(doc=doc, 229 | choices=choices, 230 | target=target, 231 | context_size=context_size, 232 | verb_position=context[-1]["verb_position"], 233 | word_dict=word_dict, 234 | stoplist=stoplist, 235 | tokenizer=tokenizer) 236 | eval_data.append(sample) 237 | pbar.update(1) 238 | with open(data_path, "wb") as f: 239 | pickle.dump(eval_data, f) 240 | logger.info("Totally {} samples generated.".format(len(eval_data))) 241 | -------------------------------------------------------------------------------- /mcpredictor/preprocess/negative_pool.py: -------------------------------------------------------------------------------- 1 | """Generate negative event pool.""" 2 | import json 3 | import logging 4 | import os 5 | import pickle 6 | import random 7 | 8 | from tqdm import tqdm 9 | 10 | from mcpredictor.utils.document import document_iterator 11 | from mcpredictor.utils.entity import Entity 12 | from mcpredictor.utils.event import Event 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def entity_check(event): 18 | """Check if the given event contains an entity.""" 19 | return isinstance(event["subject"], Entity) or \ 20 | isinstance(event["object"], Entity) or \ 21 | isinstance(event["iobject"], Entity) 22 | 23 | 24 | def generate_negative_pool(corp_dir, tokenized_dir, work_dir, num_events=None, suffix="train", file_type="tar"): 25 | """Sample a number of negative events.""" 26 | neg_pool_path = os.path.join(work_dir, "negative_pool_{}.json".format(suffix)) 27 | if os.path.exists(neg_pool_path): 28 | logger.info("{} already exists".format(neg_pool_path)) 29 | else: 30 | neg_pool = [] 31 | with tqdm() as pbar: 32 | for doc in document_iterator(corp_dir, tokenized_dir, shuffle=True, file_type=file_type): 33 | if num_events is not None and len(neg_pool) >= num_events: 34 | break 35 | else: 36 | for ent in doc.entities: 37 | ent.clear_mentions() 38 | # events = [e for e in doc.events] 39 | events = doc.events 40 | # If event less than 10, pick all events, 41 | # else randomly pick 10 events from event list. 42 | # Notice: all events should have 43 | # at least one argument that is an entity! 44 | events = [e for e in events if entity_check(e)] 45 | if len(events) < 10: 46 | neg_pool.extend(events) 47 | else: 48 | neg_pool.extend(random.sample(events, 10)) 49 | if num_events is not None and len(neg_pool) > num_events: 50 | neg_pool = neg_pool[:num_events] 51 | pbar.update(1) 52 | with open(neg_pool_path, "w") as f: 53 | json.dump(neg_pool, f) 54 | logger.info("Save negative pool to {}".format(neg_pool_path)) 55 | 56 | 57 | def load_negative_pool(work_dir, suffix="train"): 58 | """Load negative event pool.""" 59 | neg_pool_path = os.path.join(work_dir, "negative_pool_{}.json".format(suffix)) 60 | with open(neg_pool_path, "r") as f: 61 | neg_pool = json.load(f) 62 | neg_pool = [Event(**e) for e in neg_pool] 63 | return neg_pool 64 | 65 | 66 | __all__ = ["generate_negative_pool", "load_negative_pool"] 67 | -------------------------------------------------------------------------------- /mcpredictor/preprocess/single_chain.py: -------------------------------------------------------------------------------- 1 | """Generate single chain data.""" 2 | import logging 3 | import os 4 | import pickle 5 | import random 6 | from copy import copy 7 | 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, BertTokenizerFast 10 | 11 | from mcpredictor.preprocess.negative_pool import load_negative_pool 12 | from mcpredictor.preprocess.stop_event import load_stop_event 13 | from mcpredictor.preprocess.word_dict import load_word_dict 14 | from mcpredictor.utils.document import document_iterator 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def replace_mention(sentence, old_mention, new_mention): 20 | """Replace old mention in sentence with new mention.""" 21 | return sentence.replace(old_mention["text"], new_mention["text"]) 22 | 23 | 24 | def negative_sampling(positive_event, 25 | negative_pool, 26 | protagonist, 27 | non_protagonist_entities, 28 | num_events=4): 29 | """Sampling negative events from negative pool. 30 | 31 | Entities in negative events are replaced with protagonist 32 | and random non-protagonist entities. 33 | Entity mentions in sentences are replaced, too. 34 | 35 | :param positive_event: positive event 36 | :param negative_pool: negative event pool 37 | :param protagonist: protagonist entity 38 | :param non_protagonist_entities: non-protagonist entities 39 | :param num_events: number of negative events 40 | """ 41 | negative_events = [] 42 | for _ in range(num_events): 43 | # Sample a negative event 44 | negative_event = random.choice(negative_pool) 45 | while negative_event["verb_lemma"] == positive_event["verb_lemma"]: 46 | negative_event = random.choice(negative_pool) 47 | negative_event = copy(negative_event) 48 | # Assign entity mapping 49 | negative_entities = negative_event.get_entities() 50 | negative_protagonist = random.choice(negative_entities) 51 | # Replace mention and argument 52 | for old_ent in negative_entities: 53 | if old_ent is not negative_protagonist: 54 | # Select new entity 55 | if len(non_protagonist_entities) > 0: 56 | new_ent = random.choice(non_protagonist_entities) 57 | else: 58 | new_ent = old_ent 59 | else: 60 | new_ent = protagonist 61 | # Replace entity 62 | negative_event.replace_argument(old_ent, new_ent) 63 | negative_events.append(negative_event) 64 | return negative_events 65 | 66 | 67 | def generate_mask_list(chain): 68 | """Generate masked words in chain.""" 69 | masked_list = set() 70 | for event in chain: 71 | if event is not None: 72 | masked_list.update(event.get_words()) 73 | return masked_list 74 | 75 | 76 | def align_pos_to_token(words, tokenizer): 77 | """Align pos tagging to bert tokenized result.""" 78 | inputs = tokenizer(words, 79 | is_split_into_words=True, 80 | return_offsets_mapping=True, 81 | padding="max_length", 82 | truncation=True, 83 | max_length=50) 84 | input_ids = inputs.pop("input_ids") 85 | attention_mask = inputs.pop("attention_mask") 86 | # offset_mapping = inputs.pop("offset_mapping") 87 | # tag_index = 0 88 | # cur_tag = "O" 89 | # aligned_pos = [] 90 | # for offset in offset_mapping: 91 | # if offset[0] == 0 and offset[1] != 0 and tag_index < len(pos): 92 | # # Begin of a new word 93 | # cur_tag = pos[tag_index] 94 | # tag_index += 1 95 | # aligned_pos.append(cur_tag) 96 | # elif offset[0] == 0 and offset[1] == 0 or tag_index >= len(pos): 97 | # # Control tokens 98 | # aligned_pos.append("O") 99 | # else: 100 | # # Subword 101 | # aligned_pos.append(cur_tag) 102 | # return input_ids, attention_mask, aligned_pos 103 | return input_ids, attention_mask 104 | 105 | 106 | def make_sample(protagonist, 107 | context, 108 | choices, 109 | target, 110 | word_dict, 111 | tokenizer): 112 | """Make sample.""" 113 | sample_event = [] 114 | sample_sent = [] 115 | sample_mask = [] 116 | # sample_pos = [] 117 | for choice_id, choice in enumerate(choices): 118 | # chain = context + [choice] 119 | chain_event = [] 120 | chain_sent = [] 121 | chain_mask = [] 122 | # chain_pos = [] 123 | mask_list = generate_mask_list(context) 124 | # Context 125 | for event in context: 126 | if event is not None: 127 | verb, subj, obj, iobj, role = event.tuple(protagonist) 128 | predicate_gr = "{}:{}".format(verb, role) 129 | # Convert sentence 130 | tmp_mask_list = mask_list.difference(event.get_words()) 131 | # sent, pos = event.tagged_sent(role, mask_list=tmp_mask_list) 132 | sent = event.tagged_sent(role, mask_list=tmp_mask_list) 133 | else: 134 | predicate_gr = subj = obj = iobj = "None" 135 | # sent, pos = [], [] 136 | sent = [] 137 | # Convert event 138 | tmp = [predicate_gr, subj, obj, iobj] 139 | tmp = [word_dict[w] if w in word_dict else word_dict["None"] for w in tmp] 140 | chain_event.append(tmp) 141 | # input_ids, attention_mask, aligned_pos = align_pos_to_token(sent, pos, tokenizer) 142 | input_ids, attention_mask = align_pos_to_token(sent, tokenizer) 143 | chain_sent.append(input_ids) 144 | chain_mask.append(attention_mask) 145 | # chain_pos.append(aligned_pos) 146 | # Choice 147 | verb, subj, obj, iobj, role = choice.tuple(protagonist) 148 | predicate_gr = "{}:{}".format(verb, role) 149 | tmp = [predicate_gr, subj, obj, iobj] 150 | tmp = [word_dict[w] if w in word_dict else word_dict["None"] for w in tmp] 151 | chain_event.append(tmp) 152 | # Add to sample 153 | sample_event.append(chain_event) 154 | sample_sent.append(chain_sent) 155 | sample_mask.append(chain_mask) 156 | # sample_pos.append(chain_pos) 157 | # return sample_event, sample_sent, sample_mask, sample_pos, target 158 | return sample_event, sample_sent, sample_mask, target 159 | 160 | 161 | def generate_single_train(corp_dir, 162 | work_dir, 163 | tokenized_dir, 164 | # pos_dir, 165 | part_size=200000, 166 | file_type="tar", 167 | context_size=8, 168 | overwrite=False): 169 | """Generate single chain train data. 170 | 171 | :param corp_dir: train corpus directory 172 | :param work_dir: workspace directory 173 | :param tokenized_dir: tokenized raw text directory 174 | :param part_size: size of each partition 175 | :param file_type: "tar" or "txt" 176 | :param context_size: length of the context chain 177 | :param overwrite: whether to overwrite old data 178 | """ 179 | # All parts of the dataset will be store in a sub directory. 180 | data_dir = os.path.join(work_dir, "single_train") 181 | if os.path.exists(data_dir) and not overwrite: 182 | logger.info("{} already exists.".format(data_dir)) 183 | else: 184 | # Load stop list 185 | stoplist = load_stop_event(work_dir) 186 | # Load negative pool 187 | neg_pool = load_negative_pool(work_dir, "train") 188 | # Load word dictionary 189 | word_dict = load_word_dict(work_dir) 190 | # Load tokenizer 191 | special_tokens = ["[subj]", "[obj]", "[iobj]"] 192 | tokenizer = BertTokenizerFast.from_pretrained("prajjwal1/bert-tiny", 193 | additional_special_tokens=special_tokens) 194 | # Make sub directory 195 | os.makedirs(data_dir, exist_ok=True) 196 | partition = [] 197 | partition_id = 0 198 | total_num = 0 199 | with tqdm() as pbar: 200 | for doc in document_iterator(corp_dir=corp_dir, 201 | tokenized_dir=tokenized_dir, 202 | # pos_dir=pos_dir, 203 | file_type=file_type, 204 | doc_type="train"): 205 | for protagonist, chain in doc.get_chains(stoplist): 206 | # Context + Answer 207 | if len(chain) < context_size + 1: 208 | continue 209 | # Get non protagonist entities 210 | non_protagonist_entities = doc.non_protagonist_entities(protagonist) 211 | # Make sample 212 | n = len(chain) 213 | for begin, end in zip(range(n), range(8, n)): 214 | context = chain[begin:end] 215 | answer = chain[end] 216 | # Negative sampling 217 | neg_choices = negative_sampling(positive_event=answer, 218 | negative_pool=neg_pool, 219 | protagonist=protagonist, 220 | non_protagonist_entities=non_protagonist_entities) 221 | # Make choices 222 | choices = [answer] + neg_choices 223 | random.shuffle(choices) 224 | target = choices.index(answer) 225 | # Make sample 226 | sample = make_sample(protagonist=protagonist, 227 | context=context, 228 | choices=choices, 229 | target=target, 230 | word_dict=word_dict, 231 | tokenizer=tokenizer) 232 | partition.append(sample) 233 | if len(partition) == part_size: 234 | partition_path = os.path.join(data_dir, "train.{}".format(partition_id)) 235 | with open(partition_path, "wb") as f: 236 | pickle.dump(partition, f) 237 | total_num += len(partition) 238 | partition_id += 1 239 | partition = [] 240 | pbar.update(1) 241 | if len(partition) > 0: 242 | partition_path = os.path.join(data_dir, "train.{}".format(partition_id)) 243 | with open(partition_path, "wb") as f: 244 | pickle.dump(partition, f) 245 | total_num += len(partition) 246 | logger.info("Totally {} samples generated.".format(total_num)) 247 | 248 | 249 | def generate_single_eval(corp_dir, 250 | work_dir, 251 | tokenized_dir, 252 | # pos_dir, 253 | mode="dev", 254 | file_type="txt", 255 | context_size=8, 256 | overwrite=False): 257 | """Generate single chain evaluate data.""" 258 | data_path = os.path.join(work_dir, "single_{}".format(mode)) 259 | if os.path.exists(data_path) and not overwrite: 260 | logger.info("{} already exists.".format(data_path)) 261 | else: 262 | # Load stop event list 263 | stoplist = load_stop_event(work_dir) 264 | # Load word dictionary 265 | word_dict = load_word_dict(work_dir) 266 | # Load tokenizer 267 | special_tokens = ["[subj]", "[obj]", "[iobj]"] 268 | tokenizer = BertTokenizerFast.from_pretrained("prajjwal1/bert-tiny", 269 | additional_special_tokens=special_tokens) 270 | # Make sample 271 | eval_data = [] 272 | with tqdm() as pbar: 273 | for doc in document_iterator(corp_dir=corp_dir, 274 | tokenized_dir=tokenized_dir, 275 | # pos_dir=pos_dir, 276 | file_type=file_type, 277 | doc_type="eval"): 278 | protagonist = doc.entity 279 | # context = doc.context 280 | # Context cannot be directly used, since there are slight differences 281 | context = doc.get_chain_for_entity(protagonist, 282 | end_pos=doc.context[-1]["verb_position"], 283 | stoplist=stoplist) 284 | if len(context) > context_size: 285 | context = context[-context_size:] 286 | if len(context) < context_size: 287 | context = [None] * (context_size - len(context)) + context 288 | target = doc.target 289 | choices = doc.choices 290 | # Make sample 291 | sample = make_sample(protagonist=protagonist, 292 | context=context, 293 | choices=choices, 294 | target=target, 295 | word_dict=word_dict, 296 | tokenizer=tokenizer) 297 | eval_data.append(sample) 298 | pbar.update(1) 299 | with open(data_path, "wb") as f: 300 | pickle.dump(eval_data, f) 301 | logger.info("Totally {} samples generated.".format(len(eval_data))) 302 | -------------------------------------------------------------------------------- /mcpredictor/preprocess/stop_event.py: -------------------------------------------------------------------------------- 1 | """Count frequent event(verb) type.""" 2 | import logging 3 | import os 4 | from collections import Counter 5 | 6 | from tqdm import tqdm 7 | 8 | from mcpredictor.utils.document import document_iterator 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def count_stop_event(corp_dir, 15 | work_dir, 16 | file_type="tar", 17 | num_events=10, 18 | overwrite=False, 19 | ): 20 | """Count frequent event(verb) type.""" 21 | stop_event_path = os.path.join(work_dir, "stoplist.txt") 22 | if os.path.exists(stop_event_path) and not overwrite: 23 | logger.info("{} already exists".format(stop_event_path)) 24 | else: 25 | # Count verb:role occurrence 26 | counter = Counter() 27 | logger.info("Scanning training documents ...") 28 | with tqdm() as pbar: 29 | for doc in document_iterator(corp_dir=corp_dir, 30 | file_type=file_type): 31 | pbar.set_description("Processing {}".format(doc.doc_id)) 32 | for entity, chain in doc.get_chains(): 33 | preds = [e.predicate_gr(entity) for e in chain] 34 | counter.update(preds) 35 | pbar.update(1) 36 | # be:subj should be preserved 37 | del counter["be:subj"] 38 | # Select top N verb:role 39 | stop_events = [t[0] for t in counter.most_common(num_events)] 40 | logger.info("Top {} frequent predicates are: {}".format(num_events, ", ".join(stop_events))) 41 | # Save to file 42 | with open(stop_event_path, "w") as f: 43 | f.write("\n".join(stop_events)) 44 | logger.info("Save stop_events to {}".format(stop_event_path)) 45 | 46 | 47 | def load_stop_event(work_dir): 48 | """Load stop event list.""" 49 | stop_event_path = os.path.join(work_dir, "stoplist.txt") 50 | with open(stop_event_path, "r") as f: 51 | stop_events = f.read().splitlines() 52 | return stop_events 53 | 54 | 55 | __all__ = ["count_stop_event", "load_stop_event"] 56 | -------------------------------------------------------------------------------- /mcpredictor/preprocess/word_dict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def generate_word_dict(corp_dir, work_dir): 6 | """Generate word dictionary.""" 7 | 8 | 9 | def load_word_dict(work_dir): 10 | """Load word dictionary.""" 11 | word_dict_path = os.path.join(work_dir, "word_dict.json") 12 | with open(word_dict_path, "r") as f: 13 | word_dict = json.load(f) 14 | return word_dict 15 | -------------------------------------------------------------------------------- /mcpredictor/preprocess/word_embedding.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import numpy 6 | from gensim.models import Word2Vec 7 | 8 | from mcpredictor.utils.document import document_iterator 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ChainIterator: 14 | """Narrative event chain iterator for word2vec.""" 15 | 16 | def __init__(self, corp_dir): 17 | self.corp_dir = corp_dir 18 | 19 | def __iter__(self): 20 | for doc in document_iterator(self.corp_dir, 21 | tokenized_dir=None, 22 | file_type="tar", 23 | doc_type="train", 24 | shuffle=False, 25 | pos_dir=None): 26 | for entity, chain in doc.get_chains(): 27 | sequence = [] 28 | for event in chain: 29 | _, a0, a1, a2 = event.tuple() 30 | sequence.extend([event.predicate_gr(entity), a0, a1, a2]) 31 | sequence = [t for t in sequence if t != "None"] 32 | yield sequence 33 | 34 | 35 | def generate_word_embedding(train_corp_dir, work_dir, embedding_size=300, force=False): 36 | """Generate word embeddings and dictionary.""" 37 | # Train word2vec 38 | word2vec_path = os.path.join(work_dir, "word2vec_{}.bin".format(embedding_size)) 39 | if os.path.exists(word2vec_path) and not force: 40 | logger.info("Word embeddings generated.") 41 | else: 42 | logger.info("Generating word embeddings ...") 43 | word2vec = Word2Vec(ChainIterator(train_corp_dir), 44 | size=embedding_size, 45 | window=15, 46 | workers=8, 47 | min_count=20) 48 | word2vec.save(word2vec_path) 49 | logger.info("Save word2vec to {}".format(word2vec_path)) 50 | # Generate pretrain embedding matrix 51 | pretrain_embedding_path = os.path.join(work_dir, "pretrain_embedding.npy") 52 | if os.path.exists(pretrain_embedding_path) and not force: 53 | logger.info("Pretrain embedding matrix generated.") 54 | else: 55 | logger.info("Generating pretrain embedding matrix ...") 56 | word2vec = Word2Vec.load(word2vec_path) 57 | word2vec.init_sims() 58 | total_words = len(word2vec.wv.vocab) + 1 59 | pretrain_embedding = numpy.zeros( 60 | (total_words, embedding_size), dtype=numpy.float32) 61 | for word in word2vec.wv.vocab: 62 | idx = word2vec.wv.vocab[word].index 63 | pretrain_embedding[idx + 1] = word2vec.wv.syn0norm[idx] 64 | numpy.save(pretrain_embedding_path, pretrain_embedding) 65 | logger.info("Save pretrain embedding matrix to {}".format(pretrain_embedding_path)) 66 | # Generate word dictionary 67 | word_dict_path = os.path.join(work_dir, "word_dict.json") 68 | if os.path.exists(word_dict_path) and not force: 69 | logger.info("Word dictionary generated.") 70 | else: 71 | logger.info("Generating word dictionary ...") 72 | word2vec = Word2Vec.load(word2vec_path) 73 | word2vec.init_sims() 74 | word_dict = {"None": 0} 75 | for word in word2vec.wv.vocab: 76 | idx = word2vec.wv.vocab[word].index 77 | word_dict[word] = idx + 1 78 | with open(word_dict_path, "w") as f: 79 | json.dump(word_dict, f) 80 | logger.info("Word dictionary save to {}, " 81 | "totally {} words.".format(word_dict_path, len(word_dict))) 82 | -------------------------------------------------------------------------------- /mcpredictor/utils/common.py: -------------------------------------------------------------------------------- 1 | """Common functions used for several modules.""" 2 | 3 | 4 | def unescape(text, space_slashes=False): 5 | """Function copy from G&C16. 6 | 7 | :param text: text to be unescaped 8 | :type text: str 9 | :param space_slashes: if add space when convert single-slash 10 | :type space_slashes: bool 11 | :return: the unescaped text 12 | """ 13 | # Reverse various substitutions made on output 14 | text = text.replace("@semicolon@", ";") 15 | text = text.replace("@comma@", ",") 16 | if space_slashes: 17 | text = text.replace("@slash@", " / ") 18 | text = text.replace("@slashes@", " // ") 19 | else: 20 | text = text.replace("@slash@", "/") 21 | return text 22 | -------------------------------------------------------------------------------- /mcpredictor/utils/config.py: -------------------------------------------------------------------------------- 1 | """Arguments.""" 2 | import argparse 3 | 4 | 5 | def parse_args(): 6 | """Parse input arguments.""" 7 | # Load config files 8 | parser = argparse.ArgumentParser(prog="MCPredictor") 9 | # Set basic arguments 10 | parser.add_argument("--data_dir", default="/home/jinxiaolong/bl/data/gandc16", 11 | type=str, help="MCNC corpus directory") 12 | parser.add_argument("--work_dir", default="/home/jinxiaolong/bl/data/mc_data", 13 | type=str, help="Workspace directory") 14 | parser.add_argument("--device", default="cuda:0", 15 | choices=["cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3"], 16 | help="Device used for models.") 17 | # parser.add_argument("--mode", default="preprocess", 18 | # choices=["preprocess", "train", "dev", "test"], 19 | # type=str, help="Experiment mode") 20 | parser.add_argument("--model_config", default="config/scpredictor-sent.json", 21 | type=str, help="Model configuration files") 22 | parser.add_argument("--multi", action="store_true", default=False) 23 | # Set model arguments 24 | return parser.parse_args() 25 | 26 | 27 | CONFIG = parse_args() 28 | -------------------------------------------------------------------------------- /mcpredictor/utils/document.py: -------------------------------------------------------------------------------- 1 | """Document class for gigaword processed corpus.""" 2 | import os 3 | import random 4 | import tarfile 5 | 6 | from mcpredictor.utils.entity import Entity 7 | from mcpredictor.utils.event import Event 8 | 9 | 10 | def _parse_document(text, tokenized_dir=None, pos_dir=None): 11 | """Parse document. 12 | 13 | Refer to G&C16 14 | 15 | :param text: document content. 16 | :param tokenized_dir: raw text directory 17 | :param pos_dir: pos result directory 18 | :return: doc_id, entities, events 19 | """ 20 | lines = [_.strip() for _ in text.splitlines()] 21 | # Get doc_id 22 | doc_id = lines[0] 23 | # get entity position and event position 24 | entity_pos = lines.index("Entities:") 25 | event_pos = lines.index("Events:") 26 | # Add entities 27 | entities = [] 28 | for line in lines[entity_pos + 1:event_pos]: 29 | if line: 30 | entities.append(Entity.from_text(line)) 31 | # Read raw text if tokenized_dir is given 32 | if tokenized_dir is not None: 33 | raw_path = os.path.join(tokenized_dir, doc_id[:14].lower(), doc_id + ".txt") 34 | with open(raw_path, "r") as f: 35 | content = f.read().splitlines() 36 | else: 37 | content = None 38 | # Read pos tag if pos_dir is given 39 | if pos_dir is not None: 40 | pos_path = os.path.join(pos_dir, doc_id[:14].lower(), doc_id + ".txt") 41 | with open(pos_path, "r") as f: 42 | pos = f.read().splitlines() 43 | pos = [[t.split("|")[2] for t in s.split()] for s in pos] 44 | else: 45 | pos = None 46 | # Add events 47 | events = [] 48 | for line in lines[event_pos + 1:]: 49 | if line: 50 | cur_event = Event.from_text(line, entities, doc_text=content, doc_pos=pos) 51 | # Check if current event is duplicate. 52 | # Since events are sorted by verb_pos, 53 | # we only need to look back one event. 54 | # if len(events) == 0 or events[-1]["verb_position"] != cur_event["verb_position"]: 55 | # events.append(cur_event) 56 | # TODO: In old code, duplicate check is invalid, 57 | # however the code works well. 58 | # Thus we do not check duplicate event temporally. 59 | events.append(cur_event) 60 | return doc_id, entities, events 61 | 62 | 63 | class Document: 64 | """Document class.""" 65 | 66 | def __init__(self, doc_id, entities=None, events=None): 67 | self.doc_id = doc_id 68 | self.entities = entities or [] 69 | self.events = events or [] 70 | 71 | @classmethod 72 | def from_text(cls, text, tokenized_dir=None, pos_dir=None): 73 | """Initialize Document from text. 74 | 75 | :param text: document content 76 | :type text: str 77 | :param tokenized_dir: raw text (tokenized) directory 78 | :type tokenized_dir: str 79 | :param pos_dir: pos tag directory 80 | :type pos_dir: str 81 | """ 82 | doc_id, entities, events = _parse_document(text, tokenized_dir, pos_dir) 83 | return cls(doc_id, entities, events) 84 | 85 | def get_chain_for_entity(self, entity, end_pos=None, duplicate=False, stoplist=None): 86 | """Get chain for specified entity. 87 | 88 | :param entity: protagonist 89 | :type entity: Entity 90 | :param end_pos: get events until stop position 91 | :type end_pos: tuple[int, int] or None 92 | :param duplicate: whether to obtain duplicate verb 93 | :param stoplist: stop word list 94 | :return: 95 | """ 96 | # Get chain 97 | result = [event for event in self.events if event.contain(entity)] 98 | if not duplicate: 99 | result = [event for idx, event in enumerate(result) 100 | if idx == 0 or event["verb_position"] != result[idx-1]["verb_position"]] 101 | if end_pos is not None: 102 | result = [event for event in result if event.verb_position <= end_pos] 103 | if stoplist is not None: 104 | result = [event for event in result if event.predicate_gr(entity) not in stoplist] 105 | return result 106 | 107 | def get_chains(self, stoplist=None): 108 | """Get all (protagonist, chain) pairs. 109 | 110 | :param stoplist: stop verb list 111 | :type stoplist: None or list[str] 112 | """ 113 | # Get entities 114 | result = [(entity, self.get_chain_for_entity(entity)) 115 | for entity in self.entities] 116 | if stoplist is not None: 117 | result = [(entity, [event for event in chain 118 | if event.predicate_gr(entity) not in stoplist 119 | and event.verb_lemma not in stoplist]) 120 | for (entity, chain) in result] 121 | # Filter null chains 122 | result = [(entity, chain) for (entity, chain) in result 123 | if len(chain) > 0] 124 | return result 125 | 126 | def non_protagonist_entities(self, entity): 127 | """Return list of non protagonist entities 128 | 129 | :param entity: protagonist 130 | :type entity: Entity 131 | :return: 132 | """ 133 | result = [e for e in self.entities if e is not entity] 134 | return result 135 | 136 | 137 | def _parse_question(text, entities, doc_id, tokenized_dir=None, pos_dir=None): 138 | """Parse question. 139 | 140 | :param text: question text 141 | :param entities: entity list of the document 142 | :param doc_id: document id 143 | :param tokenized_dir: raw text directory 144 | :param pos_dir: 145 | :return: entity, context, choices, target 146 | """ 147 | lines = text.splitlines() 148 | entity_pos = lines.index("Entity:") 149 | context_pos = lines.index("Context:") 150 | choices_pos = lines.index("Choices:") 151 | target_pos = lines.index("Target:") 152 | entity = entities[int(lines[entity_pos + 1])] 153 | # Read raw text if tokenized_dir is given 154 | if tokenized_dir is not None: 155 | raw_path = os.path.join(tokenized_dir, doc_id[:14].lower(), doc_id + ".txt") 156 | with open(raw_path, "r") as f: 157 | content = f.read().splitlines() 158 | # Read pos tag if pos_dir is given 159 | if pos_dir is not None: 160 | pos_path = os.path.join(pos_dir, doc_id[:14].lower(), doc_id + ".txt") 161 | with open(pos_path, "r") as f: 162 | pos = f.read().splitlines() 163 | pos = [[t.split("|")[2] for t in s.split()] for s in pos] 164 | else: 165 | pos = None 166 | # Context 167 | context = [Event.from_text(e, entities, doc_text=content, doc_pos=pos) 168 | for e in lines[context_pos + 1:choices_pos - 1] if e] 169 | choices = [Event.from_text(e, entities) 170 | for e in lines[choices_pos + 1:target_pos - 1] if e] 171 | target = int(lines[target_pos + 1]) 172 | # Assign sent to answer 173 | answer = choices[target] 174 | answer["sent"] = content[answer["verb_position"][0]] 175 | return entity, context, choices, target 176 | 177 | 178 | class TestDocument(Document): 179 | """TestDocument class.""" 180 | 181 | def __init__(self, doc_id, entities=None, events=None, 182 | entity=None, context=None, choices=None, target=None): 183 | super(TestDocument, self).__init__(doc_id, entities, events) 184 | self.entity = entity 185 | self.context = context or [] 186 | self.choices = choices or [] 187 | self.target = target 188 | 189 | @classmethod 190 | def from_text(cls, text, tokenized_dir=None, pos_dir=None): 191 | """"Content should first be split into question part and document part. 192 | 193 | :param text: text to be processed. 194 | :type text: str 195 | :param tokenized_dir: raw text directory 196 | :type tokenized_dir: str 197 | :param pos_dir: pos directory 198 | :type pos_dir: str 199 | """ 200 | # Split lines 201 | lines = text.splitlines() 202 | # Get positions 203 | document_pos = lines.index("Document:") 204 | # Parse document part 205 | document_part = "\n".join(lines[document_pos+1:]) 206 | doc_id, entities, events = _parse_document(document_part, tokenized_dir, pos_dir) 207 | # Parse question part 208 | question_part = "\n".join(lines[:document_pos]) 209 | entity, context, choices, target = _parse_question(question_part, entities, doc_id, tokenized_dir, pos_dir) 210 | return cls(doc_id, entities, events, entity, context, choices, target) 211 | 212 | def get_question(self): 213 | """Transform entity into head word; 214 | transform context and choices events into predicate grammar multichain_role like ones. 215 | """ 216 | entity = self.entity.get_head_word() 217 | context = [event.predicate_gr(self.entity) for event in self.context] 218 | choices = [event.predicate_gr(self.entity) for event in self.choices] 219 | target = self.target 220 | return entity, context, choices, target 221 | 222 | 223 | def document_iterator(corp_dir, 224 | tokenized_dir=None, 225 | file_type="tar", 226 | doc_type="train", 227 | shuffle=False, 228 | pos_dir=None): 229 | """Iterator of documents.""" 230 | # Check file_type 231 | assert file_type in ["tar", "txt"], "Only accept tar/txt as file_type!" 232 | # Check doc_type 233 | assert doc_type in ["train", "eval"], "Only accept train/eval as doc_type!" 234 | # Read file_list 235 | fn_list = os.listdir(corp_dir) 236 | if shuffle: 237 | random.shuffle(fn_list) 238 | fn_list = [fn for fn in fn_list if fn.endswith(file_type)] 239 | if file_type == "txt": 240 | for fn in fn_list: 241 | fpath = os.path.join(corp_dir, fn) 242 | with open(fpath, "r") as f: 243 | content = f.read() 244 | if doc_type == "train": 245 | yield Document.from_text(content, tokenized_dir, pos_dir) 246 | else: # doc_type == "eval" 247 | yield TestDocument.from_text(content, tokenized_dir, pos_dir) 248 | else: # file_type == "tar" 249 | for fn in fn_list: 250 | fpath = os.path.join(corp_dir, fn) 251 | with tarfile.open(fpath, "r") as f: 252 | members = f.getmembers() 253 | if shuffle: 254 | random.shuffle(members) 255 | for member in members: 256 | content = f.extractfile(member).read().decode("utf-8") 257 | if doc_type == "train": 258 | yield Document.from_text(content, tokenized_dir, pos_dir) 259 | else: 260 | yield TestDocument.from_text(content, tokenized_dir, pos_dir) 261 | 262 | 263 | __all__ = ["Document", "TestDocument", "document_iterator"] 264 | -------------------------------------------------------------------------------- /mcpredictor/utils/entity.py: -------------------------------------------------------------------------------- 1 | """Entity class for gigaword processed document.""" 2 | import os 3 | import re 4 | import string 5 | 6 | from mcpredictor.utils.mention import Mention 7 | 8 | 9 | # Punctuation regex 10 | punct_re = re.compile('([%s])+' % re.escape(string.punctuation)) 11 | # Pronouns 12 | PRONOUNS = [ 13 | "i", "you", "he", "she", "it", "we", "they", 14 | "me", "him", "her", "us", "them", 15 | "myself", "yourself", "himself", "herself", "itself", "ourself", "ourselves", "themselves", 16 | "my", "your", "his", "its", "it's", "our", "their", 17 | "mine", "yours", "ours", "theirs", 18 | "this", "that", "those", "these"] 19 | # Load stop word list 20 | with open(os.path.join("data", "english_stopwords.txt"), "r") as f: 21 | STOPWORDS = f.read().splitlines() 22 | 23 | 24 | def field_value(field, text): 25 | """Find value for certain field. 26 | 27 | :param field: 28 | :type field: str 29 | :param text: 30 | :type text: str 31 | :return: 32 | """ 33 | parts = text.split("=") 34 | assert field == parts[0] 35 | return parts[1] 36 | 37 | 38 | def filter_words_in_mention(words): 39 | """Filter stop words and pronouns in mention.""" 40 | return [w for w in words if w not in PRONOUNS and w not in STOPWORDS] 41 | 42 | 43 | def get_head_word(mentions): 44 | """Get head word of mentions. 45 | 46 | Copy from G&C16 47 | """ 48 | entity_head_words = set() 49 | for mention in mentions: 50 | # Get head word from mention 51 | mention_head = mention["head"] 52 | # Remove punctuation 53 | mention_head = punct_re.sub(" ", mention_head) 54 | # Split words 55 | head_words = mention_head.split() 56 | # Get rid of words that won't help us: stopwords and pronouns 57 | head_words = filter_words_in_mention(head_words) 58 | # Don't use any 1-letter words 59 | head_words = [w for w in head_words if len(w) > 1] 60 | # If there are no words left, we can't get a headword from this mention 61 | # If there are multiple (a minority of cases), use the rightmost, 62 | # which usually is the headword 63 | if head_words: 64 | entity_head_words.add(head_words[-1]) 65 | if len(entity_head_words): 66 | return list(sorted(entity_head_words))[0] 67 | else: 68 | return "None" 69 | 70 | 71 | class Entity(dict): 72 | """Entity class.""" 73 | def __init__(self, **kwargs): 74 | # Convert dict mention to Mention object 75 | mentions = [] 76 | for mention in kwargs["mentions"]: 77 | if isinstance(mention, Mention): 78 | mentions.append(mention) 79 | else: 80 | mentions.append(Mention(**mention)) 81 | kwargs["mentions"] = mentions 82 | # Get head word 83 | kwargs.setdefault("head", get_head_word(mentions)) 84 | super(Entity, self).__init__(**kwargs) 85 | 86 | def __getstate__(self): 87 | return self.__dict__ 88 | 89 | def __setstate__(self, state): 90 | self.__dict__ = state 91 | 92 | def __getattr__(self, item): 93 | return self[item] 94 | 95 | def __repr__(self): 96 | return "<{entity_id:}:{head:}>".format(**self) 97 | 98 | def find_mention_by_pos(self, verb_position): 99 | """Find mention by verb_position.""" 100 | sent_id = verb_position[0] 101 | mentions = [m["text"] for m in self["mentions"] if m["sentence_num"] == sent_id] 102 | if len(mentions) == 0: 103 | return "None" 104 | else: 105 | return max(mentions, key=lambda x: len(x)) 106 | 107 | def find_longest_mention(self): 108 | """Find longest mention.""" 109 | mentions = [m["text"] for m in self["mentions"]] 110 | return max(mentions, key=lambda x: len(x)) 111 | 112 | def clear_mentions(self): 113 | """Clear mentions field in order to save space during storing.""" 114 | self["mentions"] = [] 115 | 116 | def get_head_word(self): 117 | """Return head word.""" 118 | return self["head"] 119 | 120 | @classmethod 121 | def from_text(cls, text): 122 | """Construct Entity object from text. 123 | 124 | :param text: text to be parsed. 125 | """ 126 | # Though parsing can be done by regex, 127 | # it is not necessary. 128 | entity_id, text = text.split(":", 1) 129 | # Find entity_id 130 | entity_id = "entity-{}".format(entity_id) 131 | # Find other attributes 132 | parts = [p.strip() for p in text.split(" // ")] 133 | category = field_value("category", parts[0]) 134 | gender = field_value("gender", parts[1]) 135 | gender_prob = float(field_value("genderProb", parts[2])) 136 | number = field_value("number", parts[3]) 137 | number_prob = float(field_value("numberProb", parts[4])) 138 | # Parse mentions 139 | mentions = field_value("mentions", parts[5]).split(" / ") 140 | mentions = [Mention.from_text(m) for m in mentions] 141 | mentions = [m for m in mentions if m.text] 142 | # Parse type 143 | if len(parts) > 6: 144 | type_ = field_value("type", parts[6]) 145 | else: 146 | type_ = "misc" 147 | return cls(entity_id=entity_id, 148 | # category=category, 149 | # gender=gender, 150 | # gender_prob=gender_prob, 151 | # number=number, 152 | # number_prob=number_prob, 153 | mentions=mentions, 154 | # type=type_ 155 | ) 156 | 157 | 158 | def transform_entity(entity, verb_position=None): 159 | """Transform entity/str into json object.""" 160 | item = {} 161 | if isinstance(entity, Entity): 162 | item["head"] = entity["head"] 163 | item["entity"] = int(entity["entity_id"][7:]) 164 | if verb_position is not None: 165 | item["mention"] = entity.find_mention_by_pos(verb_position) 166 | else: 167 | item["mention"] = entity.find_longest_mention() 168 | else: 169 | item["head"] = entity 170 | item["mention"] = entity 171 | item["entity"] = -1 172 | return item 173 | 174 | 175 | __all__ = ["Entity", "filter_words_in_mention", "transform_entity"] 176 | -------------------------------------------------------------------------------- /mcpredictor/utils/event.py: -------------------------------------------------------------------------------- 1 | """Event class for gigaword processed document.""" 2 | import re 3 | 4 | from mcpredictor.utils.common import unescape 5 | from mcpredictor.utils.entity import Entity, transform_entity 6 | 7 | event_re = re.compile(r'(?P[^/]*) / (?P[^/]*) / ' 8 | r'verb_pos=\((?P\d+),(?P\d+)\) / ' 9 | r'type=(?P[^/]*) / subj=(?P[^/]*) / obj=(?P[^/]*) / ' 10 | r'iobj=(?P[^/]*)') 11 | 12 | 13 | HEAD_MODE = "h" 14 | MENTION_MODE = "m" 15 | GLOBAL_MODE = "g" 16 | 17 | 18 | def find_entity_by_id(s, entity_list): 19 | """Return entity according to eid. 20 | This function helps to reduce memory cost, 21 | since each event saves pointer of entity instead of a string. 22 | 23 | :param s: could be a word or entity id 24 | :type s: str 25 | :param entity_list: all mentioned entities in document 26 | :type entity_list: list[Entity] 27 | :return: str or Entity 28 | """ 29 | if s.startswith("entity-"): 30 | return entity_list[int(s[7:])] 31 | else: 32 | return unescape(s) 33 | 34 | 35 | class Event(dict): 36 | """Event class.""" 37 | def __init__(self, **kwargs): 38 | for key in ["subject", "object", "iobject"]: 39 | if isinstance(kwargs[key], dict) and not isinstance(kwargs[key], Entity): 40 | kwargs[key] = Entity(**kwargs[key]) 41 | super(Event, self).__init__(**kwargs) 42 | 43 | @property 44 | def filter(self): 45 | """In filtered format, an event is represented as follows: 46 | 47 | event: { 48 | sent: str, 49 | verb_lemma: str, 50 | verb_position: [int, int], 51 | subject: {head: str, mention: str, entity: bool} 52 | object: {head: str, mention: str, entity: bool} 53 | iobject: {head: str, mention: str, entity: bool} 54 | iobject_prep: str 55 | } 56 | """ 57 | item = { 58 | "verb_lemma": self["verb_lemma"], 59 | "verb_position": self["verb_position"], 60 | "subject": transform_entity(self["subject"], self["verb_position"]), 61 | "object": transform_entity(self["object"], self["verb_position"]), 62 | "iobject": transform_entity(self["iobject"], self["verb_position"]), 63 | "iobject_prep": self["iobject_prep"] 64 | } 65 | # For negative events, they have no corresponding sentences. 66 | if "sent" in self: 67 | item["sent"] = self["sent"] 68 | return item 69 | 70 | def __getstate__(self): 71 | return self.__dict__ 72 | 73 | def __setstate__(self, state): 74 | self.__dict__ = state 75 | 76 | def __getattr__(self, item): 77 | return self[item] 78 | 79 | def __repr__(self): 80 | return "[{verb_lemma:}:" \ 81 | "{subject:}," \ 82 | "{object:}," \ 83 | "{iobject:}]".format(**self) 84 | 85 | def contain(self, argument): 86 | """Check if the event contains the argument.""" 87 | if argument == "None": 88 | return False 89 | else: 90 | return self["subject"] == argument or \ 91 | self["object"] == argument or \ 92 | self["iobject"] == argument 93 | 94 | def find_role(self, argument, stoplist=None): 95 | """Find the role of the argument.""" 96 | if argument == self["subject"]: 97 | return "subj" 98 | elif argument == self["object"]: 99 | return "obj" 100 | elif argument == self["iobject"] and self["iobject"] != "None": 101 | return "prep_{}".format(self["iobject_prep"]) 102 | else: 103 | return "None" 104 | 105 | def predicate_gr(self, argument): 106 | """Convert event representation to predicate grammar role like.""" 107 | return "{}:{}".format(self["verb_lemma"], self.find_role(argument)) 108 | 109 | def tuple(self, protagonist=None, mode=HEAD_MODE, last_verb_pos=None): 110 | """Convert event to tuple. 111 | 112 | If protagonist is null, return quadruple (verb, subj, obj, iobj), 113 | else return quintuple (verb, subj, obj, iobj, role). 114 | 115 | If last verb position is given, we will find all mentions before the last_verb_pos 116 | """ 117 | verb_position = self["verb_position"] 118 | t = (self["verb_lemma"], ) 119 | for role in ["subject", "object", "iobject"]: 120 | if isinstance(self[role], Entity): 121 | if mode == HEAD_MODE: 122 | t = t + (self[role].head, ) 123 | elif mode == MENTION_MODE: 124 | t = t + (self[role].find_mention_by_pos(verb_position).replace(" ", "_"), ) 125 | else: # GLOBAL_MODE 126 | if last_verb_pos is not None: 127 | t = t + ("##".join(self[role].find_mentions_by_pos(last_verb_pos)).replace(" ", "_"), ) 128 | else: 129 | t = t + ("##".join(self[role].find_mentions_by_pos(verb_position)).replace(" ", "_"), ) 130 | else: 131 | t = t + (self[role], ) 132 | if protagonist is not None: 133 | t = t + (self.find_role(protagonist), ) 134 | return t 135 | 136 | def get_words(self): 137 | """Get words.""" 138 | ret_val = [self["verb"]] 139 | for role in ["subject", "object", "iobject"]: 140 | if isinstance(self[role], Entity): 141 | ret_val.append(self[role]["head"]) 142 | elif self[role] != "None": 143 | ret_val.append(self[role]) 144 | return ret_val 145 | 146 | def get_entities(self): 147 | """Get entities.""" 148 | entities = [] 149 | for key in ["subject", "object", "iobject"]: 150 | if isinstance(self[key], Entity): 151 | entities.append(self[key]) 152 | return entities 153 | 154 | def replace_mention(self, __old: str, __new: str): 155 | """Replace mention in sentence.""" 156 | # After replacement, verb position will changed 157 | token_index = self["verb_position"][1] 158 | sent = self["sent"].split() 159 | before_verb = " ".join(sent[:token_index]) 160 | verb = sent[token_index] 161 | after_verb = " ".join(sent[token_index+1:]) 162 | before_verb = before_verb.replace(__old, __new).split() 163 | after_verb = after_verb.replace(__old, __new).split() 164 | new_sent = before_verb + [verb] + after_verb 165 | token_index = len(before_verb) 166 | self["sent"] = " ".join(new_sent) 167 | self["verb_position"] = [self["verb_position"][0], token_index] 168 | 169 | def tagged_sent(self, role, mask_list=None): 170 | """Tag verb role of the sentence.""" 171 | sent = self["sent"].lower().split() 172 | # pos = self["pos"] 173 | if role not in ["subj", "obj"]: 174 | role = "iobj" 175 | sent_id = self["verb_position"][0] 176 | verb_index = self["verb_position"][1] 177 | token_list = [] 178 | # pos_list = [] 179 | # Use "O" to represent control tokens 180 | for index, token in enumerate(sent): 181 | if index == verb_index: 182 | token_list.extend(["[{}]".format(role), sent[verb_index], "[{}]".format(role)]) 183 | # pos_list.extend(["O", pos[verb_index], "O"]) 184 | elif mask_list is not None and token in mask_list: 185 | token_list.append("[UNK]") 186 | # pos_list.append("O") 187 | else: 188 | token_list.append(token) 189 | # pos_list.append(pos[index]) 190 | # Extract sent 191 | # return token_list, pos_list 192 | return token_list 193 | 194 | def replace_argument(self, __old, __new): 195 | """Replace an argument with a new one.""" 196 | for key in ["subject", "object", "iobject"]: 197 | if self[key] == __old: 198 | self[key] = __new 199 | 200 | @classmethod 201 | def from_text(cls, text, entities, doc_text=None, doc_pos=None): 202 | """Construct Event object from text. 203 | 204 | :param text: text to be parsed. 205 | :type text: str 206 | :param entities: entity list from document 207 | :type entities: list[Entity] 208 | :param doc_text: document text 209 | :type doc_text: list[str] 210 | :param doc_pos: document pos 211 | :type doc_pos: list[list[str]] 212 | """ 213 | result = event_re.match(text) 214 | groups = result.groupdict() 215 | # Get verb infos 216 | verb = groups["verb"] 217 | verb_lemma = groups["verb_lemma"] 218 | verb_position = (int(groups["sentence_num"]), int(groups["word_index"])) 219 | type = groups["type"] 220 | # Get subject 221 | subject = find_entity_by_id(groups["subj"], entities) 222 | # Get object 223 | object = find_entity_by_id(groups["obj"], entities) 224 | # Get indirect object 225 | if groups["iobj"] == "None": 226 | iobject_prep = "None" 227 | iobject = "None" 228 | else: 229 | parts = groups["iobj"].split(",") 230 | iobject_prep = parts[0] 231 | iobject = find_entity_by_id(parts[1], entities) 232 | # Get sentence 233 | if doc_text is not None: 234 | sent = doc_text[verb_position[0]] 235 | else: 236 | sent = None 237 | if doc_pos is not None: 238 | pos = doc_pos[verb_position[0]] 239 | else: 240 | pos = None 241 | return cls(verb=verb, 242 | verb_lemma=verb_lemma, 243 | verb_position=verb_position, 244 | # type=type, 245 | subject=subject, 246 | object=object, 247 | iobject_prep=iobject_prep, 248 | iobject=iobject, 249 | sent=sent, 250 | pos=pos 251 | ) 252 | 253 | 254 | __all__ = ["Event", "transform_entity"] 255 | -------------------------------------------------------------------------------- /mcpredictor/utils/mention.py: -------------------------------------------------------------------------------- 1 | """Mention class for coreferenced entities.""" 2 | import re 3 | 4 | from mcpredictor.utils.common import unescape 5 | 6 | 7 | mention_re = re.compile(r"\((?P\d+),(?P\d+)\);" 8 | r"(?P.+);" 9 | r"(?P\d+);" 10 | r"(?P\d+);" 11 | r"(?P\d+);" 12 | r"(?P\d+);" 13 | r"\((?P\d+),(?P\d+)\)") 14 | 15 | 16 | class Mention(dict): 17 | """Mention class.""" 18 | 19 | def __init__(self, **kwargs): 20 | """Entity mention.""" 21 | super(Mention, self).__init__(**kwargs) 22 | 23 | def __getattr__(self, item): 24 | """Get attribute.""" 25 | return self[item] 26 | 27 | def get_head_word(self): 28 | """Return head word of this mention in lower case.""" 29 | # char_start = self["char_span"][0] 30 | # head_start, head_end = self["head_span"] 31 | # return self["text"][head_start-char_start:head_end-char_start].lower() 32 | return self["head"] 33 | 34 | @classmethod 35 | def from_text(cls, text): 36 | """Construct Mention object from text. 37 | 38 | Refer to G&C16 39 | 40 | :param text: text to be parsed. 41 | """ 42 | text = unescape(text, space_slashes=True) 43 | result = mention_re.match(text) 44 | # Match succeeded. 45 | groups = result.groupdict() 46 | # Though only char_span, text, head_span are used, 47 | # we save all information. 48 | char_span = (int(groups["char_start"]), int(groups["char_end"])) 49 | text = groups["text"] 50 | np_sentence_position = int(groups["np_sentence_position"]) 51 | np_doc_position = int(groups["np_doc_position"]) 52 | nps_in_sentence = int(groups["nps_in_sentence"]) 53 | sentence_num = int(groups["sentence_num"]) 54 | head_span = (int(groups["head_start"]), int(groups["head_end"])) 55 | # Get head 56 | char_start = char_span[0] 57 | head_start, head_end = head_span 58 | head = text[head_start-char_start:head_end-char_start].lower() 59 | # return cls(char_span=char_span, 60 | # text=text, 61 | # np_sentence_position=np_sentence_position, 62 | # np_doc_position=np_doc_position, 63 | # nps_in_sentence=nps_in_sentence, 64 | # sentence_num=sentence_num, 65 | # head_span=head_span) 66 | return cls(text=text, sentence_num=sentence_num, head=head) 67 | 68 | 69 | __all__ = ["Mention"] 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools>=41.2.0 2 | tqdm>=4.59.0 3 | transformers==3.5.1 4 | nltk==3.5 5 | numpy>=1.20.1 6 | torch>=1.7.1 7 | gensim==3.8.3 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name="MCPredictor", packages=find_packages()) 4 | --------------------------------------------------------------------------------