├── .gitignore ├── README.md ├── fig1.png ├── fig_delta_validation_loss.png └── src ├── alternative_masking_strategies.py └── semantic_types_information.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data Efficient Masked Language Modeling for Vision and Language 2 | Repository for the paper "Data Efficient Masked Language Modeling for Vision and Language", accepted to Findings of EMNLP 2021. 3 | https://arxiv.org/abs/2109.02040. 4 | 5 | ![](fig1.png) 6 | The baseline MLM masks a random token with 15\% probability (where ~50\% of the masked tokens are stop-words or punctuation). Our method masks words that require the image in order to be predicted (e.g., physical objects). 7 | Our experiments show that our pretrain masking strategy consistently improves over the baseline strategy in two evaluation setups. 8 | 9 | ## Intro 10 | 11 | The code for pretraining is based on the great LXMERT repository: https://github.com/airsplay/lxmert 12 | 13 | This repository includes: 14 | 15 | ### Data: 16 | 17 | The `data_directory` is available here: https://drive.google.com/drive/folders/1smFCIwNbIm4QhNHf4gn5RKRfcvGh4_Vl?usp=sharing 18 | 19 | - Pretrained models and fine-tuned models are available here: `data_directory/models`. 20 | 21 | - Sets of annotated Objects, Attributes, Relationships from GQA and Visual Genome `data_directory/all_objects_attributes_relationships.pickle`. 22 | 23 | - Aggregated data, where we extracted _Δ Validation loss_ (loss without the image, minus the loss with the image) for LXMERT validation set. This is used to define the necessity of the image for a masked word prediction during MLM. Available in `data_directory/aggregated_data_detla_val_loss.csv`. 24 | The structure of the csv is as follows: ![](fig_delta_validation_loss.png) 25 | - We can see the sentence, the image, and the masked token (motorcycle). 26 | - `ind_loss_with_img` is the loss with the image, `ind_loss_false_img` is the loss without the image, and `loss_gap` is the _Δ Validation loss_. 27 | - Similar for `conf_gap_of_label_with_img`, `conf_gap_of_label_false_img`, and 'conf_gap' - it is the confidence of the model (logits at location of the masked word). 28 | - `top_5_preds_token_with_img`, `top_5_preds_token_false_img` - predictions of the model, with and without the image. 29 | - `tagged_pos` - there is also the spacy pos tag for the sentence. 30 | - `label_in_top_5_with_img`, `label_in_top_5_false_img` - A boolean value for whether the label is among the top 5 predictions. In this example, without the image, the label is not among the top 5 predictions, but with the image, it is. 31 | 32 | 33 | ### Code: 34 | - Code for the alternative masking strategies, available in `src/alternative_masking_strategies.py` 35 | 36 | - Semantic classes information, including functions to detect _Objects_, _Attributes_, and _Relationships_, available in `src/semantic_types_information.py` 37 | 38 | ## Reference 39 | 40 | ```bibtex 41 | @article{bitton2021data, 42 | title={Data Efficient Masked Language Modeling for Vision and Language}, 43 | author={Bitton, Yonatan and Stanovsky, Gabriel and Elhadad, Michael and Schwartz, Roy}, 44 | journal={arXiv preprint arXiv:2109.02040}, 45 | year={2021} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonatanbitton/data_efficient_masked_language_modeling_for_vision_and_language/9b27024b8ca11959a3da70514c0043aee90ddfdb/fig1.png -------------------------------------------------------------------------------- /fig_delta_validation_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonatanbitton/data_efficient_masked_language_modeling_for_vision_and_language/9b27024b8ca11959a3da70514c0043aee90ddfdb/fig_delta_validation_loss.png -------------------------------------------------------------------------------- /src/alternative_masking_strategies.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import random 3 | 4 | from src.semantic_types_information import stopwords_and_punctuations, is_content_word, is_object 5 | 6 | 7 | def random_word_objects(tokens, tokenizer, original_sent, pos_cache): 8 | """ 9 | Masking strategy: Objects. Chooses 1 object to mask. 10 | :param tokens: same argument as original random_word function. 11 | :param tokenizer: the BertTokenizer used for the original implementation. 12 | :param original_sent: The sentence. 13 | :param max_seq_length: Maximum sequence length of the model (20 in the original implementation). 14 | :param pos_cache: A cache to hold the word, pos tag, lemma 15 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 16 | """ 17 | tagged_pos = pos_cache[original_sent] 18 | 19 | words_by_masking_strategies = [token for (token, pos, lemma) in tagged_pos if is_object(token, pos, lemma)] 20 | 21 | words_by_masking_strategies = filter_unreachable_words(tokenizer, tokens, words_by_masking_strategies) 22 | 23 | if len(words_by_masking_strategies) == 0: 24 | content_tokens = [token for token, pos, lemma in tagged_pos if 25 | is_content_word(token, pos, lemma)] 26 | words_by_masking_strategies = content_tokens 27 | words_by_masking_strategies = filter_unreachable_words(tokenizer, tokens, 28 | words_by_masking_strategies) 29 | if len(words_by_masking_strategies) > 0: 30 | chosen_word_to_mask = random.choice(words_by_masking_strategies) 31 | chosen_tokens_to_mask = tokenizer.tokenize(chosen_word_to_mask) 32 | else: 33 | chosen_tokens_to_mask = [] 34 | output_label = [] 35 | 36 | iterate_and_mask(chosen_tokens_to_mask, output_label, tokenizer, tokens) 37 | 38 | return tokens, output_label 39 | 40 | 41 | def random_word_content_words_high(tokens, tokenizer, original_sent, pos_cache): 42 | """ 43 | Masking strategy: Content words high. Mask 1 word, 80% content words, 20% stop-word or punctuation. 44 | :param tokens: same argument as original random_word function. 45 | :param tokenizer: the BertTokenizer used for the original implementation. 46 | :param original_sent: The sentence. 47 | :param pos_cache: A cache to hold the word, pos tag, lemma 48 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 49 | """ 50 | prob_dist = {'cw': 0.8, 'sw': 0.2} 51 | 52 | tagged_pos = pos_cache[original_sent] 53 | 54 | content_words = [word for word, pos, lemma in tagged_pos if 55 | is_content_word(word, pos, lemma)] 56 | content_words = filter_unreachable_words(tokenizer, tokens, content_words) 57 | stop_words = [word for word, pos, lemma in tagged_pos if 58 | not is_content_word(word, pos, lemma)] 59 | stop_words = filter_unreachable_words(tokenizer, tokens, stop_words) 60 | 61 | if random.random() < prob_dist['cw'] and len(content_words) > 0: 62 | words_by_masking_strategies = content_words 63 | elif len(stop_words) > 0: 64 | words_by_masking_strategies = stop_words 65 | else: 66 | words_by_masking_strategies = content_words 67 | 68 | chosen_word_to_mask = random.choice(words_by_masking_strategies) 69 | chosen_tokens_to_mask = tokenizer.tokenize(chosen_word_to_mask) 70 | 71 | output_label = [] 72 | 73 | iterate_and_mask(chosen_tokens_to_mask, output_label, tokenizer, tokens) 74 | 75 | return tokens, output_label 76 | 77 | 78 | def random_word_top_concrete(tokens, tokenizer, original_sent, pos_cache_concrete, max_seq_length): 79 | """ 80 | Masking strategy: Top Concrete 81 | :param tokens: same argument as original random_word function. 82 | :param tokenizer: the BertTokenizer used for the original implementation. 83 | :param original_sent: The sentence. 84 | :param max_seq_length: Maximum sequence length of the model (20 in the original implementation). 85 | :param pos_cache_concrete: A cache to hold the word, pos tag, lemma, and concreteness annotated value 86 | :return: (list of str, list of int), masked tokens and related labels for LM prediction 87 | """ 88 | if original_sent in pos_cache_concrete: 89 | tagged_pos = pos_cache_concrete[original_sent] 90 | tagged_pos = [{'word': x[0], 'pos': x[1], 'lemma': x[2], 'concreteness': x[3]} for x in tagged_pos] 91 | else: 92 | print(f'Sent not in cache: {original_sent}') 93 | raise Exception(f'Sent not in cache: {original_sent}') 94 | for x in tagged_pos: 95 | x_tokens = tokenizer.tokenize(x['word']) 96 | x['tokens'] = x_tokens 97 | 98 | changed_tokens = False 99 | changed_tokens, tokens = change_tokens_if_necessary(changed_tokens, max_seq_length, tokens, tagged_pos) 100 | 101 | all_concrete_values_and_words = [(x['concreteness'], x['word']) for x in tagged_pos if x['word'] not in stopwords_and_punctuations] 102 | top_3_concrete = heapq.nlargest(3, all_concrete_values_and_words) 103 | 104 | if len(top_3_concrete) == 0: 105 | chosen_word_to_mask = random.choice([x['word'] for x in tagged_pos]) 106 | else: 107 | if len(top_3_concrete) == 1: 108 | chosen_conc_word_to_mask = top_3_concrete[0] 109 | elif len(top_3_concrete) == 2: 110 | chosen_conc_word_to_mask = random.choices(top_3_concrete, weights=[0.75, 0.25], k=1)[0] 111 | else: 112 | chosen_conc_word_to_mask = random.choices(top_3_concrete, weights=[0.55, 0.30, 0.15], k=1)[0] 113 | chosen_word_to_mask_concreteness, chosen_word_to_mask = chosen_conc_word_to_mask[0], chosen_conc_word_to_mask[1] 114 | 115 | chosen_tokens_to_mask = tokenizer.tokenize(chosen_word_to_mask) 116 | 117 | output_label = [] 118 | 119 | iterate_and_mask(chosen_tokens_to_mask, output_label, tokenizer, tokens) 120 | 121 | return tokens, output_label 122 | 123 | 124 | 125 | def iterate_and_mask(chosen_tokens_to_mask, output_label, tokenizer, tokens): 126 | for i, token in enumerate(tokens): 127 | if token in chosen_tokens_to_mask: 128 | random_num = random.random() 129 | 130 | # 80% randomly change token to mask token 131 | if random_num < 0.8: 132 | tokens[i] = "[MASK]" 133 | 134 | # 10% randomly change token to random token 135 | elif random_num < 0.9: 136 | tokens[i] = random.choice(list(tokenizer.vocab.items()))[0] 137 | 138 | # -> rest 10% randomly keep current token 139 | 140 | # append current token to output (we will predict these later) 141 | try: 142 | output_label.append(tokenizer.vocab[token]) 143 | except KeyError: 144 | # For unknown words (should not occur with BPE vocab) 145 | output_label.append(tokenizer.vocab["[UNK]"]) 146 | else: 147 | # no masking token (will be ignored by loss function later) 148 | output_label.append(-1) 149 | 150 | 151 | def change_tokens_if_necessary(changed_tokens, max_seq_length, tokens, tagged_pos): 152 | inferred_pos = [] 153 | for t in tagged_pos: 154 | tokenized_tokens = t['tokens'] 155 | tokenized_tokens_with_pos = [{'token': x, 'pos': t['pos'], 'lemma': t['lemma'], 'word': t['word']} for x in tokenized_tokens] 156 | inferred_pos += tokenized_tokens_with_pos 157 | inferred_pos = inferred_pos[: (max_seq_length - 2)] 158 | if [x['token'] for x in inferred_pos] != tokens: 159 | tokens = [x['token'] for x in inferred_pos] 160 | changed_tokens = True 161 | return changed_tokens, tokens 162 | 163 | 164 | def filter_unreachable_words(tokenizer, tokens, words_by_masking_strategies): 165 | tokenized_words_by_masking_strategies = [tokenizer.tokenize(w) for w in words_by_masking_strategies] 166 | 167 | words_by_masking_strategies_filtered_unreachable = [] 168 | tokenized_words_by_masking_strategies_filtered_unreachable = [] 169 | assert len(words_by_masking_strategies) == len(tokenized_words_by_masking_strategies) 170 | for w, w_tokenized_lst in zip(words_by_masking_strategies, tokenized_words_by_masking_strategies): 171 | if all(x in tokens for x in w_tokenized_lst): 172 | words_by_masking_strategies_filtered_unreachable.append(w) 173 | tokenized_words_by_masking_strategies_filtered_unreachable += w_tokenized_lst 174 | 175 | if len(words_by_masking_strategies) != len(words_by_masking_strategies_filtered_unreachable): 176 | assert all(xi in tokens for xi in tokenized_words_by_masking_strategies_filtered_unreachable) 177 | 178 | return words_by_masking_strategies_filtered_unreachable -------------------------------------------------------------------------------- /src/semantic_types_information.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | stopwords = {'per', '’ll', 'could', 'fifteen', 'been', "isn't", 'whoever', 'any', 'whole', 'front', "won't", 'upon', 'there', 's', 'am', 'via', 'the', 'as', "haven't", 'on', 'km', 'further', 'their', 'quite', 'have', 'twenty', 'during', 'full', 'it', 'thin', 'so', 'what', 'an', 't', 'less', 'if', 'sixty', 'everyone', 'us', 'were', 'side', 'she', 'cannot', 'thereby', '‘ve', 'amount', 'n’t', 'be', 'nine', 'isn', 'wouldn', 'by', 'along', "'ll", 'themselves', 'forty', 'everywhere', "'d", 'thru', 'sometimes', 'hasnt', 'seeming', 'own', 'that', "'ve", 'least', 'with', 'inc', 'really', 'afterwards', 'due', 'for', 'sometime', 'last', 'find', 'therein', 'all', 'thick', 'detail', 'few', 'hundred', 'some', 'even', 'off', '’m', 'ain', '’re', 'hence', 'etc', 'into', 'rather', 'where', 'm', 'its', 'onto', '’s', 'get', 'other', 'moreover', 'noone', 'being', 'must', 'bill', "wasn't", 'system', 'neither', "you'll", 'third', 'whereby', 'nobody', 'among', 'throughout', 'except', 'beforehand', "didn't", 'was', 'without', 'whose', 'hasn', '‘d', 'or', 'theirs', 'various', 'name', 'twelve', 'myself', 'former', 'though', 'we', 'ours', 'many', 'sincere', 'regarding', 'had', 'before', 'mustn', 'either', 'doing', 'why', 'fill', 'eight', 'won', 'anything', 'hereupon', 'this', 'amoungst', '‘s', 'of', 'yourselves', 'beside', 'within', 'ourselves', '‘re', 'about', 'elsewhere', 'latter', 'through', 'll', 'i', 'wasn', 'anywhere', 'weren', 'just', 'itself', "you're", 'wherein', 'four', 'keep', 'whether', 'nothing', 'found', 'back', 'needn', "aren't", 'has', 'one', 'wherever', 'serious', 'everything', 'hadn', 'first', 'anyway', 'co', 'still', 'five', 'becomes', "don't", 'formerly', 'ever', 'part', 'nowhere', 'made', 'himself', "couldn't", 'none', 'others', 'now', 'doesn', 'at', 'another', 'does', 'kg', 'see', 'often', 'them', 'shan', 'fifty', 'ltd', 'namely', 'they', 'somewhere', 'haven', 'take', 'latterly', 'well', 'whatever', 'nor', 'whereafter', 'might', 'only', 'de', 'our', 'hers', "mustn't", 'aren', 'you', 'his', "wouldn't", 'please', 'empty', 'but', 'mightn', 'then', 'should', 'and', 'each', 'such', 'a', 'yet', 'y', 'enough', 'someone', 'would', 'since', 'however', 'make', 'alone', 'anyone', 'amongst', 'these', 'whereupon', 'fire', "hasn't", 'shouldn', 'didn', 'do', 'me', 'becoming', 'after', 'several', 'seem', 'her', 'three', 'out', 'ten', 'whence', 'eg', 'couldn', 'un', 'did', "she's", 'whither', 'toward', 'once', "should've", 'call', "weren't", 'again', 'more', 'show', 'seems', "needn't", 'thereupon', 'used', 'most', 'hereby', 'put', 'ie', 've', 'my', 'your', 'thence', 'already', 'always', 'having', 'much', 'move', 'eleven', "'re", 'here', 'yours', 'con', 'done', 'up', 'over', 'yourself', "it's", 'o', 'six', 'can', 'how', "hadn't", 'anyhow', 'below', 'also', 'say', 'together', 'down', 'using', 'while', 'almost', 'cry', "you've", '’ve', 'two', 'towards', 'meanwhile', 'perhaps', 'when', 'ma', "shouldn't", 'both', 'hereafter', 'he', 'describe', 'ca', 'which', 'every', 'between', 'give', 'go', 'very', '’d', 'nevertheless', 'is', 'n‘t', 'therefore', '‘ll', 'unless', 'next', 'who', 'became', 'mill', 'him', 'don', 'same', "'s", 'seemed', 'mostly', 'will', 're', "you'd", 'no', 'in', 'too', "mightn't", 'besides', 'are', 'because', 'couldnt', 'd', 'against', "doesn't", 'cant', 'whenever', 'somehow', 'thereafter', 'although', 'beyond', 'from', 'whereas', 'thus', 'than', "shan't", 'to', 'top', 'until', 'those', 'whom', 'bottom', 'else', 'herein', 'something', '‘m', 'may', 'not', "that'll", "'m", 'indeed', 'never', 'herself', 'interest', "n't", 'become', 'mine', 'otherwise'} 4 | punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] 5 | NOT_GROUNDED_OBJECTS_WORDS = {'photo', 'image', 'picture', 'pic', 'side', 'part', 'background'} 6 | stopwords_and_punctuations = stopwords.union(punctuations) 7 | stopwords_and_punctuations = stopwords_and_punctuations.union(NOT_GROUNDED_OBJECTS_WORDS) 8 | 9 | from collections import Counter 10 | print(f"stopwords: {len(stopwords)}, punctuations: {len(punctuations)}, stopwords_and_punctuations: {len(stopwords_and_punctuations)}") 11 | MY_RELATIONSHIP_MASKS = ['VERB', 'ADP'] 12 | TOP_COMMON_RELATIONSHIPS = ['on', 'of', 'in', 'and', 'with', 'to', 'or', 'at', 'by', 'as'] 13 | MY_OBJECT_MASKS = ['NOUN'] 14 | MY_ATTRIBUTES_MASKS = ['ADJ'] 15 | my_definite_attributes_list = ['white', 'black', 'green', 'blue', 'brown', 'gray', 'large', 'small', 'wood', 'yellow', 'tall', 'metal', 'long', 'dark', 'silver', 'pink', 'round', 'short', 'plastic', 'tan', 'purple', 'colorful', 'concrete', 'blond', 'young', 'empty', 'happy', 'bright', 'wet', 'gold', 'dirty', 'shiny', 'square', 'thin', 'little', 'leafy', 'thick', 'beige', 'calm', 'rectangular', 'dry', 'leather', 'snowy', 'pointy', 'fluffy', 'clean', 'plaid', 'electric', 'grassy', 'lit', 'blurry', 'leafless', 'flat', 'decorative', 'beautiful', 'sandy', 'steel', 'overcast', 'wide', 'stainless', 'ceramic', 'rusty', 'furry', 'ripe', 'hazy', 'high', 'cloudless', 'fresh', 'tiny', 'huge', 'skinny', 'rocky', 'curly', 'maroon', 'porcelain', 'lush', 'floral', 'reflective', 'iron', 'bald', 'rubber', 'puffy', 'broken', 'chrome', 'smooth', 'evergreen', 'low', 'narrow', 'denim', 'hardwood', 'wicker', 'straight', 'triangular', 'sunny', 'bushy', 'hairy', 'wavy', 'khaki', 'shirtless', 'marble', 'ornate', 'overhead', 'muddy', 'fuzzy', 'burnt', 'wild', 'rough', 'sharp', 'pale', 'floppy', 'barefoot', 'plain', 'delicious', 'healthy', 'soft', 'choppy', 'neon', 'aluminum', 'knit', 'wispy', 'vertical', 'patchy', 'granite', 'messy', 'pretty', 'deep', 'sleeveless', 'fallen', 'modern', 'murky', 'antique', 'heavy', 'fancy', 'transparent', 'teal', 'vintage', 'horizontal', 'gravel', 'octagonal', 'sparse', 'cotton', 'shallow', 'fat', 'overgrown', 'foggy', 'giant', 'barren', 'shaggy', 'dusty', 'wireless', 'plush', 'mesh', 'warm', 'woven', 'raw', 'clay', 'brass', 'foamy', 'brunette', 'copper', 'athletic', 'spread', 'crispy', 'unripe', 'styrofoam', 'sheer', 'palm', 'grey', 'golden', 'wooden', 'blonde', 'bloody', "striped", "arched", "checkered", "patterned", "piled", "wrinkled", "stuffed", "decorated", "rounded", "rolled", "grilled"] 16 | 17 | all_objects_attributes_relationships = pickle.load(open('all_objects_attributes_relationships.pickle', 'rb')) 18 | 19 | objects_list = [x for x in all_objects_attributes_relationships['objects']['joint'] if len(x.split(" ")) == 1] 20 | attributes_list = [x for x in all_objects_attributes_relationships['attributes']['joint'] if len(x.split(" ")) == 1] 21 | relationships_list = [x for x in all_objects_attributes_relationships['relationships']['joint'] if len(x.split(" ")) == 1] 22 | print(f"Loaded objects_list, # {len(objects_list)}") 23 | print(f"Loaded attributes_list, # {len(attributes_list)}") 24 | print(f"Loaded relationships_list, # {len(relationships_list)}") 25 | 26 | def is_content_word(token, pos=None, lemma=None): 27 | return token.lower() not in stopwords_and_punctuations and len(token) > 1 28 | 29 | def is_object(word, pos, lemma=None): 30 | return pos in MY_OBJECT_MASKS and word not in stopwords_and_punctuations and \ 31 | len(word) > 1 and (word in objects_list or lemma in objects_list) 32 | 33 | def is_relationship(word, pos, lemma): 34 | return (word in relationships_list or lemma in relationships_list) and pos in MY_RELATIONSHIP_MASKS 35 | 36 | def is_attribute(word, pos, lemma): 37 | if word in my_definite_attributes_list: 38 | return True 39 | else: 40 | return (word in attributes_list or lemma in attributes_list) and pos in MY_ATTRIBUTES_MASKS 41 | 42 | 43 | token_is_relevant_for_masking_strategy = {'sw': is_content_word, 'obj': is_object, 'att': is_attribute, 'rel': is_relationship, } 44 | --------------------------------------------------------------------------------