├── data ├── .gitkeep ├── Olympic │ └── raw.pickle ├── PsychExp │ └── raw.pickle ├── filtering │ └── wanted_emojis.csv └── emoji_codes.json ├── examples ├── .gitkeep ├── __init__.py ├── example_helper.py ├── create_twitter_vocab.py ├── tokenize_dataset.py ├── vocab_extension.py ├── finetune_youtube_last.py ├── encode_texts.py ├── README.md ├── finetune_insults_chain-thaw.py ├── dataset_split.py ├── finetune_semeval_class-avg_f1.py ├── text_emojize.py └── score_texts_emojis.py ├── model └── .gitkeep ├── torchmoji ├── __init__.py ├── .gitkeep ├── global_variables.py ├── filter_input.py ├── attlayer.py ├── tokenizer.py ├── filter_utils.py ├── create_vocab.py ├── sentence_tokenizer.py ├── word_generator.py ├── lstm.py ├── class_avg_finetuning.py ├── model_def.py └── finetuning.py ├── scripts ├── results │ └── .gitkeep ├── analyze_results.py ├── analyze_all_results.py ├── download_weights.py ├── calculate_coverages.py ├── convert_all_datasets.py └── finetune_dataset.py ├── emoji_overview.png ├── tests ├── test_helper.py ├── test_word_generator.py ├── test_sentence_tokenizer.py ├── test_tokenizer.py └── test_finetuning.py ├── setup.py ├── .travis.yml ├── LICENSE ├── .gitignore └── README.md /data/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /torchmoji/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/results/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /torchmoji/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /emoji_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/torchMoji/HEAD/emoji_overview.png -------------------------------------------------------------------------------- /data/Olympic/raw.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/torchMoji/HEAD/data/Olympic/raw.pickle -------------------------------------------------------------------------------- /data/PsychExp/raw.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/torchMoji/HEAD/data/PsychExp/raw.pickle -------------------------------------------------------------------------------- /tests/test_helper.py: -------------------------------------------------------------------------------- 1 | """ Module import helper. 2 | Modifies PATH in order to allow us to import the torchmoji directory. 3 | """ 4 | import sys 5 | from os.path import abspath, dirname 6 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 7 | -------------------------------------------------------------------------------- /examples/example_helper.py: -------------------------------------------------------------------------------- 1 | """ Module import helper. 2 | Modifies PATH in order to allow us to import the torchmoji directory. 3 | """ 4 | import sys 5 | from os.path import abspath, dirname 6 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='torchmoji', 5 | version='1.0', 6 | packages=['torchmoji'], 7 | description='torchMoji', 8 | include_package_data=True, 9 | install_requires=[ 10 | 'emoji==0.4.5', 11 | 'numpy==1.13.1', 12 | 'scipy==0.19.1', 13 | 'scikit-learn==0.19.0', 14 | 'text-unidecode==1.0', 15 | ], 16 | ) 17 | -------------------------------------------------------------------------------- /examples/create_twitter_vocab.py: -------------------------------------------------------------------------------- 1 | """ Creates a vocabulary from a tsv file. 2 | """ 3 | 4 | import codecs 5 | import example_helper 6 | from torchmoji.create_vocab import VocabBuilder 7 | from torchmoji.word_generator import TweetWordGenerator 8 | 9 | with codecs.open('../../twitterdata/tweets.2016-09-01', 'rU', 'utf-8') as stream: 10 | wg = TweetWordGenerator(stream) 11 | vb = VocabBuilder(wg) 12 | vb.count_all_words() 13 | vb.save_vocab() 14 | -------------------------------------------------------------------------------- /examples/tokenize_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Take a given list of sentences and turn it into a numpy array, where each 3 | number corresponds to a word. Padding is used (number 0) to ensure fixed length 4 | of sentences. 5 | """ 6 | 7 | from __future__ import print_function, unicode_literals 8 | import example_helper 9 | import json 10 | from torchmoji.sentence_tokenizer import SentenceTokenizer 11 | 12 | with open('../model/vocabulary.json', 'r') as f: 13 | vocabulary = json.load(f) 14 | 15 | st = SentenceTokenizer(vocabulary, 30) 16 | test_sentences = [ 17 | '\u2014 -- \u203c !!\U0001F602', 18 | 'Hello world!', 19 | 'This is a sample tweet #example', 20 | ] 21 | 22 | tokens, infos, stats = st.tokenize_sentences(test_sentences) 23 | 24 | print(tokens) 25 | print(infos) 26 | print(stats) 27 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | group: travis_latest 2 | language: python 3 | cache: pip 4 | python: 5 | - 2.7 6 | - 3.6 7 | #- nightly 8 | #- pypy 9 | #- pypy3 10 | matrix: 11 | allow_failures: 12 | - python: nightly 13 | - python: pypy 14 | - python: pypy3 15 | install: 16 | #- pip install -r requirements.txt 17 | - pip install flake8 # pytest # add another testing frameworks later 18 | before_script: 19 | # stop the build if there are Python syntax errors or undefined names 20 | - flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics 21 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 22 | - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 23 | script: 24 | - true # pytest --capture=sys # add other tests here 25 | notifications: 26 | on_success: change 27 | on_failure: change # `always` will be the setting once code changes slow down 28 | -------------------------------------------------------------------------------- /data/filtering/wanted_emojis.csv: -------------------------------------------------------------------------------- 1 | \U0001f602 2 | \U0001f612 3 | \U0001f629 4 | \U0001f62d 5 | \U0001f60d 6 | \U0001f614 7 | \U0001f44c 8 | \U0001f60a 9 | \u2764 10 | \U0001f60f 11 | \U0001f601 12 | \U0001f3b6 13 | \U0001f633 14 | \U0001f4af 15 | \U0001f634 16 | \U0001f60c 17 | \u263a 18 | \U0001f64c 19 | \U0001f495 20 | \U0001f611 21 | \U0001f605 22 | \U0001f64f 23 | \U0001f615 24 | \U0001f618 25 | \u2665 26 | \U0001f610 27 | \U0001f481 28 | \U0001f61e 29 | \U0001f648 30 | \U0001f62b 31 | \u270c 32 | \U0001f60e 33 | \U0001f621 34 | \U0001f44d 35 | \U0001f622 36 | \U0001f62a 37 | \U0001f60b 38 | \U0001f624 39 | \u270b 40 | \U0001f637 41 | \U0001f44f 42 | \U0001f440 43 | \U0001f52b 44 | \U0001f623 45 | \U0001f608 46 | \U0001f613 47 | \U0001f494 48 | \u2661 49 | \U0001f3a7 50 | \U0001f64a 51 | \U0001f609 52 | \U0001f480 53 | \U0001f616 54 | \U0001f604 55 | \U0001f61c 56 | \U0001f620 57 | \U0001f645 58 | \U0001f4aa 59 | \U0001f44a 60 | \U0001f49c 61 | \U0001f496 62 | \U0001f499 63 | \U0001f62c 64 | \u2728 65 | -------------------------------------------------------------------------------- /torchmoji/global_variables.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Global variables. 3 | """ 4 | import tempfile 5 | from os.path import abspath, dirname 6 | 7 | # The ordering of these special tokens matter 8 | # blank tokens can be used for new purposes 9 | # Tokenizer should be updated if special token prefix is changed 10 | SPECIAL_PREFIX = 'CUSTOM_' 11 | SPECIAL_TOKENS = ['CUSTOM_MASK', 12 | 'CUSTOM_UNKNOWN', 13 | 'CUSTOM_AT', 14 | 'CUSTOM_URL', 15 | 'CUSTOM_NUMBER', 16 | 'CUSTOM_BREAK'] 17 | SPECIAL_TOKENS.extend(['{}BLANK_{}'.format(SPECIAL_PREFIX, i) for i in range(6, 10)]) 18 | 19 | ROOT_PATH = dirname(dirname(abspath(__file__))) 20 | VOCAB_PATH = '{}/model/vocabulary.json'.format(ROOT_PATH) 21 | PRETRAINED_PATH = '{}/model/pytorch_model.bin'.format(ROOT_PATH) 22 | 23 | WEIGHTS_DIR = tempfile.mkdtemp() 24 | 25 | NB_TOKENS = 50000 26 | NB_EMOJI_CLASSES = 64 27 | FINETUNING_METHODS = ['last', 'full', 'new', 'chain-thaw'] 28 | FINETUNING_METRICS = ['acc', 'weighted'] 29 | -------------------------------------------------------------------------------- /examples/vocab_extension.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extend the given vocabulary using dataset-specific words. 3 | 4 | 1. First create a vocabulary for the specific dataset. 5 | 2. Find all words not in our vocabulary, but in the dataset vocabulary. 6 | 3. Take top X (default=1000) of these words and add them to the vocabulary. 7 | 4. Save this combined vocabulary and embedding matrix, which can now be used. 8 | """ 9 | 10 | from __future__ import print_function, unicode_literals 11 | import example_helper 12 | import json 13 | from torchmoji.create_vocab import extend_vocab, VocabBuilder 14 | from torchmoji.word_generator import WordGenerator 15 | 16 | new_words = ['#zzzzaaazzz', 'newword', 'newword'] 17 | word_gen = WordGenerator(new_words) 18 | vb = VocabBuilder(word_gen) 19 | vb.count_all_words() 20 | 21 | with open('../model/vocabulary.json') as f: 22 | vocab = json.load(f) 23 | 24 | print(len(vocab)) 25 | print(vb.word_counts) 26 | extend_vocab(vocab, vb, max_tokens=1) 27 | 28 | # 'newword' should be added because it's more frequent in the given vocab 29 | print(vocab['newword']) 30 | print(len(vocab)) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Bjarke Felbo, Han Thi Nguyen, Thomas Wolf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchmoji/filter_input.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | import codecs 4 | import csv 5 | import numpy as np 6 | from emoji import UNICODE_EMOJI 7 | 8 | def read_english(path="english_words.txt", add_emojis=True): 9 | # read english words for filtering (includes emojis as part of set) 10 | english = set() 11 | with codecs.open(path, "r", "utf-8") as f: 12 | for line in f: 13 | line = line.strip().lower().replace('\n', '') 14 | if len(line): 15 | english.add(line) 16 | if add_emojis: 17 | for e in UNICODE_EMOJI: 18 | english.add(e) 19 | return english 20 | 21 | def read_wanted_emojis(path="wanted_emojis.csv"): 22 | emojis = [] 23 | with open(path, 'rb') as f: 24 | reader = csv.reader(f) 25 | for line in reader: 26 | line = line[0].strip().replace('\n', '') 27 | line = line.decode('unicode-escape') 28 | emojis.append(line) 29 | return emojis 30 | 31 | def read_non_english_users(path="unwanted_users.npz"): 32 | try: 33 | neu_set = set(np.load(path)['userids']) 34 | except IOError: 35 | neu_set = set() 36 | return neu_set 37 | -------------------------------------------------------------------------------- /examples/finetune_youtube_last.py: -------------------------------------------------------------------------------- 1 | """Finetuning example. 2 | 3 | Trains the torchMoji model on the SS-Youtube dataset, using the 'last' 4 | finetuning method and the accuracy metric. 5 | 6 | The 'last' method does the following: 7 | 0) Load all weights except for the softmax layer. Do not add tokens to the 8 | vocabulary and do not extend the embedding layer. 9 | 1) Freeze all layers except for the softmax layer. 10 | 2) Train. 11 | """ 12 | 13 | from __future__ import print_function 14 | import example_helper 15 | import json 16 | from torchmoji.model_def import torchmoji_transfer 17 | from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH, ROOT_PATH 18 | from torchmoji.finetuning import ( 19 | load_benchmark, 20 | finetune) 21 | 22 | DATASET_PATH = '{}/data/SS-Youtube/raw.pickle'.format(ROOT_PATH) 23 | nb_classes = 2 24 | 25 | with open(VOCAB_PATH, 'r') as f: 26 | vocab = json.load(f) 27 | 28 | # Load dataset. 29 | data = load_benchmark(DATASET_PATH, vocab) 30 | 31 | # Set up model and finetune 32 | model = torchmoji_transfer(nb_classes, PRETRAINED_PATH) 33 | print(model) 34 | model, acc = finetune(model, data['texts'], data['labels'], nb_classes, data['batch_size'], method='last') 35 | print('Acc: {}'.format(acc)) 36 | -------------------------------------------------------------------------------- /scripts/analyze_results.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import glob 5 | import numpy as np 6 | 7 | DATASET = 'SS-Twitter' # 'SE1604' excluded due to Twitter's ToS 8 | METHOD = 'new' 9 | 10 | # Optional usage: analyze_results.py 11 | if len(sys.argv) == 3: 12 | DATASET = sys.argv[1] 13 | METHOD = sys.argv[2] 14 | 15 | RESULTS_DIR = 'results/' 16 | RESULT_PATHS = glob.glob('{}/{}_{}_*_results.txt'.format(RESULTS_DIR, DATASET, METHOD)) 17 | 18 | if not RESULT_PATHS: 19 | print('Could not find results for \'{}\' using \'{}\' in directory \'{}\'.'.format(DATASET, METHOD, RESULTS_DIR)) 20 | else: 21 | scores = [] 22 | for path in RESULT_PATHS: 23 | with open(path) as f: 24 | score = f.readline().split(':')[1] 25 | scores.append(float(score)) 26 | 27 | average = np.mean(scores) 28 | maximum = max(scores) 29 | minimum = min(scores) 30 | std = np.std(scores) 31 | 32 | print('Dataset: {}'.format(DATASET)) 33 | print('Method: {}'.format(METHOD)) 34 | print('Number of results: {}'.format(len(scores))) 35 | print('--------------------------') 36 | print('Average: {}'.format(average)) 37 | print('Maximum: {}'.format(maximum)) 38 | print('Minimum: {}'.format(minimum)) 39 | print('Standard deviaton: {}'.format(std)) 40 | -------------------------------------------------------------------------------- /scripts/analyze_all_results.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | # allow us to import the codebase directory 4 | import sys 5 | import glob 6 | import numpy as np 7 | from os.path import dirname, abspath 8 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 9 | 10 | DATASETS = ['SE0714', 'Olympic', 'PsychExp', 'SS-Twitter', 'SS-Youtube', 11 | 'SCv1', 'SV2-GEN'] # 'SE1604' excluded due to Twitter's ToS 12 | 13 | def get_results(dset): 14 | METHOD = 'last' 15 | RESULTS_DIR = 'results/' 16 | RESULT_PATHS = glob.glob('{}/{}_{}_*_results.txt'.format(RESULTS_DIR, dset, METHOD)) 17 | assert len(RESULT_PATHS) 18 | 19 | scores = [] 20 | for path in RESULT_PATHS: 21 | with open(path) as f: 22 | score = f.readline().split(':')[1] 23 | scores.append(float(score)) 24 | 25 | average = np.mean(scores) 26 | maximum = max(scores) 27 | minimum = min(scores) 28 | std = np.std(scores) 29 | 30 | print('Dataset: {}'.format(dset)) 31 | print('Method: {}'.format(METHOD)) 32 | print('Number of results: {}'.format(len(scores))) 33 | print('--------------------------') 34 | print('Average: {}'.format(average)) 35 | print('Maximum: {}'.format(maximum)) 36 | print('Minimum: {}'.format(minimum)) 37 | print('Standard deviaton: {}'.format(std)) 38 | 39 | for dset in DATASETS: 40 | get_results(dset) 41 | -------------------------------------------------------------------------------- /examples/encode_texts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ Use torchMoji to encode texts into emotional feature vectors. 4 | """ 5 | from __future__ import print_function, division, unicode_literals 6 | import json 7 | 8 | from torchmoji.sentence_tokenizer import SentenceTokenizer 9 | from torchmoji.model_def import torchmoji_feature_encoding 10 | from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH 11 | 12 | TEST_SENTENCES = ['I love mom\'s cooking', 13 | 'I love how you never reply back..', 14 | 'I love cruising with my homies', 15 | 'I love messing with yo mind!!', 16 | 'I love you and now you\'re just gone..', 17 | 'This is shit', 18 | 'This is the shit'] 19 | 20 | maxlen = 30 21 | batch_size = 32 22 | 23 | print('Tokenizing using dictionary from {}'.format(VOCAB_PATH)) 24 | with open(VOCAB_PATH, 'r') as f: 25 | vocabulary = json.load(f) 26 | st = SentenceTokenizer(vocabulary, maxlen) 27 | tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES) 28 | 29 | print('Loading model from {}.'.format(PRETRAINED_PATH)) 30 | model = torchmoji_feature_encoding(PRETRAINED_PATH) 31 | print(model) 32 | 33 | print('Encoding texts..') 34 | encoding = model(tokenized) 35 | 36 | print('First 5 dimensions for sentence: {}'.format(TEST_SENTENCES[0])) 37 | print(encoding[0,:5]) 38 | 39 | # Now you could visualize the encodings to see differences, 40 | # run a logistic regression classifier on top, 41 | # or basically anything you'd like to do. -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # torchMoji examples 2 | 3 | ## Initialization 4 | [create_twitter_vocab.py](create_twitter_vocab.py) 5 | Create a new vocabulary from a tsv file. 6 | 7 | [tokenize_dataset.py](tokenize_dataset.py) 8 | Tokenize a given dataset using the prebuilt vocabulary. 9 | 10 | [vocab_extension.py](vocab_extension.py) 11 | Extend the given vocabulary using dataset-specific words. 12 | 13 | [dataset_split.py](dataset_split.py) 14 | Split a given dataset into training, validation and testing. 15 | 16 | ## Use pretrained model/architecture 17 | [score_texts_emojis.py](score_texts_emojis.py) 18 | Use torchMoji to score texts for emoji distribution. 19 | 20 | [text_emojize.py](text_emojize.py) 21 | Use torchMoji to output emoji visualization from a single text input (mapped from `emoji_overview.png`) 22 | 23 | ```sh 24 | python examples/text_emojize.py --text "I love mom's cooking\!" 25 | # => I love mom's cooking! 😋 😍 💓 💛 ❤ 26 | ``` 27 | 28 | [encode_texts.py](encode_texts.py) 29 | Use torchMoji to encode the text into 2304-dimensional feature vectors for further modeling/analysis. 30 | 31 | ## Transfer learning 32 | [finetune_youtube_last.py](finetune_youtube_last.py) 33 | Finetune the model on the SS-Youtube dataset using the 'last' method. 34 | 35 | [finetune_insults_chain-thaw.py](finetune_insults_chain-thaw.py) 36 | Finetune the model on the Kaggle insults dataset (from blog post) using the 'chain-thaw' method. 37 | 38 | [finetune_semeval_class-avg_f1.py](finetune_semeval_class-avg_f1.py) 39 | Finetune the model on the SemeEval emotion dataset using the 'full' method and evaluate using the class average F1 metric. 40 | -------------------------------------------------------------------------------- /data/emoji_codes.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": ":joy:", 3 | "1": ":unamused:", 4 | "2": ":weary:", 5 | "3": ":sob:", 6 | "4": ":heart_eyes:", 7 | "5": ":pensive:", 8 | "6": ":ok_hand:", 9 | "7": ":blush:", 10 | "8": ":heart:", 11 | "9": ":smirk:", 12 | "10":":grin:", 13 | "11":":notes:", 14 | "12":":flushed:", 15 | "13":":100:", 16 | "14":":sleeping:", 17 | "15":":relieved:", 18 | "16":":relaxed:", 19 | "17":":raised_hands:", 20 | "18":":two_hearts:", 21 | "19":":expressionless:", 22 | "20":":sweat_smile:", 23 | "21":":pray:", 24 | "22":":confused:", 25 | "23":":kissing_heart:", 26 | "24":":hearts:", 27 | "25":":neutral_face:", 28 | "26":":information_desk_person:", 29 | "27":":disappointed:", 30 | "28":":see_no_evil:", 31 | "29":":tired_face:", 32 | "30":":v:", 33 | "31":":sunglasses:", 34 | "32":":rage:", 35 | "33":":thumbsup:", 36 | "34":":cry:", 37 | "35":":sleepy:", 38 | "36":":stuck_out_tongue_winking_eye:", 39 | "37":":triumph:", 40 | "38":":raised_hand:", 41 | "39":":mask:", 42 | "40":":clap:", 43 | "41":":eyes:", 44 | "42":":gun:", 45 | "43":":persevere:", 46 | "44":":imp:", 47 | "45":":sweat:", 48 | "46":":broken_heart:", 49 | "47":":blue_heart:", 50 | "48":":headphones:", 51 | "49":":speak_no_evil:", 52 | "50":":wink:", 53 | "51":":skull:", 54 | "52":":confounded:", 55 | "53":":smile:", 56 | "54":":stuck_out_tongue_winking_eye:", 57 | "55":":angry:", 58 | "56":":no_good:", 59 | "57":":muscle:", 60 | "58":":punch:", 61 | "59":":purple_heart:", 62 | "60":":sparkling_heart:", 63 | "61":":blue_heart:", 64 | "62":":grimacing:", 65 | "63":":sparkles:" 66 | } 67 | 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # Local data 92 | /data/local 93 | 94 | # Vim swapfiles 95 | *.swp 96 | *.swo 97 | 98 | # nosetests 99 | .noseids 100 | 101 | # pyTorch model 102 | pytorch_model.bin 103 | 104 | # VSCODE 105 | .vscode/* 106 | 107 | # data 108 | *.csv 109 | -------------------------------------------------------------------------------- /examples/finetune_insults_chain-thaw.py: -------------------------------------------------------------------------------- 1 | """Finetuning example. 2 | 3 | Trains the torchMoji model on the kaggle insults dataset, using the 'chain-thaw' 4 | finetuning method and the accuracy metric. See the blog post at 5 | https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0 6 | for more information. Note that results may differ a bit due to slight 7 | changes in preprocessing and train/val/test split. 8 | 9 | The 'chain-thaw' method does the following: 10 | 0) Load all weights except for the softmax layer. Extend the embedding layer if 11 | necessary, initialising the new weights with random values. 12 | 1) Freeze every layer except the last (softmax) layer and train it. 13 | 2) Freeze every layer except the first layer and train it. 14 | 3) Freeze every layer except the second etc., until the second last layer. 15 | 4) Unfreeze all layers and train entire model. 16 | """ 17 | 18 | from __future__ import print_function 19 | import example_helper 20 | import json 21 | from torchmoji.model_def import torchmoji_transfer 22 | from torchmoji.global_variables import PRETRAINED_PATH 23 | from torchmoji.finetuning import ( 24 | load_benchmark, 25 | finetune) 26 | 27 | 28 | DATASET_PATH = '../data/kaggle-insults/raw.pickle' 29 | nb_classes = 2 30 | 31 | with open('../model/vocabulary.json', 'r') as f: 32 | vocab = json.load(f) 33 | 34 | # Load dataset. Extend the existing vocabulary with up to 10000 tokens from 35 | # the training dataset. 36 | data = load_benchmark(DATASET_PATH, vocab, extend_with=10000) 37 | 38 | # Set up model and finetune. Note that we have to extend the embedding layer 39 | # with the number of tokens added to the vocabulary. 40 | model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added']) 41 | print(model) 42 | model, acc = finetune(model, data['texts'], data['labels'], nb_classes, 43 | data['batch_size'], method='chain-thaw') 44 | print('Acc: {}'.format(acc)) 45 | -------------------------------------------------------------------------------- /examples/dataset_split.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Split a given dataset into three different datasets: training, validation and 3 | testing. 4 | 5 | This is achieved by splitting the given list of sentences into three separate 6 | lists according to either a given ratio (e.g. [0.7, 0.1, 0.2]) or by an 7 | explicit enumeration. The sentences are also tokenised using the given 8 | vocabulary. 9 | 10 | Also splits a given list of dictionaries containing information about 11 | each sentence. 12 | 13 | An additional parameter can be set 'extend_with', which will extend the given 14 | vocabulary with up to 'extend_with' tokens, taken from the training dataset. 15 | ''' 16 | from __future__ import print_function, unicode_literals 17 | import example_helper 18 | import json 19 | 20 | from torchmoji.sentence_tokenizer import SentenceTokenizer 21 | 22 | DATASET = [ 23 | 'I am sentence 0', 24 | 'I am sentence 1', 25 | 'I am sentence 2', 26 | 'I am sentence 3', 27 | 'I am sentence 4', 28 | 'I am sentence 5', 29 | 'I am sentence 6', 30 | 'I am sentence 7', 31 | 'I am sentence 8', 32 | 'I am sentence 9 newword', 33 | ] 34 | 35 | INFO_DICTS = [ 36 | {'label': 'sentence 0'}, 37 | {'label': 'sentence 1'}, 38 | {'label': 'sentence 2'}, 39 | {'label': 'sentence 3'}, 40 | {'label': 'sentence 4'}, 41 | {'label': 'sentence 5'}, 42 | {'label': 'sentence 6'}, 43 | {'label': 'sentence 7'}, 44 | {'label': 'sentence 8'}, 45 | {'label': 'sentence 9'}, 46 | ] 47 | 48 | with open('../model/vocabulary.json', 'r') as f: 49 | vocab = json.load(f) 50 | st = SentenceTokenizer(vocab, 30) 51 | 52 | # Split using the default split ratio 53 | print(st.split_train_val_test(DATASET, INFO_DICTS)) 54 | 55 | # Split explicitly 56 | print(st.split_train_val_test(DATASET, 57 | INFO_DICTS, 58 | [[0, 1, 2, 4, 9], [5, 6], [7, 8, 3]], 59 | extend_with=1)) 60 | -------------------------------------------------------------------------------- /examples/finetune_semeval_class-avg_f1.py: -------------------------------------------------------------------------------- 1 | """Finetuning example. 2 | 3 | Trains the torchMoji model on the SemEval emotion dataset, using the 'last' 4 | finetuning method and the class average F1 metric. 5 | 6 | The 'last' method does the following: 7 | 0) Load all weights except for the softmax layer. Do not add tokens to the 8 | vocabulary and do not extend the embedding layer. 9 | 1) Freeze all layers except for the softmax layer. 10 | 2) Train. 11 | 12 | The class average F1 metric does the following: 13 | 1) For each class, relabel the dataset into binary classification 14 | (belongs to/does not belong to this class). 15 | 2) Calculate F1 score for each class. 16 | 3) Compute the average of all F1 scores. 17 | """ 18 | 19 | from __future__ import print_function 20 | import example_helper 21 | import json 22 | from torchmoji.finetuning import load_benchmark 23 | from torchmoji.class_avg_finetuning import class_avg_finetune 24 | from torchmoji.model_def import torchmoji_transfer 25 | from torchmoji.global_variables import PRETRAINED_PATH 26 | 27 | DATASET_PATH = '../data/SE0714/raw.pickle' 28 | nb_classes = 3 29 | 30 | with open('../model/vocabulary.json', 'r') as f: 31 | vocab = json.load(f) 32 | 33 | 34 | # Load dataset. Extend the existing vocabulary with up to 10000 tokens from 35 | # the training dataset. 36 | data = load_benchmark(DATASET_PATH, vocab, extend_with=10000) 37 | 38 | # Set up model and finetune. Note that we have to extend the embedding layer 39 | # with the number of tokens added to the vocabulary. 40 | # 41 | # Also note that when using class average F1 to evaluate, the model has to be 42 | # defined with two classes, since the model will be trained for each class 43 | # separately. 44 | model = torchmoji_transfer(2, PRETRAINED_PATH, extend_embedding=data['added']) 45 | print(model) 46 | 47 | # For finetuning however, pass in the actual number of classes. 48 | model, f1 = class_avg_finetune(model, data['texts'], data['labels'], 49 | nb_classes, data['batch_size'], method='last') 50 | print('F1: {}'.format(f1)) 51 | -------------------------------------------------------------------------------- /tests/test_word_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | from os.path import dirname, abspath 4 | sys.path.append(dirname(dirname(abspath(__file__)))) 5 | from nose.tools import raises 6 | from torchmoji.word_generator import WordGenerator 7 | 8 | IS_PYTHON2 = int(sys.version[0]) == 2 9 | 10 | @raises(ValueError) 11 | def test_only_unicode_accepted(): 12 | """ Non-Unicode strings raise a ValueError. 13 | In Python 3 all string are Unicode 14 | """ 15 | if not IS_PYTHON2: 16 | raise ValueError("You are using python 3 so this test should always pass") 17 | 18 | sentences = [ 19 | u'Hello world', 20 | u'I am unicode', 21 | 'I am not unicode', 22 | ] 23 | 24 | wg = WordGenerator(sentences) 25 | for w in wg: 26 | pass 27 | 28 | 29 | def test_unicode_sentences_ignored_if_set(): 30 | """ Strings with Unicode characters tokenize to empty array if they're not allowed. 31 | """ 32 | sentence = [u'Dobrý den, jak se máš?'] 33 | wg = WordGenerator(sentence, allow_unicode_text=False) 34 | assert wg.get_words(sentence[0]) == [] 35 | 36 | 37 | def test_check_ascii(): 38 | """ check_ascii recognises ASCII words properly. 39 | In Python 3 all string are Unicode 40 | """ 41 | if not IS_PYTHON2: 42 | return 43 | 44 | wg = WordGenerator([]) 45 | assert wg.check_ascii('ASCII') 46 | assert not wg.check_ascii('ščřžýá') 47 | assert not wg.check_ascii('❤ ☀ ☆ ☂ ☻ ♞ ☯ ☭ ☢') 48 | 49 | 50 | def test_convert_unicode_word(): 51 | """ convert_unicode_word converts Unicode words correctly. 52 | """ 53 | wg = WordGenerator([], allow_unicode_text=True) 54 | 55 | result = wg.convert_unicode_word(u'č') 56 | assert result == (True, u'\u010d'), '{}'.format(result) 57 | 58 | 59 | def test_convert_unicode_word_ignores_if_set(): 60 | """ convert_unicode_word ignores Unicode words if set. 61 | """ 62 | wg = WordGenerator([], allow_unicode_text=False) 63 | 64 | result = wg.convert_unicode_word(u'č') 65 | assert result == (False, ''), '{}'.format(result) 66 | 67 | 68 | def test_convert_unicode_chars(): 69 | """ convert_unicode_word correctly converts accented characters. 70 | """ 71 | wg = WordGenerator([], allow_unicode_text=True) 72 | result = wg.convert_unicode_word(u'ěščřžýáíé') 73 | assert result == (True, u'\u011b\u0161\u010d\u0159\u017e\xfd\xe1\xed\xe9'), '{}'.format(result) 74 | -------------------------------------------------------------------------------- /examples/text_emojize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ Use torchMoji to predict emojis from a single text input 4 | """ 5 | 6 | from __future__ import print_function, division, unicode_literals 7 | import example_helper 8 | import json 9 | import csv 10 | import argparse 11 | 12 | import numpy as np 13 | import emoji 14 | 15 | from torchmoji.sentence_tokenizer import SentenceTokenizer 16 | from torchmoji.model_def import torchmoji_emojis 17 | from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH 18 | 19 | # Emoji map in emoji_overview.png 20 | EMOJIS = ":joy: :unamused: :weary: :sob: :heart_eyes: \ 21 | :pensive: :ok_hand: :blush: :heart: :smirk: \ 22 | :grin: :notes: :flushed: :100: :sleeping: \ 23 | :relieved: :relaxed: :raised_hands: :two_hearts: :expressionless: \ 24 | :sweat_smile: :pray: :confused: :kissing_heart: :heartbeat: \ 25 | :neutral_face: :information_desk_person: :disappointed: :see_no_evil: :tired_face: \ 26 | :v: :sunglasses: :rage: :thumbsup: :cry: \ 27 | :sleepy: :yum: :triumph: :hand: :mask: \ 28 | :clap: :eyes: :gun: :persevere: :smiling_imp: \ 29 | :sweat: :broken_heart: :yellow_heart: :musical_note: :speak_no_evil: \ 30 | :wink: :skull: :confounded: :smile: :stuck_out_tongue_winking_eye: \ 31 | :angry: :no_good: :muscle: :facepunch: :purple_heart: \ 32 | :sparkling_heart: :blue_heart: :grimacing: :sparkles:".split(' ') 33 | 34 | def top_elements(array, k): 35 | ind = np.argpartition(array, -k)[-k:] 36 | return ind[np.argsort(array[ind])][::-1] 37 | 38 | if __name__ == "__main__": 39 | argparser = argparse.ArgumentParser() 40 | argparser.add_argument('--text', type=str, required=True, help="Input text to emojize") 41 | argparser.add_argument('--maxlen', type=int, default=30, help="Max length of input text") 42 | args = argparser.parse_args() 43 | 44 | # Tokenizing using dictionary 45 | with open(VOCAB_PATH, 'r') as f: 46 | vocabulary = json.load(f) 47 | 48 | st = SentenceTokenizer(vocabulary, args.maxlen) 49 | 50 | # Loading model 51 | model = torchmoji_emojis(PRETRAINED_PATH) 52 | # Running predictions 53 | tokenized, _, _ = st.tokenize_sentences([args.text]) 54 | # Get sentence probability 55 | prob = model(tokenized)[0] 56 | 57 | # Top emoji id 58 | emoji_ids = top_elements(prob, 5) 59 | 60 | # map to emojis 61 | emojis = map(lambda x: EMOJIS[x], emoji_ids) 62 | 63 | print(emoji.emojize("{} {}".format(args.text,' '.join(emojis)), use_aliases=True)) 64 | -------------------------------------------------------------------------------- /scripts/download_weights.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | from subprocess import call 4 | from builtins import input 5 | 6 | curr_folder = os.path.basename(os.path.normpath(os.getcwd())) 7 | 8 | weights_filename = 'pytorch_model.bin' 9 | weights_folder = 'model' 10 | weights_path = '{}/{}'.format(weights_folder, weights_filename) 11 | if curr_folder == 'scripts': 12 | weights_path = '../' + weights_path 13 | weights_download_link = 'https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0#' 14 | 15 | 16 | MB_FACTOR = float(1<<20) 17 | 18 | def prompt(): 19 | while True: 20 | valid = { 21 | 'y': True, 22 | 'ye': True, 23 | 'yes': True, 24 | 'n': False, 25 | 'no': False, 26 | } 27 | choice = input().lower() 28 | if choice in valid: 29 | return valid[choice] 30 | else: 31 | print('Please respond with \'y\' or \'n\' (or \'yes\' or \'no\')') 32 | 33 | download = True 34 | if os.path.exists(weights_path): 35 | print('Weight file already exists at {}. Would you like to redownload it anyway? [y/n]'.format(weights_path)) 36 | download = prompt() 37 | already_exists = True 38 | else: 39 | already_exists = False 40 | 41 | if download: 42 | print('About to download the pretrained weights file from {}'.format(weights_download_link)) 43 | if already_exists == False: 44 | print('The size of the file is roughly 85MB. Continue? [y/n]') 45 | else: 46 | os.unlink(weights_path) 47 | 48 | if already_exists or prompt(): 49 | print('Downloading...') 50 | 51 | #urllib.urlretrieve(weights_download_link, weights_path) 52 | #with open(weights_path,'wb') as f: 53 | # f.write(requests.get(weights_download_link).content) 54 | 55 | # downloading using wget due to issues with urlretrieve and requests 56 | sys_call = 'wget {} -O {}'.format(weights_download_link, os.path.abspath(weights_path)) 57 | print("Running system call: {}".format(sys_call)) 58 | call(sys_call, shell=True) 59 | 60 | if os.path.getsize(weights_path) / MB_FACTOR < 80: 61 | raise ValueError("Download finished, but the resulting file is too small! " + 62 | "It\'s only {} bytes.".format(os.path.getsize(weights_path))) 63 | print('Downloaded weights to {}'.format(weights_path)) 64 | else: 65 | print('Exiting.') 66 | -------------------------------------------------------------------------------- /examples/score_texts_emojis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ Use torchMoji to score texts for emoji distribution. 4 | 5 | The resulting emoji ids (0-63) correspond to the mapping 6 | in emoji_overview.png file at the root of the torchMoji repo. 7 | 8 | Writes the result to a csv file. 9 | """ 10 | from __future__ import print_function, division, unicode_literals 11 | import example_helper 12 | import json 13 | import csv 14 | import numpy as np 15 | 16 | from torchmoji.sentence_tokenizer import SentenceTokenizer 17 | from torchmoji.model_def import torchmoji_emojis 18 | from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH 19 | 20 | OUTPUT_PATH = 'test_sentences.csv' 21 | 22 | TEST_SENTENCES = ['I love mom\'s cooking', 23 | 'I love how you never reply back..', 24 | 'I love cruising with my homies', 25 | 'I love messing with yo mind!!', 26 | 'I love you and now you\'re just gone..', 27 | 'This is shit', 28 | 'This is the shit'] 29 | 30 | 31 | def top_elements(array, k): 32 | ind = np.argpartition(array, -k)[-k:] 33 | return ind[np.argsort(array[ind])][::-1] 34 | 35 | maxlen = 30 36 | 37 | print('Tokenizing using dictionary from {}'.format(VOCAB_PATH)) 38 | with open(VOCAB_PATH, 'r') as f: 39 | vocabulary = json.load(f) 40 | 41 | st = SentenceTokenizer(vocabulary, maxlen) 42 | 43 | print('Loading model from {}.'.format(PRETRAINED_PATH)) 44 | model = torchmoji_emojis(PRETRAINED_PATH) 45 | print(model) 46 | print('Running predictions.') 47 | tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES) 48 | prob = model(tokenized) 49 | 50 | for prob in [prob]: 51 | # Find top emojis for each sentence. Emoji ids (0-63) 52 | # correspond to the mapping in emoji_overview.png 53 | # at the root of the torchMoji repo. 54 | print('Writing results to {}'.format(OUTPUT_PATH)) 55 | scores = [] 56 | for i, t in enumerate(TEST_SENTENCES): 57 | t_tokens = tokenized[i] 58 | t_score = [t] 59 | t_prob = prob[i] 60 | ind_top = top_elements(t_prob, 5) 61 | t_score.append(sum(t_prob[ind_top])) 62 | t_score.extend(ind_top) 63 | t_score.extend([t_prob[ind] for ind in ind_top]) 64 | scores.append(t_score) 65 | print(t_score) 66 | 67 | with open(OUTPUT_PATH, 'w') as csvfile: 68 | writer = csv.writer(csvfile, delimiter=str(','), lineterminator='\n') 69 | writer.writerow(['Text', 'Top5%', 70 | 'Emoji_1', 'Emoji_2', 'Emoji_3', 'Emoji_4', 'Emoji_5', 71 | 'Pct_1', 'Pct_2', 'Pct_3', 'Pct_4', 'Pct_5']) 72 | for i, row in enumerate(scores): 73 | try: 74 | writer.writerow(row) 75 | except: 76 | print("Exception at row {}!".format(i)) 77 | -------------------------------------------------------------------------------- /torchmoji/attlayer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Define the Attention Layer of the model. 3 | """ 4 | 5 | from __future__ import print_function, division 6 | 7 | import torch 8 | 9 | from torch.autograd import Variable 10 | from torch.nn import Module 11 | from torch.nn.parameter import Parameter 12 | 13 | class Attention(Module): 14 | """ 15 | Computes a weighted average of the different channels across timesteps. 16 | Uses 1 parameter pr. channel to compute the attention value for a single timestep. 17 | """ 18 | 19 | def __init__(self, attention_size, return_attention=False): 20 | """ Initialize the attention layer 21 | 22 | # Arguments: 23 | attention_size: Size of the attention vector. 24 | return_attention: If true, output will include the weight for each input token 25 | used for the prediction 26 | 27 | """ 28 | super(Attention, self).__init__() 29 | self.return_attention = return_attention 30 | self.attention_size = attention_size 31 | self.attention_vector = Parameter(torch.FloatTensor(attention_size)) 32 | self.attention_vector.data.normal_(std=0.05) # Initialize attention vector 33 | 34 | def __repr__(self): 35 | s = '{name}({attention_size}, return attention={return_attention})' 36 | return s.format(name=self.__class__.__name__, **self.__dict__) 37 | 38 | def forward(self, inputs, input_lengths): 39 | """ Forward pass. 40 | 41 | # Arguments: 42 | inputs (Torch.Variable): Tensor of input sequences 43 | input_lengths (torch.LongTensor): Lengths of the sequences 44 | 45 | # Return: 46 | Tuple with (representations and attentions if self.return_attention else None). 47 | """ 48 | logits = inputs.matmul(self.attention_vector) 49 | unnorm_ai = (logits - logits.max()).exp() 50 | 51 | # Compute a mask for the attention on the padded sequences 52 | # See e.g. https://discuss.pytorch.org/t/self-attention-on-words-and-masking/5671/5 53 | max_len = unnorm_ai.size(1) 54 | idxes = torch.arange(0, max_len, out=torch.LongTensor(max_len)).unsqueeze(0) 55 | mask = Variable((idxes < input_lengths.unsqueeze(1)).float()) 56 | 57 | # apply mask and renormalize attention scores (weights) 58 | masked_weights = unnorm_ai * mask 59 | att_sums = masked_weights.sum(dim=1, keepdim=True) # sums per sequence 60 | attentions = masked_weights.div(att_sums) 61 | 62 | # apply attention weights 63 | weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) 64 | 65 | # get the final fixed vector representations of the sentences 66 | representations = weighted.sum(dim=1) 67 | 68 | return (representations, attentions if self.return_attention else None) 69 | -------------------------------------------------------------------------------- /scripts/calculate_coverages.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import pickle 3 | import json 4 | import csv 5 | import sys 6 | from io import open 7 | 8 | # Allow us to import the torchmoji directory 9 | from os.path import dirname, abspath 10 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 11 | 12 | from torchmoji.sentence_tokenizer import SentenceTokenizer, coverage 13 | 14 | try: 15 | unicode # Python 2 16 | except NameError: 17 | unicode = str # Python 3 18 | 19 | IS_PYTHON2 = int(sys.version[0]) == 2 20 | 21 | OUTPUT_PATH = 'coverage.csv' 22 | DATASET_PATHS = [ 23 | '../data/Olympic/raw.pickle', 24 | '../data/PsychExp/raw.pickle', 25 | '../data/SCv1/raw.pickle', 26 | '../data/SCv2-GEN/raw.pickle', 27 | '../data/SE0714/raw.pickle', 28 | #'../data/SE1604/raw.pickle', # Excluded due to Twitter's ToS 29 | '../data/SS-Twitter/raw.pickle', 30 | '../data/SS-Youtube/raw.pickle', 31 | ] 32 | 33 | with open('../model/vocabulary.json', 'r') as f: 34 | vocab = json.load(f) 35 | 36 | results = [] 37 | for p in DATASET_PATHS: 38 | coverage_result = [p] 39 | print('Calculating coverage for {}'.format(p)) 40 | with open(p, 'rb') as f: 41 | if IS_PYTHON2: 42 | s = pickle.load(f) 43 | else: 44 | s = pickle.load(f, fix_imports=True) 45 | 46 | # Decode data 47 | try: 48 | s['texts'] = [unicode(x) for x in s['texts']] 49 | except UnicodeDecodeError: 50 | s['texts'] = [x.decode('utf-8') for x in s['texts']] 51 | 52 | # Own 53 | st = SentenceTokenizer({}, 30) 54 | tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'], 55 | [s['train_ind'], 56 | s['val_ind'], 57 | s['test_ind']], 58 | extend_with=10000) 59 | coverage_result.append(coverage(tests[2])) 60 | 61 | # Last 62 | st = SentenceTokenizer(vocab, 30) 63 | tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'], 64 | [s['train_ind'], 65 | s['val_ind'], 66 | s['test_ind']], 67 | extend_with=0) 68 | coverage_result.append(coverage(tests[2])) 69 | 70 | # Full 71 | st = SentenceTokenizer(vocab, 30) 72 | tests, dicts, _ = st.split_train_val_test(s['texts'], s['info'], 73 | [s['train_ind'], 74 | s['val_ind'], 75 | s['test_ind']], 76 | extend_with=10000) 77 | coverage_result.append(coverage(tests[2])) 78 | 79 | results.append(coverage_result) 80 | 81 | with open(OUTPUT_PATH, 'wb') as csvfile: 82 | writer = csv.writer(csvfile, delimiter='\t', lineterminator='\n') 83 | writer.writerow(['Dataset', 'Own', 'Last', 'Full']) 84 | for i, row in enumerate(results): 85 | try: 86 | writer.writerow(row) 87 | except: 88 | print("Exception at row {}!".format(i)) 89 | 90 | print('Saved to {}'.format(OUTPUT_PATH)) 91 | -------------------------------------------------------------------------------- /scripts/convert_all_datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import json 4 | import math 5 | import pickle 6 | import sys 7 | from io import open 8 | import numpy as np 9 | from os.path import abspath, dirname 10 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 11 | 12 | from torchmoji.word_generator import WordGenerator 13 | from torchmoji.create_vocab import VocabBuilder 14 | from torchmoji.sentence_tokenizer import SentenceTokenizer, extend_vocab, coverage 15 | from torchmoji.tokenizer import tokenize 16 | 17 | try: 18 | unicode # Python 2 19 | except NameError: 20 | unicode = str # Python 3 21 | 22 | IS_PYTHON2 = int(sys.version[0]) == 2 23 | 24 | DATASETS = [ 25 | 'Olympic', 26 | 'PsychExp', 27 | 'SCv1', 28 | 'SCv2-GEN', 29 | 'SE0714', 30 | #'SE1604', # Excluded due to Twitter's ToS 31 | 'SS-Twitter', 32 | 'SS-Youtube', 33 | ] 34 | 35 | DIR = '../data' 36 | FILENAME_RAW = 'raw.pickle' 37 | FILENAME_OWN = 'own_vocab.pickle' 38 | FILENAME_OUR = 'twitter_vocab.pickle' 39 | FILENAME_COMBINED = 'combined_vocab.pickle' 40 | 41 | 42 | def roundup(x): 43 | return int(math.ceil(x / 10.0)) * 10 44 | 45 | 46 | def format_pickle(dset, train_texts, val_texts, test_texts, train_labels, val_labels, test_labels): 47 | return {'dataset': dset, 48 | 'train_texts': train_texts, 49 | 'val_texts': val_texts, 50 | 'test_texts': test_texts, 51 | 'train_labels': train_labels, 52 | 'val_labels': val_labels, 53 | 'test_labels': test_labels} 54 | 55 | def convert_dataset(filepath, extend_with, vocab): 56 | print('-- Generating {} '.format(filepath)) 57 | sys.stdout.flush() 58 | st = SentenceTokenizer(vocab, maxlen) 59 | tokenized, dicts, _ = st.split_train_val_test(texts, 60 | labels, 61 | [data['train_ind'], 62 | data['val_ind'], 63 | data['test_ind']], 64 | extend_with=extend_with) 65 | pick = format_pickle(dset, tokenized[0], tokenized[1], tokenized[2], 66 | dicts[0], dicts[1], dicts[2]) 67 | with open(filepath, 'w') as f: 68 | pickle.dump(pick, f) 69 | cover = coverage(tokenized[2]) 70 | 71 | print(' done. Coverage: {}'.format(cover)) 72 | 73 | with open('../model/vocabulary.json', 'r') as f: 74 | vocab = json.load(f) 75 | 76 | for dset in DATASETS: 77 | print('Converting {}'.format(dset)) 78 | 79 | PATH_RAW = '{}/{}/{}'.format(DIR, dset, FILENAME_RAW) 80 | PATH_OWN = '{}/{}/{}'.format(DIR, dset, FILENAME_OWN) 81 | PATH_OUR = '{}/{}/{}'.format(DIR, dset, FILENAME_OUR) 82 | PATH_COMBINED = '{}/{}/{}'.format(DIR, dset, FILENAME_COMBINED) 83 | 84 | with open(PATH_RAW, 'rb') as dataset: 85 | if IS_PYTHON2: 86 | data = pickle.load(dataset) 87 | else: 88 | data = pickle.load(dataset, fix_imports=True) 89 | 90 | # Decode data 91 | try: 92 | texts = [unicode(x) for x in data['texts']] 93 | except UnicodeDecodeError: 94 | texts = [x.decode('utf-8') for x in data['texts']] 95 | 96 | wg = WordGenerator(texts) 97 | vb = VocabBuilder(wg) 98 | vb.count_all_words() 99 | 100 | # Calculate max length of sequences considered 101 | # Adjust batch_size accordingly to prevent GPU overflow 102 | lengths = [len(tokenize(t)) for t in texts] 103 | maxlen = roundup(np.percentile(lengths, 80.0)) 104 | 105 | # Extract labels 106 | labels = [x['label'] for x in data['info']] 107 | 108 | convert_dataset(PATH_OWN, 50000, {}) 109 | convert_dataset(PATH_OUR, 0, vocab) 110 | convert_dataset(PATH_COMBINED, 10000, vocab) 111 | -------------------------------------------------------------------------------- /tests/test_sentence_tokenizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division, unicode_literals 2 | import test_helper 3 | import json 4 | 5 | from torchmoji.sentence_tokenizer import SentenceTokenizer 6 | 7 | sentences = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'] 8 | 9 | dicts = [ 10 | {'label': 0}, 11 | {'label': 1}, 12 | {'label': 2}, 13 | {'label': 3}, 14 | {'label': 4}, 15 | {'label': 5}, 16 | {'label': 6}, 17 | {'label': 7}, 18 | {'label': 8}, 19 | {'label': 9}, 20 | ] 21 | 22 | train_ind = [0, 5, 3, 6, 8] 23 | val_ind = [9, 2, 1] 24 | test_ind = [4, 7] 25 | 26 | with open('../model/vocabulary.json', 'r') as f: 27 | vocab = json.load(f) 28 | 29 | def test_dataset_split_parameter(): 30 | """ Dataset is split in the desired ratios 31 | """ 32 | split_parameter = [0.7, 0.1, 0.2] 33 | st = SentenceTokenizer(vocab, 30) 34 | 35 | result, result_dicts, _ = st.split_train_val_test(sentences, dicts, 36 | split_parameter, extend_with=0) 37 | train = result[0] 38 | val = result[1] 39 | test = result[2] 40 | 41 | train_dicts = result_dicts[0] 42 | val_dicts = result_dicts[1] 43 | test_dicts = result_dicts[2] 44 | 45 | assert len(train) == len(sentences) * split_parameter[0] 46 | assert len(val) == len(sentences) * split_parameter[1] 47 | assert len(test) == len(sentences) * split_parameter[2] 48 | 49 | assert len(train_dicts) == len(dicts) * split_parameter[0] 50 | assert len(val_dicts) == len(dicts) * split_parameter[1] 51 | assert len(test_dicts) == len(dicts) * split_parameter[2] 52 | 53 | def test_dataset_split_explicit(): 54 | """ Dataset is split according to given indices 55 | """ 56 | split_parameter = [train_ind, val_ind, test_ind] 57 | st = SentenceTokenizer(vocab, 30) 58 | tokenized, _, _ = st.tokenize_sentences(sentences) 59 | 60 | result, result_dicts, added = st.split_train_val_test(sentences, dicts, split_parameter, extend_with=0) 61 | train = result[0] 62 | val = result[1] 63 | test = result[2] 64 | 65 | train_dicts = result_dicts[0] 66 | val_dicts = result_dicts[1] 67 | test_dicts = result_dicts[2] 68 | 69 | tokenized = tokenized 70 | 71 | for i, sentence in enumerate(sentences): 72 | if i in train_ind: 73 | assert tokenized[i] in train 74 | assert dicts[i] in train_dicts 75 | elif i in val_ind: 76 | assert tokenized[i] in val 77 | assert dicts[i] in val_dicts 78 | elif i in test_ind: 79 | assert tokenized[i] in test 80 | assert dicts[i] in test_dicts 81 | 82 | assert len(train) == len(train_ind) 83 | assert len(val) == len(val_ind) 84 | assert len(test) == len(test_ind) 85 | assert len(train_dicts) == len(train_ind) 86 | assert len(val_dicts) == len(val_ind) 87 | assert len(test_dicts) == len(test_ind) 88 | 89 | def test_id_to_sentence(): 90 | """Tokenizing and converting back preserves the input. 91 | """ 92 | vb = {'CUSTOM_MASK': 0, 93 | 'aasdf': 1000, 94 | 'basdf': 2000} 95 | 96 | sentence = 'aasdf basdf basdf basdf' 97 | st = SentenceTokenizer(vb, 30) 98 | token, _, _ = st.tokenize_sentences([sentence]) 99 | assert st.to_sentence(token[0]) == sentence 100 | 101 | def test_id_to_sentence_with_unknown(): 102 | """Tokenizing and converting back preserves the input, except for unknowns. 103 | """ 104 | vb = {'CUSTOM_MASK': 0, 105 | 'CUSTOM_UNKNOWN': 1, 106 | 'aasdf': 1000, 107 | 'basdf': 2000} 108 | 109 | sentence = 'aasdf basdf ccc' 110 | expected = 'aasdf basdf CUSTOM_UNKNOWN' 111 | st = SentenceTokenizer(vb, 30) 112 | token, _, _ = st.tokenize_sentences([sentence]) 113 | assert st.to_sentence(token[0]) == expected 114 | -------------------------------------------------------------------------------- /torchmoji/tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Splits up a Unicode string into a list of tokens. 4 | Recognises: 5 | - Abbreviations 6 | - URLs 7 | - Emails 8 | - #hashtags 9 | - @mentions 10 | - emojis 11 | - emoticons (limited support) 12 | 13 | Multiple consecutive symbols are also treated as a single token. 14 | ''' 15 | from __future__ import absolute_import, division, print_function, unicode_literals 16 | 17 | import re 18 | 19 | # Basic patterns. 20 | RE_NUM = r'[0-9]+' 21 | RE_WORD = r'[a-zA-Z]+' 22 | RE_WHITESPACE = r'\s+' 23 | RE_ANY = r'.' 24 | 25 | # Combined words such as 'red-haired' or 'CUSTOM_TOKEN' 26 | RE_COMB = r'[a-zA-Z]+[-_][a-zA-Z]+' 27 | 28 | # English-specific patterns 29 | RE_CONTRACTIONS = RE_WORD + r'\'' + RE_WORD 30 | 31 | TITLES = [ 32 | r'Mr\.', 33 | r'Ms\.', 34 | r'Mrs\.', 35 | r'Dr\.', 36 | r'Prof\.', 37 | ] 38 | # Ensure case insensitivity 39 | RE_TITLES = r'|'.join([r'(?i)' + t for t in TITLES]) 40 | 41 | # Symbols have to be created as separate patterns in order to match consecutive 42 | # identical symbols. 43 | SYMBOLS = r'(){}~$^&*;:%+\xa3€`' 44 | RE_SYMBOL = r'|'.join([re.escape(s) + r'+' for s in SYMBOLS]) 45 | 46 | # Hash symbols and at symbols have to be defined separately in order to not 47 | # clash with hashtags and mentions if there are multiple - i.e. 48 | # ##hello -> ['#', '#hello'] instead of ['##', 'hello'] 49 | SPECIAL_SYMBOLS = r'|#+(?=#[a-zA-Z0-9_]+)|@+(?=@[a-zA-Z0-9_]+)|#+|@+' 50 | RE_SYMBOL += SPECIAL_SYMBOLS 51 | 52 | RE_ABBREVIATIONS = r'\b(?:', 65 | r':', 66 | r'=', 67 | r';', 68 | ] 69 | EMOTICONS_MID = [ 70 | r'-', 71 | r',', 72 | r'^', 73 | '\'', 74 | '\"', 75 | ] 76 | EMOTICONS_END = [ 77 | r'D', 78 | r'd', 79 | r'p', 80 | r'P', 81 | r'v', 82 | r')', 83 | r'o', 84 | r'O', 85 | r'(', 86 | r'3', 87 | r'/', 88 | r'|', 89 | '\\', 90 | ] 91 | EMOTICONS_EXTRA = [ 92 | r'-_-', 93 | r'x_x', 94 | r'^_^', 95 | r'o.o', 96 | r'o_o', 97 | r'(:', 98 | r'):', 99 | r');', 100 | r'(;', 101 | ] 102 | 103 | RE_EMOTICON = r'|'.join([re.escape(s) for s in EMOTICONS_EXTRA]) 104 | for s in EMOTICONS_START: 105 | for m in EMOTICONS_MID: 106 | for e in EMOTICONS_END: 107 | RE_EMOTICON += '|{0}{1}?{2}+'.format(re.escape(s), re.escape(m), re.escape(e)) 108 | 109 | # requires ucs4 in python2.7 or python3+ 110 | # RE_EMOJI = r"""[\U0001F300-\U0001F64F\U0001F680-\U0001F6FF\u2600-\u26FF\u2700-\u27BF]""" 111 | # safe for all python 112 | RE_EMOJI = r"""\ud83c[\udf00-\udfff]|\ud83d[\udc00-\ude4f\ude80-\udeff]|[\u2600-\u26FF\u2700-\u27BF]""" 113 | 114 | # List of matched token patterns, ordered from most specific to least specific. 115 | TOKENS = [ 116 | RE_URL, 117 | RE_EMAIL, 118 | RE_COMB, 119 | RE_HASHTAG, 120 | RE_MENTION, 121 | RE_HEART, 122 | RE_EMOTICON, 123 | RE_CONTRACTIONS, 124 | RE_TITLES, 125 | RE_ABBREVIATIONS, 126 | RE_NUM, 127 | RE_WORD, 128 | RE_SYMBOL, 129 | RE_EMOJI, 130 | RE_ANY 131 | ] 132 | 133 | # List of ignored token patterns 134 | IGNORED = [ 135 | RE_WHITESPACE 136 | ] 137 | 138 | # Final pattern 139 | RE_PATTERN = re.compile(r'|'.join(IGNORED) + r'|(' + r'|'.join(TOKENS) + r')', 140 | re.UNICODE) 141 | 142 | 143 | def tokenize(text): 144 | '''Splits given input string into a list of tokens. 145 | 146 | # Arguments: 147 | text: Input string to be tokenized. 148 | 149 | # Returns: 150 | List of strings (tokens). 151 | ''' 152 | result = RE_PATTERN.findall(text) 153 | 154 | # Remove empty strings 155 | result = [t for t in result if t.strip()] 156 | return result 157 | -------------------------------------------------------------------------------- /scripts/finetune_dataset.py: -------------------------------------------------------------------------------- 1 | """ Finetuning example. 2 | """ 3 | from __future__ import print_function 4 | import sys 5 | import numpy as np 6 | from os.path import abspath, dirname 7 | sys.path.insert(0, dirname(dirname(abspath(__file__)))) 8 | 9 | import json 10 | import math 11 | from torchmoji.model_def import torchmoji_transfer 12 | from torchmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH 13 | from torchmoji.finetuning import ( 14 | load_benchmark, 15 | finetune) 16 | from torchmoji.class_avg_finetuning import class_avg_finetune 17 | 18 | def roundup(x): 19 | return int(math.ceil(x / 10.0)) * 10 20 | 21 | 22 | # Format: (dataset_name, 23 | # path_to_dataset, 24 | # nb_classes, 25 | # use_f1_score) 26 | DATASETS = [ 27 | #('SE0714', '../data/SE0714/raw.pickle', 3, True), 28 | #('Olympic', '../data/Olympic/raw.pickle', 4, True), 29 | #('PsychExp', '../data/PsychExp/raw.pickle', 7, True), 30 | #('SS-Twitter', '../data/SS-Twitter/raw.pickle', 2, False), 31 | ('SS-Youtube', '../data/SS-Youtube/raw.pickle', 2, False), 32 | #('SE1604', '../data/SE1604/raw.pickle', 3, False), # Excluded due to Twitter's ToS 33 | #('SCv1', '../data/SCv1/raw.pickle', 2, True), 34 | #('SCv2-GEN', '../data/SCv2-GEN/raw.pickle', 2, True) 35 | ] 36 | 37 | RESULTS_DIR = 'results' 38 | 39 | # 'new' | 'last' | 'full' | 'chain-thaw' 40 | FINETUNE_METHOD = 'last' 41 | VERBOSE = 1 42 | 43 | nb_tokens = 50000 44 | nb_epochs = 1000 45 | epoch_size = 1000 46 | 47 | with open(VOCAB_PATH, 'r') as f: 48 | vocab = json.load(f) 49 | 50 | for rerun_iter in range(5): 51 | for p in DATASETS: 52 | 53 | # debugging 54 | assert len(vocab) == nb_tokens 55 | 56 | dset = p[0] 57 | path = p[1] 58 | nb_classes = p[2] 59 | use_f1_score = p[3] 60 | 61 | if FINETUNE_METHOD == 'last': 62 | extend_with = 0 63 | elif FINETUNE_METHOD in ['new', 'full', 'chain-thaw']: 64 | extend_with = 10000 65 | else: 66 | raise ValueError('Finetuning method not recognised!') 67 | 68 | # Load dataset. 69 | data = load_benchmark(path, vocab, extend_with=extend_with) 70 | 71 | (X_train, y_train) = (data['texts'][0], data['labels'][0]) 72 | (X_val, y_val) = (data['texts'][1], data['labels'][1]) 73 | (X_test, y_test) = (data['texts'][2], data['labels'][2]) 74 | 75 | weight_path = PRETRAINED_PATH if FINETUNE_METHOD != 'new' else None 76 | nb_model_classes = 2 if use_f1_score else nb_classes 77 | model = torchmoji_transfer( 78 | nb_model_classes, 79 | weight_path, 80 | extend_embedding=data['added']) 81 | print(model) 82 | 83 | # Training 84 | print('Training: {}'.format(path)) 85 | if use_f1_score: 86 | model, result = class_avg_finetune(model, data['texts'], 87 | data['labels'], 88 | nb_classes, data['batch_size'], 89 | FINETUNE_METHOD, 90 | verbose=VERBOSE) 91 | else: 92 | model, result = finetune(model, data['texts'], data['labels'], 93 | nb_classes, data['batch_size'], 94 | FINETUNE_METHOD, metric='acc', 95 | verbose=VERBOSE) 96 | 97 | # Write results 98 | if use_f1_score: 99 | print('Overall F1 score (dset = {}): {}'.format(dset, result)) 100 | with open('{}/{}_{}_{}_results.txt'. 101 | format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), 102 | "w") as f: 103 | f.write("F1: {}\n".format(result)) 104 | else: 105 | print('Test accuracy (dset = {}): {}'.format(dset, result)) 106 | with open('{}/{}_{}_{}_results.txt'. 107 | format(RESULTS_DIR, dset, FINETUNE_METHOD, rerun_iter), 108 | "w") as f: 109 | f.write("Acc: {}\n".format(result)) 110 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Tokenization tests. 3 | """ 4 | from __future__ import absolute_import, print_function, division, unicode_literals 5 | 6 | import sys 7 | from nose.tools import nottest 8 | from os.path import dirname, abspath 9 | sys.path.append(dirname(dirname(abspath(__file__)))) 10 | from torchmoji.tokenizer import tokenize 11 | 12 | TESTS_NORMAL = [ 13 | ('200K words!', ['200', 'K', 'words', '!']), 14 | ] 15 | 16 | TESTS_EMOJIS = [ 17 | ('i \U0001f496 you to the moon and back', 18 | ['i', '\U0001f496', 'you', 'to', 'the', 'moon', 'and', 'back']), 19 | ("i\U0001f496you to the \u2605's and back", 20 | ['i', '\U0001f496', 'you', 'to', 'the', 21 | '\u2605', "'", 's', 'and', 'back']), 22 | ('~<3~', ['~', '<3', '~']), 23 | ('<333', ['<333']), 24 | (':-)', [':-)']), 25 | ('>:-(', ['>:-(']), 26 | ('\u266b\u266a\u2605\u2606\u2665\u2764\u2661', 27 | ['\u266b', '\u266a', '\u2605', '\u2606', 28 | '\u2665', '\u2764', '\u2661']), 29 | ] 30 | 31 | TESTS_URLS = [ 32 | ('www.sample.com', ['www.sample.com']), 33 | ('http://endless.horse', ['http://endless.horse']), 34 | ('https://github.mit.ed', ['https://github.mit.ed']), 35 | ] 36 | 37 | TESTS_TWITTER = [ 38 | ('#blacklivesmatter', ['#blacklivesmatter']), 39 | ('#99_percent.', ['#99_percent', '.']), 40 | ('the#99%', ['the', '#99', '%']), 41 | ('@golden_zenith', ['@golden_zenith']), 42 | ('@99_percent', ['@99_percent']), 43 | ('latte-express@mit.ed', ['latte-express@mit.ed']), 44 | ] 45 | 46 | TESTS_PHONE_NUMS = [ 47 | ('518)528-0252', ['518', ')', '528', '-', '0252']), 48 | ('1200-0221-0234', ['1200', '-', '0221', '-', '0234']), 49 | ('1200.0221.0234', ['1200', '.', '0221', '.', '0234']), 50 | ] 51 | 52 | TESTS_DATETIME = [ 53 | ('15:00', ['15', ':', '00']), 54 | ('2:00pm', ['2', ':', '00', 'pm']), 55 | ('9/14/16', ['9', '/', '14', '/', '16']), 56 | ] 57 | 58 | TESTS_CURRENCIES = [ 59 | ('517.933\xa3', ['517', '.', '933', '\xa3']), 60 | ('$517.87', ['$', '517', '.', '87']), 61 | ('1201.6598', ['1201', '.', '6598']), 62 | ('120,6', ['120', ',', '6']), 63 | ('10,00\u20ac', ['10', ',', '00', '\u20ac']), 64 | ('1,000', ['1', ',', '000']), 65 | ('1200pesos', ['1200', 'pesos']), 66 | ] 67 | 68 | TESTS_NUM_SYM = [ 69 | ('5162f', ['5162', 'f']), 70 | ('f5162', ['f', '5162']), 71 | ('1203(', ['1203', '(']), 72 | ('(1203)', ['(', '1203', ')']), 73 | ('1200/', ['1200', '/']), 74 | ('1200+', ['1200', '+']), 75 | ('1202o-east', ['1202', 'o-east']), 76 | ('1200r', ['1200', 'r']), 77 | ('1200-1400', ['1200', '-', '1400']), 78 | ('120/today', ['120', '/', 'today']), 79 | ('today/120', ['today', '/', '120']), 80 | ('120/5', ['120', '/', '5']), 81 | ("120'/5", ['120', "'", '/', '5']), 82 | ('120/5pro', ['120', '/', '5', 'pro']), 83 | ("1200's,)", ['1200', "'", 's', ',', ')']), 84 | ('120.76.218.207', ['120', '.', '76', '.', '218', '.', '207']), 85 | ] 86 | 87 | TESTS_PUNCTUATION = [ 88 | ("don''t", ['don', "''", 't']), 89 | ("don'tcha", ["don'tcha"]), 90 | ('no?!?!;', ['no', '?', '!', '?', '!', ';']), 91 | ('no??!!..', ['no', '??', '!!', '..']), 92 | ('a.m.', ['a.m.']), 93 | ('.s.u', ['.', 's', '.', 'u']), 94 | ('!!i..n__', ['!!', 'i', '..', 'n', '__']), 95 | ('lv(<3)w(3>)u Mr.!', ['lv', '(', '<3', ')', 'w', '(', '3', 96 | '>', ')', 'u', 'Mr.', '!']), 97 | ('-->', ['--', '>']), 98 | ('->', ['-', '>']), 99 | ('<-', ['<', '-']), 100 | ('<--', ['<', '--']), 101 | ('hello (@person)', ['hello', '(', '@person', ')']), 102 | ] 103 | 104 | 105 | def test_normal(): 106 | """ Normal/combined usage. 107 | """ 108 | test_base(TESTS_NORMAL) 109 | 110 | 111 | def test_emojis(): 112 | """ Tokenizing emojis/emoticons/decorations. 113 | """ 114 | test_base(TESTS_EMOJIS) 115 | 116 | 117 | def test_urls(): 118 | """ Tokenizing URLs. 119 | """ 120 | test_base(TESTS_URLS) 121 | 122 | 123 | def test_twitter(): 124 | """ Tokenizing hashtags, mentions and emails. 125 | """ 126 | test_base(TESTS_TWITTER) 127 | 128 | 129 | def test_phone_nums(): 130 | """ Tokenizing phone numbers. 131 | """ 132 | test_base(TESTS_PHONE_NUMS) 133 | 134 | 135 | def test_datetime(): 136 | """ Tokenizing dates and times. 137 | """ 138 | test_base(TESTS_DATETIME) 139 | 140 | 141 | def test_currencies(): 142 | """ Tokenizing currencies. 143 | """ 144 | test_base(TESTS_CURRENCIES) 145 | 146 | 147 | def test_num_sym(): 148 | """ Tokenizing combinations of numbers and symbols. 149 | """ 150 | test_base(TESTS_NUM_SYM) 151 | 152 | 153 | def test_punctuation(): 154 | """ Tokenizing punctuation and contractions. 155 | """ 156 | test_base(TESTS_PUNCTUATION) 157 | 158 | 159 | @nottest 160 | def test_base(tests): 161 | """ Base function for running tests. 162 | """ 163 | for (test, expected) in tests: 164 | actual = tokenize(test) 165 | assert actual == expected, \ 166 | "Tokenization of \'{}\' failed, expected: {}, actual: {}"\ 167 | .format(test, expected, actual) 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### ------ Update September 2018 ------ 2 | It's been a year since TorchMoji and DeepMoji were released. We're trying to understand how it's being used such that we can make improvements and design better models in the future. 3 | 4 | You can help us achieve this by answering this [4-question Google Form](https://docs.google.com/forms/d/e/1FAIpQLSe1h4NSQD30YM8dsbJQEnki-02_9KVQD34qgP9to0bwAHBvBA/viewform "DeepMoji Google Form"). Thanks for your support! 5 | 6 | # 😇 TorchMoji 7 | 8 | > **Read our blog post about the implementation process [here](https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983).** 9 | 10 | TorchMoji is a [pyTorch](http://pytorch.org/) implementation of the [DeepMoji](https://github.com/bfelbo/DeepMoji) model developped by Bjarke Felbo, Alan Mislove, Anders Søgaard, Iyad Rahwan and Sune Lehmann. 11 | 12 | This model trained on 1.2 billion tweets with emojis to understand how language is used to express emotions. Through transfer learning the model can obtain state-of-the-art performance on many emotion-related text modeling tasks. 13 | 14 | Try the online demo of DeepMoji on this 🤗 [Space](https://huggingface.co/spaces/Pendrokar/DeepMoji)! See the [paper](https://arxiv.org/abs/1708.00524), [blog post](https://medium.com/@bjarkefelbo/what-can-we-learn-from-emojis-6beb165a5ea0) or [FAQ](https://www.media.mit.edu/projects/deepmoji/overview/) for more details. 15 | 16 | ## Overview 17 | * [torchmoji/](torchmoji) contains all the underlying code needed to convert a dataset to the vocabulary and use the model. 18 | * [examples/](examples) contains short code snippets showing how to convert a dataset to the vocabulary, load up the model and run it on that dataset. 19 | * [scripts/](scripts) contains code for processing and analysing datasets to reproduce results in the paper. 20 | * [model/](model) contains the pretrained model and vocabulary. 21 | * [data/](data) contains raw and processed datasets that we include in this repository for testing. 22 | * [tests/](tests) contains unit tests for the codebase. 23 | 24 | To start out with, have a look inside the [examples/](examples) directory. See [score_texts_emojis.py](examples/score_texts_emojis.py) for how to use DeepMoji to extract emoji predictions, [encode_texts.py](examples/encode_texts.py) for how to convert text into 2304-dimensional emotional feature vectors or [finetune_youtube_last.py](examples/finetune_youtube_last.py) for how to use the model for transfer learning on a new dataset. 25 | 26 | Please consider citing the [paper](https://arxiv.org/abs/1708.00524) of DeepMoji if you use the model or code (see below for citation). 27 | 28 | ## Installation 29 | 30 | We assume that you're using [Python 2.7-3.5](https://www.python.org/downloads/) with [pip](https://pip.pypa.io/en/stable/installing/) installed. 31 | 32 | First you need to install [pyTorch (version 0.2+)](http://pytorch.org/), currently by: 33 | ```bash 34 | conda install pytorch -c pytorch 35 | ``` 36 | At the present stage the model can't make efficient use of CUDA. See details in the [Hugging Face blog post](https://medium.com/huggingface/understanding-emotions-from-keras-to-pytorch-3ccb61d5a983). 37 | 38 | When pyTorch is installed, run the following in the root directory to install the remaining dependencies: 39 | 40 | ```bash 41 | pip install -e . 42 | ``` 43 | This will install the following dependencies: 44 | * [scikit-learn](https://github.com/scikit-learn/scikit-learn) 45 | * [text-unidecode](https://github.com/kmike/text-unidecode) 46 | * [emoji](https://github.com/carpedm20/emoji) 47 | 48 | Then, run the download script to downloads the pretrained torchMoji weights (~85MB) from [here](https://www.dropbox.com/s/q8lax9ary32c7t9/pytorch_model.bin?dl=0) and put them in the model/ directory: 49 | 50 | ```bash 51 | python scripts/download_weights.py 52 | ``` 53 | 54 | ## Testing 55 | To run the tests, install [nose](http://nose.readthedocs.io/en/latest/). After installing, navigate to the [tests/](tests) directory and run: 56 | 57 | ```bash 58 | cd tests 59 | nosetests -v 60 | ``` 61 | 62 | By default, this will also run finetuning tests. These tests train the model for one epoch and then check the resulting accuracy, which may take several minutes to finish. If you'd prefer to exclude those, run the following instead: 63 | 64 | ```bash 65 | cd tests 66 | nosetests -v -a '!slow' 67 | ``` 68 | 69 | ## Disclaimer 70 | This code has been tested to work with Python 2.7 and 3.5 on Ubuntu 16.04 and macOS Sierra machines. It has not been optimized for efficiency, but should be fast enough for most purposes. We do not give any guarantees that there are no bugs - use the code on your own responsibility! 71 | 72 | ## Contributions 73 | We welcome pull requests if you feel like something could be improved. You can also greatly help us by telling us how you felt when writing your most recent tweets. Just click [here](http://deepmoji.mit.edu/contribute/) to contribute. 74 | 75 | ## License 76 | This code and the pretrained model is licensed under the MIT license. 77 | 78 | ## Benchmark datasets 79 | The benchmark datasets are uploaded to this repository for convenience purposes only. They were not released by us and we do not claim any rights on them. Use the datasets at your responsibility and make sure you fulfill the licenses that they were released with. If you use any of the benchmark datasets please consider citing the original authors. 80 | 81 | ## Citation 82 | ``` 83 | @inproceedings{felbo2017, 84 | title={Using millions of emoji occurrences to learn any-domain representations for detecting sentiment, emotion and sarcasm}, 85 | author={Felbo, Bjarke and Mislove, Alan and S{\o}gaard, Anders and Rahwan, Iyad and Lehmann, Sune}, 86 | booktitle={Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 87 | year={2017} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /torchmoji/filter_utils.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function, division, unicode_literals 4 | import sys 5 | import re 6 | import string 7 | import emoji 8 | from itertools import groupby 9 | 10 | import numpy as np 11 | from torchmoji.tokenizer import RE_MENTION, RE_URL 12 | from torchmoji.global_variables import SPECIAL_TOKENS 13 | 14 | try: 15 | unichr # Python 2 16 | except NameError: 17 | unichr = chr # Python 3 18 | 19 | 20 | AtMentionRegex = re.compile(RE_MENTION) 21 | urlRegex = re.compile(RE_URL) 22 | 23 | # from http://bit.ly/2rdjgjE (UTF-8 encodings and Unicode chars) 24 | VARIATION_SELECTORS = [ '\ufe00', 25 | '\ufe01', 26 | '\ufe02', 27 | '\ufe03', 28 | '\ufe04', 29 | '\ufe05', 30 | '\ufe06', 31 | '\ufe07', 32 | '\ufe08', 33 | '\ufe09', 34 | '\ufe0a', 35 | '\ufe0b', 36 | '\ufe0c', 37 | '\ufe0d', 38 | '\ufe0e', 39 | '\ufe0f'] 40 | 41 | # from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python 42 | ALL_CHARS = (unichr(i) for i in range(sys.maxunicode)) 43 | CONTROL_CHARS = ''.join(map(unichr, list(range(0,32)) + list(range(127,160)))) 44 | CONTROL_CHAR_REGEX = re.compile('[%s]' % re.escape(CONTROL_CHARS)) 45 | 46 | def is_special_token(word): 47 | equal = False 48 | for spec in SPECIAL_TOKENS: 49 | if word == spec: 50 | equal = True 51 | break 52 | return equal 53 | 54 | def mostly_english(words, english, pct_eng_short=0.5, pct_eng_long=0.6, ignore_special_tokens=True, min_length=2): 55 | """ Ensure text meets threshold for containing English words """ 56 | 57 | n_words = 0 58 | n_english = 0 59 | 60 | if english is None: 61 | return True, 0, 0 62 | 63 | for w in words: 64 | if len(w) < min_length: 65 | continue 66 | if punct_word(w): 67 | continue 68 | if ignore_special_tokens and is_special_token(w): 69 | continue 70 | n_words += 1 71 | if w in english: 72 | n_english += 1 73 | 74 | if n_words < 2: 75 | return True, n_words, n_english 76 | if n_words < 5: 77 | valid_english = n_english >= n_words * pct_eng_short 78 | else: 79 | valid_english = n_english >= n_words * pct_eng_long 80 | return valid_english, n_words, n_english 81 | 82 | def correct_length(words, min_words, max_words, ignore_special_tokens=True): 83 | """ Ensure text meets threshold for containing English words 84 | and that it's within the min and max words limits. """ 85 | 86 | if min_words is None: 87 | min_words = 0 88 | 89 | if max_words is None: 90 | max_words = 99999 91 | 92 | n_words = 0 93 | for w in words: 94 | if punct_word(w): 95 | continue 96 | if ignore_special_tokens and is_special_token(w): 97 | continue 98 | n_words += 1 99 | valid = min_words <= n_words and n_words <= max_words 100 | return valid 101 | 102 | def punct_word(word, punctuation=string.punctuation): 103 | return all([True if c in punctuation else False for c in word]) 104 | 105 | def load_non_english_user_set(): 106 | non_english_user_set = set(np.load('uids.npz')['data']) 107 | return non_english_user_set 108 | 109 | def non_english_user(userid, non_english_user_set): 110 | neu_found = int(userid) in non_english_user_set 111 | return neu_found 112 | 113 | def separate_emojis_and_text(text): 114 | emoji_chars = [] 115 | non_emoji_chars = [] 116 | for c in text: 117 | if c in emoji.UNICODE_EMOJI: 118 | emoji_chars.append(c) 119 | else: 120 | non_emoji_chars.append(c) 121 | return ''.join(emoji_chars), ''.join(non_emoji_chars) 122 | 123 | def extract_emojis(text, wanted_emojis): 124 | text = remove_variation_selectors(text) 125 | return [c for c in text if c in wanted_emojis] 126 | 127 | def remove_variation_selectors(text): 128 | """ Remove styling glyph variants for Unicode characters. 129 | For instance, remove skin color from emojis. 130 | """ 131 | for var in VARIATION_SELECTORS: 132 | text = text.replace(var, '') 133 | return text 134 | 135 | def shorten_word(word): 136 | """ Shorten groupings of 3+ identical consecutive chars to 2, e.g. '!!!!' --> '!!' 137 | """ 138 | 139 | # only shorten ASCII words 140 | try: 141 | word.decode('ascii') 142 | except (UnicodeDecodeError, UnicodeEncodeError, AttributeError) as e: 143 | return word 144 | 145 | # must have at least 3 char to be shortened 146 | if len(word) < 3: 147 | return word 148 | 149 | # find groups of 3+ consecutive letters 150 | letter_groups = [list(g) for k, g in groupby(word)] 151 | triple_or_more = [''.join(g) for g in letter_groups if len(g) >= 3] 152 | if len(triple_or_more) == 0: 153 | return word 154 | 155 | # replace letters to find the short word 156 | short_word = word 157 | for trip in triple_or_more: 158 | short_word = short_word.replace(trip, trip[0]*2) 159 | 160 | return short_word 161 | 162 | def detect_special_tokens(word): 163 | try: 164 | int(word) 165 | word = SPECIAL_TOKENS[4] 166 | except ValueError: 167 | if AtMentionRegex.findall(word): 168 | word = SPECIAL_TOKENS[2] 169 | elif urlRegex.findall(word): 170 | word = SPECIAL_TOKENS[3] 171 | return word 172 | 173 | def process_word(word): 174 | """ Shortening and converting the word to a special token if relevant. 175 | """ 176 | word = shorten_word(word) 177 | word = detect_special_tokens(word) 178 | return word 179 | 180 | def remove_control_chars(text): 181 | return CONTROL_CHAR_REGEX.sub('', text) 182 | 183 | def convert_nonbreaking_space(text): 184 | # ugly hack handling non-breaking space no matter how badly it's been encoded in the input 185 | for r in ['\\\\xc2', '\\xc2', '\xc2', '\\\\xa0', '\\xa0', '\xa0']: 186 | text = text.replace(r, ' ') 187 | return text 188 | 189 | def convert_linebreaks(text): 190 | # ugly hack handling non-breaking space no matter how badly it's been encoded in the input 191 | # space around to ensure proper tokenization 192 | for r in ['\\\\n', '\\n', '\n', '\\\\r', '\\r', '\r', '
']: 193 | text = text.replace(r, ' ' + SPECIAL_TOKENS[5] + ' ') 194 | return text 195 | -------------------------------------------------------------------------------- /tests/test_finetuning.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division, unicode_literals 2 | 3 | import test_helper 4 | 5 | from nose.plugins.attrib import attr 6 | import json 7 | import numpy as np 8 | 9 | from torchmoji.class_avg_finetuning import relabel 10 | from torchmoji.sentence_tokenizer import SentenceTokenizer 11 | 12 | from torchmoji.finetuning import ( 13 | calculate_batchsize_maxlen, 14 | freeze_layers, 15 | change_trainable, 16 | finetune, 17 | load_benchmark 18 | ) 19 | from torchmoji.model_def import ( 20 | torchmoji_transfer, 21 | torchmoji_feature_encoding, 22 | torchmoji_emojis 23 | ) 24 | from torchmoji.global_variables import ( 25 | PRETRAINED_PATH, 26 | NB_TOKENS, 27 | VOCAB_PATH, 28 | ROOT_PATH 29 | ) 30 | 31 | 32 | def test_calculate_batchsize_maxlen(): 33 | """ Batch size and max length are calculated properly. 34 | """ 35 | texts = ['a b c d', 36 | 'e f g h i'] 37 | batch_size, maxlen = calculate_batchsize_maxlen(texts) 38 | 39 | assert batch_size == 250 40 | assert maxlen == 10, maxlen 41 | 42 | 43 | def test_freeze_layers(): 44 | """ Correct layers are frozen. 45 | """ 46 | model = torchmoji_transfer(5) 47 | keyword = 'output_layer' 48 | 49 | model = freeze_layers(model, unfrozen_keyword=keyword) 50 | 51 | for name, module in model.named_children(): 52 | trainable = keyword.lower() in name.lower() 53 | assert all(p.requires_grad == trainable for p in module.parameters()) 54 | 55 | 56 | def test_change_trainable(): 57 | """ change_trainable() changes trainability of layers. 58 | """ 59 | model = torchmoji_transfer(5) 60 | change_trainable(model.embed, False) 61 | assert not any(p.requires_grad for p in model.embed.parameters()) 62 | change_trainable(model.embed, True) 63 | assert all(p.requires_grad for p in model.embed.parameters()) 64 | 65 | 66 | def test_torchmoji_transfer_extend_embedding(): 67 | """ Defining torchmoji with extension. 68 | """ 69 | extend_with = 50 70 | model = torchmoji_transfer(5, weight_path=PRETRAINED_PATH, 71 | extend_embedding=extend_with) 72 | embedding_layer = model.embed 73 | assert embedding_layer.weight.size()[0] == NB_TOKENS + extend_with 74 | 75 | 76 | def test_torchmoji_return_attention(): 77 | seq_tensor = np.array([[1]]) 78 | # test the output of the normal model 79 | model = torchmoji_emojis(weight_path=PRETRAINED_PATH) 80 | # check correct number of outputs 81 | assert len(model(seq_tensor)) == 1 82 | # repeat above described tests when returning attention weights 83 | model = torchmoji_emojis(weight_path=PRETRAINED_PATH, return_attention=True) 84 | assert len(model(seq_tensor)) == 2 85 | 86 | 87 | def test_relabel(): 88 | """ relabel() works with multi-class labels. 89 | """ 90 | nb_classes = 3 91 | inputs = np.array([ 92 | [True, False, False], 93 | [False, True, False], 94 | [True, False, True], 95 | ]) 96 | expected_0 = np.array([True, False, True]) 97 | expected_1 = np.array([False, True, False]) 98 | expected_2 = np.array([False, False, True]) 99 | 100 | assert np.array_equal(relabel(inputs, 0, nb_classes), expected_0) 101 | assert np.array_equal(relabel(inputs, 1, nb_classes), expected_1) 102 | assert np.array_equal(relabel(inputs, 2, nb_classes), expected_2) 103 | 104 | 105 | def test_relabel_binary(): 106 | """ relabel() works with binary classification (no changes to labels) 107 | """ 108 | nb_classes = 2 109 | inputs = np.array([True, False, False]) 110 | 111 | assert np.array_equal(relabel(inputs, 0, nb_classes), inputs) 112 | 113 | 114 | @attr('slow') 115 | def test_finetune_full(): 116 | """ finetuning using 'full'. 117 | """ 118 | DATASET_PATH = ROOT_PATH+'/data/SS-Youtube/raw.pickle' 119 | nb_classes = 2 120 | # Keras and pyTorch implementation of the Adam optimizer are slightly different and change a bit the results 121 | # We reduce the min accuracy needed here to pass the test 122 | # See e.g. https://discuss.pytorch.org/t/suboptimal-convergence-when-compared-with-tensorflow-model/5099/11 123 | min_acc = 0.68 124 | 125 | with open(VOCAB_PATH, 'r') as f: 126 | vocab = json.load(f) 127 | 128 | data = load_benchmark(DATASET_PATH, vocab, extend_with=10000) 129 | print('Loading pyTorch model from {}.'.format(PRETRAINED_PATH)) 130 | model = torchmoji_transfer(nb_classes, PRETRAINED_PATH, extend_embedding=data['added']) 131 | print(model) 132 | model, acc = finetune(model, data['texts'], data['labels'], nb_classes, 133 | data['batch_size'], method='full', nb_epochs=1) 134 | 135 | print("Finetune full SS-Youtube 1 epoch acc: {}".format(acc)) 136 | assert acc >= min_acc 137 | 138 | 139 | @attr('slow') 140 | def test_finetune_last(): 141 | """ finetuning using 'last'. 142 | """ 143 | dataset_path = ROOT_PATH + '/data/SS-Youtube/raw.pickle' 144 | nb_classes = 2 145 | min_acc = 0.68 146 | 147 | with open(VOCAB_PATH, 'r') as f: 148 | vocab = json.load(f) 149 | 150 | data = load_benchmark(dataset_path, vocab) 151 | print('Loading model from {}.'.format(PRETRAINED_PATH)) 152 | model = torchmoji_transfer(nb_classes, PRETRAINED_PATH) 153 | print(model) 154 | model, acc = finetune(model, data['texts'], data['labels'], nb_classes, 155 | data['batch_size'], method='last', nb_epochs=1) 156 | 157 | print("Finetune last SS-Youtube 1 epoch acc: {}".format(acc)) 158 | 159 | assert acc >= min_acc 160 | 161 | 162 | def test_score_emoji(): 163 | """ Emoji predictions make sense. 164 | """ 165 | test_sentences = [ 166 | 'I love mom\'s cooking', 167 | 'I love how you never reply back..', 168 | 'I love cruising with my homies', 169 | 'I love messing with yo mind!!', 170 | 'I love you and now you\'re just gone..', 171 | 'This is shit', 172 | 'This is the shit' 173 | ] 174 | 175 | expected = [ 176 | np.array([36, 4, 8, 16, 47]), 177 | np.array([1, 19, 55, 25, 46]), 178 | np.array([31, 6, 30, 15, 13]), 179 | np.array([54, 44, 9, 50, 49]), 180 | np.array([46, 5, 27, 35, 34]), 181 | np.array([55, 32, 27, 1, 37]), 182 | np.array([48, 11, 6, 31, 9]) 183 | ] 184 | 185 | def top_elements(array, k): 186 | ind = np.argpartition(array, -k)[-k:] 187 | return ind[np.argsort(array[ind])][::-1] 188 | 189 | # Initialize by loading dictionary and tokenize texts 190 | with open(VOCAB_PATH, 'r') as f: 191 | vocabulary = json.load(f) 192 | 193 | st = SentenceTokenizer(vocabulary, 30) 194 | tokens, _, _ = st.tokenize_sentences(test_sentences) 195 | 196 | # Load model and run 197 | model = torchmoji_emojis(weight_path=PRETRAINED_PATH) 198 | prob = model(tokens) 199 | 200 | # Find top emojis for each sentence 201 | for i, t_prob in enumerate(list(prob)): 202 | assert np.array_equal(top_elements(t_prob, 5), expected[i]) 203 | 204 | 205 | def test_encode_texts(): 206 | """ Text encoding is stable. 207 | """ 208 | 209 | TEST_SENTENCES = ['I love mom\'s cooking', 210 | 'I love how you never reply back..', 211 | 'I love cruising with my homies', 212 | 'I love messing with yo mind!!', 213 | 'I love you and now you\'re just gone..', 214 | 'This is shit', 215 | 'This is the shit'] 216 | 217 | 218 | maxlen = 30 219 | batch_size = 32 220 | 221 | with open(VOCAB_PATH, 'r') as f: 222 | vocabulary = json.load(f) 223 | 224 | st = SentenceTokenizer(vocabulary, maxlen) 225 | 226 | print('Loading model from {}.'.format(PRETRAINED_PATH)) 227 | model = torchmoji_feature_encoding(PRETRAINED_PATH) 228 | print(model) 229 | tokenized, _, _ = st.tokenize_sentences(TEST_SENTENCES) 230 | encoding = model(tokenized) 231 | 232 | avg_across_sentences = np.around(np.mean(encoding, axis=0)[:5], 3) 233 | assert np.allclose(avg_across_sentences, np.array([-0.023, 0.021, -0.037, -0.001, -0.005])) 234 | 235 | test_encode_texts() -------------------------------------------------------------------------------- /torchmoji/create_vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function, division 3 | 4 | import glob 5 | import json 6 | import uuid 7 | from copy import deepcopy 8 | from collections import defaultdict, OrderedDict 9 | import numpy as np 10 | 11 | from torchmoji.filter_utils import is_special_token 12 | from torchmoji.word_generator import WordGenerator 13 | from torchmoji.global_variables import SPECIAL_TOKENS, VOCAB_PATH 14 | 15 | class VocabBuilder(): 16 | """ Create vocabulary with words extracted from sentences as fed from a 17 | word generator. 18 | """ 19 | def __init__(self, word_gen): 20 | # initialize any new key with value of 0 21 | self.word_counts = defaultdict(lambda: 0, {}) 22 | self.word_length_limit=30 23 | 24 | for token in SPECIAL_TOKENS: 25 | assert len(token) < self.word_length_limit 26 | self.word_counts[token] = 0 27 | self.word_gen = word_gen 28 | 29 | def count_words_in_sentence(self, words): 30 | """ Generates word counts for all tokens in the given sentence. 31 | 32 | # Arguments: 33 | words: Tokenized sentence whose words should be counted. 34 | """ 35 | for word in words: 36 | if 0 < len(word) and len(word) <= self.word_length_limit: 37 | try: 38 | self.word_counts[word] += 1 39 | except KeyError: 40 | self.word_counts[word] = 1 41 | 42 | def save_vocab(self, path=None): 43 | """ Saves the vocabulary into a file. 44 | 45 | # Arguments: 46 | path: Where the vocabulary should be saved. If not specified, a 47 | randomly generated filename is used instead. 48 | """ 49 | dtype = ([('word','|S{}'.format(self.word_length_limit)),('count','int')]) 50 | np_dict = np.array(self.word_counts.items(), dtype=dtype) 51 | 52 | # sort from highest to lowest frequency 53 | np_dict[::-1].sort(order='count') 54 | data = np_dict 55 | 56 | if path is None: 57 | path = str(uuid.uuid4()) 58 | 59 | np.savez_compressed(path, data=data) 60 | print("Saved dict to {}".format(path)) 61 | 62 | def get_next_word(self): 63 | """ Returns next tokenized sentence from the word geneerator. 64 | 65 | # Returns: 66 | List of strings, representing the next tokenized sentence. 67 | """ 68 | return self.word_gen.__iter__().next() 69 | 70 | def count_all_words(self): 71 | """ Generates word counts for all words in all sentences of the word 72 | generator. 73 | """ 74 | for words, _ in self.word_gen: 75 | self.count_words_in_sentence(words) 76 | 77 | class MasterVocab(): 78 | """ Combines vocabularies. 79 | """ 80 | def __init__(self): 81 | 82 | # initialize custom tokens 83 | self.master_vocab = {} 84 | 85 | def populate_master_vocab(self, vocab_path, min_words=1, force_appearance=None): 86 | """ Populates the master vocabulary using all vocabularies found in the 87 | given path. Vocabularies should be named *.npz. Expects the 88 | vocabularies to be numpy arrays with counts. Normalizes the counts 89 | and combines them. 90 | 91 | # Arguments: 92 | vocab_path: Path containing vocabularies to be combined. 93 | min_words: Minimum amount of occurences a word must have in order 94 | to be included in the master vocabulary. 95 | force_appearance: Optional vocabulary filename that will be added 96 | to the master vocabulary no matter what. This vocabulary must 97 | be present in vocab_path. 98 | """ 99 | 100 | paths = glob.glob(vocab_path + '*.npz') 101 | sizes = {path: 0 for path in paths} 102 | dicts = {path: {} for path in paths} 103 | 104 | # set up and get sizes of individual dictionaries 105 | for path in paths: 106 | np_data = np.load(path)['data'] 107 | 108 | for entry in np_data: 109 | word, count = entry 110 | if count < min_words: 111 | continue 112 | if is_special_token(word): 113 | continue 114 | dicts[path][word] = count 115 | 116 | sizes[path] = sum(dicts[path].values()) 117 | print('Overall word count for {} -> {}'.format(path, sizes[path])) 118 | print('Overall word number for {} -> {}'.format(path, len(dicts[path]))) 119 | 120 | vocab_of_max_size = max(sizes, key=sizes.get) 121 | max_size = sizes[vocab_of_max_size] 122 | print('Min: {}, {}, {}'.format(sizes, vocab_of_max_size, max_size)) 123 | 124 | # can force one vocabulary to always be present 125 | if force_appearance is not None: 126 | force_appearance_path = [p for p in paths if force_appearance in p][0] 127 | force_appearance_vocab = deepcopy(dicts[force_appearance_path]) 128 | print(force_appearance_path) 129 | else: 130 | force_appearance_path, force_appearance_vocab = None, None 131 | 132 | # normalize word counts before inserting into master dict 133 | for path in paths: 134 | normalization_factor = max_size / sizes[path] 135 | print('Norm factor for path {} -> {}'.format(path, normalization_factor)) 136 | 137 | for word in dicts[path]: 138 | if is_special_token(word): 139 | print("SPECIAL - ", word) 140 | continue 141 | normalized_count = dicts[path][word] * normalization_factor 142 | 143 | # can force one vocabulary to always be present 144 | if force_appearance_vocab is not None: 145 | try: 146 | force_word_count = force_appearance_vocab[word] 147 | except KeyError: 148 | continue 149 | #if force_word_count < 5: 150 | #continue 151 | 152 | if word in self.master_vocab: 153 | self.master_vocab[word] += normalized_count 154 | else: 155 | self.master_vocab[word] = normalized_count 156 | 157 | print('Size of master_dict {}'.format(len(self.master_vocab))) 158 | print("Hashes for master dict: {}".format( 159 | len([w for w in self.master_vocab if '#' in w[0]]))) 160 | 161 | def save_vocab(self, path_count, path_vocab, word_limit=100000): 162 | """ Saves the master vocabulary into a file. 163 | """ 164 | 165 | # reserve space for 10 special tokens 166 | words = OrderedDict() 167 | for token in SPECIAL_TOKENS: 168 | # store -1 instead of np.inf, which can overflow 169 | words[token] = -1 170 | 171 | # sort words by frequency 172 | desc_order = OrderedDict(sorted(self.master_vocab.items(), 173 | key=lambda kv: kv[1], reverse=True)) 174 | words.update(desc_order) 175 | 176 | # use encoding of up to 30 characters (no token conversions) 177 | # use float to store large numbers (we don't care about precision loss) 178 | np_vocab = np.array(words.items(), 179 | dtype=([('word','|S30'),('count','float')])) 180 | 181 | # output count for debugging 182 | counts = np_vocab[:word_limit] 183 | np.savez_compressed(path_count, counts=counts) 184 | 185 | # output the index of each word for easy lookup 186 | final_words = OrderedDict() 187 | for i, w in enumerate(words.keys()[:word_limit]): 188 | final_words.update({w:i}) 189 | with open(path_vocab, 'w') as f: 190 | f.write(json.dumps(final_words, indent=4, separators=(',', ': '))) 191 | 192 | 193 | def all_words_in_sentences(sentences): 194 | """ Extracts all unique words from a given list of sentences. 195 | 196 | # Arguments: 197 | sentences: List or word generator of sentences to be processed. 198 | 199 | # Returns: 200 | List of all unique words contained in the given sentences. 201 | """ 202 | vocab = [] 203 | if isinstance(sentences, WordGenerator): 204 | sentences = [s for s, _ in sentences] 205 | 206 | for sentence in sentences: 207 | for word in sentence: 208 | if word not in vocab: 209 | vocab.append(word) 210 | 211 | return vocab 212 | 213 | 214 | def extend_vocab_in_file(vocab, max_tokens=10000, vocab_path=VOCAB_PATH): 215 | """ Extends JSON-formatted vocabulary with words from vocab that are not 216 | present in the current vocabulary. Adds up to max_tokens words. 217 | Overwrites file in vocab_path. 218 | 219 | # Arguments: 220 | new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e. 221 | must have run count_all_words() previously. 222 | max_tokens: Maximum number of words to be added. 223 | vocab_path: Path to the vocabulary json which is to be extended. 224 | """ 225 | try: 226 | with open(vocab_path, 'r') as f: 227 | current_vocab = json.load(f) 228 | except IOError: 229 | print('Vocabulary file not found, expected at ' + vocab_path) 230 | return 231 | 232 | extend_vocab(current_vocab, vocab, max_tokens) 233 | 234 | # Save back to file 235 | with open(vocab_path, 'w') as f: 236 | json.dump(current_vocab, f, sort_keys=True, indent=4, separators=(',',': ')) 237 | 238 | 239 | def extend_vocab(current_vocab, new_vocab, max_tokens=10000): 240 | """ Extends current vocabulary with words from vocab that are not 241 | present in the current vocabulary. Adds up to max_tokens words. 242 | 243 | # Arguments: 244 | current_vocab: Current dictionary of tokens. 245 | new_vocab: Vocabulary to be added. MUST have word_counts populated, i.e. 246 | must have run count_all_words() previously. 247 | max_tokens: Maximum number of words to be added. 248 | 249 | # Returns: 250 | How many new tokens have been added. 251 | """ 252 | if max_tokens < 0: 253 | max_tokens = 10000 254 | 255 | words = OrderedDict() 256 | 257 | # sort words by frequency 258 | desc_order = OrderedDict(sorted(new_vocab.word_counts.items(), 259 | key=lambda kv: kv[1], reverse=True)) 260 | words.update(desc_order) 261 | 262 | base_index = len(current_vocab.keys()) 263 | added = 0 264 | for word in words: 265 | if added >= max_tokens: 266 | break 267 | if word not in current_vocab.keys(): 268 | current_vocab[word] = base_index + added 269 | added += 1 270 | 271 | return added 272 | -------------------------------------------------------------------------------- /torchmoji/sentence_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Provides functionality for converting a given list of tokens (words) into 4 | numbers, according to the given vocabulary. 5 | ''' 6 | from __future__ import print_function, division, unicode_literals 7 | 8 | import numbers 9 | import numpy as np 10 | 11 | from torchmoji.create_vocab import extend_vocab, VocabBuilder 12 | from torchmoji.word_generator import WordGenerator 13 | from torchmoji.global_variables import SPECIAL_TOKENS 14 | 15 | # import torch 16 | 17 | from sklearn.model_selection import train_test_split 18 | 19 | from copy import deepcopy 20 | 21 | class SentenceTokenizer(): 22 | """ Create numpy array of tokens corresponding to input sentences. 23 | The vocabulary can include Unicode tokens. 24 | """ 25 | def __init__(self, vocabulary, fixed_length, custom_wordgen=None, 26 | ignore_sentences_with_only_custom=False, masking_value=0, 27 | unknown_value=1): 28 | """ Needs a dictionary as input for the vocabulary. 29 | """ 30 | 31 | if len(vocabulary) > np.iinfo('uint16').max: 32 | raise ValueError('Dictionary is too big ({} tokens) for the numpy ' 33 | 'datatypes used (max limit={}). Reduce vocabulary' 34 | ' or adjust code accordingly!' 35 | .format(len(vocabulary), np.iinfo('uint16').max)) 36 | 37 | # Shouldn't be able to modify the given vocabulary 38 | self.vocabulary = deepcopy(vocabulary) 39 | self.fixed_length = fixed_length 40 | self.ignore_sentences_with_only_custom = ignore_sentences_with_only_custom 41 | self.masking_value = masking_value 42 | self.unknown_value = unknown_value 43 | 44 | # Initialized with an empty stream of sentences that must then be fed 45 | # to the generator at a later point for reusability. 46 | # A custom word generator can be used for domain-specific filtering etc 47 | if custom_wordgen is not None: 48 | assert custom_wordgen.stream is None 49 | self.wordgen = custom_wordgen 50 | self.uses_custom_wordgen = True 51 | else: 52 | self.wordgen = WordGenerator(None, allow_unicode_text=True, 53 | ignore_emojis=False, 54 | remove_variation_selectors=True, 55 | break_replacement=True) 56 | self.uses_custom_wordgen = False 57 | 58 | def tokenize_sentences(self, sentences, reset_stats=True, max_sentences=None): 59 | """ Converts a given list of sentences into a numpy array according to 60 | its vocabulary. 61 | 62 | # Arguments: 63 | sentences: List of sentences to be tokenized. 64 | reset_stats: Whether the word generator's stats should be reset. 65 | max_sentences: Maximum length of sentences. Must be set if the 66 | length cannot be inferred from the input. 67 | 68 | # Returns: 69 | Numpy array of the tokenization sentences with masking, 70 | infos, 71 | stats 72 | 73 | # Raises: 74 | ValueError: When maximum length is not set and cannot be inferred. 75 | """ 76 | 77 | if max_sentences is None and not hasattr(sentences, '__len__'): 78 | raise ValueError('Either you must provide an array with a length' 79 | 'attribute (e.g. a list) or specify the maximum ' 80 | 'length yourself using `max_sentences`!') 81 | n_sentences = (max_sentences if max_sentences is not None 82 | else len(sentences)) 83 | 84 | if self.masking_value == 0: 85 | tokens = np.zeros((n_sentences, self.fixed_length), dtype='uint16') 86 | else: 87 | tokens = (np.ones((n_sentences, self.fixed_length), dtype='uint16') 88 | * self.masking_value) 89 | 90 | if reset_stats: 91 | self.wordgen.reset_stats() 92 | 93 | # With a custom word generator info can be extracted from each 94 | # sentence (e.g. labels) 95 | infos = [] 96 | 97 | # Returns words as strings and then map them to vocabulary 98 | self.wordgen.stream = sentences 99 | next_insert = 0 100 | n_ignored_unknowns = 0 101 | for s_words, s_info in self.wordgen: 102 | s_tokens = self.find_tokens(s_words) 103 | 104 | if (self.ignore_sentences_with_only_custom and 105 | np.all([True if t < len(SPECIAL_TOKENS) 106 | else False for t in s_tokens])): 107 | n_ignored_unknowns += 1 108 | continue 109 | if len(s_tokens) > self.fixed_length: 110 | s_tokens = s_tokens[:self.fixed_length] 111 | tokens[next_insert,:len(s_tokens)] = s_tokens 112 | infos.append(s_info) 113 | next_insert += 1 114 | 115 | # For standard word generators all sentences should be tokenized 116 | # this is not necessarily the case for custom wordgenerators as they 117 | # may filter the sentences etc. 118 | if not self.uses_custom_wordgen and not self.ignore_sentences_with_only_custom: 119 | assert len(sentences) == next_insert 120 | else: 121 | # adjust based on actual tokens received 122 | tokens = tokens[:next_insert] 123 | infos = infos[:next_insert] 124 | 125 | return tokens, infos, self.wordgen.stats 126 | 127 | def find_tokens(self, words): 128 | assert len(words) > 0 129 | tokens = [] 130 | for w in words: 131 | try: 132 | tokens.append(self.vocabulary[w]) 133 | except KeyError: 134 | tokens.append(self.unknown_value) 135 | return tokens 136 | 137 | def split_train_val_test(self, sentences, info_dicts, 138 | split_parameter=[0.7, 0.1, 0.2], extend_with=0): 139 | """ Splits given sentences into three different datasets: training, 140 | validation and testing. 141 | 142 | # Arguments: 143 | sentences: The sentences to be tokenized. 144 | info_dicts: A list of dicts that contain information about each 145 | sentence (e.g. a label). 146 | split_parameter: A parameter for deciding the splits between the 147 | three different datasets. If instead of being passed three 148 | values, three lists are passed, then these will be used to 149 | specify which observation belong to which dataset. 150 | extend_with: An optional parameter. If > 0 then this is the number 151 | of tokens added to the vocabulary from this dataset. The 152 | expanded vocab will be generated using only the training set, 153 | but is applied to all three sets. 154 | 155 | # Returns: 156 | List of three lists of tokenized sentences, 157 | 158 | List of three corresponding dictionaries with information, 159 | 160 | How many tokens have been added to the vocab. Make sure to extend 161 | the embedding layer of the model accordingly. 162 | """ 163 | 164 | # If passed three lists, use those directly 165 | if isinstance(split_parameter, list) and \ 166 | all(isinstance(x, list) for x in split_parameter) and \ 167 | len(split_parameter) == 3: 168 | 169 | # Helper function to verify provided indices are numbers in range 170 | def verify_indices(inds): 171 | return list(filter(lambda i: isinstance(i, numbers.Number) 172 | and i < len(sentences), inds)) 173 | 174 | ind_train = verify_indices(split_parameter[0]) 175 | ind_val = verify_indices(split_parameter[1]) 176 | ind_test = verify_indices(split_parameter[2]) 177 | else: 178 | # Split sentences and dicts 179 | ind = list(range(len(sentences))) 180 | ind_train, ind_test = train_test_split(ind, test_size=split_parameter[2]) 181 | ind_train, ind_val = train_test_split(ind_train, test_size=split_parameter[1]) 182 | 183 | # Map indices to data 184 | train = np.array([sentences[x] for x in ind_train]) 185 | test = np.array([sentences[x] for x in ind_test]) 186 | val = np.array([sentences[x] for x in ind_val]) 187 | 188 | info_train = np.array([info_dicts[x] for x in ind_train]) 189 | info_test = np.array([info_dicts[x] for x in ind_test]) 190 | info_val = np.array([info_dicts[x] for x in ind_val]) 191 | 192 | added = 0 193 | # Extend vocabulary with training set tokens 194 | if extend_with > 0: 195 | wg = WordGenerator(train) 196 | vb = VocabBuilder(wg) 197 | vb.count_all_words() 198 | added = extend_vocab(self.vocabulary, vb, max_tokens=extend_with) 199 | 200 | # Wrap results 201 | result = [self.tokenize_sentences(s)[0] for s in [train, val, test]] 202 | result_infos = [info_train, info_val, info_test] 203 | # if type(result_infos[0][0]) in [np.double, np.float, np.int64, np.int32, np.uint8]: 204 | # result_infos = [torch.from_numpy(label).long() for label in result_infos] 205 | 206 | return result, result_infos, added 207 | 208 | def to_sentence(self, sentence_idx): 209 | """ Converts a tokenized sentence back to a list of words. 210 | 211 | # Arguments: 212 | sentence_idx: List of numbers, representing a tokenized sentence 213 | given the current vocabulary. 214 | 215 | # Returns: 216 | String created by converting all numbers back to words and joined 217 | together with spaces. 218 | """ 219 | # Have to recalculate the mappings in case the vocab was extended. 220 | ind_to_word = {ind: word for word, ind in self.vocabulary.items()} 221 | 222 | sentence_as_list = [ind_to_word[x] for x in sentence_idx] 223 | cleaned_list = [x for x in sentence_as_list if x != 'CUSTOM_MASK'] 224 | return " ".join(cleaned_list) 225 | 226 | 227 | def coverage(dataset, verbose=False): 228 | """ Computes the percentage of words in a given dataset that are unknown. 229 | 230 | # Arguments: 231 | dataset: Tokenized dataset to be checked. 232 | verbose: Verbosity flag. 233 | 234 | # Returns: 235 | Percentage of unknown tokens. 236 | """ 237 | n_total = np.count_nonzero(dataset) 238 | n_unknown = np.sum(dataset == 1) 239 | coverage = 1.0 - float(n_unknown) / n_total 240 | 241 | if verbose: 242 | print("Unknown words: {}".format(n_unknown)) 243 | print("Total words: {}".format(n_total)) 244 | print("Coverage: {}".format(coverage)) 245 | return coverage 246 | -------------------------------------------------------------------------------- /torchmoji/word_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' Extracts lists of words from a given input to be used for later vocabulary 3 | generation or for creating tokenized datasets. 4 | Supports functionality for handling different file types and 5 | filtering/processing of this input. 6 | ''' 7 | 8 | from __future__ import division, print_function, unicode_literals 9 | 10 | import re 11 | import unicodedata 12 | import numpy as np 13 | from text_unidecode import unidecode 14 | 15 | from torchmoji.tokenizer import RE_MENTION, tokenize 16 | from torchmoji.filter_utils import (convert_linebreaks, 17 | convert_nonbreaking_space, 18 | correct_length, 19 | extract_emojis, 20 | mostly_english, 21 | non_english_user, 22 | process_word, 23 | punct_word, 24 | remove_control_chars, 25 | remove_variation_selectors, 26 | separate_emojis_and_text) 27 | 28 | try: 29 | unicode # Python 2 30 | except NameError: 31 | unicode = str # Python 3 32 | 33 | # Only catch retweets in the beginning of the tweet as those are the 34 | # automatically added ones. 35 | # We do not want to remove tweets like "Omg.. please RT this!!" 36 | RETWEETS_RE = re.compile(r'^[rR][tT]') 37 | 38 | # Use fast and less precise regex for removing tweets with URLs 39 | # It doesn't matter too much if a few tweets with URL's make it through 40 | URLS_RE = re.compile(r'https?://|www\.') 41 | 42 | MENTION_RE = re.compile(RE_MENTION) 43 | ALLOWED_CONVERTED_UNICODE_PUNCTUATION = """!"#$'()+,-.:;<=>?@`~""" 44 | 45 | 46 | class WordGenerator(): 47 | ''' Cleanses input and converts into words. Needs all sentences to be in 48 | Unicode format. Has subclasses that read sentences differently based on 49 | file type. 50 | 51 | Takes a generator as input. This can be from e.g. a file. 52 | unicode_handling in ['ignore_sentence', 'convert_punctuation', 'allow'] 53 | unicode_handling in ['ignore_emoji', 'ignore_sentence', 'allow'] 54 | ''' 55 | def __init__(self, stream, allow_unicode_text=False, ignore_emojis=True, 56 | remove_variation_selectors=True, break_replacement=True): 57 | self.stream = stream 58 | self.allow_unicode_text = allow_unicode_text 59 | self.remove_variation_selectors = remove_variation_selectors 60 | self.ignore_emojis = ignore_emojis 61 | self.break_replacement = break_replacement 62 | self.reset_stats() 63 | 64 | def get_words(self, sentence): 65 | """ Tokenizes a sentence into individual words. 66 | Converts Unicode punctuation into ASCII if that option is set. 67 | Ignores sentences with Unicode if that option is set. 68 | Returns an empty list of words if the sentence has Unicode and 69 | that is not allowed. 70 | """ 71 | 72 | if not isinstance(sentence, unicode): 73 | raise ValueError("All sentences should be Unicode-encoded!") 74 | sentence = sentence.strip().lower() 75 | 76 | if self.break_replacement: 77 | sentence = convert_linebreaks(sentence) 78 | 79 | if self.remove_variation_selectors: 80 | sentence = remove_variation_selectors(sentence) 81 | 82 | # Split into words using simple whitespace splitting and convert 83 | # Unicode. This is done to prevent word splitting issues with 84 | # twokenize and Unicode 85 | words = sentence.split() 86 | converted_words = [] 87 | for w in words: 88 | accept_sentence, c_w = self.convert_unicode_word(w) 89 | # Unicode word detected and not allowed 90 | if not accept_sentence: 91 | return [] 92 | else: 93 | converted_words.append(c_w) 94 | sentence = ' '.join(converted_words) 95 | 96 | words = tokenize(sentence) 97 | words = [process_word(w) for w in words] 98 | return words 99 | 100 | def check_ascii(self, word): 101 | """ Returns whether a word is ASCII """ 102 | 103 | try: 104 | word.decode('ascii') 105 | return True 106 | except (UnicodeDecodeError, UnicodeEncodeError, AttributeError): 107 | return False 108 | 109 | def convert_unicode_punctuation(self, word): 110 | word_converted_punct = [] 111 | for c in word: 112 | decoded_c = unidecode(c).lower() 113 | if len(decoded_c) == 0: 114 | # Cannot decode to anything reasonable 115 | word_converted_punct.append(c) 116 | else: 117 | # Check if all punctuation and therefore fine 118 | # to include unidecoded version 119 | allowed_punct = punct_word( 120 | decoded_c, 121 | punctuation=ALLOWED_CONVERTED_UNICODE_PUNCTUATION) 122 | 123 | if allowed_punct: 124 | word_converted_punct.append(decoded_c) 125 | else: 126 | word_converted_punct.append(c) 127 | return ''.join(word_converted_punct) 128 | 129 | def convert_unicode_word(self, word): 130 | """ Converts Unicode words to ASCII using unidecode. If Unicode is not 131 | allowed (set as a variable during initialization), then only 132 | punctuation that can be converted to ASCII will be allowed. 133 | """ 134 | if self.check_ascii(word): 135 | return True, word 136 | 137 | # First we ensure that the Unicode is normalized so it's 138 | # always a single character. 139 | word = unicodedata.normalize("NFKC", word) 140 | 141 | # Convert Unicode punctuation to ASCII equivalent. We want 142 | # e.g. "\u203c" (double exclamation mark) to be treated the same 143 | # as "!!" no matter if we allow other Unicode characters or not. 144 | word = self.convert_unicode_punctuation(word) 145 | 146 | if self.ignore_emojis: 147 | _, word = separate_emojis_and_text(word) 148 | 149 | # If conversion of punctuation and removal of emojis took care 150 | # of all the Unicode or if we allow Unicode then everything is fine 151 | if self.check_ascii(word) or self.allow_unicode_text: 152 | return True, word 153 | else: 154 | # Sometimes we might want to simply ignore Unicode sentences 155 | # (e.g. for vocabulary creation). This is another way to prevent 156 | # "polution" of strange Unicode tokens from low quality datasets 157 | return False, '' 158 | 159 | def data_preprocess_filtering(self, line, iter_i): 160 | """ To be overridden with specific preprocessing/filtering behavior 161 | if desired. 162 | 163 | Returns a boolean of whether the line should be accepted and the 164 | preprocessed text. 165 | 166 | Runs prior to tokenization. 167 | """ 168 | return True, line, {} 169 | 170 | def data_postprocess_filtering(self, words, iter_i): 171 | """ To be overridden with specific postprocessing/filtering behavior 172 | if desired. 173 | 174 | Returns a boolean of whether the line should be accepted and the 175 | postprocessed text. 176 | 177 | Runs after tokenization. 178 | """ 179 | return True, words, {} 180 | 181 | def extract_valid_sentence_words(self, line): 182 | """ Line may either a string of a list of strings depending on how 183 | the stream is being parsed. 184 | Domain-specific processing and filtering can be done both prior to 185 | and after tokenization. 186 | Custom information about the line can be extracted during the 187 | processing phases and returned as a dict. 188 | """ 189 | 190 | info = {} 191 | 192 | pre_valid, pre_line, pre_info = \ 193 | self.data_preprocess_filtering(line, self.stats['total']) 194 | info.update(pre_info) 195 | if not pre_valid: 196 | self.stats['pretokenization_filtered'] += 1 197 | return False, [], info 198 | 199 | words = self.get_words(pre_line) 200 | if len(words) == 0: 201 | self.stats['unicode_filtered'] += 1 202 | return False, [], info 203 | 204 | post_valid, post_words, post_info = \ 205 | self.data_postprocess_filtering(words, self.stats['total']) 206 | info.update(post_info) 207 | if not post_valid: 208 | self.stats['posttokenization_filtered'] += 1 209 | return post_valid, post_words, info 210 | 211 | def generate_array_from_input(self): 212 | sentences = [] 213 | for words in self: 214 | sentences.append(words) 215 | return sentences 216 | 217 | def reset_stats(self): 218 | self.stats = {'pretokenization_filtered': 0, 219 | 'unicode_filtered': 0, 220 | 'posttokenization_filtered': 0, 221 | 'total': 0, 222 | 'valid': 0} 223 | 224 | def __iter__(self): 225 | if self.stream is None: 226 | raise ValueError("Stream should be set before iterating over it!") 227 | 228 | for line in self.stream: 229 | valid, words, info = self.extract_valid_sentence_words(line) 230 | 231 | # Words may be filtered away due to unidecode etc. 232 | # In that case the words should not be passed on. 233 | if valid and len(words): 234 | self.stats['valid'] += 1 235 | yield words, info 236 | 237 | self.stats['total'] += 1 238 | 239 | 240 | class TweetWordGenerator(WordGenerator): 241 | ''' Returns np array or generator of ASCII sentences for given tweet input. 242 | Any file opening/closing should be handled outside of this class. 243 | ''' 244 | def __init__(self, stream, wanted_emojis=None, english_words=None, 245 | non_english_user_set=None, allow_unicode_text=False, 246 | ignore_retweets=True, ignore_url_tweets=True, 247 | ignore_mention_tweets=False): 248 | 249 | self.wanted_emojis = wanted_emojis 250 | self.english_words = english_words 251 | self.non_english_user_set = non_english_user_set 252 | self.ignore_retweets = ignore_retweets 253 | self.ignore_url_tweets = ignore_url_tweets 254 | self.ignore_mention_tweets = ignore_mention_tweets 255 | WordGenerator.__init__(self, stream, 256 | allow_unicode_text=allow_unicode_text) 257 | 258 | def validated_tweet(self, data): 259 | ''' A bunch of checks to determine whether the tweet is valid. 260 | Also returns emojis contained by the tweet. 261 | ''' 262 | 263 | # Ordering of validations is important for speed 264 | # If it passes all checks, then the tweet is validated for usage 265 | 266 | # Skips incomplete tweets 267 | if len(data) <= 9: 268 | return False, [] 269 | 270 | text = data[9] 271 | 272 | if self.ignore_retweets and RETWEETS_RE.search(text): 273 | return False, [] 274 | 275 | if self.ignore_url_tweets and URLS_RE.search(text): 276 | return False, [] 277 | 278 | if self.ignore_mention_tweets and MENTION_RE.search(text): 279 | return False, [] 280 | 281 | if self.wanted_emojis is not None: 282 | uniq_emojis = np.unique(extract_emojis(text, self.wanted_emojis)) 283 | if len(uniq_emojis) == 0: 284 | return False, [] 285 | else: 286 | uniq_emojis = [] 287 | 288 | if self.non_english_user_set is not None and \ 289 | non_english_user(data[1], self.non_english_user_set): 290 | return False, [] 291 | return True, uniq_emojis 292 | 293 | def data_preprocess_filtering(self, line, iter_i): 294 | fields = line.strip().split("\t") 295 | valid, emojis = self.validated_tweet(fields) 296 | text = fields[9].replace('\\n', '') \ 297 | .replace('\\r', '') \ 298 | .replace('&', '&') if valid else '' 299 | return valid, text, {'emojis': emojis} 300 | 301 | def data_postprocess_filtering(self, words, iter_i): 302 | valid_length = correct_length(words, 1, None) 303 | valid_english, n_words, n_english = mostly_english(words, 304 | self.english_words) 305 | if valid_length and valid_english: 306 | return True, words, {'length': len(words), 307 | 'n_normal_words': n_words, 308 | 'n_english': n_english} 309 | else: 310 | return False, [], {'length': len(words), 311 | 'n_normal_words': n_words, 312 | 'n_english': n_english} 313 | -------------------------------------------------------------------------------- /torchmoji/lstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Implement a pyTorch LSTM with hard sigmoid reccurent activation functions. 3 | Adapted from the non-cuda variant of pyTorch LSTM at 4 | https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py 5 | """ 6 | 7 | from __future__ import print_function, division 8 | import math 9 | import torch 10 | 11 | from torch.nn import Module 12 | from torch.nn.parameter import Parameter 13 | from torch.nn.utils.rnn import PackedSequence 14 | import torch.nn.functional as F 15 | 16 | class LSTMHardSigmoid(Module): 17 | 18 | def __init__(self, input_size, hidden_size, 19 | num_layers=1, bias=True, batch_first=False, 20 | dropout=0, bidirectional=False): 21 | super(LSTMHardSigmoid, self).__init__() 22 | self.input_size = input_size 23 | self.hidden_size = hidden_size 24 | self.num_layers = num_layers 25 | self.bias = bias 26 | self.batch_first = batch_first 27 | self.dropout = dropout 28 | self.dropout_state = {} 29 | self.bidirectional = bidirectional 30 | num_directions = 2 if bidirectional else 1 31 | 32 | gate_size = 4 * hidden_size 33 | 34 | self._all_weights = [] 35 | for layer in range(num_layers): 36 | for direction in range(num_directions): 37 | layer_input_size = input_size if layer == 0 else hidden_size * num_directions 38 | 39 | w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) 40 | w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) 41 | b_ih = Parameter(torch.Tensor(gate_size)) 42 | b_hh = Parameter(torch.Tensor(gate_size)) 43 | layer_params = (w_ih, w_hh, b_ih, b_hh) 44 | 45 | suffix = '_reverse' if direction == 1 else '' 46 | param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] 47 | if bias: 48 | param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] 49 | param_names = [x.format(layer, suffix) for x in param_names] 50 | 51 | for name, param in zip(param_names, layer_params): 52 | setattr(self, name, param) 53 | self._all_weights.append(param_names) 54 | 55 | self.flatten_parameters() 56 | self.reset_parameters() 57 | 58 | def flatten_parameters(self): 59 | """Resets parameter data pointer so that they can use faster code paths. 60 | 61 | Right now, this is a no-op wince we don't use CUDA acceleration. 62 | """ 63 | self._data_ptrs = [] 64 | 65 | def _apply(self, fn): 66 | ret = super(LSTMHardSigmoid, self)._apply(fn) 67 | self.flatten_parameters() 68 | return ret 69 | 70 | def reset_parameters(self): 71 | stdv = 1.0 / math.sqrt(self.hidden_size) 72 | for weight in self.parameters(): 73 | weight.data.uniform_(-stdv, stdv) 74 | 75 | def forward(self, input, hx=None): 76 | is_packed = isinstance(input, PackedSequence) 77 | if is_packed: 78 | input, batch_sizes = input 79 | max_batch_size = batch_sizes[0] 80 | else: 81 | batch_sizes = None 82 | max_batch_size = input.size(0) if self.batch_first else input.size(1) 83 | 84 | if hx is None: 85 | num_directions = 2 if self.bidirectional else 1 86 | hx = torch.autograd.Variable(input.data.new(self.num_layers * 87 | num_directions, 88 | max_batch_size, 89 | self.hidden_size).zero_(), requires_grad=False) 90 | hx = (hx, hx) 91 | 92 | has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs 93 | if has_flat_weights: 94 | first_data = next(self.parameters()).data 95 | assert first_data.storage().size() == self._param_buf_size 96 | flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) 97 | else: 98 | flat_weight = None 99 | func = AutogradRNN( 100 | self.input_size, 101 | self.hidden_size, 102 | num_layers=self.num_layers, 103 | batch_first=self.batch_first, 104 | dropout=self.dropout, 105 | train=self.training, 106 | bidirectional=self.bidirectional, 107 | batch_sizes=batch_sizes, 108 | dropout_state=self.dropout_state, 109 | flat_weight=flat_weight 110 | ) 111 | output, hidden = func(input, self.all_weights, hx) 112 | if is_packed: 113 | output = PackedSequence(output, batch_sizes) 114 | return output, hidden 115 | 116 | def __repr__(self): 117 | s = '{name}({input_size}, {hidden_size}' 118 | if self.num_layers != 1: 119 | s += ', num_layers={num_layers}' 120 | if self.bias is not True: 121 | s += ', bias={bias}' 122 | if self.batch_first is not False: 123 | s += ', batch_first={batch_first}' 124 | if self.dropout != 0: 125 | s += ', dropout={dropout}' 126 | if self.bidirectional is not False: 127 | s += ', bidirectional={bidirectional}' 128 | s += ')' 129 | return s.format(name=self.__class__.__name__, **self.__dict__) 130 | 131 | def __setstate__(self, d): 132 | super(LSTMHardSigmoid, self).__setstate__(d) 133 | self.__dict__.setdefault('_data_ptrs', []) 134 | if 'all_weights' in d: 135 | self._all_weights = d['all_weights'] 136 | if isinstance(self._all_weights[0][0], str): 137 | return 138 | num_layers = self.num_layers 139 | num_directions = 2 if self.bidirectional else 1 140 | self._all_weights = [] 141 | for layer in range(num_layers): 142 | for direction in range(num_directions): 143 | suffix = '_reverse' if direction == 1 else '' 144 | weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] 145 | weights = [x.format(layer, suffix) for x in weights] 146 | if self.bias: 147 | self._all_weights += [weights] 148 | else: 149 | self._all_weights += [weights[:2]] 150 | 151 | @property 152 | def all_weights(self): 153 | return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] 154 | 155 | def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False, 156 | dropout=0, train=True, bidirectional=False, batch_sizes=None, 157 | dropout_state=None, flat_weight=None): 158 | 159 | cell = LSTMCell 160 | 161 | if batch_sizes is None: 162 | rec_factory = Recurrent 163 | else: 164 | rec_factory = variable_recurrent_factory(batch_sizes) 165 | 166 | if bidirectional: 167 | layer = (rec_factory(cell), rec_factory(cell, reverse=True)) 168 | else: 169 | layer = (rec_factory(cell),) 170 | 171 | func = StackedRNN(layer, 172 | num_layers, 173 | True, 174 | dropout=dropout, 175 | train=train) 176 | 177 | def forward(input, weight, hidden): 178 | if batch_first and batch_sizes is None: 179 | input = input.transpose(0, 1) 180 | 181 | nexth, output = func(input, hidden, weight) 182 | 183 | if batch_first and batch_sizes is None: 184 | output = output.transpose(0, 1) 185 | 186 | return output, nexth 187 | 188 | return forward 189 | 190 | def Recurrent(inner, reverse=False): 191 | def forward(input, hidden, weight): 192 | output = [] 193 | steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) 194 | for i in steps: 195 | hidden = inner(input[i], hidden, *weight) 196 | # hack to handle LSTM 197 | output.append(hidden[0] if isinstance(hidden, tuple) else hidden) 198 | 199 | if reverse: 200 | output.reverse() 201 | output = torch.cat(output, 0).view(input.size(0), *output[0].size()) 202 | 203 | return hidden, output 204 | 205 | return forward 206 | 207 | 208 | def variable_recurrent_factory(batch_sizes): 209 | def fac(inner, reverse=False): 210 | if reverse: 211 | return VariableRecurrentReverse(batch_sizes, inner) 212 | else: 213 | return VariableRecurrent(batch_sizes, inner) 214 | return fac 215 | 216 | def VariableRecurrent(batch_sizes, inner): 217 | def forward(input, hidden, weight): 218 | output = [] 219 | input_offset = 0 220 | last_batch_size = batch_sizes[0] 221 | hiddens = [] 222 | flat_hidden = not isinstance(hidden, tuple) 223 | if flat_hidden: 224 | hidden = (hidden,) 225 | for batch_size in batch_sizes: 226 | step_input = input[input_offset:input_offset + batch_size] 227 | input_offset += batch_size 228 | 229 | dec = last_batch_size - batch_size 230 | if dec > 0: 231 | hiddens.append(tuple(h[-dec:] for h in hidden)) 232 | hidden = tuple(h[:-dec] for h in hidden) 233 | last_batch_size = batch_size 234 | 235 | if flat_hidden: 236 | hidden = (inner(step_input, hidden[0], *weight),) 237 | else: 238 | hidden = inner(step_input, hidden, *weight) 239 | 240 | output.append(hidden[0]) 241 | hiddens.append(hidden) 242 | hiddens.reverse() 243 | 244 | hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) 245 | assert hidden[0].size(0) == batch_sizes[0] 246 | if flat_hidden: 247 | hidden = hidden[0] 248 | output = torch.cat(output, 0) 249 | 250 | return hidden, output 251 | 252 | return forward 253 | 254 | 255 | def VariableRecurrentReverse(batch_sizes, inner): 256 | def forward(input, hidden, weight): 257 | output = [] 258 | input_offset = input.size(0) 259 | last_batch_size = batch_sizes[-1] 260 | initial_hidden = hidden 261 | flat_hidden = not isinstance(hidden, tuple) 262 | if flat_hidden: 263 | hidden = (hidden,) 264 | initial_hidden = (initial_hidden,) 265 | hidden = tuple(h[:batch_sizes[-1]] for h in hidden) 266 | for batch_size in reversed(batch_sizes): 267 | inc = batch_size - last_batch_size 268 | if inc > 0: 269 | hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) 270 | for h, ih in zip(hidden, initial_hidden)) 271 | last_batch_size = batch_size 272 | step_input = input[input_offset - batch_size:input_offset] 273 | input_offset -= batch_size 274 | 275 | if flat_hidden: 276 | hidden = (inner(step_input, hidden[0], *weight),) 277 | else: 278 | hidden = inner(step_input, hidden, *weight) 279 | output.append(hidden[0]) 280 | 281 | output.reverse() 282 | output = torch.cat(output, 0) 283 | if flat_hidden: 284 | hidden = hidden[0] 285 | return hidden, output 286 | 287 | return forward 288 | 289 | def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): 290 | 291 | num_directions = len(inners) 292 | total_layers = num_layers * num_directions 293 | 294 | def forward(input, hidden, weight): 295 | assert(len(weight) == total_layers) 296 | next_hidden = [] 297 | 298 | if lstm: 299 | hidden = list(zip(*hidden)) 300 | 301 | for i in range(num_layers): 302 | all_output = [] 303 | for j, inner in enumerate(inners): 304 | l = i * num_directions + j 305 | 306 | hy, output = inner(input, hidden[l], weight[l]) 307 | next_hidden.append(hy) 308 | all_output.append(output) 309 | 310 | input = torch.cat(all_output, input.dim() - 1) 311 | 312 | if dropout != 0 and i < num_layers - 1: 313 | input = F.dropout(input, p=dropout, training=train, inplace=False) 314 | 315 | if lstm: 316 | next_h, next_c = zip(*next_hidden) 317 | next_hidden = ( 318 | torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), 319 | torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) 320 | ) 321 | else: 322 | next_hidden = torch.cat(next_hidden, 0).view( 323 | total_layers, *next_hidden[0].size()) 324 | 325 | return next_hidden, input 326 | 327 | return forward 328 | 329 | def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 330 | """ 331 | A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. 332 | """ 333 | hx, cx = hidden 334 | gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) 335 | 336 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 337 | 338 | ingate = hard_sigmoid(ingate) 339 | forgetgate = hard_sigmoid(forgetgate) 340 | cellgate = F.tanh(cellgate) 341 | outgate = hard_sigmoid(outgate) 342 | 343 | cy = (forgetgate * cx) + (ingate * cellgate) 344 | hy = outgate * F.tanh(cy) 345 | 346 | return hy, cy 347 | 348 | def hard_sigmoid(x): 349 | """ 350 | Computes element-wise hard sigmoid of x. 351 | See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 352 | """ 353 | x = (0.2 * x) + 0.5 354 | x = F.threshold(-x, -1, -1) 355 | x = F.threshold(-x, 0, 0) 356 | return x 357 | -------------------------------------------------------------------------------- /torchmoji/class_avg_finetuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Class average finetuning functions. Before using any of these finetuning 3 | functions, ensure that the model is set up with nb_classes=2. 4 | """ 5 | from __future__ import print_function 6 | 7 | import uuid 8 | from time import sleep 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | 15 | from torchmoji.global_variables import ( 16 | FINETUNING_METHODS, 17 | WEIGHTS_DIR) 18 | from torchmoji.finetuning import ( 19 | freeze_layers, 20 | get_data_loader, 21 | fit_model, 22 | train_by_chain_thaw, 23 | find_f1_threshold) 24 | 25 | def relabel(y, current_label_nr, nb_classes): 26 | """ Makes a binary classification for a specific class in a 27 | multi-class dataset. 28 | 29 | # Arguments: 30 | y: Outputs to be relabelled. 31 | current_label_nr: Current label number. 32 | nb_classes: Total number of classes. 33 | 34 | # Returns: 35 | Relabelled outputs of a given multi-class dataset into a binary 36 | classification dataset. 37 | """ 38 | 39 | # Handling binary classification 40 | if nb_classes == 2 and len(y.shape) == 1: 41 | return y 42 | 43 | y_new = np.zeros(len(y)) 44 | y_cut = y[:, current_label_nr] 45 | label_pos = np.where(y_cut == 1)[0] 46 | y_new[label_pos] = 1 47 | return y_new 48 | 49 | 50 | def class_avg_finetune(model, texts, labels, nb_classes, batch_size, 51 | method, epoch_size=5000, nb_epochs=1000, embed_l2=1E-6, 52 | verbose=True): 53 | """ Compiles and finetunes the given model. 54 | 55 | # Arguments: 56 | model: Model to be finetuned 57 | texts: List of three lists, containing tokenized inputs for training, 58 | validation and testing (in that order). 59 | labels: List of three lists, containing labels for training, 60 | validation and testing (in that order). 61 | nb_classes: Number of classes in the dataset. 62 | batch_size: Batch size. 63 | method: Finetuning method to be used. For available methods, see 64 | FINETUNING_METHODS in global_variables.py. Note that the model 65 | should be defined accordingly (see docstring for torchmoji_transfer()) 66 | epoch_size: Number of samples in an epoch. 67 | nb_epochs: Number of epochs. Doesn't matter much as early stopping is used. 68 | embed_l2: L2 regularization for the embedding layer. 69 | verbose: Verbosity flag. 70 | 71 | # Returns: 72 | Model after finetuning, 73 | score after finetuning using the class average F1 metric. 74 | """ 75 | 76 | if method not in FINETUNING_METHODS: 77 | raise ValueError('ERROR (class_avg_tune_trainable): ' 78 | 'Invalid method parameter. ' 79 | 'Available options: {}'.format(FINETUNING_METHODS)) 80 | 81 | (X_train, y_train) = (texts[0], labels[0]) 82 | (X_val, y_val) = (texts[1], labels[1]) 83 | (X_test, y_test) = (texts[2], labels[2]) 84 | 85 | checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \ 86 | .format(WEIGHTS_DIR, str(uuid.uuid4())) 87 | 88 | f1_init_path = '{}/torchmoji-f1-init-{}.bin' \ 89 | .format(WEIGHTS_DIR, str(uuid.uuid4())) 90 | 91 | if method in ['last', 'new']: 92 | lr = 0.001 93 | elif method in ['full', 'chain-thaw']: 94 | lr = 0.0001 95 | 96 | loss_op = nn.BCEWithLogitsLoss() 97 | 98 | # Freeze layers if using last 99 | if method == 'last': 100 | model = freeze_layers(model, unfrozen_keyword='output_layer') 101 | 102 | # Define optimizer, for chain-thaw we define it later (after freezing) 103 | if method == 'last': 104 | adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr) 105 | elif method in ['full', 'new']: 106 | # Add L2 regulation on embeddings only 107 | special_params = [id(p) for p in model.embed.parameters()] 108 | base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad] 109 | embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad] 110 | adam = optim.Adam([ 111 | {'params': base_params}, 112 | {'params': embed_parameters, 'weight_decay': embed_l2}, 113 | ], lr=lr) 114 | 115 | # Training 116 | if verbose: 117 | print('Method: {}'.format(method)) 118 | print('Classes: {}'.format(nb_classes)) 119 | 120 | if method == 'chain-thaw': 121 | result = class_avg_chainthaw(model, nb_classes=nb_classes, 122 | loss_op=loss_op, 123 | train=(X_train, y_train), 124 | val=(X_val, y_val), 125 | test=(X_test, y_test), 126 | batch_size=batch_size, 127 | epoch_size=epoch_size, 128 | nb_epochs=nb_epochs, 129 | checkpoint_weight_path=checkpoint_path, 130 | f1_init_weight_path=f1_init_path, 131 | verbose=verbose) 132 | else: 133 | result = class_avg_tune_trainable(model, nb_classes=nb_classes, 134 | loss_op=loss_op, 135 | optim_op=adam, 136 | train=(X_train, y_train), 137 | val=(X_val, y_val), 138 | test=(X_test, y_test), 139 | epoch_size=epoch_size, 140 | nb_epochs=nb_epochs, 141 | batch_size=batch_size, 142 | init_weight_path=f1_init_path, 143 | checkpoint_weight_path=checkpoint_path, 144 | verbose=verbose) 145 | return model, result 146 | 147 | 148 | def prepare_labels(y_train, y_val, y_test, iter_i, nb_classes): 149 | # Relabel into binary classification 150 | y_train_new = relabel(y_train, iter_i, nb_classes) 151 | y_val_new = relabel(y_val, iter_i, nb_classes) 152 | y_test_new = relabel(y_test, iter_i, nb_classes) 153 | return y_train_new, y_val_new, y_test_new 154 | 155 | def prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size): 156 | # Create sample generators 157 | # Make a fixed validation set to avoid fluctuations in validation 158 | train_gen = get_data_loader(X_train, y_train_new, batch_size, 159 | extended_batch_sampler=True) 160 | val_gen = get_data_loader(X_val, y_val_new, epoch_size, 161 | extended_batch_sampler=True) 162 | X_val_resamp, y_val_resamp = next(iter(val_gen)) 163 | return train_gen, X_val_resamp, y_val_resamp 164 | 165 | 166 | def class_avg_tune_trainable(model, nb_classes, loss_op, optim_op, train, val, test, 167 | epoch_size, nb_epochs, batch_size, 168 | init_weight_path, checkpoint_weight_path, patience=5, 169 | verbose=True): 170 | """ Finetunes the given model using the F1 measure. 171 | 172 | # Arguments: 173 | model: Model to be finetuned. 174 | nb_classes: Number of classes in the given dataset. 175 | train: Training data, given as a tuple of (inputs, outputs) 176 | val: Validation data, given as a tuple of (inputs, outputs) 177 | test: Testing data, given as a tuple of (inputs, outputs) 178 | epoch_size: Number of samples in an epoch. 179 | nb_epochs: Number of epochs. 180 | batch_size: Batch size. 181 | init_weight_path: Filepath where weights will be initially saved before 182 | training each class. This file will be rewritten by the function. 183 | checkpoint_weight_path: Filepath where weights will be checkpointed to 184 | during training. This file will be rewritten by the function. 185 | verbose: Verbosity flag. 186 | 187 | # Returns: 188 | F1 score of the trained model 189 | """ 190 | total_f1 = 0 191 | nb_iter = nb_classes if nb_classes > 2 else 1 192 | 193 | # Unpack args 194 | X_train, y_train = train 195 | X_val, y_val = val 196 | X_test, y_test = test 197 | 198 | # Save and reload initial weights after running for 199 | # each class to avoid learning across classes 200 | torch.save(model.state_dict(), init_weight_path) 201 | for i in range(nb_iter): 202 | if verbose: 203 | print('Iteration number {}/{}'.format(i+1, nb_iter)) 204 | 205 | model.load_state_dict(torch.load(init_weight_path)) 206 | y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val, 207 | y_test, i, nb_classes) 208 | train_gen, X_val_resamp, y_val_resamp = \ 209 | prepare_generators(X_train, y_train_new, X_val, y_val_new, 210 | batch_size, epoch_size) 211 | 212 | if verbose: 213 | print("Training..") 214 | fit_model(model, loss_op, optim_op, train_gen, [(X_val_resamp, y_val_resamp)], 215 | nb_epochs, checkpoint_weight_path, patience, verbose=0) 216 | 217 | # Reload the best weights found to avoid overfitting 218 | # Wait a bit to allow proper closing of weights file 219 | sleep(1) 220 | model.load_state_dict(torch.load(checkpoint_weight_path)) 221 | 222 | # Evaluate 223 | y_pred_val = model(X_val).cpu().numpy() 224 | y_pred_test = model(X_test).cpu().numpy() 225 | 226 | f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val, 227 | y_test_new, y_pred_test) 228 | if verbose: 229 | print('f1_test: {}'.format(f1_test)) 230 | print('best_t: {}'.format(best_t)) 231 | total_f1 += f1_test 232 | 233 | return total_f1 / nb_iter 234 | 235 | 236 | def class_avg_chainthaw(model, nb_classes, loss_op, train, val, test, batch_size, 237 | epoch_size, nb_epochs, checkpoint_weight_path, 238 | f1_init_weight_path, patience=5, 239 | initial_lr=0.001, next_lr=0.0001, verbose=True): 240 | """ Finetunes given model using chain-thaw and evaluates using F1. 241 | For a dataset with multiple classes, the model is trained once for 242 | each class, relabeling those classes into a binary classification task. 243 | The result is an average of all F1 scores for each class. 244 | 245 | # Arguments: 246 | model: Model to be finetuned. 247 | nb_classes: Number of classes in the given dataset. 248 | train: Training data, given as a tuple of (inputs, outputs) 249 | val: Validation data, given as a tuple of (inputs, outputs) 250 | test: Testing data, given as a tuple of (inputs, outputs) 251 | batch_size: Batch size. 252 | loss: Loss function to be used during training. 253 | epoch_size: Number of samples in an epoch. 254 | nb_epochs: Number of epochs. 255 | checkpoint_weight_path: Filepath where weights will be checkpointed to 256 | during training. This file will be rewritten by the function. 257 | f1_init_weight_path: Filepath where weights will be saved to and 258 | reloaded from before training each class. This ensures that 259 | each class is trained independently. This file will be rewritten. 260 | initial_lr: Initial learning rate. Will only be used for the first 261 | training step (i.e. the softmax layer) 262 | next_lr: Learning rate for every subsequent step. 263 | seed: Random number generator seed. 264 | verbose: Verbosity flag. 265 | 266 | # Returns: 267 | Averaged F1 score. 268 | """ 269 | 270 | # Unpack args 271 | X_train, y_train = train 272 | X_val, y_val = val 273 | X_test, y_test = test 274 | 275 | total_f1 = 0 276 | nb_iter = nb_classes if nb_classes > 2 else 1 277 | 278 | torch.save(model.state_dict(), f1_init_weight_path) 279 | 280 | for i in range(nb_iter): 281 | if verbose: 282 | print('Iteration number {}/{}'.format(i+1, nb_iter)) 283 | 284 | model.load_state_dict(torch.load(f1_init_weight_path)) 285 | y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val, 286 | y_test, i, nb_classes) 287 | train_gen, X_val_resamp, y_val_resamp = \ 288 | prepare_generators(X_train, y_train_new, X_val, y_val_new, 289 | batch_size, epoch_size) 290 | 291 | if verbose: 292 | print("Training..") 293 | 294 | # Train using chain-thaw 295 | train_by_chain_thaw(model=model, train_gen=train_gen, 296 | val_gen=[(X_val_resamp, y_val_resamp)], 297 | loss_op=loss_op, patience=patience, 298 | nb_epochs=nb_epochs, 299 | checkpoint_path=checkpoint_weight_path, 300 | initial_lr=initial_lr, next_lr=next_lr, 301 | verbose=verbose) 302 | 303 | # Evaluate 304 | y_pred_val = model(X_val).cpu().numpy() 305 | y_pred_test = model(X_test).cpu().numpy() 306 | 307 | f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val, 308 | y_test_new, y_pred_test) 309 | 310 | if verbose: 311 | print('f1_test: {}'.format(f1_test)) 312 | print('best_t: {}'.format(best_t)) 313 | total_f1 += f1_test 314 | 315 | return total_f1 / nb_iter 316 | -------------------------------------------------------------------------------- /torchmoji/model_def.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Model definition functions and weight loading. 3 | """ 4 | 5 | from __future__ import print_function, division, unicode_literals 6 | 7 | from os.path import exists 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence 13 | 14 | from torchmoji.lstm import LSTMHardSigmoid 15 | from torchmoji.attlayer import Attention 16 | from torchmoji.global_variables import NB_TOKENS, NB_EMOJI_CLASSES 17 | 18 | 19 | def torchmoji_feature_encoding(weight_path, return_attention=False): 20 | """ Loads the pretrained torchMoji model for extracting features 21 | from the penultimate feature layer. In this way, it transforms 22 | the text into its emotional encoding. 23 | 24 | # Arguments: 25 | weight_path: Path to model weights to be loaded. 26 | return_attention: If true, output will include weight of each input token 27 | used for the prediction 28 | 29 | # Returns: 30 | Pretrained model for encoding text into feature vectors. 31 | """ 32 | 33 | model = TorchMoji(nb_classes=None, 34 | nb_tokens=NB_TOKENS, 35 | feature_output=True, 36 | return_attention=return_attention) 37 | load_specific_weights(model, weight_path, exclude_names=['output_layer']) 38 | return model 39 | 40 | 41 | def torchmoji_emojis(weight_path, return_attention=False): 42 | """ Loads the pretrained torchMoji model for extracting features 43 | from the penultimate feature layer. In this way, it transforms 44 | the text into its emotional encoding. 45 | 46 | # Arguments: 47 | weight_path: Path to model weights to be loaded. 48 | return_attention: If true, output will include weight of each input token 49 | used for the prediction 50 | 51 | # Returns: 52 | Pretrained model for encoding text into feature vectors. 53 | """ 54 | 55 | model = TorchMoji(nb_classes=NB_EMOJI_CLASSES, 56 | nb_tokens=NB_TOKENS, 57 | return_attention=return_attention) 58 | model.load_state_dict(torch.load(weight_path)) 59 | return model 60 | 61 | 62 | def torchmoji_transfer(nb_classes, weight_path=None, extend_embedding=0, 63 | embed_dropout_rate=0.1, final_dropout_rate=0.5): 64 | """ Loads the pretrained torchMoji model for finetuning/transfer learning. 65 | Does not load weights for the softmax layer. 66 | 67 | Note that if you are planning to use class average F1 for evaluation, 68 | nb_classes should be set to 2 instead of the actual number of classes 69 | in the dataset, since binary classification will be performed on each 70 | class individually. 71 | 72 | Note that for the 'new' method, weight_path should be left as None. 73 | 74 | # Arguments: 75 | nb_classes: Number of classes in the dataset. 76 | weight_path: Path to model weights to be loaded. 77 | extend_embedding: Number of tokens that have been added to the 78 | vocabulary on top of NB_TOKENS. If this number is larger than 0, 79 | the embedding layer's dimensions are adjusted accordingly, with the 80 | additional weights being set to random values. 81 | embed_dropout_rate: Dropout rate for the embedding layer. 82 | final_dropout_rate: Dropout rate for the final Softmax layer. 83 | 84 | # Returns: 85 | Model with the given parameters. 86 | """ 87 | 88 | model = TorchMoji(nb_classes=nb_classes, 89 | nb_tokens=NB_TOKENS + extend_embedding, 90 | embed_dropout_rate=embed_dropout_rate, 91 | final_dropout_rate=final_dropout_rate, 92 | output_logits=True) 93 | if weight_path is not None: 94 | load_specific_weights(model, weight_path, 95 | exclude_names=['output_layer'], 96 | extend_embedding=extend_embedding) 97 | return model 98 | 99 | 100 | class TorchMoji(nn.Module): 101 | def __init__(self, nb_classes, nb_tokens, feature_output=False, output_logits=False, 102 | embed_dropout_rate=0, final_dropout_rate=0, return_attention=False): 103 | """ 104 | torchMoji model. 105 | IMPORTANT: The model is loaded in evaluation mode by default (self.eval()) 106 | 107 | # Arguments: 108 | nb_classes: Number of classes in the dataset. 109 | nb_tokens: Number of tokens in the dataset (i.e. vocabulary size). 110 | feature_output: If True the model returns the penultimate 111 | feature vector rather than Softmax probabilities 112 | (defaults to False). 113 | output_logits: If True the model returns logits rather than probabilities 114 | (defaults to False). 115 | embed_dropout_rate: Dropout rate for the embedding layer. 116 | final_dropout_rate: Dropout rate for the final Softmax layer. 117 | return_attention: If True the model also returns attention weights over the sentence 118 | (defaults to False). 119 | """ 120 | super(TorchMoji, self).__init__() 121 | 122 | embedding_dim = 256 123 | hidden_size = 512 124 | attention_size = 4 * hidden_size + embedding_dim 125 | 126 | self.feature_output = feature_output 127 | self.embed_dropout_rate = embed_dropout_rate 128 | self.final_dropout_rate = final_dropout_rate 129 | self.return_attention = return_attention 130 | self.hidden_size = hidden_size 131 | self.output_logits = output_logits 132 | self.nb_classes = nb_classes 133 | 134 | self.add_module('embed', nn.Embedding(nb_tokens, embedding_dim)) 135 | # dropout2D: embedding channels are dropped out instead of words 136 | # many exampels in the datasets contain few words that losing one or more words can alter the emotions completely 137 | self.add_module('embed_dropout', nn.Dropout2d(embed_dropout_rate)) 138 | self.add_module('lstm_0', LSTMHardSigmoid(embedding_dim, hidden_size, batch_first=True, bidirectional=True)) 139 | self.add_module('lstm_1', LSTMHardSigmoid(hidden_size*2, hidden_size, batch_first=True, bidirectional=True)) 140 | self.add_module('attention_layer', Attention(attention_size=attention_size, return_attention=return_attention)) 141 | if not feature_output: 142 | self.add_module('final_dropout', nn.Dropout(final_dropout_rate)) 143 | if output_logits: 144 | self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1))) 145 | else: 146 | self.add_module('output_layer', nn.Sequential(nn.Linear(attention_size, nb_classes if self.nb_classes > 2 else 1), 147 | nn.Softmax() if self.nb_classes > 2 else nn.Sigmoid())) 148 | self.init_weights() 149 | # Put model in evaluation mode by default 150 | self.eval() 151 | 152 | def init_weights(self): 153 | """ 154 | Here we reproduce Keras default initialization weights for consistency with Keras version 155 | """ 156 | ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) 157 | hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) 158 | b = (param.data for name, param in self.named_parameters() if 'bias' in name) 159 | nn.init.uniform(self.embed.weight.data, a=-0.5, b=0.5) 160 | for t in ih: 161 | nn.init.xavier_uniform(t) 162 | for t in hh: 163 | nn.init.orthogonal(t) 164 | for t in b: 165 | nn.init.constant(t, 0) 166 | if not self.feature_output: 167 | nn.init.xavier_uniform(self.output_layer[0].weight.data) 168 | 169 | def forward(self, input_seqs): 170 | """ Forward pass. 171 | 172 | # Arguments: 173 | input_seqs: Can be one of Numpy array, Torch.LongTensor, Torch.Variable, Torch.PackedSequence. 174 | 175 | # Return: 176 | Same format as input format (except for PackedSequence returned as Variable). 177 | """ 178 | # Check if we have Torch.LongTensor inputs or not Torch.Variable (assume Numpy array in this case), take note to return same format 179 | return_numpy = False 180 | return_tensor = False 181 | if isinstance(input_seqs, (torch.LongTensor, torch.cuda.LongTensor)): 182 | input_seqs = Variable(input_seqs) 183 | return_tensor = True 184 | elif not isinstance(input_seqs, Variable): 185 | input_seqs = Variable(torch.from_numpy(input_seqs.astype('int64')).long()) 186 | return_numpy = True 187 | 188 | # If we don't have a packed inputs, let's pack it 189 | reorder_output = False 190 | if not isinstance(input_seqs, PackedSequence): 191 | ho = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_() 192 | co = self.lstm_0.weight_hh_l0.data.new(2, input_seqs.size()[0], self.hidden_size).zero_() 193 | 194 | # Reorder batch by sequence length 195 | input_lengths = torch.LongTensor([torch.max(input_seqs[i, :].data.nonzero()) + 1 for i in range(input_seqs.size()[0])]) 196 | input_lengths, perm_idx = input_lengths.sort(0, descending=True) 197 | input_seqs = input_seqs[perm_idx][:, :input_lengths.max()] 198 | 199 | # Pack sequence and work on data tensor to reduce embeddings/dropout computations 200 | packed_input = pack_padded_sequence(input_seqs, input_lengths.cpu().numpy(), batch_first=True) 201 | reorder_output = True 202 | else: 203 | ho = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_() 204 | co = self.lstm_0.weight_hh_l0.data.data.new(2, input_seqs.size()[0], self.hidden_size).zero_() 205 | input_lengths = input_seqs.batch_sizes 206 | packed_input = input_seqs 207 | 208 | hidden = (Variable(ho, requires_grad=False), Variable(co, requires_grad=False)) 209 | 210 | # Embed with an activation function to bound the values of the embeddings 211 | x = self.embed(packed_input.data) 212 | x = nn.Tanh()(x) 213 | 214 | # pyTorch 2D dropout2d operate on axis 1 which is fine for us 215 | x = self.embed_dropout(x) 216 | 217 | # Update packed sequence data for RNN 218 | packed_input = PackedSequence(x, packed_input.batch_sizes) 219 | 220 | # skip-connection from embedding to output eases gradient-flow and allows access to lower-level features 221 | # ordering of the way the merge is done is important for consistency with the pretrained model 222 | lstm_0_output, _ = self.lstm_0(packed_input, hidden) 223 | lstm_1_output, _ = self.lstm_1(lstm_0_output, hidden) 224 | 225 | # Update packed sequence data for attention layer 226 | packed_input = PackedSequence(torch.cat((lstm_1_output.data, 227 | lstm_0_output.data, 228 | packed_input.data), dim=1), 229 | packed_input.batch_sizes) 230 | 231 | input_seqs, _ = pad_packed_sequence(packed_input, batch_first=True) 232 | 233 | x, att_weights = self.attention_layer(input_seqs, input_lengths) 234 | 235 | # output class probabilities or penultimate feature vector 236 | if not self.feature_output: 237 | x = self.final_dropout(x) 238 | outputs = self.output_layer(x) 239 | else: 240 | outputs = x 241 | 242 | # Reorder output if needed 243 | if reorder_output: 244 | reorered = Variable(outputs.data.new(outputs.size())) 245 | reorered[perm_idx] = outputs 246 | outputs = reorered 247 | 248 | # Adapt return format if needed 249 | if return_tensor: 250 | outputs = outputs.data 251 | if return_numpy: 252 | outputs = outputs.data.numpy() 253 | 254 | if self.return_attention: 255 | return outputs, att_weights 256 | else: 257 | return outputs 258 | 259 | 260 | def load_specific_weights(model, weight_path, exclude_names=[], extend_embedding=0, verbose=True): 261 | """ Loads model weights from the given file path, excluding any 262 | given layers. 263 | 264 | # Arguments: 265 | model: Model whose weights should be loaded. 266 | weight_path: Path to file containing model weights. 267 | exclude_names: List of layer names whose weights should not be loaded. 268 | extend_embedding: Number of new words being added to vocabulary. 269 | verbose: Verbosity flag. 270 | 271 | # Raises: 272 | ValueError if the file at weight_path does not exist. 273 | """ 274 | if not exists(weight_path): 275 | raise ValueError('ERROR (load_weights): The weights file at {} does ' 276 | 'not exist. Refer to the README for instructions.' 277 | .format(weight_path)) 278 | 279 | if extend_embedding and 'embed' in exclude_names: 280 | raise ValueError('ERROR (load_weights): Cannot extend a vocabulary ' 281 | 'without loading the embedding weights.') 282 | 283 | # Copy only weights from the temporary model that are wanted 284 | # for the specific task (e.g. the Softmax is often ignored) 285 | weights = torch.load(weight_path) 286 | for key, weight in weights.items(): 287 | if any(excluded in key for excluded in exclude_names): 288 | if verbose: 289 | print('Ignoring weights for {}'.format(key)) 290 | continue 291 | 292 | try: 293 | model_w = model.state_dict()[key] 294 | except KeyError: 295 | raise KeyError("Weights had parameters {},".format(key) 296 | + " but could not find this parameters in model.") 297 | 298 | if verbose: 299 | print('Loading weights for {}'.format(key)) 300 | 301 | # extend embedding layer to allow new randomly initialized words 302 | # if requested. Otherwise, just load the weights for the layer. 303 | if 'embed' in key and extend_embedding > 0: 304 | weight = torch.cat((weight, model_w[NB_TOKENS:, :]), dim=0) 305 | if verbose: 306 | print('Extended vocabulary for embedding layer ' + 307 | 'from {} to {} tokens.'.format( 308 | NB_TOKENS, NB_TOKENS + extend_embedding)) 309 | try: 310 | model_w.copy_(weight) 311 | except: 312 | print('While copying the weigths named {}, whose dimensions in the model are' 313 | ' {} and whose dimensions in the saved file are {}, ...'.format( 314 | key, model_w.size(), weight.size())) 315 | raise 316 | -------------------------------------------------------------------------------- /torchmoji/finetuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Finetuning functions for doing transfer learning to new datasets. 3 | """ 4 | from __future__ import print_function 5 | 6 | import uuid 7 | from time import sleep 8 | from io import open 9 | 10 | import math 11 | import pickle 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from sklearn.metrics import accuracy_score 18 | from torch.autograd import Variable 19 | from torch.utils.data import Dataset, DataLoader 20 | from torch.utils.data.sampler import BatchSampler, SequentialSampler 21 | from torch.nn.utils import clip_grad_norm 22 | 23 | from sklearn.metrics import f1_score 24 | 25 | from torchmoji.global_variables import (FINETUNING_METHODS, 26 | FINETUNING_METRICS, 27 | WEIGHTS_DIR) 28 | from torchmoji.tokenizer import tokenize 29 | from torchmoji.sentence_tokenizer import SentenceTokenizer 30 | 31 | try: 32 | unicode 33 | IS_PYTHON2 = True 34 | except NameError: 35 | unicode = str 36 | IS_PYTHON2 = False 37 | 38 | 39 | def load_benchmark(path, vocab, extend_with=0): 40 | """ Loads the given benchmark dataset. 41 | 42 | Tokenizes the texts using the provided vocabulary, extending it with 43 | words from the training dataset if extend_with > 0. Splits them into 44 | three lists: training, validation and testing (in that order). 45 | 46 | Also calculates the maximum length of the texts and the 47 | suggested batch_size. 48 | 49 | # Arguments: 50 | path: Path to the dataset to be loaded. 51 | vocab: Vocabulary to be used for tokenizing texts. 52 | extend_with: If > 0, the vocabulary will be extended with up to 53 | extend_with tokens from the training set before tokenizing. 54 | 55 | # Returns: 56 | A dictionary with the following fields: 57 | texts: List of three lists, containing tokenized inputs for 58 | training, validation and testing (in that order). 59 | labels: List of three lists, containing labels for training, 60 | validation and testing (in that order). 61 | added: Number of tokens added to the vocabulary. 62 | batch_size: Batch size. 63 | maxlen: Maximum length of an input. 64 | """ 65 | # Pre-processing dataset 66 | with open(path, 'rb') as dataset: 67 | if IS_PYTHON2: 68 | data = pickle.load(dataset) 69 | else: 70 | data = pickle.load(dataset, fix_imports=True) 71 | 72 | # Decode data 73 | try: 74 | texts = [unicode(x) for x in data['texts']] 75 | except UnicodeDecodeError: 76 | texts = [x.decode('utf-8') for x in data['texts']] 77 | 78 | # Extract labels 79 | labels = [x['label'] for x in data['info']] 80 | 81 | batch_size, maxlen = calculate_batchsize_maxlen(texts) 82 | 83 | st = SentenceTokenizer(vocab, maxlen) 84 | 85 | # Split up dataset. Extend the existing vocabulary with up to extend_with 86 | # tokens from the training dataset. 87 | texts, labels, added = st.split_train_val_test(texts, 88 | labels, 89 | [data['train_ind'], 90 | data['val_ind'], 91 | data['test_ind']], 92 | extend_with=extend_with) 93 | return {'texts': texts, 94 | 'labels': labels, 95 | 'added': added, 96 | 'batch_size': batch_size, 97 | 'maxlen': maxlen} 98 | 99 | 100 | def calculate_batchsize_maxlen(texts): 101 | """ Calculates the maximum length in the provided texts and a suitable 102 | batch size. Rounds up maxlen to the nearest multiple of ten. 103 | 104 | # Arguments: 105 | texts: List of inputs. 106 | 107 | # Returns: 108 | Batch size, 109 | max length 110 | """ 111 | def roundup(x): 112 | return int(math.ceil(x / 10.0)) * 10 113 | 114 | # Calculate max length of sequences considered 115 | # Adjust batch_size accordingly to prevent GPU overflow 116 | lengths = [len(tokenize(t)) for t in texts] 117 | maxlen = roundup(np.percentile(lengths, 80.0)) 118 | batch_size = 250 if maxlen <= 100 else 50 119 | return batch_size, maxlen 120 | 121 | 122 | 123 | def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None): 124 | """ Freezes all layers in the given model, except for ones that are 125 | explicitly specified to not be frozen. 126 | 127 | # Arguments: 128 | model: Model whose layers should be modified. 129 | unfrozen_types: List of layer types which shouldn't be frozen. 130 | unfrozen_keyword: Name keywords of layers that shouldn't be frozen. 131 | 132 | # Returns: 133 | Model with the selected layers frozen. 134 | """ 135 | # Get trainable modules 136 | trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0] 137 | for name, module in trainable_modules: 138 | trainable = (any(typ in str(module) for typ in unfrozen_types) or 139 | (unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower())) 140 | change_trainable(module, trainable, verbose=False) 141 | return model 142 | 143 | 144 | def change_trainable(module, trainable, verbose=False): 145 | """ Helper method that freezes or unfreezes a given layer. 146 | 147 | # Arguments: 148 | module: Module to be modified. 149 | trainable: Whether the layer should be frozen or unfrozen. 150 | verbose: Verbosity flag. 151 | """ 152 | 153 | if verbose: print('Changing MODULE', module, 'to trainable =', trainable) 154 | for name, param in module.named_parameters(): 155 | if verbose: print('Setting weight', name, 'to trainable =', trainable) 156 | param.requires_grad = trainable 157 | 158 | if verbose: 159 | action = 'Unfroze' if trainable else 'Froze' 160 | if verbose: print("{} {}".format(action, module)) 161 | 162 | 163 | def find_f1_threshold(model, val_gen, test_gen, average='binary'): 164 | """ Choose a threshold for F1 based on the validation dataset 165 | (see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/ 166 | for details on why to find another threshold than simply 0.5) 167 | 168 | # Arguments: 169 | model: pyTorch model 170 | val_gen: Validation set dataloader. 171 | test_gen: Testing set dataloader. 172 | 173 | # Returns: 174 | F1 score for the given data and 175 | the corresponding F1 threshold 176 | """ 177 | thresholds = np.arange(0.01, 0.5, step=0.01) 178 | f1_scores = [] 179 | 180 | model.eval() 181 | val_out = [(y, model(X)) for X, y in val_gen] 182 | y_val, y_pred_val = (list(t) for t in zip(*val_out)) 183 | 184 | test_out = [(y, model(X)) for X, y in test_gen] 185 | y_test, y_pred_test = (list(t) for t in zip(*val_out)) 186 | 187 | for t in thresholds: 188 | y_pred_val_ind = (y_pred_val > t) 189 | f1_val = f1_score(y_val, y_pred_val_ind, average=average) 190 | f1_scores.append(f1_val) 191 | 192 | best_t = thresholds[np.argmax(f1_scores)] 193 | y_pred_ind = (y_pred_test > best_t) 194 | f1_test = f1_score(y_test, y_pred_ind, average=average) 195 | return f1_test, best_t 196 | 197 | 198 | def finetune(model, texts, labels, nb_classes, batch_size, method, 199 | metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6, 200 | verbose=1): 201 | """ Compiles and finetunes the given pytorch model. 202 | 203 | # Arguments: 204 | model: Model to be finetuned 205 | texts: List of three lists, containing tokenized inputs for training, 206 | validation and testing (in that order). 207 | labels: List of three lists, containing labels for training, 208 | validation and testing (in that order). 209 | nb_classes: Number of classes in the dataset. 210 | batch_size: Batch size. 211 | method: Finetuning method to be used. For available methods, see 212 | FINETUNING_METHODS in global_variables.py. 213 | metric: Evaluation metric to be used. For available metrics, see 214 | FINETUNING_METRICS in global_variables.py. 215 | epoch_size: Number of samples in an epoch. 216 | nb_epochs: Number of epochs. Doesn't matter much as early stopping is used. 217 | embed_l2: L2 regularization for the embedding layer. 218 | verbose: Verbosity flag. 219 | 220 | # Returns: 221 | Model after finetuning, 222 | score after finetuning using the provided metric. 223 | """ 224 | 225 | if method not in FINETUNING_METHODS: 226 | raise ValueError('ERROR (finetune): Invalid method parameter. ' 227 | 'Available options: {}'.format(FINETUNING_METHODS)) 228 | if metric not in FINETUNING_METRICS: 229 | raise ValueError('ERROR (finetune): Invalid metric parameter. ' 230 | 'Available options: {}'.format(FINETUNING_METRICS)) 231 | 232 | train_gen = get_data_loader(texts[0], labels[0], batch_size, 233 | extended_batch_sampler=True, epoch_size=epoch_size) 234 | val_gen = get_data_loader(texts[1], labels[1], batch_size, 235 | extended_batch_sampler=False) 236 | test_gen = get_data_loader(texts[2], labels[2], batch_size, 237 | extended_batch_sampler=False) 238 | 239 | checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \ 240 | .format(WEIGHTS_DIR, str(uuid.uuid4())) 241 | 242 | if method in ['last', 'new']: 243 | lr = 0.001 244 | elif method in ['full', 'chain-thaw']: 245 | lr = 0.0001 246 | 247 | loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \ 248 | else nn.CrossEntropyLoss() 249 | 250 | # Freeze layers if using last 251 | if method == 'last': 252 | model = freeze_layers(model, unfrozen_keyword='output_layer') 253 | 254 | # Define optimizer, for chain-thaw we define it later (after freezing) 255 | if method == 'last': 256 | adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr) 257 | elif method in ['full', 'new']: 258 | # Add L2 regulation on embeddings only 259 | embed_params_id = [id(p) for p in model.embed.parameters()] 260 | output_layer_params_id = [id(p) for p in model.output_layer.parameters()] 261 | base_params = [p for p in model.parameters() 262 | if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad] 263 | embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad] 264 | output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad] 265 | adam = optim.Adam([ 266 | {'params': base_params}, 267 | {'params': embed_params, 'weight_decay': embed_l2}, 268 | {'params': output_layer_params, 'lr': 0.001}, 269 | ], lr=lr) 270 | 271 | # Training 272 | if verbose: 273 | print('Method: {}'.format(method)) 274 | print('Metric: {}'.format(metric)) 275 | print('Classes: {}'.format(nb_classes)) 276 | 277 | if method == 'chain-thaw': 278 | result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2, 279 | evaluate=metric, verbose=verbose) 280 | else: 281 | result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, 282 | evaluate=metric, verbose=verbose) 283 | return model, result 284 | 285 | 286 | def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen, 287 | nb_epochs, checkpoint_path, patience=5, evaluate='acc', 288 | verbose=2): 289 | """ Finetunes the given model using the accuracy measure. 290 | 291 | # Arguments: 292 | model: Model to be finetuned. 293 | nb_classes: Number of classes in the given dataset. 294 | train: Training data, given as a tuple of (inputs, outputs) 295 | val: Validation data, given as a tuple of (inputs, outputs) 296 | test: Testing data, given as a tuple of (inputs, outputs) 297 | epoch_size: Number of samples in an epoch. 298 | nb_epochs: Number of epochs. 299 | batch_size: Batch size. 300 | checkpoint_weight_path: Filepath where weights will be checkpointed to 301 | during training. This file will be rewritten by the function. 302 | patience: Patience for callback methods. 303 | evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. 304 | verbose: Verbosity flag. 305 | 306 | # Returns: 307 | Accuracy of the trained model, ONLY if 'evaluate' is set. 308 | """ 309 | if verbose: 310 | print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad])) 311 | print("Training...") 312 | if evaluate == 'acc': 313 | print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen)) 314 | elif evaluate == 'weighted_f1': 315 | print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen)) 316 | 317 | fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience) 318 | 319 | # Reload the best weights found to avoid overfitting 320 | # Wait a bit to allow proper closing of weights file 321 | sleep(1) 322 | model.load_state_dict(torch.load(checkpoint_path)) 323 | if verbose >= 2: 324 | print("Loaded weights from {}".format(checkpoint_path)) 325 | 326 | if evaluate == 'acc': 327 | return evaluate_using_acc(model, test_gen) 328 | elif evaluate == 'weighted_f1': 329 | return evaluate_using_weighted_f1(model, test_gen, val_gen) 330 | 331 | 332 | def evaluate_using_weighted_f1(model, test_gen, val_gen): 333 | """ Evaluation function using macro weighted F1 score. 334 | 335 | # Arguments: 336 | model: Model to be evaluated. 337 | X_test: Inputs of the testing set. 338 | y_test: Outputs of the testing set. 339 | X_val: Inputs of the validation set. 340 | y_val: Outputs of the validation set. 341 | batch_size: Batch size. 342 | 343 | # Returns: 344 | Weighted F1 score of the given model. 345 | """ 346 | # Evaluate on test and val data 347 | f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1') 348 | return f1_test 349 | 350 | 351 | def evaluate_using_acc(model, test_gen): 352 | """ Evaluation function using accuracy. 353 | 354 | # Arguments: 355 | model: Model to be evaluated. 356 | test_gen: Testing data iterator (DataLoader) 357 | 358 | # Returns: 359 | Accuracy of the given model. 360 | """ 361 | 362 | # Validate on test_data 363 | model.eval() 364 | accs = [] 365 | for i, data in enumerate(test_gen): 366 | x, y = data 367 | outs = model(x) 368 | if model.nb_classes > 2: 369 | pred = torch.max(outs, 1)[1] 370 | acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy()) 371 | else: 372 | pred = (outs >= 0).long() 373 | acc = (pred == y).double().sum() / len(pred) 374 | accs.append(acc) 375 | return np.mean(accs) 376 | 377 | 378 | def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, 379 | patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1): 380 | """ Finetunes given model using chain-thaw and evaluates using accuracy. 381 | 382 | # Arguments: 383 | model: Model to be finetuned. 384 | train: Training data, given as a tuple of (inputs, outputs) 385 | val: Validation data, given as a tuple of (inputs, outputs) 386 | test: Testing data, given as a tuple of (inputs, outputs) 387 | batch_size: Batch size. 388 | loss: Loss function to be used during training. 389 | epoch_size: Number of samples in an epoch. 390 | nb_epochs: Number of epochs. 391 | checkpoint_weight_path: Filepath where weights will be checkpointed to 392 | during training. This file will be rewritten by the function. 393 | initial_lr: Initial learning rate. Will only be used for the first 394 | training step (i.e. the output_layer layer) 395 | next_lr: Learning rate for every subsequent step. 396 | seed: Random number generator seed. 397 | verbose: Verbosity flag. 398 | evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. 399 | 400 | # Returns: 401 | Accuracy of the finetuned model. 402 | """ 403 | if verbose: 404 | print('Training..') 405 | 406 | # Train using chain-thaw 407 | train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, 408 | initial_lr, next_lr, embed_l2, verbose) 409 | 410 | if evaluate == 'acc': 411 | return evaluate_using_acc(model, test_gen) 412 | elif evaluate == 'weighted_f1': 413 | return evaluate_using_weighted_f1(model, test_gen, val_gen) 414 | 415 | 416 | def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, 417 | initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1): 418 | """ Finetunes model using the chain-thaw method. 419 | 420 | This is done as follows: 421 | 1) Freeze every layer except the last (output_layer) layer and train it. 422 | 2) Freeze every layer except the first layer and train it. 423 | 3) Freeze every layer except the second etc., until the second last layer. 424 | 4) Unfreeze all layers and train entire model. 425 | 426 | # Arguments: 427 | model: Model to be trained. 428 | train_gen: Training sample generator. 429 | val_data: Validation data. 430 | loss: Loss function to be used. 431 | finetuning_args: Training early stopping and checkpoint saving parameters 432 | epoch_size: Number of samples in an epoch. 433 | nb_epochs: Number of epochs. 434 | checkpoint_weight_path: Where weight checkpoints should be saved. 435 | batch_size: Batch size. 436 | initial_lr: Initial learning rate. Will only be used for the first 437 | training step (i.e. the output_layer layer) 438 | next_lr: Learning rate for every subsequent step. 439 | verbose: Verbosity flag. 440 | """ 441 | # Get trainable layers 442 | layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0] 443 | 444 | # Bring last layer to front 445 | layers.insert(0, layers.pop(len(layers) - 1)) 446 | 447 | # Add None to the end to signify finetuning all layers 448 | layers.append(None) 449 | 450 | lr = None 451 | # Finetune each layer one by one and finetune all of them at once 452 | # at the end 453 | for layer in layers: 454 | if lr is None: 455 | lr = initial_lr 456 | elif lr == initial_lr: 457 | lr = next_lr 458 | 459 | # Freeze all except current layer 460 | for _layer in layers: 461 | if _layer is not None: 462 | trainable = _layer == layer or layer is None 463 | change_trainable(_layer, trainable=trainable, verbose=False) 464 | 465 | # Verify we froze the right layers 466 | for _layer in model.children(): 467 | assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None 468 | 469 | if verbose: 470 | if layer is None: 471 | print('Finetuning all layers') 472 | else: 473 | print('Finetuning {}'.format(layer)) 474 | 475 | special_params = [id(p) for p in model.embed.parameters()] 476 | base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad] 477 | embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad] 478 | adam = optim.Adam([ 479 | {'params': base_params}, 480 | {'params': embed_parameters, 'weight_decay': embed_l2}, 481 | ], lr=lr) 482 | 483 | fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs, 484 | checkpoint_path, patience) 485 | 486 | # Reload the best weights found to avoid overfitting 487 | # Wait a bit to allow proper closing of weights file 488 | sleep(1) 489 | model.load_state_dict(torch.load(checkpoint_path)) 490 | if verbose >= 2: 491 | print("Loaded weights from {}".format(checkpoint_path)) 492 | 493 | 494 | def calc_loss(loss_op, pred, yv): 495 | if type(loss_op) is nn.CrossEntropyLoss: 496 | return loss_op(pred.squeeze(), yv.squeeze()) 497 | else: 498 | return loss_op(pred.squeeze(), yv.squeeze().float()) 499 | 500 | 501 | def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs, 502 | checkpoint_path, patience): 503 | """ Analog to Keras fit_generator function. 504 | 505 | # Arguments: 506 | model: Model to be finetuned. 507 | loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.) 508 | optim_op: optimization operation (Adam e.g.) 509 | train_gen: Training data iterator (DataLoader) 510 | val_gen: Validation data iterator (DataLoader) 511 | epochs: Number of epochs. 512 | checkpoint_path: Filepath where weights will be checkpointed to 513 | during training. This file will be rewritten by the function. 514 | patience: Patience for callback methods. 515 | verbose: Verbosity flag. 516 | 517 | # Returns: 518 | Accuracy of the trained model, ONLY if 'evaluate' is set. 519 | """ 520 | # Save original checkpoint 521 | torch.save(model.state_dict(), checkpoint_path) 522 | 523 | model.eval() 524 | best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen]) 525 | print("original val loss", best_loss) 526 | 527 | epoch_without_impr = 0 528 | for epoch in range(epochs): 529 | for i, data in enumerate(train_gen): 530 | X_train, y_train = data 531 | X_train = Variable(X_train, requires_grad=False) 532 | y_train = Variable(y_train, requires_grad=False) 533 | model.train() 534 | optim_op.zero_grad() 535 | output = model(X_train) 536 | loss = calc_loss(loss_op, output, y_train) 537 | loss.backward() 538 | clip_grad_norm(model.parameters(), 1) 539 | optim_op.step() 540 | 541 | acc = evaluate_using_acc(model, [(X_train.data, y_train.data)]) 542 | print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc) 543 | 544 | model.eval() 545 | acc = evaluate_using_acc(model, val_gen) 546 | print("val acc", acc) 547 | 548 | val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen]) 549 | print("val loss", val_loss) 550 | if best_loss is not None and val_loss >= best_loss: 551 | epoch_without_impr += 1 552 | print('No improvement over previous best loss: ', best_loss) 553 | 554 | # Save checkpoint 555 | if best_loss is None or val_loss < best_loss: 556 | best_loss = val_loss 557 | torch.save(model.state_dict(), checkpoint_path) 558 | print('Saving model at', checkpoint_path) 559 | 560 | # Early stopping 561 | if epoch_without_impr >= patience: 562 | break 563 | 564 | def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42): 565 | """ Returns a dataloader that enables larger epochs on small datasets and 566 | has upsampling functionality. 567 | 568 | # Arguments: 569 | X_in: Inputs of the given dataset. 570 | y_in: Outputs of the given dataset. 571 | batch_size: Batch size. 572 | epoch_size: Number of samples in an epoch. 573 | upsample: Whether upsampling should be done. This flag should only be 574 | set on binary class problems. 575 | 576 | # Returns: 577 | DataLoader. 578 | """ 579 | dataset = DeepMojiDataset(X_in, y_in) 580 | 581 | if extended_batch_sampler: 582 | batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed) 583 | else: 584 | batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False) 585 | 586 | return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) 587 | 588 | class DeepMojiDataset(Dataset): 589 | """ A simple Dataset class. 590 | 591 | # Arguments: 592 | X_in: Inputs of the given dataset. 593 | y_in: Outputs of the given dataset. 594 | 595 | # __getitem__ output: 596 | (torch.LongTensor, torch.LongTensor) 597 | """ 598 | def __init__(self, X_in, y_in): 599 | # Check if we have Torch.LongTensor inputs (assume Numpy array otherwise) 600 | if not isinstance(X_in, torch.LongTensor): 601 | X_in = torch.from_numpy(X_in.astype('int64')).long() 602 | if not isinstance(y_in, torch.LongTensor): 603 | y_in = torch.from_numpy(y_in.astype('int64')).long() 604 | 605 | self.X_in = torch.split(X_in, 1, dim=0) 606 | self.y_in = torch.split(y_in, 1, dim=0) 607 | 608 | def __len__(self): 609 | return len(self.X_in) 610 | 611 | def __getitem__(self, idx): 612 | return self.X_in[idx].squeeze(), self.y_in[idx].squeeze() 613 | 614 | class DeepMojiBatchSampler(object): 615 | """A Batch sampler that enables larger epochs on small datasets and 616 | has upsampling functionality. 617 | 618 | # Arguments: 619 | y_in: Labels of the dataset. 620 | batch_size: Batch size. 621 | epoch_size: Number of samples in an epoch. 622 | upsample: Whether upsampling should be done. This flag should only be 623 | set on binary class problems. 624 | seed: Random number generator seed. 625 | 626 | # __iter__ output: 627 | iterator of lists (batches) of indices in the dataset 628 | """ 629 | 630 | def __init__(self, y_in, batch_size, epoch_size, upsample, seed): 631 | self.batch_size = batch_size 632 | self.epoch_size = epoch_size 633 | self.upsample = upsample 634 | 635 | np.random.seed(seed) 636 | 637 | if upsample: 638 | # Should only be used on binary class problems 639 | assert len(y_in.shape) == 1 640 | neg = np.where(y_in.numpy() == 0)[0] 641 | pos = np.where(y_in.numpy() == 1)[0] 642 | assert epoch_size % 2 == 0 643 | samples_pr_class = int(epoch_size / 2) 644 | else: 645 | ind = range(len(y_in)) 646 | 647 | if not upsample: 648 | # Randomly sample observations in a balanced way 649 | self.sample_ind = np.random.choice(ind, epoch_size, replace=True) 650 | else: 651 | # Randomly sample observations in a balanced way 652 | sample_neg = np.random.choice(neg, samples_pr_class, replace=True) 653 | sample_pos = np.random.choice(pos, samples_pr_class, replace=True) 654 | concat_ind = np.concatenate((sample_neg, sample_pos), axis=0) 655 | 656 | # Shuffle to avoid labels being in specific order 657 | # (all negative then positive) 658 | p = np.random.permutation(len(concat_ind)) 659 | self.sample_ind = concat_ind[p] 660 | 661 | label_dist = np.mean(y_in.numpy()[self.sample_ind]) 662 | assert(label_dist > 0.45) 663 | assert(label_dist < 0.55) 664 | 665 | def __iter__(self): 666 | # Hand-off data using batch_size 667 | for i in range(int(self.epoch_size/self.batch_size)): 668 | start = i * self.batch_size 669 | end = min(start + self.batch_size, self.epoch_size) 670 | yield self.sample_ind[start:end] 671 | 672 | def __len__(self): 673 | # Take care of the last (maybe incomplete) batch 674 | return (self.epoch_size + self.batch_size - 1) // self.batch_size 675 | --------------------------------------------------------------------------------